mirror of
https://github.com/yusing/godoxy.git
synced 2025-06-04 02:42:34 +02:00
feat: add groups support for OIDC claims
Allow users to specify allowed groups in the env and use it to inspect the claims. This performs a logical AND of users and groups (additive).
This commit is contained in:
parent
b44c8586cc
commit
73de18e197
7 changed files with 441 additions and 42 deletions
15
.env.example
15
.env.example
|
@ -19,8 +19,21 @@ GODOXY_API_JWT_TOKEN_TTL=1h
|
||||||
# GODOXY_OIDC_REDIRECT_URL=https://your-domain/api/auth/callback
|
# GODOXY_OIDC_REDIRECT_URL=https://your-domain/api/auth/callback
|
||||||
# Comma-separated list of scopes
|
# Comma-separated list of scopes
|
||||||
# GODOXY_OIDC_SCOPES=openid, profile, email
|
# GODOXY_OIDC_SCOPES=openid, profile, email
|
||||||
# Comma-separated list of allowed users
|
#
|
||||||
|
# User definitions: Uncomment and configure these values to restrict access to specific users or groups.
|
||||||
|
# These two fields act as a logical AND operator. For example, given the following membership:
|
||||||
|
# user1, group1
|
||||||
|
# user2, group1
|
||||||
|
# user3, group2
|
||||||
|
# user1, group2
|
||||||
|
# You can allow access to user3 AND all users of group1 by providing:
|
||||||
|
# # GODOXY_OIDC_ALLOWED_USERS=user3
|
||||||
|
# # GODOXY_OIDC_ALLOWED_GROUPS=group1
|
||||||
|
#
|
||||||
|
# Comma-separated list of allowed users.
|
||||||
# GODOXY_OIDC_ALLOWED_USERS=user1,user2
|
# GODOXY_OIDC_ALLOWED_USERS=user1,user2
|
||||||
|
# Optional: Comma-separated list of allowed groups.
|
||||||
|
# GODOXY_OIDC_ALLOWED_GROUPS=group1,group2
|
||||||
|
|
||||||
# Proxy listening address
|
# Proxy listening address
|
||||||
GODOXY_HTTP_ADDR=:80
|
GODOXY_HTTP_ADDR=:80
|
||||||
|
|
|
@ -14,16 +14,18 @@ import (
|
||||||
U "github.com/yusing/go-proxy/internal/api/v1/utils"
|
U "github.com/yusing/go-proxy/internal/api/v1/utils"
|
||||||
"github.com/yusing/go-proxy/internal/common"
|
"github.com/yusing/go-proxy/internal/common"
|
||||||
E "github.com/yusing/go-proxy/internal/error"
|
E "github.com/yusing/go-proxy/internal/error"
|
||||||
|
CE "github.com/yusing/go-proxy/internal/utils"
|
||||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
)
|
)
|
||||||
|
|
||||||
type OIDCProvider struct {
|
type OIDCProvider struct {
|
||||||
oauthConfig *oauth2.Config
|
oauthConfig *oauth2.Config
|
||||||
oidcProvider *oidc.Provider
|
oidcProvider *oidc.Provider
|
||||||
oidcVerifier *oidc.IDTokenVerifier
|
oidcVerifier *oidc.IDTokenVerifier
|
||||||
allowedUsers []string
|
allowedUsers []string
|
||||||
isMiddleware bool
|
allowedGroups []string
|
||||||
|
isMiddleware bool
|
||||||
}
|
}
|
||||||
|
|
||||||
const CookieOauthState = "godoxy_oidc_state"
|
const CookieOauthState = "godoxy_oidc_state"
|
||||||
|
@ -33,9 +35,9 @@ const (
|
||||||
OIDCLogoutPath = "/auth/logout"
|
OIDCLogoutPath = "/auth/logout"
|
||||||
)
|
)
|
||||||
|
|
||||||
func NewOIDCProvider(issuerURL, clientID, clientSecret, redirectURL string, allowedUsers []string) (*OIDCProvider, error) {
|
func NewOIDCProvider(issuerURL, clientID, clientSecret, redirectURL string, allowedUsers, allowedGroups []string) (*OIDCProvider, error) {
|
||||||
if len(allowedUsers) == 0 {
|
if len(allowedUsers)+len(allowedGroups) == 0 {
|
||||||
return nil, errors.New("OIDC allowed users must not be empty")
|
return nil, errors.New("OIDC users, groups, or both must not be empty")
|
||||||
}
|
}
|
||||||
|
|
||||||
provider, err := oidc.NewProvider(context.Background(), issuerURL)
|
provider, err := oidc.NewProvider(context.Background(), issuerURL)
|
||||||
|
@ -55,7 +57,8 @@ func NewOIDCProvider(issuerURL, clientID, clientSecret, redirectURL string, allo
|
||||||
oidcVerifier: provider.Verifier(&oidc.Config{
|
oidcVerifier: provider.Verifier(&oidc.Config{
|
||||||
ClientID: clientID,
|
ClientID: clientID,
|
||||||
}),
|
}),
|
||||||
allowedUsers: allowedUsers,
|
allowedUsers: allowedUsers,
|
||||||
|
allowedGroups: allowedGroups,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -67,6 +70,7 @@ func NewOIDCProviderFromEnv() (*OIDCProvider, error) {
|
||||||
common.OIDCClientSecret,
|
common.OIDCClientSecret,
|
||||||
common.OIDCRedirectURL,
|
common.OIDCRedirectURL,
|
||||||
common.OIDCAllowedUsers,
|
common.OIDCAllowedUsers,
|
||||||
|
common.OIDCAllowedGroups,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -85,6 +89,10 @@ func (auth *OIDCProvider) SetAllowedUsers(users []string) {
|
||||||
auth.allowedUsers = users
|
auth.allowedUsers = users
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (auth *OIDCProvider) SetAllowedGroups(groups []string) {
|
||||||
|
auth.allowedGroups = groups
|
||||||
|
}
|
||||||
|
|
||||||
func (auth *OIDCProvider) CheckToken(r *http.Request) error {
|
func (auth *OIDCProvider) CheckToken(r *http.Request) error {
|
||||||
token, err := r.Cookie(auth.TokenCookieName())
|
token, err := r.Cookie(auth.TokenCookieName())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -94,7 +102,7 @@ func (auth *OIDCProvider) CheckToken(r *http.Request) error {
|
||||||
// checks for Expiry, Audience == ClientID, Issuer, etc.
|
// checks for Expiry, Audience == ClientID, Issuer, etc.
|
||||||
idToken, err := auth.oidcVerifier.Verify(r.Context(), token.Value)
|
idToken, err := auth.oidcVerifier.Verify(r.Context(), token.Value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to verify ID token: %w", err)
|
return fmt.Errorf("failed to verify ID token: %w: %w", ErrInvalidToken, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(idToken.Audience) == 0 {
|
if len(idToken.Audience) == 0 {
|
||||||
|
@ -102,14 +110,18 @@ func (auth *OIDCProvider) CheckToken(r *http.Request) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
var claims struct {
|
var claims struct {
|
||||||
Email string `json:"email"`
|
Email string `json:"email"`
|
||||||
Username string `json:"preferred_username"`
|
Username string `json:"preferred_username"`
|
||||||
|
Groups []string `json:"groups"`
|
||||||
}
|
}
|
||||||
if err := idToken.Claims(&claims); err != nil {
|
if err := idToken.Claims(&claims); err != nil {
|
||||||
return fmt.Errorf("failed to parse claims: %w", err)
|
return fmt.Errorf("failed to parse claims: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !slices.Contains(auth.allowedUsers, claims.Username) {
|
// Logical AND between allowed users and groups.
|
||||||
|
allowedUser := slices.Contains(auth.allowedUsers, claims.Username)
|
||||||
|
allowedGroup := len(CE.Intersect(claims.Groups, auth.allowedGroups)) > 0
|
||||||
|
if !allowedUser && !allowedGroup {
|
||||||
return ErrUserNotAllowed.Subject(claims.Username)
|
return ErrUserNotAllowed.Subject(claims.Username)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -2,11 +2,17 @@ package auth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/coreos/go-oidc/v3/oidc"
|
"github.com/coreos/go-oidc/v3/oidc"
|
||||||
|
"github.com/golang-jwt/jwt/v5"
|
||||||
"github.com/yusing/go-proxy/internal/common"
|
"github.com/yusing/go-proxy/internal/common"
|
||||||
E "github.com/yusing/go-proxy/internal/error"
|
E "github.com/yusing/go-proxy/internal/error"
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
|
@ -34,7 +40,95 @@ func setupMockOIDC(t *testing.T) {
|
||||||
oidcVerifier: provider.Verifier(&oidc.Config{
|
oidcVerifier: provider.Verifier(&oidc.Config{
|
||||||
ClientID: "test-client",
|
ClientID: "test-client",
|
||||||
}),
|
}),
|
||||||
allowedUsers: []string{"test-user"},
|
allowedUsers: []string{"test-user"},
|
||||||
|
allowedGroups: []string{"test-group1", "test-group2"},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// discoveryDocument returns a mock OIDC discovery document.
|
||||||
|
func discoveryDocument(t *testing.T, server *httptest.Server) map[string]any {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
discovery := map[string]any{
|
||||||
|
"issuer": server.URL,
|
||||||
|
"authorization_endpoint": server.URL + "/auth",
|
||||||
|
"token_endpoint": server.URL + "/token",
|
||||||
|
}
|
||||||
|
|
||||||
|
return discovery
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
keyID = "test-key-id"
|
||||||
|
clientID = "test-client-id"
|
||||||
|
)
|
||||||
|
|
||||||
|
type provider struct {
|
||||||
|
ts *httptest.Server
|
||||||
|
key *rsa.PrivateKey
|
||||||
|
verifier *oidc.IDTokenVerifier
|
||||||
|
}
|
||||||
|
|
||||||
|
func (j *provider) SignClaims(t *testing.T, claims jwt.Claims) string {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
|
||||||
|
token.Header["kid"] = keyID
|
||||||
|
signed, err := token.SignedString(j.key)
|
||||||
|
ExpectNoError(t, err)
|
||||||
|
return signed
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupProvider(t *testing.T) *provider {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
// Generate an RSA key pair for the test.
|
||||||
|
privKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
ExpectNoError(t, err)
|
||||||
|
|
||||||
|
// Build the matching public JWK that will be served by the endpoint.
|
||||||
|
jwk := buildRSAJWK(t, &privKey.PublicKey, keyID)
|
||||||
|
|
||||||
|
// Start a test server that serves the JWKS endpoint.
|
||||||
|
ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/.well-known/jwks.json":
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"keys": []any{jwk},
|
||||||
|
})
|
||||||
|
default:
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
t.Cleanup(ts.Close)
|
||||||
|
|
||||||
|
// Create a test OIDCProvider.
|
||||||
|
providerCtx := oidc.ClientContext(context.Background(), ts.Client())
|
||||||
|
keySet := oidc.NewRemoteKeySet(providerCtx, ts.URL+"/.well-known/jwks.json")
|
||||||
|
|
||||||
|
return &provider{
|
||||||
|
ts: ts,
|
||||||
|
key: privKey,
|
||||||
|
verifier: oidc.NewVerifier(ts.URL, keySet, &oidc.Config{
|
||||||
|
ClientID: clientID, // matches audience in the token
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildRSAJWK is a helper to construct a minimal JWK for the JWKS endpoint
|
||||||
|
func buildRSAJWK(t *testing.T, pub *rsa.PublicKey, kid string) map[string]any {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
nBytes := pub.N.Bytes()
|
||||||
|
eBytes := []byte{0x01, 0x00, 0x01} // Usually 65537
|
||||||
|
|
||||||
|
return map[string]any{
|
||||||
|
"kty": "RSA",
|
||||||
|
"alg": "RS256",
|
||||||
|
"use": "sig",
|
||||||
|
"kid": kid,
|
||||||
|
"n": base64.RawURLEncoding.EncodeToString(nBytes),
|
||||||
|
"e": base64.RawURLEncoding.EncodeToString(eBytes),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -145,14 +239,27 @@ func TestOIDCCallbackHandler(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestInitOIDC(t *testing.T) {
|
func TestInitOIDC(t *testing.T) {
|
||||||
|
setupMockOIDC(t)
|
||||||
|
// Create a test server that serves the discovery document
|
||||||
|
var server *httptest.Server
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
ExpectNoError(t, json.NewEncoder(w).Encode(discoveryDocument(t, server)))
|
||||||
|
})
|
||||||
|
server = httptest.NewServer(mux)
|
||||||
|
t.Cleanup(server.Close)
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
issuerURL string
|
issuerURL string
|
||||||
clientID string
|
clientID string
|
||||||
clientSecret string
|
clientSecret string
|
||||||
redirectURL string
|
redirectURL string
|
||||||
allowedUsers []string
|
allowedUsers []string
|
||||||
wantErr bool
|
allowedGroups []string
|
||||||
|
wantErr bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "Fail - Empty configuration",
|
name: "Fail - Empty configuration",
|
||||||
|
@ -163,33 +270,179 @@ func TestInitOIDC(t *testing.T) {
|
||||||
allowedUsers: nil,
|
allowedUsers: nil,
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
},
|
},
|
||||||
// {
|
|
||||||
// name: "Success - Valid configuration",
|
|
||||||
// issuerURL: "https://example.com",
|
|
||||||
// clientID: "client_id",
|
|
||||||
// clientSecret: "client_secret",
|
|
||||||
// redirectURL: "https://example.com/callback",
|
|
||||||
// allowedUsers: []string{"user1", "user2"},
|
|
||||||
// wantErr: false,
|
|
||||||
// },
|
|
||||||
{
|
{
|
||||||
name: "Fail - No allowed users",
|
name: "Success - Valid configuration with users",
|
||||||
|
issuerURL: server.URL,
|
||||||
|
clientID: "client_id",
|
||||||
|
clientSecret: "client_secret",
|
||||||
|
redirectURL: "https://example.com/callback",
|
||||||
|
allowedUsers: []string{"user1", "user2"},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Success - Valid configuration with groups",
|
||||||
|
issuerURL: server.URL,
|
||||||
|
clientID: "client_id",
|
||||||
|
clientSecret: "client_secret",
|
||||||
|
redirectURL: "https://example.com/callback",
|
||||||
|
allowedGroups: []string{"group1", "group2"},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Fail - No allowed users or allowed groups",
|
||||||
issuerURL: "https://example.com",
|
issuerURL: "https://example.com",
|
||||||
clientID: "client_id",
|
clientID: "client_id",
|
||||||
clientSecret: "client_secret",
|
clientSecret: "client_secret",
|
||||||
redirectURL: "https://example.com/callback",
|
redirectURL: "https://example.com/callback",
|
||||||
allowedUsers: []string{},
|
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
t.Cleanup(cleanup)
|
_, err := NewOIDCProvider(tt.issuerURL, tt.clientID, tt.clientSecret, tt.redirectURL, tt.allowedUsers, tt.allowedGroups)
|
||||||
_, err := NewOIDCProvider(tt.issuerURL, tt.clientID, tt.clientSecret, tt.redirectURL, tt.allowedUsers)
|
|
||||||
if (err != nil) != tt.wantErr {
|
if (err != nil) != tt.wantErr {
|
||||||
t.Errorf("InitOIDC() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("InitOIDC() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCheckToken(t *testing.T) {
|
||||||
|
provider := setupProvider(t)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
allowedUsers []string
|
||||||
|
allowedGroups []string
|
||||||
|
claims jwt.Claims
|
||||||
|
wantErr error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Success - Valid token with allowed user",
|
||||||
|
allowedUsers: []string{"user1"},
|
||||||
|
claims: jwt.MapClaims{
|
||||||
|
"iss": provider.ts.URL,
|
||||||
|
"aud": clientID,
|
||||||
|
"exp": time.Now().Add(time.Hour).Unix(),
|
||||||
|
"preferred_username": "user1",
|
||||||
|
"groups": []string{"group1"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Success - Valid token with allowed group",
|
||||||
|
allowedGroups: []string{"group1"},
|
||||||
|
claims: jwt.MapClaims{
|
||||||
|
"iss": provider.ts.URL,
|
||||||
|
"aud": clientID,
|
||||||
|
"exp": time.Now().Add(time.Hour).Unix(),
|
||||||
|
"preferred_username": "user1",
|
||||||
|
"groups": []string{"group1"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Success - Server omits groups, but user is allowed",
|
||||||
|
allowedUsers: []string{"user1"},
|
||||||
|
claims: jwt.MapClaims{
|
||||||
|
"iss": provider.ts.URL,
|
||||||
|
"aud": clientID,
|
||||||
|
"exp": time.Now().Add(time.Hour).Unix(),
|
||||||
|
"preferred_username": "user1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Success - Server omits preferred_username, but group is allowed",
|
||||||
|
allowedGroups: []string{"group1"},
|
||||||
|
claims: jwt.MapClaims{
|
||||||
|
"iss": provider.ts.URL,
|
||||||
|
"aud": clientID,
|
||||||
|
"exp": time.Now().Add(time.Hour).Unix(),
|
||||||
|
"groups": []string{"group1"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Success - Valid token with allowed user and group",
|
||||||
|
allowedUsers: []string{"user1"},
|
||||||
|
allowedGroups: []string{"group1"},
|
||||||
|
claims: jwt.MapClaims{
|
||||||
|
"iss": provider.ts.URL,
|
||||||
|
"aud": clientID,
|
||||||
|
"exp": time.Now().Add(time.Hour).Unix(),
|
||||||
|
"preferred_username": "user1",
|
||||||
|
"groups": []string{"group1"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Error - User not allowed",
|
||||||
|
allowedUsers: []string{"user2", "user3"},
|
||||||
|
allowedGroups: []string{"group2", "group3"},
|
||||||
|
claims: jwt.MapClaims{
|
||||||
|
"iss": provider.ts.URL,
|
||||||
|
"aud": clientID,
|
||||||
|
"exp": time.Now().Add(time.Hour).Unix(),
|
||||||
|
"preferred_username": "user1",
|
||||||
|
"groups": []string{"group1"},
|
||||||
|
},
|
||||||
|
wantErr: ErrUserNotAllowed,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Error - Server returns incorrect issuer",
|
||||||
|
claims: jwt.MapClaims{
|
||||||
|
"iss": "https://example.com",
|
||||||
|
"aud": clientID,
|
||||||
|
"exp": time.Now().Add(time.Hour).Unix(),
|
||||||
|
"preferred_username": "user1",
|
||||||
|
"groups": []string{"group1"},
|
||||||
|
},
|
||||||
|
wantErr: ErrInvalidToken,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Error - Server returns incorrect audience",
|
||||||
|
claims: jwt.MapClaims{
|
||||||
|
"iss": provider.ts.URL,
|
||||||
|
"aud": "some-other-audience",
|
||||||
|
"exp": time.Now().Add(time.Hour).Unix(),
|
||||||
|
"preferred_username": "user1",
|
||||||
|
"groups": []string{"group1"},
|
||||||
|
},
|
||||||
|
wantErr: ErrInvalidToken,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Error - Server returns expired token",
|
||||||
|
claims: jwt.MapClaims{
|
||||||
|
"iss": provider.ts.URL,
|
||||||
|
"aud": clientID,
|
||||||
|
"exp": time.Now().Add(-time.Hour).Unix(),
|
||||||
|
"preferred_username": "user1",
|
||||||
|
"groups": []string{"group1"},
|
||||||
|
},
|
||||||
|
wantErr: ErrInvalidToken,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
// Create the Auth Provider.
|
||||||
|
auth := &OIDCProvider{
|
||||||
|
oidcVerifier: provider.verifier,
|
||||||
|
allowedUsers: tc.allowedUsers,
|
||||||
|
allowedGroups: tc.allowedGroups,
|
||||||
|
}
|
||||||
|
// Sign the claims to create a token.
|
||||||
|
signedToken := provider.SignClaims(t, tc.claims)
|
||||||
|
// Craft a test HTTP request that includes the token as a cookie.
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req.AddCookie(&http.Cookie{
|
||||||
|
Name: auth.TokenCookieName(),
|
||||||
|
Value: signedToken,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Call CheckToken and verify the result.
|
||||||
|
err := auth.CheckToken(req)
|
||||||
|
if tc.wantErr == nil {
|
||||||
|
ExpectNoError(t, err)
|
||||||
|
} else {
|
||||||
|
ExpectError(t, tc.wantErr, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -47,12 +47,13 @@ var (
|
||||||
APIPassword = GetEnvString("API_PASSWORD", "password")
|
APIPassword = GetEnvString("API_PASSWORD", "password")
|
||||||
|
|
||||||
// OIDC Configuration.
|
// OIDC Configuration.
|
||||||
OIDCIssuerURL = GetEnvString("OIDC_ISSUER_URL", "")
|
OIDCIssuerURL = GetEnvString("OIDC_ISSUER_URL", "")
|
||||||
OIDCClientID = GetEnvString("OIDC_CLIENT_ID", "")
|
OIDCClientID = GetEnvString("OIDC_CLIENT_ID", "")
|
||||||
OIDCClientSecret = GetEnvString("OIDC_CLIENT_SECRET", "")
|
OIDCClientSecret = GetEnvString("OIDC_CLIENT_SECRET", "")
|
||||||
OIDCRedirectURL = GetEnvString("OIDC_REDIRECT_URL", "")
|
OIDCRedirectURL = GetEnvString("OIDC_REDIRECT_URL", "")
|
||||||
OIDCScopes = GetEnvString("OIDC_SCOPES", "openid, profile, email")
|
OIDCScopes = GetEnvString("OIDC_SCOPES", "openid, profile, email")
|
||||||
OIDCAllowedUsers = GetCommaSepEnv("OIDC_ALLOWED_USERS", "")
|
OIDCAllowedUsers = GetCommaSepEnv("OIDC_ALLOWED_USERS", "")
|
||||||
|
OIDCAllowedGroups = GetCommaSepEnv("OIDC_ALLOWED_GROUPS", "")
|
||||||
)
|
)
|
||||||
|
|
||||||
func GetEnv[T any](key string, defaultValue T, parser func(string) (T, error)) T {
|
func GetEnv[T any](key string, defaultValue T, parser func(string) (T, error)) T {
|
||||||
|
|
|
@ -8,7 +8,8 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type oidcMiddleware struct {
|
type oidcMiddleware struct {
|
||||||
AllowedUsers []string `json:"allowed_users"`
|
AllowedUsers []string `json:"allowed_users"`
|
||||||
|
AllowedGroups []string `json:"allowed_groups"`
|
||||||
|
|
||||||
auth auth.Provider
|
auth auth.Provider
|
||||||
authMux *http.ServeMux
|
authMux *http.ServeMux
|
||||||
|
@ -30,6 +31,9 @@ func (amw *oidcMiddleware) finalize() error {
|
||||||
if len(amw.AllowedUsers) > 0 {
|
if len(amw.AllowedUsers) > 0 {
|
||||||
authProvider.SetAllowedUsers(amw.AllowedUsers)
|
authProvider.SetAllowedUsers(amw.AllowedUsers)
|
||||||
}
|
}
|
||||||
|
if len(amw.AllowedGroups) > 0 {
|
||||||
|
authProvider.SetAllowedGroups(amw.AllowedGroups)
|
||||||
|
}
|
||||||
|
|
||||||
amw.authMux = http.NewServeMux()
|
amw.authMux = http.NewServeMux()
|
||||||
amw.authMux.HandleFunc(auth.OIDCMiddlewareCallbackPath, authProvider.LoginCallbackHandler)
|
amw.authMux.HandleFunc(auth.OIDCMiddlewareCallbackPath, authProvider.LoginCallbackHandler)
|
||||||
|
|
20
internal/utils/slices.go
Normal file
20
internal/utils/slices.go
Normal file
|
@ -0,0 +1,20 @@
|
||||||
|
package utils
|
||||||
|
|
||||||
|
// Intersect returns a new slice containing the elements that are present in both input slices.
|
||||||
|
// This provides a more efficient solution than using two nested loops.
|
||||||
|
func Intersect[T comparable, Slice ~[]T](slice1 Slice, slice2 Slice) Slice {
|
||||||
|
var result Slice
|
||||||
|
seen := map[T]struct{}{}
|
||||||
|
|
||||||
|
for i := range slice1 {
|
||||||
|
seen[slice1[i]] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range slice2 {
|
||||||
|
if _, ok := seen[slice2[i]]; ok {
|
||||||
|
result = append(result, slice2[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
96
internal/utils/slices_test.go
Normal file
96
internal/utils/slices_test.go
Normal file
|
@ -0,0 +1,96 @@
|
||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
utils "github.com/yusing/go-proxy/internal/utils/testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestIntersect(t *testing.T) {
|
||||||
|
t.Run("strings", func(t *testing.T) {
|
||||||
|
t.Run("no intersection", func(t *testing.T) {
|
||||||
|
var (
|
||||||
|
slice1 = []string{"a", "b", "c"}
|
||||||
|
slice2 = []string{"d", "e", "f"}
|
||||||
|
want []string
|
||||||
|
)
|
||||||
|
result := Intersect(slice1, slice2)
|
||||||
|
slices.Sort(result)
|
||||||
|
slices.Sort(want)
|
||||||
|
utils.ExpectDeepEqual(t, result, want)
|
||||||
|
})
|
||||||
|
t.Run("intersection", func(t *testing.T) {
|
||||||
|
var (
|
||||||
|
slice1 = []string{"a", "b", "c"}
|
||||||
|
slice2 = []string{"b", "c", "d"}
|
||||||
|
want = []string{"b", "c"}
|
||||||
|
)
|
||||||
|
result := Intersect(slice1, slice2)
|
||||||
|
slices.Sort(result)
|
||||||
|
slices.Sort(want)
|
||||||
|
utils.ExpectDeepEqual(t, result, want)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
t.Run("ints", func(t *testing.T) {
|
||||||
|
t.Run("no intersection", func(t *testing.T) {
|
||||||
|
var (
|
||||||
|
slice1 = []int{1, 2, 3}
|
||||||
|
slice2 = []int{4, 5, 6}
|
||||||
|
want []int
|
||||||
|
)
|
||||||
|
result := Intersect(slice1, slice2)
|
||||||
|
slices.Sort(result)
|
||||||
|
slices.Sort(want)
|
||||||
|
utils.ExpectDeepEqual(t, result, want)
|
||||||
|
})
|
||||||
|
t.Run("intersection", func(t *testing.T) {
|
||||||
|
var (
|
||||||
|
slice1 = []int{1, 2, 3}
|
||||||
|
slice2 = []int{2, 3, 4}
|
||||||
|
want = []int{2, 3}
|
||||||
|
)
|
||||||
|
result := Intersect(slice1, slice2)
|
||||||
|
slices.Sort(result)
|
||||||
|
slices.Sort(want)
|
||||||
|
utils.ExpectDeepEqual(t, result, want)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
t.Run("complex", func(t *testing.T) {
|
||||||
|
type T struct {
|
||||||
|
A string
|
||||||
|
B int
|
||||||
|
}
|
||||||
|
t.Run("no intersection", func(t *testing.T) {
|
||||||
|
var (
|
||||||
|
slice1 = []T{{"a", 1}, {"b", 2}, {"c", 3}}
|
||||||
|
slice2 = []T{{"d", 4}, {"e", 5}, {"f", 6}}
|
||||||
|
want []T
|
||||||
|
)
|
||||||
|
result := Intersect(slice1, slice2)
|
||||||
|
slices.SortFunc(result, func(i T, j T) int {
|
||||||
|
return strings.Compare(i.A, j.A)
|
||||||
|
})
|
||||||
|
slices.SortFunc(want, func(i T, j T) int {
|
||||||
|
return strings.Compare(i.A, j.A)
|
||||||
|
})
|
||||||
|
utils.ExpectDeepEqual(t, result, want)
|
||||||
|
})
|
||||||
|
t.Run("intersection", func(t *testing.T) {
|
||||||
|
var (
|
||||||
|
slice1 = []T{{"a", 1}, {"b", 2}, {"c", 3}}
|
||||||
|
slice2 = []T{{"b", 2}, {"c", 3}, {"d", 4}}
|
||||||
|
want = []T{{"b", 2}, {"c", 3}}
|
||||||
|
)
|
||||||
|
result := Intersect(slice1, slice2)
|
||||||
|
slices.SortFunc(result, func(i T, j T) int {
|
||||||
|
return strings.Compare(i.A, j.A)
|
||||||
|
})
|
||||||
|
slices.SortFunc(want, func(i T, j T) int {
|
||||||
|
return strings.Compare(i.A, j.A)
|
||||||
|
})
|
||||||
|
utils.ExpectDeepEqual(t, result, want)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
Loading…
Add table
Reference in a new issue