From abbe4ffcebce18373a03862ab8bbee873e66a08e Mon Sep 17 00:00:00 2001 From: Peter Olds Date: Mon, 13 Jan 2025 22:15:57 -0800 Subject: [PATCH] feat: add groups support for OIDC claims (#41) Allow users to specify allowed groups in the env and use it to inspect the claims. This performs a logical AND of users and groups (additive). --- .env.example | 15 +- internal/api/v1/auth/oidc.go | 38 ++-- internal/api/v1/auth/oidc_test.go | 295 +++++++++++++++++++++++++-- internal/common/env.go | 13 +- internal/net/http/middleware/oidc.go | 6 +- internal/utils/slices.go | 20 ++ internal/utils/slices_test.go | 96 +++++++++ 7 files changed, 441 insertions(+), 42 deletions(-) create mode 100644 internal/utils/slices.go create mode 100644 internal/utils/slices_test.go diff --git a/.env.example b/.env.example index 7588c4a..f022fab 100644 --- a/.env.example +++ b/.env.example @@ -19,8 +19,21 @@ GODOXY_API_JWT_TOKEN_TTL=1h # GODOXY_OIDC_REDIRECT_URL=https://your-domain/api/auth/callback # Comma-separated list of scopes # GODOXY_OIDC_SCOPES=openid, profile, email -# Comma-separated list of allowed users +# +# User definitions: Uncomment and configure these values to restrict access to specific users or groups. +# These two fields act as a logical AND operator. For example, given the following membership: +# user1, group1 +# user2, group1 +# user3, group2 +# user1, group2 +# You can allow access to user3 AND all users of group1 by providing: +# # GODOXY_OIDC_ALLOWED_USERS=user3 +# # GODOXY_OIDC_ALLOWED_GROUPS=group1 +# +# Comma-separated list of allowed users. # GODOXY_OIDC_ALLOWED_USERS=user1,user2 +# Optional: Comma-separated list of allowed groups. +# GODOXY_OIDC_ALLOWED_GROUPS=group1,group2 # Proxy listening address GODOXY_HTTP_ADDR=:80 diff --git a/internal/api/v1/auth/oidc.go b/internal/api/v1/auth/oidc.go index bd376a5..0c96559 100644 --- a/internal/api/v1/auth/oidc.go +++ b/internal/api/v1/auth/oidc.go @@ -14,16 +14,18 @@ import ( U "github.com/yusing/go-proxy/internal/api/v1/utils" "github.com/yusing/go-proxy/internal/common" E "github.com/yusing/go-proxy/internal/error" + CE "github.com/yusing/go-proxy/internal/utils" "github.com/yusing/go-proxy/internal/utils/strutils" "golang.org/x/oauth2" ) type OIDCProvider struct { - oauthConfig *oauth2.Config - oidcProvider *oidc.Provider - oidcVerifier *oidc.IDTokenVerifier - allowedUsers []string - isMiddleware bool + oauthConfig *oauth2.Config + oidcProvider *oidc.Provider + oidcVerifier *oidc.IDTokenVerifier + allowedUsers []string + allowedGroups []string + isMiddleware bool } const CookieOauthState = "godoxy_oidc_state" @@ -33,9 +35,9 @@ const ( OIDCLogoutPath = "/auth/logout" ) -func NewOIDCProvider(issuerURL, clientID, clientSecret, redirectURL string, allowedUsers []string) (*OIDCProvider, error) { - if len(allowedUsers) == 0 { - return nil, errors.New("OIDC allowed users must not be empty") +func NewOIDCProvider(issuerURL, clientID, clientSecret, redirectURL string, allowedUsers, allowedGroups []string) (*OIDCProvider, error) { + if len(allowedUsers)+len(allowedGroups) == 0 { + return nil, errors.New("OIDC users, groups, or both must not be empty") } provider, err := oidc.NewProvider(context.Background(), issuerURL) @@ -55,7 +57,8 @@ func NewOIDCProvider(issuerURL, clientID, clientSecret, redirectURL string, allo oidcVerifier: provider.Verifier(&oidc.Config{ ClientID: clientID, }), - allowedUsers: allowedUsers, + allowedUsers: allowedUsers, + allowedGroups: allowedGroups, }, nil } @@ -67,6 +70,7 @@ func NewOIDCProviderFromEnv() (*OIDCProvider, error) { common.OIDCClientSecret, common.OIDCRedirectURL, common.OIDCAllowedUsers, + common.OIDCAllowedGroups, ) } @@ -82,6 +86,10 @@ func (auth *OIDCProvider) SetAllowedUsers(users []string) { auth.allowedUsers = users } +func (auth *OIDCProvider) SetAllowedGroups(groups []string) { + auth.allowedGroups = groups +} + func (auth *OIDCProvider) CheckToken(r *http.Request) error { token, err := r.Cookie(auth.TokenCookieName()) if err != nil { @@ -91,7 +99,7 @@ func (auth *OIDCProvider) CheckToken(r *http.Request) error { // checks for Expiry, Audience == ClientID, Issuer, etc. idToken, err := auth.oidcVerifier.Verify(r.Context(), token.Value) if err != nil { - return fmt.Errorf("failed to verify ID token: %w", err) + return fmt.Errorf("failed to verify ID token: %w: %w", ErrInvalidToken, err) } if len(idToken.Audience) == 0 { @@ -99,14 +107,18 @@ func (auth *OIDCProvider) CheckToken(r *http.Request) error { } var claims struct { - Email string `json:"email"` - Username string `json:"preferred_username"` + Email string `json:"email"` + Username string `json:"preferred_username"` + Groups []string `json:"groups"` } if err := idToken.Claims(&claims); err != nil { return fmt.Errorf("failed to parse claims: %w", err) } - if !slices.Contains(auth.allowedUsers, claims.Username) { + // Logical AND between allowed users and groups. + allowedUser := slices.Contains(auth.allowedUsers, claims.Username) + allowedGroup := len(CE.Intersect(claims.Groups, auth.allowedGroups)) > 0 + if !allowedUser && !allowedGroup { return ErrUserNotAllowed.Subject(claims.Username) } return nil diff --git a/internal/api/v1/auth/oidc_test.go b/internal/api/v1/auth/oidc_test.go index c83579e..d14715e 100644 --- a/internal/api/v1/auth/oidc_test.go +++ b/internal/api/v1/auth/oidc_test.go @@ -2,11 +2,17 @@ package auth import ( "context" + "crypto/rand" + "crypto/rsa" + "encoding/base64" + "encoding/json" "net/http" "net/http/httptest" "testing" + "time" "github.com/coreos/go-oidc/v3/oidc" + "github.com/golang-jwt/jwt/v5" "github.com/yusing/go-proxy/internal/common" E "github.com/yusing/go-proxy/internal/error" "golang.org/x/oauth2" @@ -34,7 +40,95 @@ func setupMockOIDC(t *testing.T) { oidcVerifier: provider.Verifier(&oidc.Config{ ClientID: "test-client", }), - allowedUsers: []string{"test-user"}, + allowedUsers: []string{"test-user"}, + allowedGroups: []string{"test-group1", "test-group2"}, + } +} + +// discoveryDocument returns a mock OIDC discovery document. +func discoveryDocument(t *testing.T, server *httptest.Server) map[string]any { + t.Helper() + + discovery := map[string]any{ + "issuer": server.URL, + "authorization_endpoint": server.URL + "/auth", + "token_endpoint": server.URL + "/token", + } + + return discovery +} + +const ( + keyID = "test-key-id" + clientID = "test-client-id" +) + +type provider struct { + ts *httptest.Server + key *rsa.PrivateKey + verifier *oidc.IDTokenVerifier +} + +func (j *provider) SignClaims(t *testing.T, claims jwt.Claims) string { + t.Helper() + + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + token.Header["kid"] = keyID + signed, err := token.SignedString(j.key) + ExpectNoError(t, err) + return signed +} + +func setupProvider(t *testing.T) *provider { + t.Helper() + + // Generate an RSA key pair for the test. + privKey, err := rsa.GenerateKey(rand.Reader, 2048) + ExpectNoError(t, err) + + // Build the matching public JWK that will be served by the endpoint. + jwk := buildRSAJWK(t, &privKey.PublicKey, keyID) + + // Start a test server that serves the JWKS endpoint. + ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/.well-known/jwks.json": + _ = json.NewEncoder(w).Encode(map[string]any{ + "keys": []any{jwk}, + }) + default: + http.NotFound(w, r) + } + })) + t.Cleanup(ts.Close) + + // Create a test OIDCProvider. + providerCtx := oidc.ClientContext(context.Background(), ts.Client()) + keySet := oidc.NewRemoteKeySet(providerCtx, ts.URL+"/.well-known/jwks.json") + + return &provider{ + ts: ts, + key: privKey, + verifier: oidc.NewVerifier(ts.URL, keySet, &oidc.Config{ + ClientID: clientID, // matches audience in the token + }), + } +} + +// buildRSAJWK is a helper to construct a minimal JWK for the JWKS endpoint +func buildRSAJWK(t *testing.T, pub *rsa.PublicKey, kid string) map[string]any { + t.Helper() + + nBytes := pub.N.Bytes() + eBytes := []byte{0x01, 0x00, 0x01} // Usually 65537 + + return map[string]any{ + "kty": "RSA", + "alg": "RS256", + "use": "sig", + "kid": kid, + "n": base64.RawURLEncoding.EncodeToString(nBytes), + "e": base64.RawURLEncoding.EncodeToString(eBytes), } } @@ -145,14 +239,27 @@ func TestOIDCCallbackHandler(t *testing.T) { } func TestInitOIDC(t *testing.T) { + setupMockOIDC(t) + // Create a test server that serves the discovery document + var server *httptest.Server + mux := http.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + ExpectNoError(t, json.NewEncoder(w).Encode(discoveryDocument(t, server))) + }) + server = httptest.NewServer(mux) + t.Cleanup(server.Close) + t.Cleanup(cleanup) + tests := []struct { - name string - issuerURL string - clientID string - clientSecret string - redirectURL string - allowedUsers []string - wantErr bool + name string + issuerURL string + clientID string + clientSecret string + redirectURL string + allowedUsers []string + allowedGroups []string + wantErr bool }{ { name: "Fail - Empty configuration", @@ -163,33 +270,179 @@ func TestInitOIDC(t *testing.T) { allowedUsers: nil, wantErr: true, }, - // { - // name: "Success - Valid configuration", - // issuerURL: "https://example.com", - // clientID: "client_id", - // clientSecret: "client_secret", - // redirectURL: "https://example.com/callback", - // allowedUsers: []string{"user1", "user2"}, - // wantErr: false, - // }, { - name: "Fail - No allowed users", + name: "Success - Valid configuration with users", + issuerURL: server.URL, + clientID: "client_id", + clientSecret: "client_secret", + redirectURL: "https://example.com/callback", + allowedUsers: []string{"user1", "user2"}, + wantErr: false, + }, + { + name: "Success - Valid configuration with groups", + issuerURL: server.URL, + clientID: "client_id", + clientSecret: "client_secret", + redirectURL: "https://example.com/callback", + allowedGroups: []string{"group1", "group2"}, + wantErr: false, + }, + { + name: "Fail - No allowed users or allowed groups", issuerURL: "https://example.com", clientID: "client_id", clientSecret: "client_secret", redirectURL: "https://example.com/callback", - allowedUsers: []string{}, wantErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - t.Cleanup(cleanup) - _, err := NewOIDCProvider(tt.issuerURL, tt.clientID, tt.clientSecret, tt.redirectURL, tt.allowedUsers) + _, err := NewOIDCProvider(tt.issuerURL, tt.clientID, tt.clientSecret, tt.redirectURL, tt.allowedUsers, tt.allowedGroups) if (err != nil) != tt.wantErr { t.Errorf("InitOIDC() error = %v, wantErr %v", err, tt.wantErr) } }) } } + +func TestCheckToken(t *testing.T) { + provider := setupProvider(t) + + tests := []struct { + name string + allowedUsers []string + allowedGroups []string + claims jwt.Claims + wantErr error + }{ + { + name: "Success - Valid token with allowed user", + allowedUsers: []string{"user1"}, + claims: jwt.MapClaims{ + "iss": provider.ts.URL, + "aud": clientID, + "exp": time.Now().Add(time.Hour).Unix(), + "preferred_username": "user1", + "groups": []string{"group1"}, + }, + }, + { + name: "Success - Valid token with allowed group", + allowedGroups: []string{"group1"}, + claims: jwt.MapClaims{ + "iss": provider.ts.URL, + "aud": clientID, + "exp": time.Now().Add(time.Hour).Unix(), + "preferred_username": "user1", + "groups": []string{"group1"}, + }, + }, + { + name: "Success - Server omits groups, but user is allowed", + allowedUsers: []string{"user1"}, + claims: jwt.MapClaims{ + "iss": provider.ts.URL, + "aud": clientID, + "exp": time.Now().Add(time.Hour).Unix(), + "preferred_username": "user1", + }, + }, + { + name: "Success - Server omits preferred_username, but group is allowed", + allowedGroups: []string{"group1"}, + claims: jwt.MapClaims{ + "iss": provider.ts.URL, + "aud": clientID, + "exp": time.Now().Add(time.Hour).Unix(), + "groups": []string{"group1"}, + }, + }, + { + name: "Success - Valid token with allowed user and group", + allowedUsers: []string{"user1"}, + allowedGroups: []string{"group1"}, + claims: jwt.MapClaims{ + "iss": provider.ts.URL, + "aud": clientID, + "exp": time.Now().Add(time.Hour).Unix(), + "preferred_username": "user1", + "groups": []string{"group1"}, + }, + }, + { + name: "Error - User not allowed", + allowedUsers: []string{"user2", "user3"}, + allowedGroups: []string{"group2", "group3"}, + claims: jwt.MapClaims{ + "iss": provider.ts.URL, + "aud": clientID, + "exp": time.Now().Add(time.Hour).Unix(), + "preferred_username": "user1", + "groups": []string{"group1"}, + }, + wantErr: ErrUserNotAllowed, + }, + { + name: "Error - Server returns incorrect issuer", + claims: jwt.MapClaims{ + "iss": "https://example.com", + "aud": clientID, + "exp": time.Now().Add(time.Hour).Unix(), + "preferred_username": "user1", + "groups": []string{"group1"}, + }, + wantErr: ErrInvalidToken, + }, + { + name: "Error - Server returns incorrect audience", + claims: jwt.MapClaims{ + "iss": provider.ts.URL, + "aud": "some-other-audience", + "exp": time.Now().Add(time.Hour).Unix(), + "preferred_username": "user1", + "groups": []string{"group1"}, + }, + wantErr: ErrInvalidToken, + }, + { + name: "Error - Server returns expired token", + claims: jwt.MapClaims{ + "iss": provider.ts.URL, + "aud": clientID, + "exp": time.Now().Add(-time.Hour).Unix(), + "preferred_username": "user1", + "groups": []string{"group1"}, + }, + wantErr: ErrInvalidToken, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Create the Auth Provider. + auth := &OIDCProvider{ + oidcVerifier: provider.verifier, + allowedUsers: tc.allowedUsers, + allowedGroups: tc.allowedGroups, + } + // Sign the claims to create a token. + signedToken := provider.SignClaims(t, tc.claims) + // Craft a test HTTP request that includes the token as a cookie. + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.AddCookie(&http.Cookie{ + Name: auth.TokenCookieName(), + Value: signedToken, + }) + + // Call CheckToken and verify the result. + err := auth.CheckToken(req) + if tc.wantErr == nil { + ExpectNoError(t, err) + } else { + ExpectError(t, tc.wantErr, err) + } + }) + } +} diff --git a/internal/common/env.go b/internal/common/env.go index 3d207d7..3cd7523 100644 --- a/internal/common/env.go +++ b/internal/common/env.go @@ -47,12 +47,13 @@ var ( APIPassword = GetEnvString("API_PASSWORD", "password") // OIDC Configuration. - OIDCIssuerURL = GetEnvString("OIDC_ISSUER_URL", "") - OIDCClientID = GetEnvString("OIDC_CLIENT_ID", "") - OIDCClientSecret = GetEnvString("OIDC_CLIENT_SECRET", "") - OIDCRedirectURL = GetEnvString("OIDC_REDIRECT_URL", "") - OIDCScopes = GetEnvString("OIDC_SCOPES", "openid, profile, email") - OIDCAllowedUsers = GetCommaSepEnv("OIDC_ALLOWED_USERS", "") + OIDCIssuerURL = GetEnvString("OIDC_ISSUER_URL", "") + OIDCClientID = GetEnvString("OIDC_CLIENT_ID", "") + OIDCClientSecret = GetEnvString("OIDC_CLIENT_SECRET", "") + OIDCRedirectURL = GetEnvString("OIDC_REDIRECT_URL", "") + OIDCScopes = GetEnvString("OIDC_SCOPES", "openid, profile, email") + OIDCAllowedUsers = GetCommaSepEnv("OIDC_ALLOWED_USERS", "") + OIDCAllowedGroups = GetCommaSepEnv("OIDC_ALLOWED_GROUPS", "") ) func GetEnv[T any](key string, defaultValue T, parser func(string) (T, error)) T { diff --git a/internal/net/http/middleware/oidc.go b/internal/net/http/middleware/oidc.go index b8ad7cd..8e1b3e6 100644 --- a/internal/net/http/middleware/oidc.go +++ b/internal/net/http/middleware/oidc.go @@ -8,7 +8,8 @@ import ( ) type oidcMiddleware struct { - AllowedUsers []string `json:"allowed_users"` + AllowedUsers []string `json:"allowed_users"` + AllowedGroups []string `json:"allowed_groups"` auth auth.Provider authMux *http.ServeMux @@ -30,6 +31,9 @@ func (amw *oidcMiddleware) finalize() error { if len(amw.AllowedUsers) > 0 { authProvider.SetAllowedUsers(amw.AllowedUsers) } + if len(amw.AllowedGroups) > 0 { + authProvider.SetAllowedGroups(amw.AllowedGroups) + } amw.authMux = http.NewServeMux() amw.authMux.HandleFunc(auth.OIDCMiddlewareCallbackPath, authProvider.LoginCallbackHandler) diff --git a/internal/utils/slices.go b/internal/utils/slices.go new file mode 100644 index 0000000..afe2914 --- /dev/null +++ b/internal/utils/slices.go @@ -0,0 +1,20 @@ +package utils + +// Intersect returns a new slice containing the elements that are present in both input slices. +// This provides a more efficient solution than using two nested loops. +func Intersect[T comparable, Slice ~[]T](slice1 Slice, slice2 Slice) Slice { + var result Slice + seen := map[T]struct{}{} + + for i := range slice1 { + seen[slice1[i]] = struct{}{} + } + + for i := range slice2 { + if _, ok := seen[slice2[i]]; ok { + result = append(result, slice2[i]) + } + } + + return result +} diff --git a/internal/utils/slices_test.go b/internal/utils/slices_test.go new file mode 100644 index 0000000..8d2a1f1 --- /dev/null +++ b/internal/utils/slices_test.go @@ -0,0 +1,96 @@ +package utils + +import ( + "slices" + "strings" + "testing" + + utils "github.com/yusing/go-proxy/internal/utils/testing" +) + +func TestIntersect(t *testing.T) { + t.Run("strings", func(t *testing.T) { + t.Run("no intersection", func(t *testing.T) { + var ( + slice1 = []string{"a", "b", "c"} + slice2 = []string{"d", "e", "f"} + want []string + ) + result := Intersect(slice1, slice2) + slices.Sort(result) + slices.Sort(want) + utils.ExpectDeepEqual(t, result, want) + }) + t.Run("intersection", func(t *testing.T) { + var ( + slice1 = []string{"a", "b", "c"} + slice2 = []string{"b", "c", "d"} + want = []string{"b", "c"} + ) + result := Intersect(slice1, slice2) + slices.Sort(result) + slices.Sort(want) + utils.ExpectDeepEqual(t, result, want) + }) + }) + t.Run("ints", func(t *testing.T) { + t.Run("no intersection", func(t *testing.T) { + var ( + slice1 = []int{1, 2, 3} + slice2 = []int{4, 5, 6} + want []int + ) + result := Intersect(slice1, slice2) + slices.Sort(result) + slices.Sort(want) + utils.ExpectDeepEqual(t, result, want) + }) + t.Run("intersection", func(t *testing.T) { + var ( + slice1 = []int{1, 2, 3} + slice2 = []int{2, 3, 4} + want = []int{2, 3} + ) + result := Intersect(slice1, slice2) + slices.Sort(result) + slices.Sort(want) + utils.ExpectDeepEqual(t, result, want) + }) + }) + t.Run("complex", func(t *testing.T) { + type T struct { + A string + B int + } + t.Run("no intersection", func(t *testing.T) { + var ( + slice1 = []T{{"a", 1}, {"b", 2}, {"c", 3}} + slice2 = []T{{"d", 4}, {"e", 5}, {"f", 6}} + want []T + ) + result := Intersect(slice1, slice2) + slices.SortFunc(result, func(i T, j T) int { + return strings.Compare(i.A, j.A) + }) + slices.SortFunc(want, func(i T, j T) int { + return strings.Compare(i.A, j.A) + }) + utils.ExpectDeepEqual(t, result, want) + }) + t.Run("intersection", func(t *testing.T) { + var ( + slice1 = []T{{"a", 1}, {"b", 2}, {"c", 3}} + slice2 = []T{{"b", 2}, {"c", 3}, {"d", 4}} + want = []T{{"b", 2}, {"c", 3}} + ) + result := Intersect(slice1, slice2) + slices.SortFunc(result, func(i T, j T) int { + return strings.Compare(i.A, j.A) + }) + slices.SortFunc(want, func(i T, j T) int { + return strings.Compare(i.A, j.A) + }) + utils.ExpectDeepEqual(t, result, want) + }) + }) +}