diff --git a/internal/agent/agent.go b/internal/agent/agent.go index 20553ebab..62a9f27ad 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -16,6 +16,7 @@ import ( "syscall" deviceConfig "github.com/joshuar/go-hass-agent/internal/config" + "github.com/joshuar/go-hass-agent/internal/hass/api" "github.com/joshuar/go-hass-agent/internal/agent/config" "github.com/joshuar/go-hass-agent/internal/agent/ui" @@ -79,22 +80,26 @@ func Run(options AgentOptions) { agent := newAgent(options.ID, options.Headless) defer close(agent.done) - agentCtx, cancelFunc := context.WithCancel(context.Background()) - agent.setupLogging(agentCtx) - + agent.setupLogging() registrationDone := make(chan struct{}) - go agent.registrationProcess(agentCtx, "", "", options.Register, options.Headless, registrationDone) + configDone := make(chan struct{}) + + go agent.registrationProcess(context.Background(), "", "", options.Register, options.Headless, registrationDone) var workerWg sync.WaitGroup + var ctx context.Context + var cancelFunc context.CancelFunc go func() { <-registrationDone var err error if err = UpgradeConfig(agent.config); err != nil { - log.Warn().Err(err).Msg("Could not start.") + log.Warn().Err(err).Msg("Could not upgrade config.") } if err = ValidateConfig(agent.config); err != nil { - log.Fatal().Err(err).Msg("Could not start.") + log.Fatal().Err(err).Msg("Could not validate config.") } + ctx, cancelFunc = agent.setupContext() + close(configDone) if agent.sensors, err = tracker.NewSensorTracker(agent); err != nil { log.Fatal().Err(err).Msg("Could not start.") } @@ -104,24 +109,26 @@ func Run(options AgentOptions) { workerWg.Add(1) go func() { - device.StartWorkers(agentCtx, agent.sensors, sensorWorkers...) + device.StartWorkers(ctx, agent.sensors, sensorWorkers...) }() workerWg.Add(1) go func() { defer workerWg.Done() - agent.runNotificationsWorker(agentCtx, options) + agent.runNotificationsWorker(ctx, options) }() }() + + <-configDone agent.handleSignals(cancelFunc) - agent.handleShutdown(agentCtx) + agent.handleShutdown(ctx) // If we are not running in headless mode, show a tray icon if !options.Headless { - agent.ui.DisplayTrayIcon(agentCtx, agent) + agent.ui.DisplayTrayIcon(ctx, agent) agent.ui.Run() } workerWg.Wait() - <-agentCtx.Done() + <-ctx.Done() } // Register runs a registration flow. It either prompts the user for needed @@ -175,7 +182,7 @@ func ShowInfo(options AgentOptions) { // setupLogging will attempt to create and then write logging to a file. If it // cannot do this, logging will only be available on stdout -func (agent *Agent) setupLogging(ctx context.Context) { +func (agent *Agent) setupLogging() { logFile, err := agent.config.StoragePath("go-hass-app.log") if err != nil { log.Error().Err(err). @@ -189,14 +196,22 @@ func (agent *Agent) setupLogging(ctx context.Context) { consoleWriter := zerolog.ConsoleWriter{Out: os.Stdout} multiWriter := zerolog.MultiLevelWriter(consoleWriter, logWriter) log.Logger = log.Output(multiWriter) - go func() { - <-ctx.Done() - logWriter.Close() - }() } } } +func (agent *Agent) setupContext() (context.Context, context.CancelFunc) { + SharedConfig := &api.APIConfig{} + if err := agent.config.Get(config.PrefAPIURL, &SharedConfig.APIURL); err != nil { + log.Fatal().Err(err).Msg("Could not export apiURL.") + } + if err := agent.config.Get(config.PrefSecret, &SharedConfig.Secret); err != nil && SharedConfig.Secret != "NOTSET" { + log.Fatal().Err(err).Msg("Could not export secret.") + } + ctx := api.NewContext(context.Background(), SharedConfig) + return context.WithCancel(ctx) +} + // handleSignals will handle Ctrl-C of the agent func (agent *Agent) handleSignals(cancelFunc context.CancelFunc) { c := make(chan os.Signal, 1) diff --git a/internal/agent/ui/fyneUI.go b/internal/agent/ui/fyneUI.go index 2097c6367..9e8b04a40 100644 --- a/internal/agent/ui/fyneUI.go +++ b/internal/agent/ui/fyneUI.go @@ -162,7 +162,7 @@ func (ui *fyneUI) DisplayRegistrationWindow(ctx context.Context, agent Agent, do // about the agent, such as version numbers. func (ui *fyneUI) aboutWindow(ctx context.Context, agent Agent, t *translations.Translator) fyne.Window { var widgets []fyne.CanvasObject - if hassConfig, err := hass.GetHassConfig(ctx, agent); err != nil { + if hassConfig, err := hass.GetHassConfig(ctx); err != nil { widgets = append(widgets, widget.NewLabel(t.Translate( "App Version: %s", agent.AppVersion()))) } else { diff --git a/internal/hass/api/requestType.go b/internal/hass/api/apiTypes.go similarity index 52% rename from internal/hass/api/requestType.go rename to internal/hass/api/apiTypes.go index f21279600..cae1f27bc 100644 --- a/internal/hass/api/requestType.go +++ b/internal/hass/api/apiTypes.go @@ -1,4 +1,4 @@ -// Code generated by "stringer -type=RequestType -output requestType.go -linecomment"; DO NOT EDIT. +// Code generated by "stringer -type=RequestType,ResponseType -output apiTypes.go -linecomment"; DO NOT EDIT. package api @@ -26,3 +26,22 @@ func (i RequestType) String() string { } return _RequestType_name[_RequestType_index[i]:_RequestType_index[i+1]] } +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[ResponseTypeRegistration-6] + _ = x[ResponseTypeUpdate-7] +} + +const _ResponseType_name = "registrationupdate" + +var _ResponseType_index = [...]uint8{0, 12, 18} + +func (i ResponseType) String() string { + i -= 6 + if i < 0 || i >= ResponseType(len(_ResponseType_index)-1) { + return "ResponseType(" + strconv.FormatInt(int64(i+6), 10) + ")" + } + return _ResponseType_name[_ResponseType_index[i]:_ResponseType_index[i+1]] +} diff --git a/internal/hass/api/config.go b/internal/hass/api/config.go index 0a9dcfd5d..c4bab97e0 100644 --- a/internal/hass/api/config.go +++ b/internal/hass/api/config.go @@ -5,7 +5,7 @@ package api -//go:generate moq -out mock_AgentConfig_test.go . AgentConfig +//go:generate moq -out mock_Agent_test.go . Agent type Agent interface { GetConfig(string, interface{}) error } diff --git a/internal/hass/api/context.go b/internal/hass/api/context.go new file mode 100644 index 000000000..4f883cc20 --- /dev/null +++ b/internal/hass/api/context.go @@ -0,0 +1,33 @@ +// Copyright (c) 2023 Joshua Rich +// +// This software is released under the MIT License. +// https://opensource.org/licenses/MIT + +package api + +import "context" + +type APIConfig struct { + APIURL string + Secret string +} + +// key is an unexported type for keys defined in this package. +// This prevents collisions with keys defined in other packages. +type key int + +// userKey is the key for user.User values in Contexts. It is +// unexported; clients use user.NewContext and user.FromContext +// instead of using this key directly. +var userKey key + +// NewContext returns a new Context that carries value u. +func NewContext(ctx context.Context, c *APIConfig) context.Context { + return context.WithValue(ctx, userKey, c) +} + +// FromContext returns the User value stored in ctx, if any. +func FromContext(ctx context.Context) (*APIConfig, bool) { + u, ok := ctx.Value(userKey).(*APIConfig) + return u, ok +} diff --git a/internal/hass/api/mock_Agent_test.go b/internal/hass/api/mock_Agent_test.go new file mode 100644 index 000000000..d9ec9f798 --- /dev/null +++ b/internal/hass/api/mock_Agent_test.go @@ -0,0 +1,80 @@ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq + +package api + +import ( + "sync" +) + +// Ensure, that AgentMock does implement Agent. +// If this is not the case, regenerate this file with moq. +var _ Agent = &AgentMock{} + +// AgentMock is a mock implementation of Agent. +// +// func TestSomethingThatUsesAgent(t *testing.T) { +// +// // make and configure a mocked Agent +// mockedAgent := &AgentMock{ +// GetConfigFunc: func(s string, ifaceVal interface{}) error { +// panic("mock out the GetConfig method") +// }, +// } +// +// // use mockedAgent in code that requires Agent +// // and then make assertions. +// +// } +type AgentMock struct { + // GetConfigFunc mocks the GetConfig method. + GetConfigFunc func(s string, ifaceVal interface{}) error + + // calls tracks calls to the methods. + calls struct { + // GetConfig holds details about calls to the GetConfig method. + GetConfig []struct { + // S is the s argument value. + S string + // IfaceVal is the ifaceVal argument value. + IfaceVal interface{} + } + } + lockGetConfig sync.RWMutex +} + +// GetConfig calls GetConfigFunc. +func (mock *AgentMock) GetConfig(s string, ifaceVal interface{}) error { + if mock.GetConfigFunc == nil { + panic("AgentMock.GetConfigFunc: method is nil but Agent.GetConfig was just called") + } + callInfo := struct { + S string + IfaceVal interface{} + }{ + S: s, + IfaceVal: ifaceVal, + } + mock.lockGetConfig.Lock() + mock.calls.GetConfig = append(mock.calls.GetConfig, callInfo) + mock.lockGetConfig.Unlock() + return mock.GetConfigFunc(s, ifaceVal) +} + +// GetConfigCalls gets all the calls that were made to GetConfig. +// Check the length with: +// +// len(mockedAgent.GetConfigCalls()) +func (mock *AgentMock) GetConfigCalls() []struct { + S string + IfaceVal interface{} +} { + var calls []struct { + S string + IfaceVal interface{} + } + mock.lockGetConfig.RLock() + calls = mock.calls.GetConfig + mock.lockGetConfig.RUnlock() + return calls +} diff --git a/internal/hass/api/mock_RegistrationInfo_test.go b/internal/hass/api/mock_RegistrationInfo_test.go new file mode 100644 index 000000000..f7173823c --- /dev/null +++ b/internal/hass/api/mock_RegistrationInfo_test.go @@ -0,0 +1,104 @@ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq + +package api + +import ( + "sync" +) + +// Ensure, that RegistrationInfoMock does implement RegistrationInfo. +// If this is not the case, regenerate this file with moq. +var _ RegistrationInfo = &RegistrationInfoMock{} + +// RegistrationInfoMock is a mock implementation of RegistrationInfo. +// +// func TestSomethingThatUsesRegistrationInfo(t *testing.T) { +// +// // make and configure a mocked RegistrationInfo +// mockedRegistrationInfo := &RegistrationInfoMock{ +// ServerFunc: func() string { +// panic("mock out the Server method") +// }, +// TokenFunc: func() string { +// panic("mock out the Token method") +// }, +// } +// +// // use mockedRegistrationInfo in code that requires RegistrationInfo +// // and then make assertions. +// +// } +type RegistrationInfoMock struct { + // ServerFunc mocks the Server method. + ServerFunc func() string + + // TokenFunc mocks the Token method. + TokenFunc func() string + + // calls tracks calls to the methods. + calls struct { + // Server holds details about calls to the Server method. + Server []struct { + } + // Token holds details about calls to the Token method. + Token []struct { + } + } + lockServer sync.RWMutex + lockToken sync.RWMutex +} + +// Server calls ServerFunc. +func (mock *RegistrationInfoMock) Server() string { + if mock.ServerFunc == nil { + panic("RegistrationInfoMock.ServerFunc: method is nil but RegistrationInfo.Server was just called") + } + callInfo := struct { + }{} + mock.lockServer.Lock() + mock.calls.Server = append(mock.calls.Server, callInfo) + mock.lockServer.Unlock() + return mock.ServerFunc() +} + +// ServerCalls gets all the calls that were made to Server. +// Check the length with: +// +// len(mockedRegistrationInfo.ServerCalls()) +func (mock *RegistrationInfoMock) ServerCalls() []struct { +} { + var calls []struct { + } + mock.lockServer.RLock() + calls = mock.calls.Server + mock.lockServer.RUnlock() + return calls +} + +// Token calls TokenFunc. +func (mock *RegistrationInfoMock) Token() string { + if mock.TokenFunc == nil { + panic("RegistrationInfoMock.TokenFunc: method is nil but RegistrationInfo.Token was just called") + } + callInfo := struct { + }{} + mock.lockToken.Lock() + mock.calls.Token = append(mock.calls.Token, callInfo) + mock.lockToken.Unlock() + return mock.TokenFunc() +} + +// TokenCalls gets all the calls that were made to Token. +// Check the length with: +// +// len(mockedRegistrationInfo.TokenCalls()) +func (mock *RegistrationInfoMock) TokenCalls() []struct { +} { + var calls []struct { + } + mock.lockToken.RLock() + calls = mock.calls.Token + mock.lockToken.RUnlock() + return calls +} diff --git a/internal/hass/api/mock_Request_test.go b/internal/hass/api/mock_Request_test.go index 9b57d6ef6..9b82845b9 100644 --- a/internal/hass/api/mock_Request_test.go +++ b/internal/hass/api/mock_Request_test.go @@ -4,7 +4,6 @@ package api import ( - "bytes" "encoding/json" "sync" ) @@ -25,9 +24,6 @@ var _ Request = &RequestMock{} // RequestTypeFunc: func() RequestType { // panic("mock out the RequestType method") // }, -// ResponseHandlerFunc: func(buffer bytes.Buffer, responseCh chan Response) { -// panic("mock out the ResponseHandler method") -// }, // } // // // use mockedRequest in code that requires Request @@ -41,9 +37,6 @@ type RequestMock struct { // RequestTypeFunc mocks the RequestType method. RequestTypeFunc func() RequestType - // ResponseHandlerFunc mocks the ResponseHandler method. - ResponseHandlerFunc func(buffer bytes.Buffer, responseCh chan Response) - // calls tracks calls to the methods. calls struct { // RequestData holds details about calls to the RequestData method. @@ -52,17 +45,9 @@ type RequestMock struct { // RequestType holds details about calls to the RequestType method. RequestType []struct { } - // ResponseHandler holds details about calls to the ResponseHandler method. - ResponseHandler []struct { - // Buffer is the buffer argument value. - Buffer bytes.Buffer - // ResponseCh is the responseCh argument value. - ResponseCh chan Response - } } - lockRequestData sync.RWMutex - lockRequestType sync.RWMutex - lockResponseHandler sync.RWMutex + lockRequestData sync.RWMutex + lockRequestType sync.RWMutex } // RequestData calls RequestDataFunc. @@ -118,39 +103,3 @@ func (mock *RequestMock) RequestTypeCalls() []struct { mock.lockRequestType.RUnlock() return calls } - -// ResponseHandler calls ResponseHandlerFunc. -func (mock *RequestMock) ResponseHandler(buffer bytes.Buffer, responseCh chan Response) { - if mock.ResponseHandlerFunc == nil { - panic("RequestMock.ResponseHandlerFunc: method is nil but Request.ResponseHandler was just called") - } - callInfo := struct { - Buffer bytes.Buffer - ResponseCh chan Response - }{ - Buffer: buffer, - ResponseCh: responseCh, - } - mock.lockResponseHandler.Lock() - mock.calls.ResponseHandler = append(mock.calls.ResponseHandler, callInfo) - mock.lockResponseHandler.Unlock() - mock.ResponseHandlerFunc(buffer, responseCh) -} - -// ResponseHandlerCalls gets all the calls that were made to ResponseHandler. -// Check the length with: -// -// len(mockedRequest.ResponseHandlerCalls()) -func (mock *RequestMock) ResponseHandlerCalls() []struct { - Buffer bytes.Buffer - ResponseCh chan Response -} { - var calls []struct { - Buffer bytes.Buffer - ResponseCh chan Response - } - mock.lockResponseHandler.RLock() - calls = mock.calls.ResponseHandler - mock.lockResponseHandler.RUnlock() - return calls -} diff --git a/internal/hass/api/request.go b/internal/hass/api/request.go index 083b2252a..7286c2c03 100644 --- a/internal/hass/api/request.go +++ b/internal/hass/api/request.go @@ -10,33 +10,36 @@ import ( "context" "encoding/json" "errors" + "sync" "time" "github.com/carlmjohnson/requests" - "github.com/joshuar/go-hass-agent/internal/agent/config" ) -//go:generate stringer -type=RequestType -output requestType.go -linecomment +//go:generate stringer -type=RequestType,ResponseType -output apiTypes.go -linecomment const ( RequestTypeEncrypted RequestType = iota + 1 // encrypted RequestTypeGetConfig // get_config RequestTypeUpdateLocation // update_location RequestTypeRegisterSensor // register_sensor RequestTypeUpdateSensorStates // update_sensor_states + + ResponseTypeRegistration ResponseType = iota + 1 // registration + ResponseTypeUpdate // update ) type RequestType int +type ResponseType int //go:generate moq -out mock_Request_test.go . Request type Request interface { RequestType() RequestType RequestData() json.RawMessage - ResponseHandler(bytes.Buffer, chan Response) } func marshalJSON(request Request, secret string) ([]byte, error) { if request.RequestType() == RequestTypeEncrypted { - if secret != "" { + if secret != "" && secret != "NOTSET" { return json.Marshal(&EncryptedRequest{ Type: RequestTypeEncrypted.String(), Encrypted: true, @@ -64,41 +67,47 @@ type EncryptedRequest struct { Encrypted bool `json:"encrypted"` } -func ExecuteRequest(ctx context.Context, request Request, agent Agent, responseCh chan Response) { - var res bytes.Buffer - +func ExecuteRequest(ctx context.Context, request Request, responseCh chan interface{}) { defer close(responseCh) - - var apiURL, secret string - if err := agent.GetConfig(config.PrefAPIURL, &apiURL); err != nil { - responseCh <- NewGenericResponse(err, request.RequestType()) + cfg, ok := FromContext(ctx) + if !ok { + responseCh <- errors.New("no config found in context") return } - if request.RequestType() == RequestTypeEncrypted { - if err := agent.GetConfig(config.PrefSecret, &secret); err != nil { - responseCh <- NewGenericResponse(err, request.RequestType()) - return - } - } - reqJSON, err := marshalJSON(request, secret) + reqJSON, err := marshalJSON(request, cfg.Secret) if err != nil { - responseCh <- NewGenericResponse(err, request.RequestType()) + responseCh <- err return } requestCtx, cancel := context.WithTimeout(ctx, time.Second) defer cancel() - err = requests. - URL(apiURL). - BodyBytes(reqJSON). - ToBytesBuffer(&res). - Fetch(requestCtx) - if err != nil { - cancel() - responseCh <- NewGenericResponse(err, request.RequestType()) - return - } - request.ResponseHandler(res, responseCh) + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + var rBuf bytes.Buffer + err = requests. + URL(cfg.APIURL). + BodyBytes(reqJSON). + ToBytesBuffer(&rBuf). + Fetch(requestCtx) + if err != nil { + cancel() + responseCh <- err + return + } else { + response, err := parseResponse(request.RequestType(), &rBuf) + if err != nil { + responseCh <- err + return + } else { + responseCh <- response + return + } + } + }() + wg.Wait() } diff --git a/internal/hass/api/request_test.go b/internal/hass/api/request_test.go index 7888c0c29..0b94dbe53 100644 --- a/internal/hass/api/request_test.go +++ b/internal/hass/api/request_test.go @@ -6,20 +6,33 @@ package api import ( - bytes "bytes" "context" "encoding/json" - "errors" "net/http" "net/http/httptest" "reflect" - "sync" "testing" - "github.com/joshuar/go-hass-agent/internal/agent/config" "github.com/stretchr/testify/assert" ) +func mockServer(t *testing.T) *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer r.Body.Close() + req := &UnencryptedRequest{} + err := json.NewDecoder(r.Body).Decode(&req) + assert.Nil(t, err) + switch req.Type { + case "register_sensor": + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"success":true}`)) + case "encrypted": + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"success":true}`)) + } + })) +} + func Test_marshalJSON(t *testing.T) { mockReq := &RequestMock{ RequestDataFunc: func() json.RawMessage { @@ -29,7 +42,6 @@ func Test_marshalJSON(t *testing.T) { return RequestTypeUpdateSensorStates }, } - mockEncReq := &RequestMock{ RequestDataFunc: func() json.RawMessage { return json.RawMessage(`{"someField": "someValue"}`) @@ -70,173 +82,50 @@ func Test_marshalJSON(t *testing.T) { t.Run(tt.name, func(t *testing.T) { got, err := marshalJSON(tt.args.request, tt.args.secret) if (err != nil) != tt.wantErr { - t.Errorf("MarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf("marshalJSON() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { - t.Errorf("MarshalJSON() = %v, want %v", got, tt.want) + t.Errorf("marshalJSON() = %v, want %v", got, tt.want) } }) } } -func mockServer(t *testing.T) *httptest.Server { - return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - req := &UnencryptedRequest{} - err := json.NewDecoder(r.Body).Decode(&req) - assert.Nil(t, err) - switch req.Type { - case "register_sensor": - w.WriteHeader(http.StatusOK) - w.Write([]byte(`{"success":true}`)) - case "encrypted": - w.WriteHeader(http.StatusOK) - w.Write([]byte(`{"success":true}`)) - } - })) -} - -type req struct { - reqType RequestType - data json.RawMessage -} - -func (r *req) RequestType() RequestType { - return r.reqType -} - -func (r *req) RequestData() json.RawMessage { - return r.data -} - -func (r *req) ResponseHandler(b bytes.Buffer, resp chan Response) { - resp <- NewGenericResponse(nil, r.reqType) -} - -type encReq struct { -} - -func (r *encReq) RequestType() RequestType { - return RequestTypeEncrypted -} - -func (r *encReq) RequestData() json.RawMessage { - return json.RawMessage(`{"someField": "someValue"}`) -} - -func (r *encReq) ResponseHandler(b bytes.Buffer, resp chan Response) { - resp <- NewGenericResponse(nil, RequestTypeEncrypted) - -} - func TestExecuteRequest(t *testing.T) { - server := mockServer(t) - defer server.Close() - - goodConfig := &AgentConfigMock{ - GetFunc: func(s string, ifaceVal interface{}) error { - v := ifaceVal.(*string) - switch s { - case config.PrefAPIURL: - *v = server.URL - return nil - case config.PrefSecret: - *v = "aSecret" - return nil - default: - return errors.New("not found") - } - }, + mockServer := mockServer(t) + defer mockServer.Close() + mockConfig := &APIConfig{ + APIURL: mockServer.URL, } - - badConfig := &AgentConfigMock{ - GetFunc: func(s string, ifaceVal interface{}) error { - v := ifaceVal.(*string) - switch s { - case config.PrefAPIURL: - *v = server.URL - return nil - case config.PrefSecret: - *v = "" - return nil - default: - return errors.New("not found") - } + ctx := NewContext(context.TODO(), mockConfig) + mockReq := &RequestMock{ + RequestDataFunc: func() json.RawMessage { + return json.RawMessage(`{"someField": "someValue"}`) + }, + RequestTypeFunc: func() RequestType { + return RequestTypeUpdateSensorStates }, } + responseCh := make(chan interface{}, 1) type args struct { ctx context.Context request Request - config Agent - responseCh chan Response + responseCh chan interface{} } tests := []struct { - name string - args args - wantErr bool + name string + args args }{ { - name: "good request", - args: args{ - ctx: context.Background(), - request: &req{reqType: RequestTypeRegisterSensor}, - config: goodConfig, - responseCh: make(chan Response, 1), - }, - wantErr: false, - }, - { - name: "bad encrypted request, missing secret", - args: args{ - ctx: context.Background(), - request: &encReq{}, - config: badConfig, - responseCh: make(chan Response, 1), - }, - wantErr: true, - }, - { - name: "good encrypted request", - args: args{ - ctx: context.Background(), - request: &encReq{}, - config: goodConfig, - responseCh: make(chan Response, 1), - }, - wantErr: false, - }, - { - name: "bad json", - args: args{ - ctx: context.Background(), - request: &req{ - reqType: RequestTypeRegisterSensor, - data: json.RawMessage(`sdgasghsdag`), - }, - config: goodConfig, - responseCh: make(chan Response, 1), - }, - wantErr: true, + name: "default test", + args: args{ctx: ctx, request: mockReq, responseCh: responseCh}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - resp := <-tt.args.responseCh - if err := resp.Error(); (err != nil) != tt.wantErr { - t.Errorf("api.TestExecuteRequest() error = %v, wantErr %v", err, tt.wantErr) - } - }() - wg.Add(1) - go func() { - defer wg.Done() - ExecuteRequest(tt.args.ctx, tt.args.request, tt.args.config, tt.args.responseCh) - }() - wg.Wait() + ExecuteRequest(tt.args.ctx, tt.args.request, tt.args.responseCh) }) } } diff --git a/internal/hass/api/response.go b/internal/hass/api/response.go index 465b4441a..c1b040c37 100644 --- a/internal/hass/api/response.go +++ b/internal/hass/api/response.go @@ -5,46 +5,115 @@ package api -//go:generate moq -out mock_Response_test.go . Response -type Response interface { - SensorRegistrationResponse - SensorUpdateResponse - Error() error - Type() RequestType +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + + "github.com/rs/zerolog/log" +) + +type SensorResponse struct { + responseType ResponseType + disabled bool + registered bool } -type SensorRegistrationResponse interface { - Registered() bool +func (r *SensorResponse) Type() ResponseType { + return r.responseType } -type SensorUpdateResponse interface { - Disabled() bool +func (r *SensorResponse) Disabled() bool { + return r.disabled } -type GenericResponse struct { - error - requestType RequestType +func (r *SensorResponse) Registered() bool { + return r.registered } -func (e *GenericResponse) Error() error { - return e.error +func parseRegistrationResponse(buf *bytes.Buffer) (*SensorResponse, error) { + r, err := parseAsMap(buf) + if err != nil { + return nil, err + } + if _, ok := r["success"]; ok { + if success, err := assertAs[bool](r["success"]); err != nil || !success { + return nil, errors.New("unsuccessful registration") + } + return &SensorResponse{registered: true, responseType: ResponseTypeRegistration}, nil + } + return nil, errors.New("unknown response structure") } -func (e *GenericResponse) Type() RequestType { - return e.requestType +func parseUpdateResponse(buf *bytes.Buffer) (*SensorResponse, error) { + r, err := parseAsMap(buf) + if err != nil { + return nil, err + } + for k, v := range r { + log.Trace().Str("id", k).Msg("Parsing response for sensor.") + r, err := assertAs[map[string]interface{}](v) + if err != nil { + return nil, err + } + if _, ok := r["success"]; ok { + if success, err := assertAs[bool](r["success"]); err != nil || !success { + if err != nil { + return nil, err + } + log.Trace().Str("id", k).Msg("Unsuccessful response.") + responseErr, err := assertAs[map[string]interface{}](r["error"]) + if err != nil { + return nil, errors.New("unknown error") + } else { + return nil, fmt.Errorf("code %s: %s", responseErr["code"], responseErr["message"]) + } + } + } + if _, ok := r["is_disabled"]; ok { + log.Trace().Str("id", k).Bool("disabled", true).Msg("Successful response.") + return &SensorResponse{disabled: true, responseType: ResponseTypeUpdate}, nil + } else { + log.Trace().Str("id", k).Bool("disabled", false).Msg("Successful response.") + return &SensorResponse{disabled: false, responseType: ResponseTypeUpdate}, nil + } + } + return nil, errors.New("unknown response structure") } -func (e *GenericResponse) Disabled() bool { - return false +func parseResponse(t RequestType, buf *bytes.Buffer) (interface{}, error) { + switch t { + case RequestTypeUpdateLocation: + return buf.Bytes(), nil + case RequestTypeGetConfig: + return buf.Bytes(), nil + case RequestTypeRegisterSensor: + return parseRegistrationResponse(buf) + case RequestTypeUpdateSensorStates: + return parseUpdateResponse(buf) + default: + return nil, errors.New("unknown response") + } } -func (e *GenericResponse) Registered() bool { - return false +func parseAsMap(buf *bytes.Buffer) (map[string]interface{}, error) { + var r interface{} + err := json.Unmarshal(buf.Bytes(), &r) + if err != nil { + return nil, fmt.Errorf("could not unmarshal response (%s)", buf.String()) + } + rMap, ok := r.(map[string]interface{}) + if !ok { + return nil, errors.New("could not parse response as map") + } + return rMap, nil } -func NewGenericResponse(e error, t RequestType) *GenericResponse { - return &GenericResponse{ - error: e, - requestType: t, +func assertAs[T any](thing interface{}) (T, error) { + if asT, ok := thing.(T); !ok { + return *new(T), errors.New("could not assert value") + } else { + return asT, nil } } diff --git a/internal/hass/config.go b/internal/hass/config.go index 1a8c7aed6..fb2473da7 100644 --- a/internal/hass/config.go +++ b/internal/hass/config.go @@ -6,12 +6,12 @@ package hass import ( - "bytes" "context" "encoding/json" - "errors" "sync" + "github.com/rs/zerolog/log" + "github.com/joshuar/go-hass-agent/internal/hass/api" "github.com/perimeterx/marshmallow" ) @@ -52,38 +52,32 @@ func (h *haConfig) RequestData() json.RawMessage { return nil } -func (h *haConfig) ResponseHandler(resp bytes.Buffer, respCh chan api.Response) { - if resp.Bytes() == nil { - err := errors.New("no response returned") - response := api.NewGenericResponse(err, api.RequestTypeGetConfig) - respCh <- response +func (h *haConfig) extractConfig(b []byte) { + if b == nil { + log.Warn().Msg("No config returned.") return } h.mu.Lock() - result, err := marshmallow.Unmarshal(resp.Bytes(), &h.haConfigProps) + result, err := marshmallow.Unmarshal(b, &h.haConfigProps) if err != nil { - response := api.NewGenericResponse(err, api.RequestTypeGetConfig) - respCh <- response - return + log.Warn().Msg("Could not extract config structure.") } h.rawConfigProps = result h.mu.Unlock() - response := api.NewGenericResponse(nil, api.RequestTypeGetConfig) - respCh <- response } -func GetHassConfig(ctx context.Context, c api.Agent) (*haConfig, error) { +func GetHassConfig(ctx context.Context) (*haConfig, error) { h := new(haConfig) - respCh := make(chan api.Response, 1) - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - api.ExecuteRequest(ctx, h, c, respCh) - }() + respCh := make(chan interface{}, 1) + go api.ExecuteRequest(ctx, h, respCh) response := <-respCh - if response.Error() != nil { - return nil, response.Error() + switch r := response.(type) { + case []byte: + h.extractConfig(response.([]byte)) + case error: + log.Warn().Err(r).Msg("Failed to fetch Home Assistant config.") + default: + log.Warn().Msgf("Unknown response type %T", r) } return h, nil } diff --git a/internal/hass/location.go b/internal/hass/location.go index 49ac3c0d2..6258c21e8 100644 --- a/internal/hass/location.go +++ b/internal/hass/location.go @@ -6,12 +6,9 @@ package hass import ( - "bytes" "encoding/json" - "errors" "github.com/joshuar/go-hass-agent/internal/hass/api" - "github.com/rs/zerolog/log" ) // LocationUpdate represents the location information that can be sent to HA @@ -38,35 +35,3 @@ func (l *LocationUpdate) RequestData() json.RawMessage { raw := json.RawMessage(data) return raw } - -func (l *LocationUpdate) ResponseHandler(res bytes.Buffer, respCh chan api.Response) { - response := new(locationResponse) - if res.Len() == 0 { - response.err = errors.New("no response data") - respCh <- response - } else { - respCh <- response - } -} - -type locationResponse struct { - err error -} - -func (l locationResponse) Registered() bool { - log.Debug().Msg("Registered should not be called for location response.") - return false -} - -func (l locationResponse) Disabled() bool { - log.Debug().Msg("Disabled should not be called for location response.") - return false -} - -func (l locationResponse) Error() error { - return l.err -} - -func (l locationResponse) Type() api.RequestType { - return api.RequestTypeUpdateLocation -} diff --git a/internal/hass/sensor/sensor.go b/internal/hass/sensor/sensor.go index d8ea2ea01..7273cc7e8 100644 --- a/internal/hass/sensor/sensor.go +++ b/internal/hass/sensor/sensor.go @@ -6,10 +6,7 @@ package sensor import ( - "bytes" "encoding/json" - "errors" - "fmt" "github.com/joshuar/go-hass-agent/internal/hass/api" "github.com/rs/zerolog/log" @@ -49,57 +46,6 @@ func (reg *SensorRegistrationInfo) RequestData() json.RawMessage { return data } -func (reg *SensorRegistrationInfo) ResponseHandler(res bytes.Buffer, respCh chan api.Response) { - response, err := marshalResponse(res) - if err != nil { - respCh <- &SensorUpdateResponse{ - err: err, - } - } - respCh <- NewSensorRegistrationResponse(response) -} - -type SensorRegistrationResponse struct { - err error - registered bool -} - -func (r SensorRegistrationResponse) Error() error { - return r.err -} - -func (r SensorRegistrationResponse) Type() api.RequestType { - return api.RequestTypeRegisterSensor -} - -func (r SensorRegistrationResponse) Disabled() bool { - return false -} - -func (r SensorRegistrationResponse) Registered() bool { - return r.registered -} - -func NewSensorRegistrationResponse(r map[string]interface{}) *SensorRegistrationResponse { - s := new(SensorRegistrationResponse) - if v, ok := r["success"]; ok { - success, err := assertAs[bool](v) - if err != nil { - s.err = err - return s - } else { - if success { - s.registered = true - return s - } else { - s.err = errors.New("unsuccessful registration") - return s - } - } - } - return s -} - // SensorUpdateInfo is the JSON structure required to update HA with the current // sensor state. type SensorUpdateInfo struct { @@ -123,90 +69,3 @@ func (upd *SensorUpdateInfo) RequestData() json.RawMessage { } return data } - -func (upd *SensorUpdateInfo) ResponseHandler(res bytes.Buffer, respCh chan api.Response) { - response, err := marshalResponse(res) - if err != nil { - respCh <- &SensorUpdateResponse{ - err: err, - } - } - respCh <- NewSensorUpdateResponse(upd.UniqueID, response) -} - -type SensorUpdateResponse struct { - err error - disabled bool -} - -func (r SensorUpdateResponse) Error() error { - return r.err -} - -func (r SensorUpdateResponse) Type() api.RequestType { - return api.RequestTypeUpdateSensorStates -} - -func (r SensorUpdateResponse) Disabled() bool { - return r.disabled -} - -func (r SensorUpdateResponse) Registered() bool { - return true -} - -func NewSensorUpdateResponse(i string, r map[string]interface{}) *SensorUpdateResponse { - s := new(SensorUpdateResponse) - if v, ok := r[i]; ok { - status, err := assertAs[map[string]interface{}](v) - if err != nil { - s.err = err - return s - } - success, err := assertAs[bool](status["success"]) - if err != nil { - s.err = err - return s - } else { - if !success { - hassErr, err := assertAs[map[string]interface{}](status["error"]) - if err != nil { - s.err = errors.New("unknown error") - return s - } else { - s.err = fmt.Errorf("code %s: %s", hassErr["code"], hassErr["message"]) - return s - } - } - if _, ok := status["is_disabled"]; ok { - s.disabled = true - } else { - s.disabled = false - } - } - } - - return s -} - -func marshalResponse(raw bytes.Buffer) (map[string]interface{}, error) { - var r interface{} - err := json.Unmarshal(raw.Bytes(), &r) - if err != nil { - return nil, fmt.Errorf("could not unmarshal response (%s)", raw.String()) - } - response, ok := r.(map[string]interface{}) - if !ok { - return nil, errors.New("could not assert response as map") - } - return response, nil -} - -func assertAs[T any](thing interface{}) (T, error) { - if asT, ok := thing.(T); !ok { - return *new(T), errors.New("could not assert value") - } else { - return asT, nil - } - -} diff --git a/internal/hass/sensor/sensor_test.go b/internal/hass/sensor/sensor_test.go index dce3067b6..8792dc901 100644 --- a/internal/hass/sensor/sensor_test.go +++ b/internal/hass/sensor/sensor_test.go @@ -6,7 +6,6 @@ package sensor import ( - "bytes" "encoding/json" "reflect" "testing" @@ -14,7 +13,7 @@ import ( "github.com/joshuar/go-hass-agent/internal/hass/api" ) -func TestSensorRegistrationInfo_RequestData(t *testing.T) { +func TestSensorRegistrationInfo_RequestType(t *testing.T) { type fields struct { State interface{} StateAttributes interface{} @@ -31,17 +30,11 @@ func TestSensorRegistrationInfo_RequestData(t *testing.T) { tests := []struct { name string fields fields - want json.RawMessage + want api.RequestType }{ { - name: "successful test", - fields: fields{ - Name: "aSensor", - Type: "aType", - State: "someState", - UniqueID: "anID", - }, - want: json.RawMessage(`{"state":"someState","unique_id":"anID","type":"aType","name":"aSensor"}`), + name: "default test", + want: api.RequestTypeRegisterSensor, }, } for _, tt := range tests { @@ -59,14 +52,14 @@ func TestSensorRegistrationInfo_RequestData(t *testing.T) { DeviceClass: tt.fields.DeviceClass, Disabled: tt.fields.Disabled, } - if got := reg.RequestData(); !reflect.DeepEqual(got, tt.want) { - t.Errorf("SensorRegistrationInfo.RequestData() = %v, want %v", got, tt.want) + if got := reg.RequestType(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("SensorRegistrationInfo.RequestType() = %v, want %v", got, tt.want) } }) } } -func TestSensorRegistrationInfo_ResponseHandler(t *testing.T) { +func TestSensorRegistrationInfo_RequestData(t *testing.T) { type fields struct { State interface{} StateAttributes interface{} @@ -80,16 +73,21 @@ func TestSensorRegistrationInfo_ResponseHandler(t *testing.T) { DeviceClass string Disabled bool } - type args struct { - res bytes.Buffer - respCh chan api.Response - } tests := []struct { name string fields fields - args args + want json.RawMessage }{ - // TODO: Add test cases. + { + name: "successful test", + fields: fields{ + Name: "aSensor", + Type: "aType", + State: "someState", + UniqueID: "anID", + }, + want: json.RawMessage(`{"state":"someState","unique_id":"anID","type":"aType","name":"aSensor"}`), + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -106,32 +104,14 @@ func TestSensorRegistrationInfo_ResponseHandler(t *testing.T) { DeviceClass: tt.fields.DeviceClass, Disabled: tt.fields.Disabled, } - reg.ResponseHandler(tt.args.res, tt.args.respCh) - }) - } -} - -func TestNewSensorRegistrationResponse(t *testing.T) { - type args struct { - r map[string]interface{} - } - tests := []struct { - name string - args args - want *SensorRegistrationResponse - }{ - // TODO: Add test cases. - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := NewSensorRegistrationResponse(tt.args.r); !reflect.DeepEqual(got, tt.want) { - t.Errorf("NewSensorRegistrationResponse() = %v, want %v", got, tt.want) + if got := reg.RequestData(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("SensorRegistrationInfo.RequestData() = %v, want %v", got, tt.want) } }) } } -func TestSensorUpdateInfo_RequestData(t *testing.T) { +func TestSensorUpdateInfo_RequestType(t *testing.T) { type fields struct { StateAttributes interface{} State interface{} @@ -142,9 +122,12 @@ func TestSensorUpdateInfo_RequestData(t *testing.T) { tests := []struct { name string fields fields - want json.RawMessage + want api.RequestType }{ - // TODO: Add test cases. + { + name: "default test", + want: api.RequestTypeUpdateSensorStates, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -155,14 +138,14 @@ func TestSensorUpdateInfo_RequestData(t *testing.T) { Type: tt.fields.Type, UniqueID: tt.fields.UniqueID, } - if got := upd.RequestData(); !reflect.DeepEqual(got, tt.want) { - t.Errorf("SensorUpdateInfo.RequestData() = %v, want %v", got, tt.want) + if got := upd.RequestType(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("SensorUpdateInfo.RequestType() = %v, want %v", got, tt.want) } }) } } -func TestSensorUpdateInfo_ResponseHandler(t *testing.T) { +func TestSensorUpdateInfo_RequestData(t *testing.T) { type fields struct { StateAttributes interface{} State interface{} @@ -170,16 +153,20 @@ func TestSensorUpdateInfo_ResponseHandler(t *testing.T) { Type string UniqueID string } - type args struct { - res bytes.Buffer - respCh chan api.Response - } tests := []struct { name string fields fields - args args + want json.RawMessage }{ - // TODO: Add test cases. + { + name: "successful test", + fields: fields{ + Type: "aType", + State: "someState", + UniqueID: "anID", + }, + want: json.RawMessage(`{"state":"someState","type":"aType","unique_id":"anID"}`), + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -190,179 +177,8 @@ func TestSensorUpdateInfo_ResponseHandler(t *testing.T) { Type: tt.fields.Type, UniqueID: tt.fields.UniqueID, } - upd.ResponseHandler(tt.args.res, tt.args.respCh) - }) - } -} - -func TestSensorUpdateResponse_Error(t *testing.T) { - type fields struct { - err error - disabled bool - } - tests := []struct { - name string - fields fields - wantErr bool - }{ - // TODO: Add test cases. - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - r := SensorUpdateResponse{ - err: tt.fields.err, - disabled: tt.fields.disabled, - } - if err := r.Error(); (err != nil) != tt.wantErr { - t.Errorf("SensorUpdateResponse.Error() error = %v, wantErr %v", err, tt.wantErr) - } - }) - } -} - -func TestSensorUpdateResponse_Type(t *testing.T) { - type fields struct { - err error - disabled bool - } - tests := []struct { - name string - fields fields - want api.RequestType - }{ - // TODO: Add test cases. - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - r := SensorUpdateResponse{ - err: tt.fields.err, - disabled: tt.fields.disabled, - } - if got := r.Type(); !reflect.DeepEqual(got, tt.want) { - t.Errorf("SensorUpdateResponse.Type() = %v, want %v", got, tt.want) - } - }) - } -} - -func TestSensorUpdateResponse_Disabled(t *testing.T) { - type fields struct { - err error - disabled bool - } - tests := []struct { - name string - fields fields - want bool - }{ - // TODO: Add test cases. - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - r := SensorUpdateResponse{ - err: tt.fields.err, - disabled: tt.fields.disabled, - } - if got := r.Disabled(); got != tt.want { - t.Errorf("SensorUpdateResponse.Disabled() = %v, want %v", got, tt.want) - } - }) - } -} - -func TestSensorUpdateResponse_Registered(t *testing.T) { - type fields struct { - err error - disabled bool - } - tests := []struct { - name string - fields fields - want bool - }{ - // TODO: Add test cases. - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - r := SensorUpdateResponse{ - err: tt.fields.err, - disabled: tt.fields.disabled, - } - if got := r.Registered(); got != tt.want { - t.Errorf("SensorUpdateResponse.Registered() = %v, want %v", got, tt.want) - } - }) - } -} - -func TestNewSensorUpdateResponse(t *testing.T) { - type args struct { - i string - r map[string]interface{} - } - tests := []struct { - name string - args args - want *SensorUpdateResponse - }{ - // TODO: Add test cases. - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := NewSensorUpdateResponse(tt.args.i, tt.args.r); !reflect.DeepEqual(got, tt.want) { - t.Errorf("NewSensorUpdateResponse() = %v, want %v", got, tt.want) - } - }) - } -} - -func Test_marshalResponse(t *testing.T) { - type args struct { - raw bytes.Buffer - } - tests := []struct { - name string - args args - want map[string]interface{} - wantErr bool - }{ - // TODO: Add test cases. - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := marshalResponse(tt.args.raw) - if (err != nil) != tt.wantErr { - t.Errorf("marshalResponse() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("marshalResponse() = %v, want %v", got, tt.want) - } - }) - } -} - -func Test_assertAs(t *testing.T) { - type args struct { - thing interface{} - } - tests := []struct { - name string - args args - want string - wantErr bool - }{ - // TODO: Add test cases. - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := assertAs[string](tt.args.thing) - if (err != nil) != tt.wantErr { - t.Errorf("assertAs() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("assertAs() = %v, want %v", got, tt.want) + if got := upd.RequestData(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("SensorUpdateInfo.RequestData() = %v, want %v", got, tt.want) } }) } diff --git a/internal/tracker/location.go b/internal/tracker/location.go index d7a68c229..90d9315c3 100644 --- a/internal/tracker/location.go +++ b/internal/tracker/location.go @@ -7,7 +7,6 @@ package tracker import ( "context" - "sync" "github.com/joshuar/go-hass-agent/internal/hass" "github.com/joshuar/go-hass-agent/internal/hass/api" @@ -41,24 +40,16 @@ func marshalLocationUpdate(l Location) *hass.LocationUpdate { } } -func updateLocation(ctx context.Context, a agent, l Location) { - respCh := make(chan api.Response, 1) - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - api.ExecuteRequest(ctx, marshalLocationUpdate(l), a, respCh) - }() - wg.Add(1) - go func() { - defer wg.Done() - // defer close(respCh) - response := <-respCh - if response.Error() != nil { - log.Error().Err(response.Error()). - Msg("Failed to update location.") - } else { - log.Debug().Msg("Location Updated.") - } - }() +func updateLocation(ctx context.Context, l Location) { + respCh := make(chan interface{}, 1) + go api.ExecuteRequest(ctx, marshalLocationUpdate(l), respCh) + response := <-respCh + switch r := response.(type) { + case []byte: + log.Debug().Msg("Location Updated.") + case error: + log.Warn().Err(r).Msg("Failed to update location.") + default: + log.Warn().Msgf("Unknown response type %T", r) + } } diff --git a/internal/tracker/mock_agent_test.go b/internal/tracker/mock_agent_test.go index 61fceb1b9..9f18b700a 100644 --- a/internal/tracker/mock_agent_test.go +++ b/internal/tracker/mock_agent_test.go @@ -17,9 +17,6 @@ var _ agent = &agentMock{} // // // make and configure a mocked agent // mockedagent := &agentMock{ -// GetConfigFunc: func(s string, ifaceVal interface{}) error { -// panic("mock out the GetConfig method") -// }, // StoragePathFunc: func(s string) (string, error) { // panic("mock out the StoragePath method") // }, @@ -30,67 +27,20 @@ var _ agent = &agentMock{} // // } type agentMock struct { - // GetConfigFunc mocks the GetConfig method. - GetConfigFunc func(s string, ifaceVal interface{}) error - // StoragePathFunc mocks the StoragePath method. StoragePathFunc func(s string) (string, error) // calls tracks calls to the methods. calls struct { - // GetConfig holds details about calls to the GetConfig method. - GetConfig []struct { - // S is the s argument value. - S string - // IfaceVal is the ifaceVal argument value. - IfaceVal interface{} - } // StoragePath holds details about calls to the StoragePath method. StoragePath []struct { // S is the s argument value. S string } } - lockGetConfig sync.RWMutex lockStoragePath sync.RWMutex } -// GetConfig calls GetConfigFunc. -func (mock *agentMock) GetConfig(s string, ifaceVal interface{}) error { - if mock.GetConfigFunc == nil { - panic("agentMock.GetConfigFunc: method is nil but agent.GetConfig was just called") - } - callInfo := struct { - S string - IfaceVal interface{} - }{ - S: s, - IfaceVal: ifaceVal, - } - mock.lockGetConfig.Lock() - mock.calls.GetConfig = append(mock.calls.GetConfig, callInfo) - mock.lockGetConfig.Unlock() - return mock.GetConfigFunc(s, ifaceVal) -} - -// GetConfigCalls gets all the calls that were made to GetConfig. -// Check the length with: -// -// len(mockedagent.GetConfigCalls()) -func (mock *agentMock) GetConfigCalls() []struct { - S string - IfaceVal interface{} -} { - var calls []struct { - S string - IfaceVal interface{} - } - mock.lockGetConfig.RLock() - calls = mock.calls.GetConfig - mock.lockGetConfig.RUnlock() - return calls -} - // StoragePath calls StoragePathFunc. func (mock *agentMock) StoragePath(s string) (string, error) { if mock.StoragePathFunc == nil { diff --git a/internal/hass/api/mock_Response_test.go b/internal/tracker/mock_apiResponse_test.go similarity index 54% rename from internal/hass/api/mock_Response_test.go rename to internal/tracker/mock_apiResponse_test.go index f4a19d3e6..aa72aea73 100644 --- a/internal/hass/api/mock_Response_test.go +++ b/internal/tracker/mock_apiResponse_test.go @@ -1,61 +1,53 @@ // Code generated by moq; DO NOT EDIT. // github.com/matryer/moq -package api +package tracker import ( + "github.com/joshuar/go-hass-agent/internal/hass/api" "sync" ) -// Ensure, that ResponseMock does implement Response. +// Ensure, that apiResponseMock does implement apiResponse. // If this is not the case, regenerate this file with moq. -var _ Response = &ResponseMock{} +var _ apiResponse = &apiResponseMock{} -// ResponseMock is a mock implementation of Response. +// apiResponseMock is a mock implementation of apiResponse. // -// func TestSomethingThatUsesResponse(t *testing.T) { +// func TestSomethingThatUsesapiResponse(t *testing.T) { // -// // make and configure a mocked Response -// mockedResponse := &ResponseMock{ +// // make and configure a mocked apiResponse +// mockedapiResponse := &apiResponseMock{ // DisabledFunc: func() bool { // panic("mock out the Disabled method") // }, -// ErrorFunc: func() error { -// panic("mock out the Error method") -// }, // RegisteredFunc: func() bool { // panic("mock out the Registered method") // }, -// TypeFunc: func() RequestType { +// TypeFunc: func() api.ResponseType { // panic("mock out the Type method") // }, // } // -// // use mockedResponse in code that requires Response +// // use mockedapiResponse in code that requires apiResponse // // and then make assertions. // // } -type ResponseMock struct { +type apiResponseMock struct { // DisabledFunc mocks the Disabled method. DisabledFunc func() bool - // ErrorFunc mocks the Error method. - ErrorFunc func() error - // RegisteredFunc mocks the Registered method. RegisteredFunc func() bool // TypeFunc mocks the Type method. - TypeFunc func() RequestType + TypeFunc func() api.ResponseType // calls tracks calls to the methods. calls struct { // Disabled holds details about calls to the Disabled method. Disabled []struct { } - // Error holds details about calls to the Error method. - Error []struct { - } // Registered holds details about calls to the Registered method. Registered []struct { } @@ -64,15 +56,14 @@ type ResponseMock struct { } } lockDisabled sync.RWMutex - lockError sync.RWMutex lockRegistered sync.RWMutex lockType sync.RWMutex } // Disabled calls DisabledFunc. -func (mock *ResponseMock) Disabled() bool { +func (mock *apiResponseMock) Disabled() bool { if mock.DisabledFunc == nil { - panic("ResponseMock.DisabledFunc: method is nil but Response.Disabled was just called") + panic("apiResponseMock.DisabledFunc: method is nil but apiResponse.Disabled was just called") } callInfo := struct { }{} @@ -85,8 +76,8 @@ func (mock *ResponseMock) Disabled() bool { // DisabledCalls gets all the calls that were made to Disabled. // Check the length with: // -// len(mockedResponse.DisabledCalls()) -func (mock *ResponseMock) DisabledCalls() []struct { +// len(mockedapiResponse.DisabledCalls()) +func (mock *apiResponseMock) DisabledCalls() []struct { } { var calls []struct { } @@ -96,37 +87,10 @@ func (mock *ResponseMock) DisabledCalls() []struct { return calls } -// Error calls ErrorFunc. -func (mock *ResponseMock) Error() error { - if mock.ErrorFunc == nil { - panic("ResponseMock.ErrorFunc: method is nil but Response.Error was just called") - } - callInfo := struct { - }{} - mock.lockError.Lock() - mock.calls.Error = append(mock.calls.Error, callInfo) - mock.lockError.Unlock() - return mock.ErrorFunc() -} - -// ErrorCalls gets all the calls that were made to Error. -// Check the length with: -// -// len(mockedResponse.ErrorCalls()) -func (mock *ResponseMock) ErrorCalls() []struct { -} { - var calls []struct { - } - mock.lockError.RLock() - calls = mock.calls.Error - mock.lockError.RUnlock() - return calls -} - // Registered calls RegisteredFunc. -func (mock *ResponseMock) Registered() bool { +func (mock *apiResponseMock) Registered() bool { if mock.RegisteredFunc == nil { - panic("ResponseMock.RegisteredFunc: method is nil but Response.Registered was just called") + panic("apiResponseMock.RegisteredFunc: method is nil but apiResponse.Registered was just called") } callInfo := struct { }{} @@ -139,8 +103,8 @@ func (mock *ResponseMock) Registered() bool { // RegisteredCalls gets all the calls that were made to Registered. // Check the length with: // -// len(mockedResponse.RegisteredCalls()) -func (mock *ResponseMock) RegisteredCalls() []struct { +// len(mockedapiResponse.RegisteredCalls()) +func (mock *apiResponseMock) RegisteredCalls() []struct { } { var calls []struct { } @@ -151,9 +115,9 @@ func (mock *ResponseMock) RegisteredCalls() []struct { } // Type calls TypeFunc. -func (mock *ResponseMock) Type() RequestType { +func (mock *apiResponseMock) Type() api.ResponseType { if mock.TypeFunc == nil { - panic("ResponseMock.TypeFunc: method is nil but Response.Type was just called") + panic("apiResponseMock.TypeFunc: method is nil but apiResponse.Type was just called") } callInfo := struct { }{} @@ -166,8 +130,8 @@ func (mock *ResponseMock) Type() RequestType { // TypeCalls gets all the calls that were made to Type. // Check the length with: // -// len(mockedResponse.TypeCalls()) -func (mock *ResponseMock) TypeCalls() []struct { +// len(mockedapiResponse.TypeCalls()) +func (mock *apiResponseMock) TypeCalls() []struct { } { var calls []struct { } diff --git a/internal/tracker/sensor.go b/internal/tracker/sensor.go index 5caf5caad..50a0479a1 100644 --- a/internal/tracker/sensor.go +++ b/internal/tracker/sensor.go @@ -6,6 +6,9 @@ package tracker import ( + "fmt" + "strings" + "github.com/joshuar/go-hass-agent/internal/hass/sensor" ) @@ -27,6 +30,15 @@ type Sensor interface { Attributes() interface{} } +func prettyPrintState(s Sensor) string { + var b strings.Builder + fmt.Fprintf(&b, "%v", s.State()) + if s.Units() != "" { + fmt.Fprintf(&b, " %s", s.Units()) + } + return b.String() +} + func marshalSensorUpdate(s Sensor) *sensor.SensorUpdateInfo { return &sensor.SensorUpdateInfo{ StateAttributes: s.Attributes(), diff --git a/internal/tracker/tracker.go b/internal/tracker/tracker.go index 72e340475..ef3c55f6e 100644 --- a/internal/tracker/tracker.go +++ b/internal/tracker/tracker.go @@ -8,7 +8,6 @@ package tracker import ( "context" "errors" - "fmt" "sort" "sync" @@ -32,10 +31,16 @@ type Registry interface { //go:generate moq -out mock_agent_test.go . agent type agent interface { - GetConfig(string, interface{}) error StoragePath(string) (string, error) } +//go:generate moq -out mock_apiResponse_test.go . apiResponse +type apiResponse interface { + Registered() bool + Disabled() bool + Type() api.ResponseType +} + type SensorTracker struct { registry Registry agentConfig agent @@ -85,11 +90,12 @@ func (t *SensorTracker) SensorList() []string { // send will send a sensor update to HA, checking to ensure the sensor is not // disabled. It will also update the local registry state based on the response. -func (t *SensorTracker) send(ctx context.Context, config agent, sensorUpdate Sensor) { - var wg sync.WaitGroup +func (t *SensorTracker) send(ctx context.Context, sensorUpdate Sensor) { var req api.Request if disabled := <-t.registry.IsDisabled(sensorUpdate.ID()); disabled { - log.Debug().Msgf("Sensor %s is disabled. Ignoring update.", sensorUpdate.ID()) + log.Debug().Str("id", sensorUpdate.ID()). + Msg("Sensor is disabled. Ignoring update.") + return } registered := <-t.registry.IsRegistered(sensorUpdate.ID()) switch registered { @@ -98,73 +104,66 @@ func (t *SensorTracker) send(ctx context.Context, config agent, sensorUpdate Sen case false: req = marshalSensorRegistration(sensorUpdate) } - responseCh := make(chan api.Response, 1) - wg.Add(1) - go func() { - defer wg.Done() - response := <-responseCh - t.handle(response, sensorUpdate) - }() - wg.Add(1) - go func() { - defer wg.Done() - api.ExecuteRequest(ctx, req, config, responseCh) - }() - wg.Wait() + responseCh := make(chan interface{}, 1) + go api.ExecuteRequest(ctx, req, responseCh) + response := <-responseCh + switch r := response.(type) { + case apiResponse: + t.handle(r, sensorUpdate) + case error: + log.Warn().Err(r).Str("id", sensorUpdate.ID()). + Msg("Failed to send sensor data to Home Assistant.") + default: + log.Warn().Msgf("Unknown response type %T", r) + } } // handle will take the response sent back by the Home Assistant API and run // appropriate actions. This includes recording registration or setting disabled // status. -func (t *SensorTracker) handle(response api.Response, sensorUpdate Sensor) { - if response.Error() != nil { - log.Error().Err(response.Error()). - Str("name", sensorUpdate.Name()). - Msg("Failed to send sensor data to Home Assistant.") - } else { - log.Debug(). +func (t *SensorTracker) handle(response apiResponse, sensorUpdate Sensor) { + log.Debug(). + Str("name", sensorUpdate.Name()). + Str("id", sensorUpdate.ID()). + Str("state", prettyPrintState(sensorUpdate)). + Msg("Sensor updated.") + if err := t.add(sensorUpdate); err != nil { + log.Warn().Err(err). Str("name", sensorUpdate.Name()). - Str("id", sensorUpdate.ID()). - Str("state", fmt.Sprintf("%v %s", sensorUpdate.State(), sensorUpdate.Units())). - Msg("Sensor updated.") - if err := t.add(sensorUpdate); err != nil { - log.Warn().Err(err). - Str("name", sensorUpdate.Name()). - Msg("Unable to add state for sensor to tracker.") - } - if response.Type() == api.RequestTypeUpdateSensorStates { - switch { - case response.Disabled(): - if err := t.registry.SetDisabled(sensorUpdate.ID(), true); err != nil { - log.Warn().Err(err). - Str("name", sensorUpdate.Name()). - Msg("Unable to set as disabled in registry.") - } else { - log.Debug(). - Str("name", sensorUpdate.Name()). - Msg("Sensor set to disabled.") - } - case !response.Disabled() && <-t.registry.IsDisabled(sensorUpdate.ID()): - if err := t.registry.SetDisabled(sensorUpdate.ID(), false); err != nil { - log.Warn().Err(err). - Str("name", sensorUpdate.Name()). - Msg("Unable to set as not disabled in registry.") - } - } - } - if response.Type() == api.RequestTypeRegisterSensor && response.Registered() { - if err := t.registry.SetRegistered(sensorUpdate.ID(), true); err != nil { + Msg("Unable to add state for sensor to tracker.") + } + if response.Type() == api.ResponseTypeUpdate { + switch { + case response.Disabled(): + if err := t.registry.SetDisabled(sensorUpdate.ID(), true); err != nil { log.Warn().Err(err). Str("name", sensorUpdate.Name()). - Msg("Unable to set as registered in registry.") + Msg("Unable to set as disabled in registry.") } else { log.Debug(). Str("name", sensorUpdate.Name()). - Str("id", sensorUpdate.ID()). - Msg("Sensor registered in Home Assistant.") + Msg("Sensor set to disabled.") + } + case !response.Disabled() && <-t.registry.IsDisabled(sensorUpdate.ID()): + if err := t.registry.SetDisabled(sensorUpdate.ID(), false); err != nil { + log.Warn().Err(err). + Str("name", sensorUpdate.Name()). + Msg("Unable to set as not disabled in registry.") } } } + if response.Type() == api.ResponseTypeRegistration && response.Registered() { + if err := t.registry.SetRegistered(sensorUpdate.ID(), true); err != nil { + log.Warn().Err(err). + Str("name", sensorUpdate.Name()). + Msg("Unable to set as registered in registry.") + } else { + log.Debug(). + Str("name", sensorUpdate.Name()). + Str("id", sensorUpdate.ID()). + Msg("Sensor registered in Home Assistant.") + } + } } // UpdateSensors is the externally exposed method that devices can use to send a @@ -183,9 +182,9 @@ func (t *SensorTracker) UpdateSensors(ctx context.Context, sensors ...interface{ for s := range sensorData { switch sensor := s.(type) { case Sensor: - t.send(ctx, t.agentConfig, sensor) + t.send(ctx, sensor) case Location: - updateLocation(ctx, t.agentConfig, sensor) + updateLocation(ctx, sensor) } i++ } @@ -208,9 +207,8 @@ func NewSensorTracker(agentConfig agent) (*SensorTracker, error) { return nil, errors.New("unable to create a sensor tracker") } sensorTracker := &SensorTracker{ - registry: db, - sensor: make(map[string]Sensor), - agentConfig: agentConfig, + registry: db, + sensor: make(map[string]Sensor), } return sensorTracker, nil } diff --git a/internal/tracker/tracker_test.go b/internal/tracker/tracker_test.go index e75e068ab..5bfb7a600 100644 --- a/internal/tracker/tracker_test.go +++ b/internal/tracker/tracker_test.go @@ -8,28 +8,53 @@ package tracker import ( "context" "encoding/json" - "errors" "net/http" "net/http/httptest" - "strings" + "reflect" "sync" "testing" - "github.com/joshuar/go-hass-agent/internal/agent/config" "github.com/joshuar/go-hass-agent/internal/hass/api" "github.com/joshuar/go-hass-agent/internal/hass/sensor" "github.com/stretchr/testify/assert" ) +func mockServer(t *testing.T) *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer r.Body.Close() + raw := struct { + Type string `json:"type"` + }{} + err := json.NewDecoder(r.Body).Decode(&raw) + assert.Nil(t, err) + switch raw.Type { + case "update_sensor_states": + upd := &sensor.SensorUpdateInfo{} + assert.Nil(t, err) + json.NewDecoder(r.Body).Decode(&upd) + assert.Nil(t, err) + resp := "{" + `"` + upd.UniqueID + `"` + `:{"success":true}}` + w.WriteHeader(http.StatusOK) + w.Write([]byte(resp)) + case "register_sensor": + reg := &sensor.SensorRegistrationInfo{} + json.NewDecoder(r.Body).Decode(®) + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"success":true}`)) + } + })) +} + func TestSensorTracker_add(t *testing.T) { mockSensor := &SensorMock{ IDFunc: func() string { return "sensorID" }, } type fields struct { - registry Registry - sensor map[string]Sensor - mu sync.RWMutex + registry Registry + agentConfig agent + sensor map[string]Sensor + mu sync.RWMutex } type args struct { s Sensor @@ -56,12 +81,13 @@ func TestSensorTracker_add(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - tracker := &SensorTracker{ - registry: tt.fields.registry, - sensor: tt.fields.sensor, - mu: tt.fields.mu, + tr := &SensorTracker{ + registry: tt.fields.registry, + agentConfig: tt.fields.agentConfig, + sensor: tt.fields.sensor, + mu: tt.fields.mu, } - if err := tracker.add(tt.args.s); (err != nil) != tt.wantErr { + if err := tr.add(tt.args.s); (err != nil) != tt.wantErr { t.Errorf("SensorTracker.add() error = %v, wantErr %v", err, tt.wantErr) } }) @@ -74,9 +100,10 @@ func TestSensorTracker_Get(t *testing.T) { mockMap["sensorID"] = mockSensor type fields struct { - registry Registry - sensor map[string]Sensor - mu sync.RWMutex + registry Registry + agentConfig agent + sensor map[string]Sensor + mu sync.RWMutex } type args struct { id string @@ -93,6 +120,7 @@ func TestSensorTracker_Get(t *testing.T) { fields: fields{sensor: mockMap}, args: args{id: "sensorID"}, wantErr: false, + want: mockSensor, }, { name: "unsuccessful get", @@ -103,153 +131,264 @@ func TestSensorTracker_Get(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - tracker := &SensorTracker{ - registry: tt.fields.registry, - sensor: tt.fields.sensor, - mu: tt.fields.mu, + tr := &SensorTracker{ + registry: tt.fields.registry, + agentConfig: tt.fields.agentConfig, + sensor: tt.fields.sensor, + mu: tt.fields.mu, } - _, err := tracker.Get(tt.args.id) + got, err := tr.Get(tt.args.id) if (err != nil) != tt.wantErr { t.Errorf("SensorTracker.Get() error = %v, wantErr %v", err, tt.wantErr) return } - // if !reflect.DeepEqual(got, tt.want) { - // t.Errorf("SensorTracker.Get() = %v, want %v", got, tt.want) - // } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("SensorTracker.Get() = %v, want %v", got, tt.want) + } }) } } -// func NewMockConfig(t *testing.T) *mockConfig { -// path, err := os.MkdirTemp("/tmp", "go-hass-agent-test") -// assert.Nil(t, err) -// return &mockConfig{ -// storage: path, -// } -// } - -func TestSensorTracker_Update(t *testing.T) { - mockServer := func(t *testing.T) *httptest.Server { - return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - req := &api.UnencryptedRequest{} - err := json.NewDecoder(r.Body).Decode(&req) - assert.Nil(t, err) - switch req.Type { - case "update_sensor_states": - switch { - case strings.Contains(string(req.Data), "bad"): - w.WriteHeader(http.StatusOK) - w.Write([]byte(`{"sensorID":{"success":false,"error":{"code":"invalid_format","message": "Unexpected value for type"}}}`)) - case strings.Contains(string(req.Data), "disabled"): - w.WriteHeader(http.StatusOK) - w.Write([]byte(`{"sensorID":{"success":true,"is_disabled":true}}`)) - default: - w.WriteHeader(http.StatusOK) - w.Write([]byte(`{"sensorID":{"success":true}}`)) - } - case "register_sensor": - w.WriteHeader(http.StatusCreated) - w.Write([]byte(`{"success":true}`)) - } - })) +func TestSensorTracker_SensorList(t *testing.T) { + mockSensor := &SensorMock{ + StateFunc: func() interface{} { return "aState" }, } - server := mockServer(t) - defer server.Close() + mockMap := make(map[string]Sensor) + mockMap["sensorID"] = mockSensor - mockExistingRegistry := &RegistryMock{ - IsRegisteredFunc: func(s string) chan bool { - valueCh := make(chan bool, 1) - valueCh <- true - return valueCh - }, - IsDisabledFunc: func(s string) chan bool { - valueCh := make(chan bool, 1) - valueCh <- false - return valueCh - }, - SetDisabledFunc: func(s string, b bool) error { - return nil - }, + type fields struct { + registry Registry + agentConfig agent + sensor map[string]Sensor + mu sync.RWMutex } - - mockNewRegistry := &RegistryMock{ - IsRegisteredFunc: func(s string) chan bool { - valueCh := make(chan bool, 1) - valueCh <- false - return valueCh + tests := []struct { + name string + fields fields + want []string + }{ + { + name: "with sensors", + fields: fields{sensor: mockMap}, + want: []string{"sensorID"}, }, - IsDisabledFunc: func(s string) chan bool { - valueCh := make(chan bool, 1) - valueCh <- false - return valueCh + { + name: "without sensors", + want: nil, }, - SetDisabledFunc: func(s string, b bool) error { return nil }, - SetRegisteredFunc: func(s string, b bool) error { return nil }, } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tr := &SensorTracker{ + registry: tt.fields.registry, + agentConfig: tt.fields.agentConfig, + sensor: tt.fields.sensor, + mu: tt.fields.mu, + } + if got := tr.SensorList(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("SensorTracker.SensorList() = %v, want %v", got, tt.want) + } + }) + } +} - mockSensorUpdate := &SensorMock{ +func TestSensorTracker_send(t *testing.T) { + mockServer := mockServer(t) + defer mockServer.Close() + mockConfig := &api.APIConfig{ + APIURL: mockServer.URL, + } + ctx := api.NewContext(context.TODO(), mockConfig) + mockUpdate := &SensorMock{ + IDFunc: func() string { return "updateID" }, + NameFunc: func() string { return "Update Sensor" }, + UnitsFunc: func() string { return "" }, + StateFunc: func() interface{} { return "aState" }, + AttributesFunc: func() interface{} { return nil }, + IconFunc: func() string { return "anIcon" }, + SensorTypeFunc: func() sensor.SensorType { return sensor.TypeSensor }, + } + mockRegistration := &SensorMock{ + IDFunc: func() string { return "regID" }, + NameFunc: func() string { return "Registration Sensor" }, + UnitsFunc: func() string { return "" }, + StateFunc: func() interface{} { return "aState" }, AttributesFunc: func() interface{} { return nil }, - StateFunc: func() interface{} { return "goodState" }, - IconFunc: func() string { return "mdi:icon" }, + IconFunc: func() string { return "anIcon" }, SensorTypeFunc: func() sensor.SensorType { return sensor.TypeSensor }, - IDFunc: func() string { return "sensorID" }, - NameFunc: func() string { return "sensorName" }, - UnitsFunc: func() string { return "units" }, DeviceClassFunc: func() sensor.SensorDeviceClass { return sensor.Duration }, StateClassFunc: func() sensor.SensorStateClass { return sensor.StateMeasurement }, CategoryFunc: func() string { return "" }, } - - mockBadSensorUpdate := &SensorMock{ + mockDisabled := &SensorMock{ + IDFunc: func() string { return "disabledID" }, + NameFunc: func() string { return "Update Sensor" }, + UnitsFunc: func() string { return "" }, + StateFunc: func() interface{} { return "aState" }, AttributesFunc: func() interface{} { return nil }, - StateFunc: func() interface{} { return "badState" }, - IconFunc: func() string { return "mdi:icon" }, + IconFunc: func() string { return "anIcon" }, SensorTypeFunc: func() sensor.SensorType { return sensor.TypeSensor }, - IDFunc: func() string { return "sensorID" }, - NameFunc: func() string { return "sensorName" }, - UnitsFunc: func() string { return "units" }, DeviceClassFunc: func() sensor.SensorDeviceClass { return sensor.Duration }, StateClassFunc: func() sensor.SensorStateClass { return sensor.StateMeasurement }, CategoryFunc: func() string { return "" }, } + mockMap := make(map[string]Sensor) + mockMap["updateID"] = mockUpdate + mockMap["regID"] = mockRegistration + mockMap["disabledID"] = mockDisabled + mockRegistry := &RegistryMock{ + IsDisabledFunc: func(s string) chan bool { + d := make(chan bool, 1) + switch s { + case "disabledID": + d <- true + default: + d <- false + } + close(d) + return d + }, + IsRegisteredFunc: func(s string) chan bool { + d := make(chan bool, 1) + switch s { + case "updateID": + d <- true + case "regID": + d <- false + } + close(d) + return d + }, + SetRegisteredFunc: func(s string, b bool) error { + return nil + }, + } + + type fields struct { + registry Registry + agentConfig agent + sensor map[string]Sensor + mu sync.RWMutex + } + type args struct { + ctx context.Context + sensorUpdate Sensor + } + tests := []struct { + name string + fields fields + args args + }{ + { + name: "sensor update", + fields: fields{sensor: mockMap, registry: mockRegistry}, + args: args{ctx: ctx, sensorUpdate: mockUpdate}, + }, + { + name: "sensor registration", + fields: fields{sensor: mockMap, registry: mockRegistry}, + args: args{ctx: ctx, sensorUpdate: mockRegistration}, + }, + { + name: "disabled sensor", + fields: fields{sensor: mockMap, registry: mockRegistry}, + args: args{ctx: ctx, sensorUpdate: mockDisabled}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tr := &SensorTracker{ + registry: tt.fields.registry, + agentConfig: tt.fields.agentConfig, + sensor: tt.fields.sensor, + mu: tt.fields.mu, + } + tr.send(tt.args.ctx, tt.args.sensorUpdate) + }) + } +} - mockDisabledSensorUpdate := &SensorMock{ +func TestSensorTracker_handle(t *testing.T) { + mockUpdate := &SensorMock{ + IDFunc: func() string { return "updateID" }, + NameFunc: func() string { return "Update Sensor" }, + UnitsFunc: func() string { return "" }, + StateFunc: func() interface{} { return "aState" }, + AttributesFunc: func() interface{} { return nil }, + IconFunc: func() string { return "anIcon" }, + SensorTypeFunc: func() sensor.SensorType { return sensor.TypeSensor }, + } + mockUpdateResponse := &apiResponseMock{ + TypeFunc: func() api.ResponseType { return api.ResponseTypeUpdate }, + DisabledFunc: func() bool { return false }, + } + mockRegistration := &SensorMock{ + IDFunc: func() string { return "regID" }, + NameFunc: func() string { return "Registration Sensor" }, + UnitsFunc: func() string { return "" }, + StateFunc: func() interface{} { return "aState" }, AttributesFunc: func() interface{} { return nil }, - StateFunc: func() interface{} { return "disabled" }, - IconFunc: func() string { return "mdi:icon" }, + IconFunc: func() string { return "anIcon" }, SensorTypeFunc: func() sensor.SensorType { return sensor.TypeSensor }, - IDFunc: func() string { return "sensorID" }, - NameFunc: func() string { return "sensorName" }, - UnitsFunc: func() string { return "units" }, DeviceClassFunc: func() sensor.SensorDeviceClass { return sensor.Duration }, StateClassFunc: func() sensor.SensorStateClass { return sensor.StateMeasurement }, CategoryFunc: func() string { return "" }, } - - mockConfig := &agentMock{ - GetConfigFunc: func(s string, ifaceVal interface{}) error { - v := ifaceVal.(*string) + mockRegistrationResponse := &apiResponseMock{ + TypeFunc: func() api.ResponseType { return api.ResponseTypeRegistration }, + RegisteredFunc: func() bool { return true }, + } + mockDisabled := &SensorMock{ + IDFunc: func() string { return "disabledID" }, + NameFunc: func() string { return "Disabled Sensor" }, + UnitsFunc: func() string { return "" }, + StateFunc: func() interface{} { return "aState" }, + AttributesFunc: func() interface{} { return nil }, + IconFunc: func() string { return "anIcon" }, + SensorTypeFunc: func() sensor.SensorType { return sensor.TypeSensor }, + } + mockDisabledResponse := &apiResponseMock{ + TypeFunc: func() api.ResponseType { return api.ResponseTypeUpdate }, + DisabledFunc: func() bool { return true }, + } + mockMap := make(map[string]Sensor) + mockMap["updateID"] = mockUpdate + mockMap["regID"] = mockRegistration + mockMap["disabledID"] = mockDisabled + mockRegistry := &RegistryMock{ + IsDisabledFunc: func(s string) chan bool { + d := make(chan bool, 1) + d <- false + close(d) + return d + }, + IsRegisteredFunc: func(s string) chan bool { + d := make(chan bool, 1) switch s { - case config.PrefAPIURL: - *v = server.URL - return nil - case config.PrefSecret: - *v = "" - return nil - default: - return errors.New("not found") + case "updateID": + d <- true + case "regID": + d <- false } + close(d) + return d + }, + SetRegisteredFunc: func(s string, b bool) error { + return nil + }, + SetDisabledFunc: func(s string, b bool) error { + return nil }, } type fields struct { - registry Registry - sensor map[string]Sensor - mu sync.RWMutex + registry Registry + agentConfig agent + sensor map[string]Sensor + mu sync.RWMutex } type args struct { - ctx context.Context - config agent + response apiResponse sensorUpdate Sensor } tests := []struct { @@ -258,140 +397,156 @@ func TestSensorTracker_Update(t *testing.T) { args args }{ { - name: "successful update", - fields: fields{ - sensor: make(map[string]Sensor), - registry: mockExistingRegistry, - }, - args: args{ - ctx: context.Background(), - config: mockConfig, - sensorUpdate: mockSensorUpdate, - }, + name: "successful update", + args: args{response: mockUpdateResponse, sensorUpdate: mockUpdate}, + fields: fields{sensor: mockMap, registry: mockRegistry}, }, { - name: "bad update", - fields: fields{ - sensor: make(map[string]Sensor), - registry: mockExistingRegistry, - }, - args: args{ - ctx: context.Background(), - config: mockConfig, - sensorUpdate: mockBadSensorUpdate, - }, + name: "successful registration", + args: args{response: mockRegistrationResponse, sensorUpdate: mockRegistration}, + fields: fields{sensor: mockMap, registry: mockRegistry}, }, { - name: "disabled update", - fields: fields{ - sensor: make(map[string]Sensor), - registry: mockExistingRegistry, - }, - args: args{ - ctx: context.Background(), - config: mockConfig, - sensorUpdate: mockDisabledSensorUpdate, - }, - }, - { - name: "successful new", - fields: fields{ - sensor: make(map[string]Sensor), - registry: mockNewRegistry, - }, - args: args{ - ctx: context.Background(), - config: mockConfig, - sensorUpdate: mockSensorUpdate, - }, + name: "disabled update", + args: args{response: mockDisabledResponse, sensorUpdate: mockDisabled}, + fields: fields{sensor: mockMap, registry: mockRegistry}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - tracker := &SensorTracker{ - registry: tt.fields.registry, - sensor: tt.fields.sensor, - mu: tt.fields.mu, + tr := &SensorTracker{ + registry: tt.fields.registry, + agentConfig: tt.fields.agentConfig, + sensor: tt.fields.sensor, + mu: tt.fields.mu, } - tracker.send(tt.args.ctx, tt.args.config, tt.args.sensorUpdate) + tr.handle(tt.args.response, tt.args.sensorUpdate) }) } } -// func Test_startWorkers(t *testing.T) { -// ctx, cancelFunc := context.WithCancel(context.Background()) -// updateCh := make(chan interface{}) -// defer close(updateCh) -// defer cancelFunc() -// mockWorker := func(context.Context, chan interface{}) { -// t.Log("worker ran") -// } -// w := []func(context.Context, chan interface{}){mockWorker} - -// type args struct { -// ctx context.Context -// workers []func(context.Context, chan interface{}) -// updateCh chan interface{} -// } -// tests := []struct { -// name string -// args args -// }{ -// { -// name: "default test", -// args: args{ -// ctx: ctx, -// workers: w, -// updateCh: updateCh, -// }, -// }, -// } -// for _, tt := range tests { -// t.Run(tt.name, func(t *testing.T) { -// startWorkers(tt.args.ctx, tt.args.workers, tt.args.updateCh) -// }) -// } -// } +func TestSensorTracker_UpdateSensors(t *testing.T) { + mockServer := mockServer(t) + defer mockServer.Close() + mockConfig := &api.APIConfig{ + APIURL: mockServer.URL, + } + ctx := api.NewContext(context.TODO(), mockConfig) + mockUpdate := &SensorMock{ + IDFunc: func() string { return "updateID" }, + NameFunc: func() string { return "Update Sensor" }, + UnitsFunc: func() string { return "" }, + StateFunc: func() interface{} { return "aState" }, + AttributesFunc: func() interface{} { return nil }, + IconFunc: func() string { return "anIcon" }, + SensorTypeFunc: func() sensor.SensorType { return sensor.TypeSensor }, + } + mockMap := make(map[string]Sensor) + mockMap["updateID"] = mockUpdate + mockRegistry := &RegistryMock{ + IsDisabledFunc: func(s string) chan bool { + d := make(chan bool, 1) + d <- false + close(d) + return d + }, + IsRegisteredFunc: func(s string) chan bool { + d := make(chan bool, 1) + switch s { + case "updateID": + d <- true + case "regID": + d <- false + } + close(d) + return d + }, + SetRegisteredFunc: func(s string, b bool) error { + return nil + }, + SetDisabledFunc: func(s string, b bool) error { + return nil + }, + } + var single, many []interface{} + single = append(single, mockUpdate) + many = append(many, mockUpdate, mockUpdate, mockUpdate) -// func TestSensorTracker_trackUpdates(t *testing.T) { -// ctx, cancelFunc := context.WithCancel(context.Background()) -// updateCh := make(chan interface{}) -// defer close(updateCh) -// defer cancelFunc() + type fields struct { + registry Registry + agentConfig agent + sensor map[string]Sensor + mu sync.RWMutex + } + type args struct { + ctx context.Context + sensors []interface{} + } + tests := []struct { + name string + fields fields + args args + wantErr bool + }{ + { + name: "single update", + fields: fields{registry: mockRegistry, sensor: mockMap}, + args: args{ctx: ctx, sensors: single}, + }, + { + name: "many updates", + fields: fields{registry: mockRegistry, sensor: mockMap}, + args: args{ctx: ctx, sensors: many}, + }, + { + name: "no updates", + fields: fields{registry: mockRegistry, sensor: mockMap}, + args: args{ctx: ctx, sensors: nil}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tr := &SensorTracker{ + registry: tt.fields.registry, + agentConfig: tt.fields.agentConfig, + sensor: tt.fields.sensor, + mu: tt.fields.mu, + } + if err := tr.UpdateSensors(tt.args.ctx, tt.args.sensors...); (err != nil) != tt.wantErr { + t.Errorf("SensorTracker.UpdateSensors() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} -// type fields struct { -// registry Registry -// sensor map[string]Sensor -// mu sync.RWMutex -// } -// type args struct { -// ctx context.Context -// config agent -// updateCh chan interface{} -// } -// tests := []struct { -// name string -// fields fields -// args args -// }{ -// { -// name: "default test", -// args: args{ -// ctx: ctx, -// config: &agentMock{}, -// updateCh: updateCh, -// }, -// }, -// } -// for _, tt := range tests { -// t.Run(tt.name, func(t *testing.T) { -// tr := &SensorTracker{ -// registry: tt.fields.registry, -// sensor: tt.fields.sensor, -// mu: tt.fields.mu, -// } -// go tr.trackUpdates(tt.args.ctx, tt.args.config, tt.args.updateCh) -// cancelFunc() -// }) -// } -// } +func TestNewSensorTracker(t *testing.T) { + agentCfg := &agentMock{ + StoragePathFunc: func(s string) (string, error) { return t.TempDir(), nil }, + } + type args struct { + agentConfig agent + } + tests := []struct { + name string + args args + want *SensorTracker + wantErr bool + }{ + { + name: "default test", + args: args{agentConfig: agentCfg}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := NewSensorTracker(tt.args.agentConfig) + if (err != nil) != tt.wantErr { + t.Errorf("NewSensorTracker() error = %v, wantErr %v", err, tt.wantErr) + return + } + // if !reflect.DeepEqual(got, tt.want) { + // t.Errorf("NewSensorTracker() = %v, want %v", got, tt.want) + // } + }) + } +}