diff --git a/jwt/jwt_test.go b/jwt/jwt_test.go index f380bea9..39e6a6ec 100644 --- a/jwt/jwt_test.go +++ b/jwt/jwt_test.go @@ -114,31 +114,61 @@ func TestJwtFromHeader(t *testing.T) { } }() - for _, test := range hamac { - // Arrange - app := fiber.New() + t.Run("regular", func(t *testing.T) { + for _, test := range hamac { + // Arrange + app := fiber.New() + + app.Use(jwtware.New(jwtware.Config{ + SigningKey: jwtware.SigningKey{ + JWTAlg: test.SigningMethod, + Key: []byte(defaultSigningKey), + }, + })) + + app.Get("/ok", func(c *fiber.Ctx) error { + return c.SendString("OK") + }) + + req := httptest.NewRequest("GET", "/ok", nil) + req.Header.Add("Authorization", "Bearer "+test.Token) + + // Act + resp, err := app.Test(req) + + // Assert + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, 200, resp.StatusCode) + } + }) - app.Use(jwtware.New(jwtware.Config{ - SigningKey: jwtware.SigningKey{ - JWTAlg: test.SigningMethod, - Key: []byte(defaultSigningKey), - }, - })) + t.Run("malformed header", func(t *testing.T) { + for _, test := range hamac { + // Arrange + app := fiber.New() - app.Get("/ok", func(c *fiber.Ctx) error { - return c.SendString("OK") - }) + app.Use(jwtware.New(jwtware.Config{ + SigningKey: jwtware.SigningKey{ + JWTAlg: test.SigningMethod, + Key: []byte(defaultSigningKey), + }, + })) - req := httptest.NewRequest("GET", "/ok", nil) - req.Header.Add("Authorization", "Bearer "+test.Token) + app.Get("/ok", func(c *fiber.Ctx) error { + return c.SendString("OK") + }) - // Act - resp, err := app.Test(req) + req := httptest.NewRequest("GET", "/ok", nil) + req.Header.Add("Authorization", "Bearer"+test.Token) - // Assert - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, 200, resp.StatusCode) - } + // Act + resp, err := app.Test(req) + + // Assert + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, 400, resp.StatusCode) + } + }) } func TestJwtFromCookie(t *testing.T) { diff --git a/jwt/utils.go b/jwt/utils.go index c44213a2..aafbd7c7 100644 --- a/jwt/utils.go +++ b/jwt/utils.go @@ -16,6 +16,8 @@ type jwtExtractor func(c *fiber.Ctx) (string, error) // jwtFromHeader returns a function that extracts token from the request header. func jwtFromHeader(header string, authScheme string) func(c *fiber.Ctx) (string, error) { + // Enforce the presence of a space between the authentication scheme and the token + authScheme = authScheme + " " return func(c *fiber.Ctx) (string, error) { auth := c.Get(header) l := len(authScheme)