feat: add groups support for OIDC claims

Allow users to specify allowed groups in the env and use it to inspect the claims.

This performs a logical AND of users and groups (additive).
This commit is contained in:
Peter Olds 2025-01-13 21:40:10 -08:00
parent b44c8586cc
commit 73de18e197
No known key found for this signature in database
7 changed files with 441 additions and 42 deletions

View file

@ -19,8 +19,21 @@ GODOXY_API_JWT_TOKEN_TTL=1h
# GODOXY_OIDC_REDIRECT_URL=https://your-domain/api/auth/callback
# Comma-separated list of scopes
# GODOXY_OIDC_SCOPES=openid, profile, email
# Comma-separated list of allowed users
#
# User definitions: Uncomment and configure these values to restrict access to specific users or groups.
# These two fields act as a logical AND operator. For example, given the following membership:
# user1, group1
# user2, group1
# user3, group2
# user1, group2
# You can allow access to user3 AND all users of group1 by providing:
# # GODOXY_OIDC_ALLOWED_USERS=user3
# # GODOXY_OIDC_ALLOWED_GROUPS=group1
#
# Comma-separated list of allowed users.
# GODOXY_OIDC_ALLOWED_USERS=user1,user2
# Optional: Comma-separated list of allowed groups.
# GODOXY_OIDC_ALLOWED_GROUPS=group1,group2
# Proxy listening address
GODOXY_HTTP_ADDR=:80

View file

@ -14,16 +14,18 @@ import (
U "github.com/yusing/go-proxy/internal/api/v1/utils"
"github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error"
CE "github.com/yusing/go-proxy/internal/utils"
"github.com/yusing/go-proxy/internal/utils/strutils"
"golang.org/x/oauth2"
)
type OIDCProvider struct {
oauthConfig *oauth2.Config
oidcProvider *oidc.Provider
oidcVerifier *oidc.IDTokenVerifier
allowedUsers []string
isMiddleware bool
oauthConfig *oauth2.Config
oidcProvider *oidc.Provider
oidcVerifier *oidc.IDTokenVerifier
allowedUsers []string
allowedGroups []string
isMiddleware bool
}
const CookieOauthState = "godoxy_oidc_state"
@ -33,9 +35,9 @@ const (
OIDCLogoutPath = "/auth/logout"
)
func NewOIDCProvider(issuerURL, clientID, clientSecret, redirectURL string, allowedUsers []string) (*OIDCProvider, error) {
if len(allowedUsers) == 0 {
return nil, errors.New("OIDC allowed users must not be empty")
func NewOIDCProvider(issuerURL, clientID, clientSecret, redirectURL string, allowedUsers, allowedGroups []string) (*OIDCProvider, error) {
if len(allowedUsers)+len(allowedGroups) == 0 {
return nil, errors.New("OIDC users, groups, or both must not be empty")
}
provider, err := oidc.NewProvider(context.Background(), issuerURL)
@ -55,7 +57,8 @@ func NewOIDCProvider(issuerURL, clientID, clientSecret, redirectURL string, allo
oidcVerifier: provider.Verifier(&oidc.Config{
ClientID: clientID,
}),
allowedUsers: allowedUsers,
allowedUsers: allowedUsers,
allowedGroups: allowedGroups,
}, nil
}
@ -67,6 +70,7 @@ func NewOIDCProviderFromEnv() (*OIDCProvider, error) {
common.OIDCClientSecret,
common.OIDCRedirectURL,
common.OIDCAllowedUsers,
common.OIDCAllowedGroups,
)
}
@ -85,6 +89,10 @@ func (auth *OIDCProvider) SetAllowedUsers(users []string) {
auth.allowedUsers = users
}
func (auth *OIDCProvider) SetAllowedGroups(groups []string) {
auth.allowedGroups = groups
}
func (auth *OIDCProvider) CheckToken(r *http.Request) error {
token, err := r.Cookie(auth.TokenCookieName())
if err != nil {
@ -94,7 +102,7 @@ func (auth *OIDCProvider) CheckToken(r *http.Request) error {
// checks for Expiry, Audience == ClientID, Issuer, etc.
idToken, err := auth.oidcVerifier.Verify(r.Context(), token.Value)
if err != nil {
return fmt.Errorf("failed to verify ID token: %w", err)
return fmt.Errorf("failed to verify ID token: %w: %w", ErrInvalidToken, err)
}
if len(idToken.Audience) == 0 {
@ -102,14 +110,18 @@ func (auth *OIDCProvider) CheckToken(r *http.Request) error {
}
var claims struct {
Email string `json:"email"`
Username string `json:"preferred_username"`
Email string `json:"email"`
Username string `json:"preferred_username"`
Groups []string `json:"groups"`
}
if err := idToken.Claims(&claims); err != nil {
return fmt.Errorf("failed to parse claims: %w", err)
}
if !slices.Contains(auth.allowedUsers, claims.Username) {
// Logical AND between allowed users and groups.
allowedUser := slices.Contains(auth.allowedUsers, claims.Username)
allowedGroup := len(CE.Intersect(claims.Groups, auth.allowedGroups)) > 0
if !allowedUser && !allowedGroup {
return ErrUserNotAllowed.Subject(claims.Username)
}
return nil

View file

@ -2,11 +2,17 @@ package auth
import (
"context"
"crypto/rand"
"crypto/rsa"
"encoding/base64"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/golang-jwt/jwt/v5"
"github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error"
"golang.org/x/oauth2"
@ -34,7 +40,95 @@ func setupMockOIDC(t *testing.T) {
oidcVerifier: provider.Verifier(&oidc.Config{
ClientID: "test-client",
}),
allowedUsers: []string{"test-user"},
allowedUsers: []string{"test-user"},
allowedGroups: []string{"test-group1", "test-group2"},
}
}
// discoveryDocument returns a mock OIDC discovery document.
func discoveryDocument(t *testing.T, server *httptest.Server) map[string]any {
t.Helper()
discovery := map[string]any{
"issuer": server.URL,
"authorization_endpoint": server.URL + "/auth",
"token_endpoint": server.URL + "/token",
}
return discovery
}
const (
keyID = "test-key-id"
clientID = "test-client-id"
)
type provider struct {
ts *httptest.Server
key *rsa.PrivateKey
verifier *oidc.IDTokenVerifier
}
func (j *provider) SignClaims(t *testing.T, claims jwt.Claims) string {
t.Helper()
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
token.Header["kid"] = keyID
signed, err := token.SignedString(j.key)
ExpectNoError(t, err)
return signed
}
func setupProvider(t *testing.T) *provider {
t.Helper()
// Generate an RSA key pair for the test.
privKey, err := rsa.GenerateKey(rand.Reader, 2048)
ExpectNoError(t, err)
// Build the matching public JWK that will be served by the endpoint.
jwk := buildRSAJWK(t, &privKey.PublicKey, keyID)
// Start a test server that serves the JWKS endpoint.
ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/.well-known/jwks.json":
_ = json.NewEncoder(w).Encode(map[string]any{
"keys": []any{jwk},
})
default:
http.NotFound(w, r)
}
}))
t.Cleanup(ts.Close)
// Create a test OIDCProvider.
providerCtx := oidc.ClientContext(context.Background(), ts.Client())
keySet := oidc.NewRemoteKeySet(providerCtx, ts.URL+"/.well-known/jwks.json")
return &provider{
ts: ts,
key: privKey,
verifier: oidc.NewVerifier(ts.URL, keySet, &oidc.Config{
ClientID: clientID, // matches audience in the token
}),
}
}
// buildRSAJWK is a helper to construct a minimal JWK for the JWKS endpoint
func buildRSAJWK(t *testing.T, pub *rsa.PublicKey, kid string) map[string]any {
t.Helper()
nBytes := pub.N.Bytes()
eBytes := []byte{0x01, 0x00, 0x01} // Usually 65537
return map[string]any{
"kty": "RSA",
"alg": "RS256",
"use": "sig",
"kid": kid,
"n": base64.RawURLEncoding.EncodeToString(nBytes),
"e": base64.RawURLEncoding.EncodeToString(eBytes),
}
}
@ -145,14 +239,27 @@ func TestOIDCCallbackHandler(t *testing.T) {
}
func TestInitOIDC(t *testing.T) {
setupMockOIDC(t)
// Create a test server that serves the discovery document
var server *httptest.Server
mux := http.NewServeMux()
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
ExpectNoError(t, json.NewEncoder(w).Encode(discoveryDocument(t, server)))
})
server = httptest.NewServer(mux)
t.Cleanup(server.Close)
t.Cleanup(cleanup)
tests := []struct {
name string
issuerURL string
clientID string
clientSecret string
redirectURL string
allowedUsers []string
wantErr bool
name string
issuerURL string
clientID string
clientSecret string
redirectURL string
allowedUsers []string
allowedGroups []string
wantErr bool
}{
{
name: "Fail - Empty configuration",
@ -163,33 +270,179 @@ func TestInitOIDC(t *testing.T) {
allowedUsers: nil,
wantErr: true,
},
// {
// name: "Success - Valid configuration",
// issuerURL: "https://example.com",
// clientID: "client_id",
// clientSecret: "client_secret",
// redirectURL: "https://example.com/callback",
// allowedUsers: []string{"user1", "user2"},
// wantErr: false,
// },
{
name: "Fail - No allowed users",
name: "Success - Valid configuration with users",
issuerURL: server.URL,
clientID: "client_id",
clientSecret: "client_secret",
redirectURL: "https://example.com/callback",
allowedUsers: []string{"user1", "user2"},
wantErr: false,
},
{
name: "Success - Valid configuration with groups",
issuerURL: server.URL,
clientID: "client_id",
clientSecret: "client_secret",
redirectURL: "https://example.com/callback",
allowedGroups: []string{"group1", "group2"},
wantErr: false,
},
{
name: "Fail - No allowed users or allowed groups",
issuerURL: "https://example.com",
clientID: "client_id",
clientSecret: "client_secret",
redirectURL: "https://example.com/callback",
allowedUsers: []string{},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Cleanup(cleanup)
_, err := NewOIDCProvider(tt.issuerURL, tt.clientID, tt.clientSecret, tt.redirectURL, tt.allowedUsers)
_, err := NewOIDCProvider(tt.issuerURL, tt.clientID, tt.clientSecret, tt.redirectURL, tt.allowedUsers, tt.allowedGroups)
if (err != nil) != tt.wantErr {
t.Errorf("InitOIDC() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestCheckToken(t *testing.T) {
provider := setupProvider(t)
tests := []struct {
name string
allowedUsers []string
allowedGroups []string
claims jwt.Claims
wantErr error
}{
{
name: "Success - Valid token with allowed user",
allowedUsers: []string{"user1"},
claims: jwt.MapClaims{
"iss": provider.ts.URL,
"aud": clientID,
"exp": time.Now().Add(time.Hour).Unix(),
"preferred_username": "user1",
"groups": []string{"group1"},
},
},
{
name: "Success - Valid token with allowed group",
allowedGroups: []string{"group1"},
claims: jwt.MapClaims{
"iss": provider.ts.URL,
"aud": clientID,
"exp": time.Now().Add(time.Hour).Unix(),
"preferred_username": "user1",
"groups": []string{"group1"},
},
},
{
name: "Success - Server omits groups, but user is allowed",
allowedUsers: []string{"user1"},
claims: jwt.MapClaims{
"iss": provider.ts.URL,
"aud": clientID,
"exp": time.Now().Add(time.Hour).Unix(),
"preferred_username": "user1",
},
},
{
name: "Success - Server omits preferred_username, but group is allowed",
allowedGroups: []string{"group1"},
claims: jwt.MapClaims{
"iss": provider.ts.URL,
"aud": clientID,
"exp": time.Now().Add(time.Hour).Unix(),
"groups": []string{"group1"},
},
},
{
name: "Success - Valid token with allowed user and group",
allowedUsers: []string{"user1"},
allowedGroups: []string{"group1"},
claims: jwt.MapClaims{
"iss": provider.ts.URL,
"aud": clientID,
"exp": time.Now().Add(time.Hour).Unix(),
"preferred_username": "user1",
"groups": []string{"group1"},
},
},
{
name: "Error - User not allowed",
allowedUsers: []string{"user2", "user3"},
allowedGroups: []string{"group2", "group3"},
claims: jwt.MapClaims{
"iss": provider.ts.URL,
"aud": clientID,
"exp": time.Now().Add(time.Hour).Unix(),
"preferred_username": "user1",
"groups": []string{"group1"},
},
wantErr: ErrUserNotAllowed,
},
{
name: "Error - Server returns incorrect issuer",
claims: jwt.MapClaims{
"iss": "https://example.com",
"aud": clientID,
"exp": time.Now().Add(time.Hour).Unix(),
"preferred_username": "user1",
"groups": []string{"group1"},
},
wantErr: ErrInvalidToken,
},
{
name: "Error - Server returns incorrect audience",
claims: jwt.MapClaims{
"iss": provider.ts.URL,
"aud": "some-other-audience",
"exp": time.Now().Add(time.Hour).Unix(),
"preferred_username": "user1",
"groups": []string{"group1"},
},
wantErr: ErrInvalidToken,
},
{
name: "Error - Server returns expired token",
claims: jwt.MapClaims{
"iss": provider.ts.URL,
"aud": clientID,
"exp": time.Now().Add(-time.Hour).Unix(),
"preferred_username": "user1",
"groups": []string{"group1"},
},
wantErr: ErrInvalidToken,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// Create the Auth Provider.
auth := &OIDCProvider{
oidcVerifier: provider.verifier,
allowedUsers: tc.allowedUsers,
allowedGroups: tc.allowedGroups,
}
// Sign the claims to create a token.
signedToken := provider.SignClaims(t, tc.claims)
// Craft a test HTTP request that includes the token as a cookie.
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.AddCookie(&http.Cookie{
Name: auth.TokenCookieName(),
Value: signedToken,
})
// Call CheckToken and verify the result.
err := auth.CheckToken(req)
if tc.wantErr == nil {
ExpectNoError(t, err)
} else {
ExpectError(t, tc.wantErr, err)
}
})
}
}

View file

@ -47,12 +47,13 @@ var (
APIPassword = GetEnvString("API_PASSWORD", "password")
// OIDC Configuration.
OIDCIssuerURL = GetEnvString("OIDC_ISSUER_URL", "")
OIDCClientID = GetEnvString("OIDC_CLIENT_ID", "")
OIDCClientSecret = GetEnvString("OIDC_CLIENT_SECRET", "")
OIDCRedirectURL = GetEnvString("OIDC_REDIRECT_URL", "")
OIDCScopes = GetEnvString("OIDC_SCOPES", "openid, profile, email")
OIDCAllowedUsers = GetCommaSepEnv("OIDC_ALLOWED_USERS", "")
OIDCIssuerURL = GetEnvString("OIDC_ISSUER_URL", "")
OIDCClientID = GetEnvString("OIDC_CLIENT_ID", "")
OIDCClientSecret = GetEnvString("OIDC_CLIENT_SECRET", "")
OIDCRedirectURL = GetEnvString("OIDC_REDIRECT_URL", "")
OIDCScopes = GetEnvString("OIDC_SCOPES", "openid, profile, email")
OIDCAllowedUsers = GetCommaSepEnv("OIDC_ALLOWED_USERS", "")
OIDCAllowedGroups = GetCommaSepEnv("OIDC_ALLOWED_GROUPS", "")
)
func GetEnv[T any](key string, defaultValue T, parser func(string) (T, error)) T {

View file

@ -8,7 +8,8 @@ import (
)
type oidcMiddleware struct {
AllowedUsers []string `json:"allowed_users"`
AllowedUsers []string `json:"allowed_users"`
AllowedGroups []string `json:"allowed_groups"`
auth auth.Provider
authMux *http.ServeMux
@ -30,6 +31,9 @@ func (amw *oidcMiddleware) finalize() error {
if len(amw.AllowedUsers) > 0 {
authProvider.SetAllowedUsers(amw.AllowedUsers)
}
if len(amw.AllowedGroups) > 0 {
authProvider.SetAllowedGroups(amw.AllowedGroups)
}
amw.authMux = http.NewServeMux()
amw.authMux.HandleFunc(auth.OIDCMiddlewareCallbackPath, authProvider.LoginCallbackHandler)

20
internal/utils/slices.go Normal file
View file

@ -0,0 +1,20 @@
package utils
// Intersect returns a new slice containing the elements that are present in both input slices.
// This provides a more efficient solution than using two nested loops.
func Intersect[T comparable, Slice ~[]T](slice1 Slice, slice2 Slice) Slice {
var result Slice
seen := map[T]struct{}{}
for i := range slice1 {
seen[slice1[i]] = struct{}{}
}
for i := range slice2 {
if _, ok := seen[slice2[i]]; ok {
result = append(result, slice2[i])
}
}
return result
}

View file

@ -0,0 +1,96 @@
package utils
import (
"slices"
"strings"
"testing"
utils "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestIntersect(t *testing.T) {
t.Run("strings", func(t *testing.T) {
t.Run("no intersection", func(t *testing.T) {
var (
slice1 = []string{"a", "b", "c"}
slice2 = []string{"d", "e", "f"}
want []string
)
result := Intersect(slice1, slice2)
slices.Sort(result)
slices.Sort(want)
utils.ExpectDeepEqual(t, result, want)
})
t.Run("intersection", func(t *testing.T) {
var (
slice1 = []string{"a", "b", "c"}
slice2 = []string{"b", "c", "d"}
want = []string{"b", "c"}
)
result := Intersect(slice1, slice2)
slices.Sort(result)
slices.Sort(want)
utils.ExpectDeepEqual(t, result, want)
})
})
t.Run("ints", func(t *testing.T) {
t.Run("no intersection", func(t *testing.T) {
var (
slice1 = []int{1, 2, 3}
slice2 = []int{4, 5, 6}
want []int
)
result := Intersect(slice1, slice2)
slices.Sort(result)
slices.Sort(want)
utils.ExpectDeepEqual(t, result, want)
})
t.Run("intersection", func(t *testing.T) {
var (
slice1 = []int{1, 2, 3}
slice2 = []int{2, 3, 4}
want = []int{2, 3}
)
result := Intersect(slice1, slice2)
slices.Sort(result)
slices.Sort(want)
utils.ExpectDeepEqual(t, result, want)
})
})
t.Run("complex", func(t *testing.T) {
type T struct {
A string
B int
}
t.Run("no intersection", func(t *testing.T) {
var (
slice1 = []T{{"a", 1}, {"b", 2}, {"c", 3}}
slice2 = []T{{"d", 4}, {"e", 5}, {"f", 6}}
want []T
)
result := Intersect(slice1, slice2)
slices.SortFunc(result, func(i T, j T) int {
return strings.Compare(i.A, j.A)
})
slices.SortFunc(want, func(i T, j T) int {
return strings.Compare(i.A, j.A)
})
utils.ExpectDeepEqual(t, result, want)
})
t.Run("intersection", func(t *testing.T) {
var (
slice1 = []T{{"a", 1}, {"b", 2}, {"c", 3}}
slice2 = []T{{"b", 2}, {"c", 3}, {"d", 4}}
want = []T{{"b", 2}, {"c", 3}}
)
result := Intersect(slice1, slice2)
slices.SortFunc(result, func(i T, j T) int {
return strings.Compare(i.A, j.A)
})
slices.SortFunc(want, func(i T, j T) int {
return strings.Compare(i.A, j.A)
})
utils.ExpectDeepEqual(t, result, want)
})
})
}