diff --git a/.gitignore b/.gitignore index 4b042b1..a5e92c4 100644 --- a/.gitignore +++ b/.gitignore @@ -17,6 +17,11 @@ report.json # Ignore the vendor structure *vendor/ +# VSCode +*.code-workspace +.vscode/* +.dev/* + # Ignore the various build artifacts .ignore diff --git a/anclafx/provide.go b/anclafx/provide.go index 849effa..361af1c 100644 --- a/anclafx/provide.go +++ b/anclafx/provide.go @@ -19,7 +19,6 @@ func Provide() fx.Option { ancla.ProvideListener, ancla.ProvideDefaultListenerWatchers, chrysom.ProvideBasicClient, - chrysom.ProvideDefaultListenerReader, chrysom.ProvideListenerClient, ), chrysom.ProvideMetrics(), diff --git a/anclafx/provide_test.go b/anclafx/provide_test.go index 67ce214..bf132ba 100644 --- a/anclafx/provide_test.go +++ b/anclafx/provide_test.go @@ -3,7 +3,6 @@ package anclafx_test import ( - "context" "testing" "github.com/stretchr/testify/require" @@ -19,10 +18,9 @@ import ( type out struct { fx.Out - Factory *touchstone.Factory - BasicClientConfig chrysom.BasicClientConfig - GetLogger chrysom.GetLogger - SetLogger chrysom.SetLogger + Factory *touchstone.Factory + ClientOptions chrysom.ClientOptions `group:"client_options,flatten"` + ListenerOptions chrysom.ListenerOptions `group:"listener_options,flatten"` } func provideDefaults() (out, error) { @@ -37,21 +35,20 @@ func provideDefaults() (out, error) { return out{ Factory: touchstone.NewFactory(cfg, zap.NewNop(), pr), - BasicClientConfig: chrysom.BasicClientConfig{ - Address: "example.com", - Bucket: "bucket-name", + ClientOptions: chrysom.ClientOptions{ + chrysom.Bucket("bucket-name"), }, - GetLogger: func(context.Context) *zap.Logger { return zap.NewNop() }, - SetLogger: func(context.Context, *zap.Logger) context.Context { return context.Background() }, + // Listener has no required options + ListenerOptions: chrysom.ListenerOptions{}, }, nil } func TestProvide(t *testing.T) { t.Run("Test anclafx.Provide() defaults", func(t *testing.T) { var ( - svc ancla.Service - bc *chrysom.BasicClient - l *chrysom.ListenerClient + svc ancla.Service + pushReader chrysom.PushReader + listener *chrysom.ListenerClient ) app := fxtest.New(t, @@ -61,8 +58,8 @@ func TestProvide(t *testing.T) { ), fx.Populate( &svc, - &bc, - &l, + &pushReader, + &listener, ), ) @@ -71,8 +68,8 @@ func TestProvide(t *testing.T) { require.NoError(app.Err()) app.RequireStart() require.NotNil(svc) - require.NotNil(bc) - require.NotNil(l) + require.NotNil(pushReader) + require.NotNil(listener) app.RequireStop() }) } diff --git a/auth/acquire.go b/auth/acquire.go index 1927de5..91c36fe 100644 --- a/auth/acquire.go +++ b/auth/acquire.go @@ -13,3 +13,9 @@ type Decorator interface { // Decorate decorates the given http request with authorization header(s). Decorate(ctx context.Context, req *http.Request) error } + +type DecoratorFunc func(context.Context, *http.Request) error + +func (f DecoratorFunc) Decorate(ctx context.Context, req *http.Request) error { return f(ctx, req) } + +var Nop = DecoratorFunc(func(context.Context, *http.Request) error { return nil }) diff --git a/auth/context_test.go b/auth/context_test.go new file mode 100644 index 0000000..a74a655 --- /dev/null +++ b/auth/context_test.go @@ -0,0 +1,37 @@ +// SPDX-FileCopyrightText: 2025 Comcast Cable Communications Management, LLC +// SPDX-License-Identifier: Apache-2.0 + +package auth + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestPrincipal(t *testing.T) { + t.Run("Test SetPartnerIDs, GetPartnerIDs", func(t *testing.T) { + assert := assert.New(t) + partnerIDs := []string{"foo", "bar"} + ctx := SetPartnerIDs(context.Background(), partnerIDs) + actualPartnerIDs, ok := GetPartnerIDs(ctx) + assert.True(ok) + assert.Equal(partnerIDs, actualPartnerIDs) + actualPartnerIDs, ok = GetPartnerIDs(context.Background()) + assert.False(ok) + var empty []string + assert.Equal(empty, actualPartnerIDs) + }) + t.Run("Test SetPrincipal, GetPrincipal", func(t *testing.T) { + assert := assert.New(t) + principal := "foo" + ctx := SetPrincipal(context.Background(), principal) + actualPrincipal, ok := GetPrincipal(ctx) + assert.True(ok) + assert.Equal(principal, actualPrincipal) + actualPrincipal, ok = GetPrincipal(context.Background()) + assert.False(ok) + assert.Equal("", actualPrincipal) + }) +} diff --git a/chrysom/basicClient.go b/chrysom/basicClient.go index cead42d..d9134e8 100644 --- a/chrysom/basicClient.go +++ b/chrysom/basicClient.go @@ -11,7 +11,6 @@ import ( "fmt" "io" "net/http" - "time" "github.com/xmidt-org/ancla/auth" "github.com/xmidt-org/ancla/model" @@ -25,53 +24,28 @@ const ( ) var ( - ErrNilMeasures = errors.New("measures cannot be nil") - ErrAddressEmpty = errors.New("argus address is required") - ErrBucketEmpty = errors.New("bucket name is required") - ErrItemIDEmpty = errors.New("item ID is required") - ErrItemDataEmpty = errors.New("data field in item is required") - ErrUndefinedIntervalTicker = errors.New("interval ticker is nil. Can't listen for updates") - ErrAuthDecoratorFailure = errors.New("failed decorating auth header") - ErrBadRequest = errors.New("argus rejected the request as invalid") + ErrItemIDEmpty = errors.New("item ID is required") + ErrItemDataEmpty = errors.New("data field in item is required") + ErrAuthDecoratorFailure = errors.New("failed decorating auth header") + ErrBadRequest = errors.New("argus rejected the request as invalid") ) var ( - errNonSuccessResponse = errors.New("argus responded with a non-success status code") - errNewRequestFailure = errors.New("failed creating an HTTP request") - errDoRequestFailure = errors.New("http client failed while sending request") - errReadingBodyFailure = errors.New("failed while reading http response body") - errJSONUnmarshal = errors.New("failed unmarshaling JSON response payload") - errJSONMarshal = errors.New("failed marshaling item as JSON payload") - errFailedConfig = errors.New("ancla configuration error") + ErrFailedAuthentication = errors.New("failed to authentication with argus") + errNonSuccessResponse = errors.New("argus responded with a non-success status code") + errNewRequestFailure = errors.New("failed creating an HTTP request") + errDoRequestFailure = errors.New("http client failed while sending request") + errReadingBodyFailure = errors.New("failed while reading http response body") + errJSONUnmarshal = errors.New("failed unmarshaling JSON response payload") + errJSONMarshal = errors.New("failed marshaling item as JSON payload") ) -// BasicClientConfig contains config data for the client that will be used to -// make requests to the Argus client. -type BasicClientConfig struct { - // Address is the Argus URL (i.e. https://example-argus.io:8090) - Address string - - // Bucket partition to be used by this client. - Bucket string - - // HTTPClient refers to the client that will be used to send requests. - // (Optional) Defaults to http.DefaultClient. - HTTPClient *http.Client - - // Auth provides the mechanism to add auth headers to outgoing requests. - // (Optional) If not provided, no auth headers are added. - Auth auth.Decorator - - // PullInterval is how often listeners should get updates. - // (Optional). Defaults to 5 seconds. - PullInterval time.Duration -} - // BasicClient is the client used to make requests to Argus. type BasicClient struct { client *http.Client auth auth.Decorator storeBaseURL string + storeAPIPath string bucket string getLogger func(context.Context) *zap.Logger } @@ -83,31 +57,32 @@ type response struct { } const ( - storeAPIPath = "/api/v1/store" + storeV1APIPath = "/api/v1/store" errWrappedFmt = "%w: %s" errStatusCodeFmt = "%w: received status %v" errorHeaderKey = "errorHeader" ) -// Items is a slice of model.Item(s) . -type Items []model.Item - // NewBasicClient creates a new BasicClient that can be used to // make requests to Argus. -func NewBasicClient(config BasicClientConfig, - getLogger func(context.Context) *zap.Logger) (*BasicClient, error) { - err := validateBasicConfig(&config) - if err != nil { - return nil, err - } +func NewBasicClient(opts ...ClientOption) (*BasicClient, error) { + var ( + client BasicClient + defaultClientOptions = ClientOptions{ + // localhost defaults + StoreBaseURL(""), + StoreAPIPath(""), + // Nop defaults + HTTPClient(nil), + GetClientLogger(nil), + Auth(nil), + } + ) - return &BasicClient{ - client: config.HTTPClient, - auth: config.Auth, - bucket: config.Bucket, - storeBaseURL: config.Address + storeAPIPath, - getLogger: getLogger, - }, nil + opts = append(defaultClientOptions, opts...) + opts = append(opts, clientValidator()) + + return &client, ClientOptions(opts).apply(&client) } // GetItems fetches all items that belong to a given owner. @@ -251,19 +226,3 @@ func translateNonSuccessStatusCode(code int) error { return errNonSuccessResponse } } - -func validateBasicConfig(config *BasicClientConfig) error { - if config.Address == "" { - return ErrAddressEmpty - } - - if config.Bucket == "" { - return ErrBucketEmpty - } - - if config.HTTPClient == nil { - config.HTTPClient = http.DefaultClient - } - - return nil -} diff --git a/chrysom/basicClientOptions.go b/chrysom/basicClientOptions.go new file mode 100644 index 0000000..6d24fd5 --- /dev/null +++ b/chrysom/basicClientOptions.go @@ -0,0 +1,134 @@ +// SPDX-FileCopyrightText: 2021 Comcast Cable Communications Management, LLC +// SPDX-License-Identifier: Apache-2.0 + +package chrysom + +import ( + "context" + "errors" + "net/http" + "net/url" + + "github.com/xmidt-org/ancla/auth" + "go.uber.org/zap" +) + +var ( + ErrMisconfiguredClient = errors.New("ancla client configuration error") +) + +// ClientOption is a functional option type for BasicClient. +type ClientOption interface { + apply(*BasicClient) error +} + +type ClientOptions []ClientOption + +func (opts ClientOptions) apply(c *BasicClient) (errs error) { + for _, o := range opts { + errs = errors.Join(errs, o.apply(c)) + } + + return errs +} + +type clientOptionFunc func(*BasicClient) error + +func (f clientOptionFunc) apply(c *BasicClient) error { + return f(c) +} + +// StoreBaseURL sets the store address for the client. +func StoreBaseURL(url string) ClientOption { + return clientOptionFunc( + func(c *BasicClient) error { + c.storeBaseURL = "http://localhost:6600" + if url != "" { + c.storeBaseURL = url + } + + return nil + }) +} + +// StoreAPIPath sets the store url api path. +// (Optional) Default is "/api/v1/store". +func StoreAPIPath(path string) ClientOption { + return clientOptionFunc( + func(c *BasicClient) error { + c.storeAPIPath = storeV1APIPath + if path != "" { + c.storeAPIPath = path + } + + return nil + }) +} + +// Bucket sets the partition to be used by this client. +func Bucket(bucket string) ClientOption { + return clientOptionFunc( + func(c *BasicClient) error { + c.bucket = bucket + + return nil + }) +} + +// HTTPClient sets the HTTP client. +func HTTPClient(client *http.Client) ClientOption { + return clientOptionFunc( + func(c *BasicClient) error { + c.client = http.DefaultClient + if client != nil { + c.client = client + } + + return nil + }) +} + +// GetLogger sets the getlogger, a func that returns a logger from the given context. +func GetClientLogger(get func(context.Context) *zap.Logger) ClientOption { + return clientOptionFunc( + func(c *BasicClient) error { + c.getLogger = func(context.Context) *zap.Logger { return zap.NewNop() } + if get != nil { + c.getLogger = get + } + + return nil + }) +} + +// Auth sets auth, auth provides the mechanism to add auth headers to outgoing requests. +// (Optional) If not provided, no auth headers are added. +func Auth(authD auth.Decorator) ClientOption { + return clientOptionFunc( + func(c *BasicClient) error { + c.auth = auth.Nop + if authD != nil { + c.auth = authD + } + + return nil + }) +} + +func clientValidator() ClientOption { + return clientOptionFunc( + func(c *BasicClient) (errs error) { + c.storeBaseURL, errs = url.JoinPath(c.storeBaseURL, c.storeAPIPath) + if errs != nil { + errs = errors.Join(errors.New("failed to combine StoreBaseURL & StoreAPIPath"), errs) + } + if c.bucket == "" { + errs = errors.Join(errs, errors.New("empty string Bucket")) + } + if errs != nil { + errs = errors.Join(ErrMisconfiguredClient, errs) + } + + return + }) +} diff --git a/chrysom/basicClient_test.go b/chrysom/basicClient_test.go index 6f19710..0cc7a54 100644 --- a/chrysom/basicClient_test.go +++ b/chrysom/basicClient_test.go @@ -22,75 +22,77 @@ import ( "go.uber.org/zap" ) -const failingURL = "nowhere://" +const ( + failingURL = "nowhere://" + bucket = "bucket-name" +) var ( - _ Pusher = &BasicClient{} - _ Reader = &BasicClient{} - errFails = errors.New("fails") + errFails = errors.New("fails") ) -func TestValidateBasicConfig(t *testing.T) { - type testCase struct { - Description string - Input *BasicClientConfig - Client *http.Client - ExpectedErr error - ExpectedConfig *BasicClientConfig +var ( + requiredClientOptions = ClientOptions{ + Bucket(bucket), } +) - allDefaultsCaseConfig := &BasicClientConfig{ - HTTPClient: http.DefaultClient, - Address: "example.com", - Bucket: "bucket-name", - } - allDefinedCaseConfig := &BasicClientConfig{ - HTTPClient: http.DefaultClient, - Address: "example.com", - Bucket: "amazing-bucket", +func TestClientOptions(t *testing.T) { + type testCase struct { + Description string + ClientOptions ClientOptions + ExpectedErr error } tcs := []testCase{ { - Description: "No address", - Input: &BasicClientConfig{ - Bucket: "bucket-name", - }, - ExpectedErr: ErrAddressEmpty, + Description: "Missing required options failure", + ClientOptions: ClientOptions{}, + ExpectedErr: ErrMisconfiguredClient, }, { - Description: "No bucket", - Input: &BasicClientConfig{ - Address: "example.com", - }, - ExpectedErr: ErrBucketEmpty, + Description: "Incorrect client values without defaults", + ClientOptions: ClientOptions{Bucket("")}, + ExpectedErr: ErrMisconfiguredClient, }, { - Description: "All default values", - Input: &BasicClientConfig{ - Address: "example.com", - Bucket: "bucket-name", - }, - ExpectedConfig: allDefaultsCaseConfig, + Description: "Correct required values and bad optional values (ignored)", + ClientOptions: append(requiredClientOptions, ClientOptions{ + StoreBaseURL(""), + StoreAPIPath(""), + GetClientLogger(nil), + HTTPClient(nil), + Auth(nil), + }), }, { - Description: "All defined", - Input: &BasicClientConfig{ - Address: "example.com", - Bucket: "amazing-bucket", - }, - ExpectedConfig: allDefinedCaseConfig, + Description: "Correct required and optional values", + ClientOptions: append(requiredClientOptions, ClientOptions{ + StoreBaseURL("localhost"), + StoreAPIPath(storeV1APIPath), + GetClientLogger(func(context.Context) *zap.Logger { return zap.NewNop() }), + HTTPClient(http.DefaultClient), + Auth(auth.Nop), + }), + }, + { + Description: "Correct required values only", + ClientOptions: requiredClientOptions, }, } for _, tc := range tcs { t.Run(tc.Description, func(t *testing.T) { assert := assert.New(t) - err := validateBasicConfig(tc.Input) - assert.Equal(tc.ExpectedErr, err) - if tc.ExpectedErr == nil { - assert.Equal(tc.ExpectedConfig, tc.Input) + client, errs := NewBasicClient(tc.ClientOptions) + if tc.ExpectedErr != nil { + assert.ErrorIs(errs, tc.ExpectedErr) + + return } + + assert.NoError(errs) + assert.NotNil(client) }) } } @@ -111,20 +113,20 @@ func TestSendRequest(t *testing.T) { tcs := []testCase{ { - Description: "New Request fails", + Description: "New Request failure", Method: "what method?", URL: "example.com", ExpectedErr: errNewRequestFailure, }, { - Description: "Auth decorator fails", + Description: "Auth decorator failure", Method: http.MethodGet, URL: "example.com", MockError: errFails, ExpectedErr: ErrAuthDecoratorFailure, }, { - Description: "Client Do fails", + Description: "Client Do failure", Method: http.MethodPut, ClientDoFails: true, ExpectedErr: errDoRequestFailure, @@ -183,17 +185,12 @@ func TestSendRequest(t *testing.T) { server := httptest.NewServer(echoHandler) defer server.Close() - client, err := NewBasicClient(BasicClientConfig{ - Address: "example.com", - Bucket: "bucket-name", - }, - func(context.Context) *zap.Logger { - return zap.NewNop() - }) + opts := append(requiredClientOptions, StoreBaseURL(server.URL)) + client, err := NewBasicClient(opts) + authDecorator := new(auth.MockDecorator) if tc.MockAuth != "" || tc.MockError != nil { - authDecorator := new(auth.MockDecorator) - authDecorator.On("Decorate").Return(tc.MockError) + authDecorator.On("Decorate").Return(tc.MockError).Once() client.auth = authDecorator } @@ -211,6 +208,8 @@ func TestSendRequest(t *testing.T) { } else { assert.True(errors.Is(err, tc.ExpectedErr)) } + + authDecorator.AssertExpectations(t) }) } } @@ -230,12 +229,12 @@ func TestGetItems(t *testing.T) { tcs := []testCase{ { - Description: "Make request fails", + Description: "Make request failure", ExpectedErr: ErrAuthDecoratorFailure, MockError: errFails, }, { - Description: "Do request fails", + Description: "Do request failure", ShouldDoRequestFail: true, ExpectedErr: errDoRequestFailure, }, @@ -280,33 +279,28 @@ func TestGetItems(t *testing.T) { var ( assert = assert.New(t) require = require.New(t) - bucket = "bucket-name" + bucket = bucket owner = "owner-name" ) server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { assert.Equal(http.MethodGet, r.Method) assert.Equal(owner, r.Header.Get(ItemOwnerHeaderKey)) - assert.Equal(fmt.Sprintf("%s/%s", storeAPIPath, bucket), r.URL.Path) + assert.Equal(fmt.Sprintf("%s/%s", storeV1APIPath, bucket), r.URL.Path) assert.Equal(tc.MockAuth, r.Header.Get(auth.MockAuthHeaderName)) rw.WriteHeader(tc.ResponseCode) rw.Write(tc.ResponsePayload) })) - client, err := NewBasicClient(BasicClientConfig{ - Address: server.URL, - Bucket: bucket, - }, - func(context.Context) *zap.Logger { - return zap.NewNop() - }) + opts := append(requiredClientOptions, StoreBaseURL(server.URL)) + client, err := NewBasicClient(opts) require.Nil(err) + authDecorator := new(auth.MockDecorator) if tc.MockAuth != "" || tc.MockError != nil { - authDecorator := new(auth.MockDecorator) - authDecorator.On("Decorate").Return(tc.MockError) + authDecorator.On("Decorate").Return(tc.MockError).Once() client.auth = authDecorator } @@ -320,6 +314,8 @@ func TestGetItems(t *testing.T) { if tc.ExpectedErr == nil { assert.EqualValues(tc.ExpectedOutput, output) } + + authDecorator.AssertExpectations(t) }) } } @@ -360,13 +356,13 @@ func TestPushItem(t *testing.T) { ExpectedErr: ErrItemDataEmpty, }, { - Description: "Make request fails", + Description: "Make request failure", Item: validItem, ExpectedErr: ErrAuthDecoratorFailure, MockError: errFails, }, { - Description: "Do request fails", + Description: "Do request failure", Item: validItem, ShouldDoRequestFail: true, ExpectedErr: errDoRequestFailure, @@ -419,12 +415,12 @@ func TestPushItem(t *testing.T) { var ( assert = assert.New(t) require = require.New(t) - bucket = "bucket-name" + bucket = bucket id = "252f10c83610ebca1a059c0bae8255eba2f95be4d1d7bcfa89d7248a82d9f111" ) server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - assert.Equal(fmt.Sprintf("%s/%s/%s", storeAPIPath, bucket, id), r.URL.Path) + assert.Equal(fmt.Sprintf("%s/%s/%s", storeV1APIPath, bucket, id), r.URL.Path) assert.Equal(tc.Owner, r.Header.Get(ItemOwnerHeaderKey)) assert.Equal(tc.MockAuth, r.Header.Get(auth.MockAuthHeaderName)) @@ -439,17 +435,12 @@ func TestPushItem(t *testing.T) { } })) - client, err := NewBasicClient(BasicClientConfig{ - Address: server.URL, - Bucket: bucket, - }, - func(context.Context) *zap.Logger { - return zap.NewNop() - }) + opts := append(requiredClientOptions, StoreBaseURL(server.URL)) + client, err := NewBasicClient(opts) + authDecorator := new(auth.MockDecorator) if tc.MockAuth != "" || tc.MockError != nil { - authDecorator := new(auth.MockDecorator) - authDecorator.On("Decorate").Return(tc.MockError) + authDecorator.On("Decorate").Return(tc.MockError).Once() client.auth = authDecorator } @@ -469,6 +460,8 @@ func TestPushItem(t *testing.T) { } else { assert.True(errors.Is(err, tc.ExpectedErr)) } + + authDecorator.AssertExpectations(t) }) } } @@ -489,12 +482,12 @@ func TestRemoveItem(t *testing.T) { tcs := []testCase{ { - Description: "Make request fails", + Description: "Make request failure", ExpectedErr: ErrAuthDecoratorFailure, MockError: errFails, }, { - Description: "Do request fails", + Description: "Do request failure", ShouldDoRequestFail: true, ExpectedErr: errDoRequestFailure, }, @@ -533,12 +526,12 @@ func TestRemoveItem(t *testing.T) { var ( assert = assert.New(t) require = require.New(t) - bucket = "bucket-name" + bucket = bucket // nolint:gosec id = "7e8c5f378b4addbaebc70897c4478cca06009e3e360208ebd073dbee4b3774e7" ) server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - assert.Equal(fmt.Sprintf("%s/%s/%s", storeAPIPath, bucket, id), r.URL.Path) + assert.Equal(fmt.Sprintf("%s/%s/%s", storeV1APIPath, bucket, id), r.URL.Path) assert.Equal(http.MethodDelete, r.Method) assert.Equal(tc.MockAuth, r.Header.Get(auth.MockAuthHeaderName)) @@ -546,16 +539,12 @@ func TestRemoveItem(t *testing.T) { rw.Write(tc.ResponsePayload) })) - client, err := NewBasicClient(BasicClientConfig{ - Address: server.URL, - Bucket: bucket, - }, func(context.Context) *zap.Logger { - return zap.NewNop() - }) + opts := append(requiredClientOptions, StoreBaseURL(server.URL)) + client, err := NewBasicClient(opts) + authDecorator := new(auth.MockDecorator) if tc.MockAuth != "" || tc.MockError != nil { - authDecorator := new(auth.MockDecorator) - authDecorator.On("Decorate").Return(tc.MockError) + authDecorator.On("Decorate").Return(tc.MockError).Once() client.auth = authDecorator } @@ -571,6 +560,8 @@ func TestRemoveItem(t *testing.T) { } else { assert.True(errors.Is(err, tc.ExpectedErr)) } + + authDecorator.AssertExpectations(t) }) } } @@ -648,7 +639,7 @@ func getItemsValidPayload() []byte { } func getItemsHappyOutput() Items { - return []model.Item{ + return Items{ { ID: "7e8c5f378b4addbaebc70897c4478cca06009e3e360208ebd073dbee4b3774e7", Data: map[string]interface{}{ diff --git a/chrysom/fx.go b/chrysom/fx.go index 6c7c81c..68933b0 100644 --- a/chrysom/fx.go +++ b/chrysom/fx.go @@ -10,66 +10,59 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/xmidt-org/touchstone" "go.uber.org/fx" - "go.uber.org/zap" ) -// GetLogger returns a logger from the given context. -type GetLogger func(context.Context) *zap.Logger - -// SetLogger embeds the `Listener.logger` in outgoing request contexts for `Listener.Update` calls. -type SetLogger func(context.Context, *zap.Logger) context.Context +var ( + ErrMisconfiguredListener = errors.New("ancla listener configuration error") +) -type BasicClientIn struct { +type ProvideBasicClientIn struct { fx.In - // Ancla Client config. - Config BasicClientConfig - // GetLogger returns a logger from the given context. - GetLogger GetLogger + Options ClientOptions `group:"client_options"` } -// ProvideBasicClient provides a new BasicClient. -func ProvideBasicClient(in BasicClientIn) (*BasicClient, error) { - client, err := NewBasicClient(in.Config, in.GetLogger) - if err != nil { - return nil, errors.Join(errFailedConfig, err) - } +type ProvideBasicClientOut struct { + fx.Out - return client, nil + // Ancla service's db client. + PushReader PushReader + // Ancla listener's db client option. + Reader ListenerOption `group:"listener_options"` +} + +func ProvideBasicClient(in ProvideBasicClientIn) (ProvideBasicClientOut, error) { + client, err := NewBasicClient(in.Options...) + return ProvideBasicClientOut{ + PushReader: client, + Reader: reader(client), + }, err } // ListenerConfig contains config data for polling the Argus client. type ListenerClientIn struct { fx.In - // Listener fetches a copy of all items within a bucket on - // an interval based on `BasicClientConfig.PullInterval`. - // (Optional). If not provided, listening won't be enabled for this client. - Listener Listener - // Config configures the ancla client and its listeners. - Config BasicClientConfig // PollsTotalCounter measures the number of polls (and their success/failure outcomes) to fetch new items. PollsTotalCounter *prometheus.CounterVec `name:"chrysom_polls_total"` - // Reader is the DB interface used to fetch new items using `GeItems`. - Reader Reader - // GetLogger returns a logger from the given context. - GetLogger GetLogger - // SetLogger embeds the `Listener.logger` in outgoing request contexts for `Listener.Update` calls. - SetLogger SetLogger + Options ListenerOptions `group:"listener_options"` } // ProvideListenerClient provides a new ListenerClient. func ProvideListenerClient(in ListenerClientIn) (*ListenerClient, error) { - client, err := NewListenerClient(in.Listener, in.GetLogger, in.SetLogger, in.Config.PullInterval, in.PollsTotalCounter, in.Reader) + client, err := NewListenerClient(in.PollsTotalCounter, in.Options...) if err != nil { - return nil, errors.Join(err, errFailedConfig) + return nil, errors.Join(err, ErrMisconfiguredListener) } return client, nil } -func ProvideDefaultListenerReader(client *BasicClient) Reader { - return client +// ReaderOptionOut contains options data for Listener client's reader. +type ReaderOptionOut struct { + fx.Out + + Option ListenerOption `group:"listener_options"` } type StartListenerIn struct { diff --git a/chrysom/listenerClient.go b/chrysom/listenerClient.go index 6682f00..228d0d0 100644 --- a/chrysom/listenerClient.go +++ b/chrysom/listenerClient.go @@ -10,20 +10,14 @@ import ( "time" "github.com/prometheus/client_golang/prometheus" + "go.uber.org/zap" ) -// Errors that can be returned by this package. Since some of these errors are returned wrapped, it -// is safest to use errors.Is() to check for them. -// Some internal errors might be unwrapped from output errors but unless these errors become exported, -// they are not part of the library API and may change in future versions. var ( - ErrFailedAuthentication = errors.New("failed to authentication with argus") - - ErrListenerNotStopped = errors.New("listener is either running or starting") - ErrListenerNotRunning = errors.New("listener is either stopped or stopping") - ErrNoListenerProvided = errors.New("no listener provided") - ErrNoReaderProvided = errors.New("no reader provided") + ErrListenerNotStopped = errors.New("listener is either running or starting") + ErrListenerNotRunning = errors.New("listener is either stopped or stopping") + ErrUndefinedIntervalTicker = errors.New("interval ticker is nil. Can't listen for updates") ) // listening states @@ -39,17 +33,13 @@ const ( // ListenerClient is the client used to poll Argus for updates. type ListenerClient struct { - observer *observerConfig - getLogger func(context.Context) *zap.Logger - setLogger func(context.Context, *zap.Logger) context.Context - reader Reader -} - -type observerConfig struct { - listener Listener + listener ListenerInterface ticker *time.Ticker pullInterval time.Duration pollsTotalCounter *prometheus.CounterVec + getLogger func(context.Context) *zap.Logger + setLogger func(context.Context, *zap.Logger) context.Context + reader Reader shutdown chan struct{} state int32 @@ -57,36 +47,23 @@ type observerConfig struct { // NewListenerClient creates a new ListenerClient to be used to poll Argus // for updates. -func NewListenerClient(listener Listener, - getLogger func(context.Context) *zap.Logger, - setLogger func(context.Context, *zap.Logger) context.Context, - pullInterval time.Duration, pollsTotalCounter *prometheus.CounterVec, reader Reader) (*ListenerClient, error) { - if listener == nil { - return nil, ErrNoListenerProvided - } - if pullInterval == 0 { - pullInterval = defaultPullInterval +func NewListenerClient(pollsTotalCounter *prometheus.CounterVec, opts ...ListenerOption) (*ListenerClient, error) { + defaultListenerOptions := ListenerOptions{ + // defaultPullInterval + PullInterval(0), + // Nops defaults + GetListenerLogger(nil), + SetListenerLogger(nil), } - if setLogger == nil { - setLogger = func(ctx context.Context, _ *zap.Logger) context.Context { - return ctx - } + client := ListenerClient{ + pollsTotalCounter: pollsTotalCounter, + shutdown: make(chan struct{}), } - if reader == nil { - return nil, ErrNoReaderProvided - } - return &ListenerClient{ - observer: &observerConfig{ - listener: listener, - ticker: time.NewTicker(pullInterval), - pullInterval: pullInterval, - pollsTotalCounter: pollsTotalCounter, - shutdown: make(chan struct{}), - }, - getLogger: getLogger, - setLogger: setLogger, - reader: reader, - }, nil + + opts = append(defaultListenerOptions, opts...) + opts = append(opts, listenerValidator()) + + return &client, ListenerOptions(opts).apply(&client) } // Start begins listening for updates on an interval given that client configuration @@ -94,43 +71,39 @@ func NewListenerClient(listener Listener, // is a NoOp. If you want to restart the current listener process, call Stop() first. func (c *ListenerClient) Start(ctx context.Context) error { logger := c.getLogger(ctx) - if c.observer == nil || c.observer.listener == nil { - logger.Warn("No listener was setup to receive updates.") - return nil - } - if c.observer.ticker == nil { + if c.ticker == nil { logger.Error("Observer ticker is nil", zap.Error(ErrUndefinedIntervalTicker)) return ErrUndefinedIntervalTicker } - if !atomic.CompareAndSwapInt32(&c.observer.state, stopped, transitioning) { + if !atomic.CompareAndSwapInt32(&c.state, stopped, transitioning) { logger.Error("Start called when a listener was not in stopped state", zap.Error(ErrListenerNotStopped)) return ErrListenerNotStopped } - c.observer.ticker.Reset(c.observer.pullInterval) + c.ticker.Reset(c.pullInterval) go func() { for { select { - case <-c.observer.shutdown: + case <-c.shutdown: return - case <-c.observer.ticker.C: + case <-c.ticker.C: outcome := SuccessOutcome ctx := c.setLogger(context.Background(), logger) items, err := c.reader.GetItems(ctx, "") if err == nil { - c.observer.listener.Update(items) + c.listener.Update(items) } else { outcome = FailureOutcome logger.Error("Failed to get items for listeners", zap.Error(err)) } - c.observer.pollsTotalCounter.With(prometheus.Labels{ + c.pollsTotalCounter.With(prometheus.Labels{ OutcomeLabel: outcome}).Add(1) } } }() - atomic.SwapInt32(&c.observer.state, running) + atomic.SwapInt32(&c.state, running) return nil } @@ -138,18 +111,18 @@ func (c *ListenerClient) Start(ctx context.Context) error { // Calling Stop() when a listener is not running (or while one is getting stopped) returns an // error. func (c *ListenerClient) Stop(ctx context.Context) error { - if c.observer == nil || c.observer.ticker == nil { + if c.ticker == nil { return nil } logger := c.getLogger(ctx) - if !atomic.CompareAndSwapInt32(&c.observer.state, running, transitioning) { + if !atomic.CompareAndSwapInt32(&c.state, running, transitioning) { logger.Error("Stop called when a listener was not in running state", zap.Error(ErrListenerNotStopped)) return ErrListenerNotRunning } - c.observer.ticker.Stop() - c.observer.shutdown <- struct{}{} - atomic.SwapInt32(&c.observer.state, stopped) + c.ticker.Stop() + c.shutdown <- struct{}{} + atomic.SwapInt32(&c.state, stopped) return nil } diff --git a/chrysom/listenerClientOptions.go b/chrysom/listenerClientOptions.go new file mode 100644 index 0000000..fd81b8b --- /dev/null +++ b/chrysom/listenerClientOptions.go @@ -0,0 +1,113 @@ +// SPDX-FileCopyrightText: 2021 Comcast Cable Communications Management, LLC +// SPDX-License-Identifier: Apache-2.0 + +package chrysom + +import ( + "context" + "errors" + "time" + + "go.uber.org/zap" +) + +// ListenerOption is a functional option type for ListenerClient. +type ListenerOption interface { + apply(*ListenerClient) error +} + +type ListenerOptions []ListenerOption + +func (opts ListenerOptions) apply(c *ListenerClient) (errs error) { + for _, o := range opts { + errs = errors.Join(errs, o.apply(c)) + } + + return errs +} + +type listenerOptionFunc func(*ListenerClient) error + +func (f listenerOptionFunc) apply(c *ListenerClient) error { + return f(c) +} + +// reader sets the reader. +// Used internally by `ProvideBasicClient` for fx dependency injection. +func reader(reader Reader) ListenerOption { + return listenerOptionFunc( + func(c *ListenerClient) error { + c.reader = reader + + return nil + }) +} + +// GetListenerLogger sets the getlogger, a func that returns a logger from the given context. +func GetListenerLogger(get func(context.Context) *zap.Logger) ListenerOption { + return listenerOptionFunc( + func(c *ListenerClient) error { + c.getLogger = func(context.Context) *zap.Logger { return zap.NewNop() } + if get != nil { + c.getLogger = get + } + + return nil + }) +} + +// SetListenerLogger sets the getlogger, a func that embeds the a given logger in outgoing request contexts. +func SetListenerLogger(set func(context.Context, *zap.Logger) context.Context) ListenerOption { + return listenerOptionFunc( + func(c *ListenerClient) error { + c.setLogger = func(context.Context, *zap.Logger) context.Context { return context.TODO() } + if set != nil { + c.setLogger = set + } + + return nil + }) +} + +// PullInterval sets the pull interval, determines how often listeners should get updates. +// (Optional). Defaults to 5 seconds. +func PullInterval(duration time.Duration) ListenerOption { + return listenerOptionFunc( + func(c *ListenerClient) error { + c.pullInterval = defaultPullInterval + if duration > 0 { + c.pullInterval = duration + } + + c.ticker = time.NewTicker(c.pullInterval) + + return nil + }) +} + +// Listener sets the Listener client's listener, listener is called during every PullInterval. +func Listener(listener ListenerInterface) ListenerOption { + return listenerOptionFunc( + func(c *ListenerClient) error { + c.listener = listener + + return nil + }) +} + +func listenerValidator() ListenerOption { + return listenerOptionFunc( + func(c *ListenerClient) (errs error) { + if c.reader == nil { + errs = errors.Join(errs, errors.New("nil Reader")) + } + if c.listener == nil { + errs = errors.Join(errs, errors.New("nil Listener")) + } + if errs != nil { + errs = errors.Join(ErrMisconfiguredListener, errs) + } + + return + }) +} diff --git a/chrysom/listenerClient_test.go b/chrysom/listenerClient_test.go index 028fb65..d2b338f 100644 --- a/chrysom/listenerClient_test.go +++ b/chrysom/listenerClient_test.go @@ -5,7 +5,6 @@ package chrysom import ( "context" - "errors" "fmt" "net/http" "net/http/httptest" @@ -32,10 +31,88 @@ var ( ) ) +func TestListenerOptions(t *testing.T) { + + listener := ListenerInterface(mockListener) + anclaClient, err := NewBasicClient(requiredClientOptions) + require.NoError(t, err) + require.NotNil(t, anclaClient) + + requiredListenerOptions := ListenerOptions{ + reader(anclaClient), + Listener(listener), + } + + type testCase struct { + Description string + ListenerOptions ListenerOptions + ExpectedErr error + } + + tcs := []testCase{ + { + Description: "Nil reader failure", + ListenerOptions: ListenerOptions{ + reader(nil), + Listener(listener), + }, + ExpectedErr: ErrMisconfiguredListener, + }, + { + Description: "Nil listener failure", + ListenerOptions: ListenerOptions{ + reader(anclaClient), + Listener(nil), + }, + ExpectedErr: ErrMisconfiguredListener, + }, + { + Description: "Correct required values and bad optional values (ignored)", + ListenerOptions: append(requiredListenerOptions, + ListenerOptions{ + GetListenerLogger(nil), + SetListenerLogger(nil), + PullInterval(-1), + }, + ), + }, + { + Description: "Correct required and optional values", + ListenerOptions: append(requiredListenerOptions, + ListenerOptions{ + GetListenerLogger(func(context.Context) *zap.Logger { return zap.NewNop() }), + SetListenerLogger(func(context.Context, *zap.Logger) context.Context { return context.TODO() }), + PullInterval(1), + }, + ), + }, + { + Description: "Correct listener values", + ListenerOptions: requiredListenerOptions, + }, + } + + for _, tc := range tcs { + t.Run(tc.Description, func(t *testing.T) { + assert := assert.New(t) + listener, errs := NewListenerClient(pollsTotalCounter, tc.ListenerOptions) + if tc.ExpectedErr != nil { + assert.ErrorIs(errs, tc.ExpectedErr) + + return + } + + assert.NoError(errs) + assert.NotNil(listener) + }) + } +} + func TestListenerStartStopPairsParallel(t *testing.T) { require := require.New(t) client, close, err := newStartStopClient(true) - assert.Nil(t, err) + require.NoError(err) + require.NotNil(client) defer close() t.Run("ParallelGroup", func(t *testing.T) { @@ -48,7 +125,7 @@ func TestListenerStartStopPairsParallel(t *testing.T) { if errStart != nil { assert.Equal(ErrListenerNotStopped, errStart) } - client.observer.listener.Update(Items{}) + client.listener.Update(Items{}) time.Sleep(time.Millisecond * 400) errStop := client.Stop(context.Background()) if errStop != nil { @@ -58,7 +135,7 @@ func TestListenerStartStopPairsParallel(t *testing.T) { } }) - require.Equal(stopped, client.observer.state) + require.Equal(stopped, client.state) } func TestListenerStartStopPairsSerial(t *testing.T) { @@ -77,13 +154,13 @@ func TestListenerStartStopPairsSerial(t *testing.T) { fmt.Printf("%d: Done\n", testNumber) }) } - require.Equal(stopped, client.observer.state) + require.Equal(stopped, client.state) } func TestListenerEdgeCases(t *testing.T) { t.Run("NoListener", func(t *testing.T) { _, _, err := newStartStopClient(false) - assert.Equal(t, ErrNoListenerProvided, err) + assert.ErrorIs(t, err, ErrMisconfiguredListener) }) t.Run("NilTicker", func(t *testing.T) { @@ -91,7 +168,7 @@ func TestListenerEdgeCases(t *testing.T) { client, stopServer, err := newStartStopClient(true) assert.Nil(err) defer stopServer() - client.observer.ticker = nil + client.ticker = nil assert.Equal(ErrUndefinedIntervalTicker, client.Start(context.Background())) }) } @@ -101,60 +178,26 @@ func newStartStopClient(includeListener bool) (*ListenerClient, func(), error) { rw.Write(getItemsValidPayload()) })) - var listener Listener + var listener ListenerInterface if includeListener { listener = mockListener } - client, err := NewListenerClient(listener, - func(context.Context) *zap.Logger { return zap.NewNop() }, - func(context.Context, *zap.Logger) context.Context { return context.Background() }, - time.Millisecond*200, pollsTotalCounter, &BasicClient{client: http.DefaultClient}) + anclaClient, err := NewBasicClient(requiredClientOptions) if err != nil { - return nil, nil, err + return nil, func() {}, err } - return client, server.Close, nil -} - -func TestValidateListenerConfig(t *testing.T) { - tcs := []struct { - desc string - listener Listener - pullInterval time.Duration - expectedErr error - pollsTotalCounter *prometheus.CounterVec - reader Reader - }{ - { - desc: "Listener Config Failure", - expectedErr: ErrNoListenerProvided, - }, - { - desc: "No reader Failure", - listener: mockListener, - pullInterval: time.Second, - pollsTotalCounter: pollsTotalCounter, - expectedErr: ErrNoReaderProvided, - }, - { - desc: "Happy case Success", - listener: mockListener, - pullInterval: time.Second, - pollsTotalCounter: pollsTotalCounter, - reader: &BasicClient{}, - }, - } - for _, tc := range tcs { - t.Run(tc.desc, func(t *testing.T) { - assert := assert.New(t) - _, err := NewListenerClient(tc.listener, - func(context.Context) *zap.Logger { return zap.NewNop() }, - func(context.Context, *zap.Logger) context.Context { return context.Background() }, - tc.pullInterval, tc.pollsTotalCounter, tc.reader) - assert.True(errors.Is(err, tc.expectedErr), - fmt.Errorf("error [%v] doesn't contain error [%v] in its err chain", - err, tc.expectedErr), - ) + listenerClient, err := NewListenerClient(pollsTotalCounter, + ListenerOptions{ + PullInterval(time.Millisecond * 200), + reader(anclaClient), + Listener(listener), + GetListenerLogger(nil), + SetListenerLogger(nil), }) + if err != nil { + return nil, nil, err } + + return listenerClient, server.Close, nil } diff --git a/chrysom/store.go b/chrysom/store.go index 70e34c3..9e969f1 100644 --- a/chrysom/store.go +++ b/chrysom/store.go @@ -9,6 +9,8 @@ import ( "github.com/xmidt-org/ancla/model" ) +type Items []model.Item + type PushReader interface { Pusher Reader @@ -22,7 +24,7 @@ type Pusher interface { RemoveItem(ctx context.Context, id, owner string) (model.Item, error) } -type Listener interface { +type ListenerInterface interface { // Update is called when we get changes to our item listeners with either // additions, or updates. // @@ -43,5 +45,5 @@ type Reader interface { type ConfigureListener interface { // SetListener will attempt to set the lister. - SetListener(listener Listener) error + SetListener(listener ListenerInterface) error } diff --git a/fx.go b/fx.go index 61b3df2..8e1ca82 100644 --- a/fx.go +++ b/fx.go @@ -13,16 +13,16 @@ import ( type ServiceIn struct { fx.In - // Ancla Client. - BasicClient *chrysom.BasicClient + // PushReader is the user provided db client. + PushReader chrysom.PushReader `optional:"true"` } -// ProvideService builds the Argus client service from the given configuration. -func ProvideService(in ServiceIn) Service { - return NewService(in.BasicClient) +// ProvideService provides the Argus client service from the given configuration. +func ProvideService(in ServiceIn) (Service, error) { + return NewService(in.PushReader), nil } -// TODO: Refactor and move Watch and Listener related code to chrysom. +// TODO: Refactor and move Watch and ListenerInterface related code to chrysom. type DefaultListenersIn struct { fx.In @@ -54,17 +54,26 @@ type ListenerIn struct { Watchers []Watch `group:"watchers"` } -func ProvideListener(in ListenerIn) chrysom.Listener { - return chrysom.ListenerFunc(func(items chrysom.Items) { - iws, err := ItemsToInternalWebhooks(items) - if err != nil { - in.Shutdowner.Shutdown(fx.ExitCode(1)) +// ListenerOut contains options data for Listener client's reader. +type ListenerOut struct { + fx.Out + + Option chrysom.ListenerOption `group:"listener_options"` +} - return - } +func ProvideListener(in ListenerIn) ListenerOut { + return ListenerOut{ + Option: chrysom.Listener(chrysom.ListenerFunc(func(items chrysom.Items) { + iws, err := ItemsToInternalWebhooks(items) + if err != nil { + in.Shutdowner.Shutdown(fx.ExitCode(1)) - for _, watch := range in.Watchers { - watch.Update(iws) - } - }) + return + } + + for _, watch := range in.Watchers { + watch.Update(iws) + } + })), + } } diff --git a/service.go b/service.go index 120b74d..993c274 100644 --- a/service.go +++ b/service.go @@ -34,9 +34,6 @@ type Service interface { // Config contains information needed to initialize the Argus database client. type Config struct { - // BasicClientConfig is the configuration for the Argus database client. - BasicClientConfig chrysom.BasicClientConfig - // DisablePartnerIDs, if true, will allow webhooks to register without // checking the validity of the partnerIDs in the request. DisablePartnerIDs bool @@ -95,7 +92,7 @@ func (s *service) GetAll(ctx context.Context) ([]InternalWebhook, error) { } // NewService returns an ancla client used to interact with an Argus database. -func NewService(client *chrysom.BasicClient) *service { +func NewService(client chrysom.PushReader) *service { return &service{ argus: client, now: time.Now,