diff --git a/auth.go b/auth.go index 91faa7e4..40b01d9c 100644 --- a/auth.go +++ b/auth.go @@ -234,39 +234,44 @@ func (s *Service) AddProviderWithUserAttributes(name, cid, csecret string, userA L: s.logger, UserAttributes: userAttributes, } - s.addProvider(name, p) + s.addProviderByName(name, p) } -func (s *Service) addProvider(name string, p provider.Params) { +func (s *Service) addProviderByName(name string, p provider.Params) { + var prov provider.Provider switch strings.ToLower(name) { case "github": - s.providers = append(s.providers, provider.NewService(provider.NewGithub(p))) + prov = provider.NewGithub(p) case "google": - s.providers = append(s.providers, provider.NewService(provider.NewGoogle(p))) + prov = provider.NewGoogle(p) case "facebook": - s.providers = append(s.providers, provider.NewService(provider.NewFacebook(p))) + prov = provider.NewFacebook(p) case "yandex": - s.providers = append(s.providers, provider.NewService(provider.NewYandex(p))) + prov = provider.NewYandex(p) case "battlenet": - s.providers = append(s.providers, provider.NewService(provider.NewBattlenet(p))) + prov = provider.NewBattlenet(p) case "microsoft": - s.providers = append(s.providers, provider.NewService(provider.NewMicrosoft(p))) + prov = provider.NewMicrosoft(p) case "twitter": - s.providers = append(s.providers, provider.NewService(provider.NewTwitter(p))) + prov = provider.NewTwitter(p) case "patreon": - s.providers = append(s.providers, provider.NewService(provider.NewPatreon(p))) + prov = provider.NewPatreon(p) case "dev": - s.providers = append(s.providers, provider.NewService(provider.NewDev(p))) + prov = provider.NewDev(p) default: return } + s.addProvider(prov) +} + +func (s *Service) addProvider(prov provider.Provider) { + s.providers = append(s.providers, provider.NewService(prov)) s.authMiddleware.Providers = s.providers } // AddProvider adds provider for given name func (s *Service) AddProvider(name, cid, csecret string) { - p := provider.Params{ URL: s.opts.URL, JwtService: s.jwtService, @@ -277,8 +282,7 @@ func (s *Service) AddProvider(name, cid, csecret string) { L: s.logger, UserAttributes: map[string]string{}, } - - s.addProvider(name, p) + s.addProviderByName(name, p) } // AddDevProvider with a custom host and port @@ -292,7 +296,7 @@ func (s *Service) AddDevProvider(host string, port int) { Port: port, Host: host, } - s.providers = append(s.providers, provider.NewService(provider.NewDev(p))) + s.addProvider(provider.NewDev(p)) } // AddAppleProvider allow SignIn with Apple ID @@ -311,7 +315,7 @@ func (s *Service) AddAppleProvider(appleConfig provider.AppleConfig, privKeyLoad return fmt.Errorf("an AppleProvider creating failed: %w", err) } - s.providers = append(s.providers, provider.NewService(appleProvider)) + s.addProvider(appleProvider) return nil } @@ -326,9 +330,7 @@ func (s *Service) AddCustomProvider(name string, client Client, copts provider.C Csecret: client.Csecret, L: s.logger, } - - s.providers = append(s.providers, provider.NewService(provider.NewCustom(name, p, copts))) - s.authMiddleware.Providers = s.providers + s.addProvider(provider.NewCustom(name, p, copts)) } // AddDirectProvider adds provider with direct check against data store @@ -342,8 +344,7 @@ func (s *Service) AddDirectProvider(name string, credChecker provider.CredChecke CredChecker: credChecker, AvatarSaver: s.avatarProxy, } - s.providers = append(s.providers, provider.NewService(dh)) - s.authMiddleware.Providers = s.providers + s.addProvider(dh) } // AddDirectProviderWithUserIDFunc adds provider with direct check against data store and sets custom UserIDFunc allows @@ -359,8 +360,7 @@ func (s *Service) AddDirectProviderWithUserIDFunc(name string, credChecker provi AvatarSaver: s.avatarProxy, UserIDFunc: ufn, } - s.providers = append(s.providers, provider.NewService(dh)) - s.authMiddleware.Providers = s.providers + s.addProvider(dh) } // AddVerifProvider adds provider user's verification sent by sender @@ -375,14 +375,12 @@ func (s *Service) AddVerifProvider(name, msgTmpl string, sender provider.Sender) Template: msgTmpl, UseGravatar: s.useGravatar, } - s.providers = append(s.providers, provider.NewService(dh)) - s.authMiddleware.Providers = s.providers + s.addProvider(dh) } // AddCustomHandler adds user-defined self-implemented handler of auth provider -func (s *Service) AddCustomHandler(handler provider.Provider) { - s.providers = append(s.providers, provider.NewService(handler)) - s.authMiddleware.Providers = s.providers +func (s *Service) AddCustomHandler(p provider.Provider) { + s.addProvider(p) } // DevAuth makes dev oauth2 server, for testing and development only! diff --git a/auth_test.go b/auth_test.go index 15b9efbe..fab65103 100644 --- a/auth_test.go +++ b/auth_test.go @@ -52,12 +52,13 @@ func TestProvider(t *testing.T) { _, err := svc.Provider("some provider") assert.EqualError(t, err, "provider some provider not found") - svc.AddProvider("dev", "cid", "csecret") + svc.AddProviderWithUserAttributes("dev", "cid", "csecret", provider.UserAttributes{"attrName": "attrValue"}) svc.AddProvider("github", "cid", "csecret") svc.AddProvider("google", "cid", "csecret") svc.AddProvider("facebook", "cid", "csecret") svc.AddProvider("yandex", "cid", "csecret") svc.AddProvider("microsoft", "cid", "csecret") + svc.AddProvider("twitter", "cid", "csecret") svc.AddProvider("battlenet", "cid", "csecret") svc.AddProvider("patreon", "cid", "csecret") svc.AddProvider("bad", "cid", "csecret") @@ -72,6 +73,7 @@ func TestProvider(t *testing.T) { assert.Equal(t, "cid", op.Cid) assert.Equal(t, "csecret", op.Csecret) assert.Equal(t, "go-pkgz/auth", op.Issuer) + assert.Equal(t, provider.UserAttributes{"attrName": "attrValue"}, op.Params.UserAttributes) p, err = svc.Provider("github") assert.NoError(t, err) @@ -79,7 +81,7 @@ func TestProvider(t *testing.T) { assert.Equal(t, "github", op.Name()) pp := svc.Providers() - assert.Equal(t, 9, len(pp)) + assert.Equal(t, 10, len(pp)) ch, err := svc.Provider("telegramBotMySiteCom") assert.NoError(t, err) @@ -227,7 +229,11 @@ func TestIntegrationAvatar(t *testing.T) { } func TestIntegrationList(t *testing.T) { - _, teardown := prepService(t) + _, teardown := prepService(t, func(svc *Service) { + svc.AddProvider("github", "cid", "csec") + // add go-oauth2/oauth2 provider + svc.AddCustomProvider("custom123", Client{"cid", "csecret"}, provider.CustomHandlerOpt{}) + }) defer teardown() resp, err := http.Get("http://127.0.0.1:8089/auth/list") @@ -237,7 +243,7 @@ func TestIntegrationList(t *testing.T) { b, err := io.ReadAll(resp.Body) require.NoError(t, err) - assert.Equal(t, `["dev","github","custom123","direct","direct_custom","email"]`+"\n", string(b)) + assert.Equal(t, `["dev","github","custom123"]`+"\n", string(b)) } func TestIntegrationUserInfo(t *testing.T) { @@ -336,7 +342,11 @@ func TestBadRequests(t *testing.T) { } func TestDirectProvider(t *testing.T) { - _, teardown := prepService(t) + _, teardown := prepService(t, func(svc *Service) { + svc.AddDirectProvider("direct", provider.CredCheckerFunc(func(user, password string) (ok bool, err error) { + return user == "dev_direct" && password == "password", nil + })) + }) defer teardown() // login @@ -374,19 +384,28 @@ func TestDirectProvider(t *testing.T) { } func TestDirectProvider_WithCustomUserIDFunc(t *testing.T) { - _, teardown := prepService(t) + _, teardown := prepService(t, func(svc *Service) { + svc.AddDirectProviderWithUserIDFunc("directCustom", + provider.CredCheckerFunc(func(user, password string) (ok bool, err error) { + return user == "dev_direct" && password == "password", nil + }), + func(user string, r *http.Request) string { + return "blah" + }, + ) + }) defer teardown() // login jar, err := cookiejar.New(nil) require.Nil(t, err) client := &http.Client{Jar: jar, Timeout: 5 * time.Second} - resp, err := client.Get("http://127.0.0.1:8089/auth/direct_custom/login?user=dev_direct&passwd=bad") + resp, err := client.Get("http://127.0.0.1:8089/auth/directCustom/login?user=dev_direct&passwd=bad") require.Nil(t, err) defer resp.Body.Close() assert.Equal(t, 403, resp.StatusCode) - resp, err = client.Get("http://127.0.0.1:8089/auth/direct_custom/login?user=dev_direct&passwd=password") + resp, err = client.Get("http://127.0.0.1:8089/auth/directCustom/login?user=dev_direct&passwd=password") require.Nil(t, err) defer resp.Body.Close() assert.Equal(t, 200, resp.StatusCode) @@ -396,7 +415,7 @@ func TestDirectProvider_WithCustomUserIDFunc(t *testing.T) { t.Logf("resp %s", string(body)) t.Logf("headers: %+v", resp.Header) - assert.Contains(t, string(body), `"name":"dev_direct","id":"direct_custom_5bf1fd927dfb8679496a2e6cf00cbe50c1c87145"`) + assert.Contains(t, string(body), `"name":"dev_direct","id":"directCustom_5bf1fd927dfb8679496a2e6cf00cbe50c1c87145"`) require.Equal(t, 2, len(resp.Cookies())) assert.Equal(t, "JWT", resp.Cookies()[0].Name) @@ -412,7 +431,9 @@ func TestDirectProvider_WithCustomUserIDFunc(t *testing.T) { } func TestVerifProvider(t *testing.T) { - _, teardown := prepService(t) + _, teardown := prepService(t, func(svc *Service) { + svc.AddVerifProvider("email", "{{.Token}}", &sender) + }) defer teardown() // login @@ -488,7 +509,16 @@ func TestStatus(t *testing.T) { } -func prepService(t *testing.T) (svc *Service, teardown func()) { //nolint unparam +func TestDevAuthServerWithoutDevProvider(t *testing.T) { + svc := NewService(Opts{}) + assert.NotNil(t, svc) + + _, err := svc.DevAuth() + require.NotNil(t, err) + assert.EqualError(t, err, "dev provider not registered: provider dev not found") +} + +func prepService(t *testing.T, providerConfigFunctions ...func(svc *Service)) (svc *Service, teardown func()) { //nolint unparam options := Opts{ SecretReader: token.SecretFunc(func(string) (string, error) { return "secret", nil }), @@ -509,28 +539,12 @@ func prepService(t *testing.T) (svc *Service, teardown func()) { //nolint unpara } svc = NewService(options) - svc.AddDevProvider("localhost", 18084) // add dev provider on 18084 - svc.AddProvider("github", "cid", "csec") // add github provider - - // add go-oauth2/oauth2 provider - svc.AddCustomProvider("custom123", Client{"cid", "csecret"}, provider.CustomHandlerOpt{}) - - // add direct provider - svc.AddDirectProvider("direct", provider.CredCheckerFunc(func(user, password string) (ok bool, err error) { - return user == "dev_direct" && password == "password", nil - })) - // add direct provider with custom user id func - svc.AddDirectProviderWithUserIDFunc("direct_custom", - provider.CredCheckerFunc(func(user, password string) (ok bool, err error) { - return user == "dev_direct" && password == "password", nil - }), - func(user string, r *http.Request) string { - return "blah" - }, - ) + svc.AddDevProvider("localhost", 18084) // add dev provider on 18084 - svc.AddVerifProvider("email", "{{.Token}}", &sender) + for _, f := range providerConfigFunctions { + f(svc) + } // run dev/test oauth2 server on :18084 devAuth, err := svc.DevAuth() @@ -546,7 +560,7 @@ func prepService(t *testing.T) (svc *Service, teardown func()) { //nolint unpara _, _ = w.Write([]byte("open route, no token needed\n")) })) mux.Handle("/private", m.Auth(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { // token required - _, _ = w.Write([]byte("open route, no token needed\n")) + _, _ = w.Write([]byte("protected route, authenticated with token\n")) }))) // setup auth routes