mirror of
https://github.com/yusing/godoxy.git
synced 2025-06-01 09:32:35 +02:00
feat: add groups support for OIDC claims
Allow users to specify allowed groups in the env and use it to inspect the claims. This performs a logical AND of users and groups (additive).
This commit is contained in:
parent
b44c8586cc
commit
73de18e197
7 changed files with 441 additions and 42 deletions
15
.env.example
15
.env.example
|
@ -19,8 +19,21 @@ GODOXY_API_JWT_TOKEN_TTL=1h
|
|||
# GODOXY_OIDC_REDIRECT_URL=https://your-domain/api/auth/callback
|
||||
# Comma-separated list of scopes
|
||||
# GODOXY_OIDC_SCOPES=openid, profile, email
|
||||
# Comma-separated list of allowed users
|
||||
#
|
||||
# User definitions: Uncomment and configure these values to restrict access to specific users or groups.
|
||||
# These two fields act as a logical AND operator. For example, given the following membership:
|
||||
# user1, group1
|
||||
# user2, group1
|
||||
# user3, group2
|
||||
# user1, group2
|
||||
# You can allow access to user3 AND all users of group1 by providing:
|
||||
# # GODOXY_OIDC_ALLOWED_USERS=user3
|
||||
# # GODOXY_OIDC_ALLOWED_GROUPS=group1
|
||||
#
|
||||
# Comma-separated list of allowed users.
|
||||
# GODOXY_OIDC_ALLOWED_USERS=user1,user2
|
||||
# Optional: Comma-separated list of allowed groups.
|
||||
# GODOXY_OIDC_ALLOWED_GROUPS=group1,group2
|
||||
|
||||
# Proxy listening address
|
||||
GODOXY_HTTP_ADDR=:80
|
||||
|
|
|
@ -14,16 +14,18 @@ import (
|
|||
U "github.com/yusing/go-proxy/internal/api/v1/utils"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
CE "github.com/yusing/go-proxy/internal/utils"
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
type OIDCProvider struct {
|
||||
oauthConfig *oauth2.Config
|
||||
oidcProvider *oidc.Provider
|
||||
oidcVerifier *oidc.IDTokenVerifier
|
||||
allowedUsers []string
|
||||
isMiddleware bool
|
||||
oauthConfig *oauth2.Config
|
||||
oidcProvider *oidc.Provider
|
||||
oidcVerifier *oidc.IDTokenVerifier
|
||||
allowedUsers []string
|
||||
allowedGroups []string
|
||||
isMiddleware bool
|
||||
}
|
||||
|
||||
const CookieOauthState = "godoxy_oidc_state"
|
||||
|
@ -33,9 +35,9 @@ const (
|
|||
OIDCLogoutPath = "/auth/logout"
|
||||
)
|
||||
|
||||
func NewOIDCProvider(issuerURL, clientID, clientSecret, redirectURL string, allowedUsers []string) (*OIDCProvider, error) {
|
||||
if len(allowedUsers) == 0 {
|
||||
return nil, errors.New("OIDC allowed users must not be empty")
|
||||
func NewOIDCProvider(issuerURL, clientID, clientSecret, redirectURL string, allowedUsers, allowedGroups []string) (*OIDCProvider, error) {
|
||||
if len(allowedUsers)+len(allowedGroups) == 0 {
|
||||
return nil, errors.New("OIDC users, groups, or both must not be empty")
|
||||
}
|
||||
|
||||
provider, err := oidc.NewProvider(context.Background(), issuerURL)
|
||||
|
@ -55,7 +57,8 @@ func NewOIDCProvider(issuerURL, clientID, clientSecret, redirectURL string, allo
|
|||
oidcVerifier: provider.Verifier(&oidc.Config{
|
||||
ClientID: clientID,
|
||||
}),
|
||||
allowedUsers: allowedUsers,
|
||||
allowedUsers: allowedUsers,
|
||||
allowedGroups: allowedGroups,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -67,6 +70,7 @@ func NewOIDCProviderFromEnv() (*OIDCProvider, error) {
|
|||
common.OIDCClientSecret,
|
||||
common.OIDCRedirectURL,
|
||||
common.OIDCAllowedUsers,
|
||||
common.OIDCAllowedGroups,
|
||||
)
|
||||
}
|
||||
|
||||
|
@ -85,6 +89,10 @@ func (auth *OIDCProvider) SetAllowedUsers(users []string) {
|
|||
auth.allowedUsers = users
|
||||
}
|
||||
|
||||
func (auth *OIDCProvider) SetAllowedGroups(groups []string) {
|
||||
auth.allowedGroups = groups
|
||||
}
|
||||
|
||||
func (auth *OIDCProvider) CheckToken(r *http.Request) error {
|
||||
token, err := r.Cookie(auth.TokenCookieName())
|
||||
if err != nil {
|
||||
|
@ -94,7 +102,7 @@ func (auth *OIDCProvider) CheckToken(r *http.Request) error {
|
|||
// checks for Expiry, Audience == ClientID, Issuer, etc.
|
||||
idToken, err := auth.oidcVerifier.Verify(r.Context(), token.Value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to verify ID token: %w", err)
|
||||
return fmt.Errorf("failed to verify ID token: %w: %w", ErrInvalidToken, err)
|
||||
}
|
||||
|
||||
if len(idToken.Audience) == 0 {
|
||||
|
@ -102,14 +110,18 @@ func (auth *OIDCProvider) CheckToken(r *http.Request) error {
|
|||
}
|
||||
|
||||
var claims struct {
|
||||
Email string `json:"email"`
|
||||
Username string `json:"preferred_username"`
|
||||
Email string `json:"email"`
|
||||
Username string `json:"preferred_username"`
|
||||
Groups []string `json:"groups"`
|
||||
}
|
||||
if err := idToken.Claims(&claims); err != nil {
|
||||
return fmt.Errorf("failed to parse claims: %w", err)
|
||||
}
|
||||
|
||||
if !slices.Contains(auth.allowedUsers, claims.Username) {
|
||||
// Logical AND between allowed users and groups.
|
||||
allowedUser := slices.Contains(auth.allowedUsers, claims.Username)
|
||||
allowedGroup := len(CE.Intersect(claims.Groups, auth.allowedGroups)) > 0
|
||||
if !allowedUser && !allowedGroup {
|
||||
return ErrUserNotAllowed.Subject(claims.Username)
|
||||
}
|
||||
return nil
|
||||
|
|
|
@ -2,11 +2,17 @@ package auth
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"golang.org/x/oauth2"
|
||||
|
@ -34,7 +40,95 @@ func setupMockOIDC(t *testing.T) {
|
|||
oidcVerifier: provider.Verifier(&oidc.Config{
|
||||
ClientID: "test-client",
|
||||
}),
|
||||
allowedUsers: []string{"test-user"},
|
||||
allowedUsers: []string{"test-user"},
|
||||
allowedGroups: []string{"test-group1", "test-group2"},
|
||||
}
|
||||
}
|
||||
|
||||
// discoveryDocument returns a mock OIDC discovery document.
|
||||
func discoveryDocument(t *testing.T, server *httptest.Server) map[string]any {
|
||||
t.Helper()
|
||||
|
||||
discovery := map[string]any{
|
||||
"issuer": server.URL,
|
||||
"authorization_endpoint": server.URL + "/auth",
|
||||
"token_endpoint": server.URL + "/token",
|
||||
}
|
||||
|
||||
return discovery
|
||||
}
|
||||
|
||||
const (
|
||||
keyID = "test-key-id"
|
||||
clientID = "test-client-id"
|
||||
)
|
||||
|
||||
type provider struct {
|
||||
ts *httptest.Server
|
||||
key *rsa.PrivateKey
|
||||
verifier *oidc.IDTokenVerifier
|
||||
}
|
||||
|
||||
func (j *provider) SignClaims(t *testing.T, claims jwt.Claims) string {
|
||||
t.Helper()
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
|
||||
token.Header["kid"] = keyID
|
||||
signed, err := token.SignedString(j.key)
|
||||
ExpectNoError(t, err)
|
||||
return signed
|
||||
}
|
||||
|
||||
func setupProvider(t *testing.T) *provider {
|
||||
t.Helper()
|
||||
|
||||
// Generate an RSA key pair for the test.
|
||||
privKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
ExpectNoError(t, err)
|
||||
|
||||
// Build the matching public JWK that will be served by the endpoint.
|
||||
jwk := buildRSAJWK(t, &privKey.PublicKey, keyID)
|
||||
|
||||
// Start a test server that serves the JWKS endpoint.
|
||||
ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/.well-known/jwks.json":
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"keys": []any{jwk},
|
||||
})
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
t.Cleanup(ts.Close)
|
||||
|
||||
// Create a test OIDCProvider.
|
||||
providerCtx := oidc.ClientContext(context.Background(), ts.Client())
|
||||
keySet := oidc.NewRemoteKeySet(providerCtx, ts.URL+"/.well-known/jwks.json")
|
||||
|
||||
return &provider{
|
||||
ts: ts,
|
||||
key: privKey,
|
||||
verifier: oidc.NewVerifier(ts.URL, keySet, &oidc.Config{
|
||||
ClientID: clientID, // matches audience in the token
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
// buildRSAJWK is a helper to construct a minimal JWK for the JWKS endpoint
|
||||
func buildRSAJWK(t *testing.T, pub *rsa.PublicKey, kid string) map[string]any {
|
||||
t.Helper()
|
||||
|
||||
nBytes := pub.N.Bytes()
|
||||
eBytes := []byte{0x01, 0x00, 0x01} // Usually 65537
|
||||
|
||||
return map[string]any{
|
||||
"kty": "RSA",
|
||||
"alg": "RS256",
|
||||
"use": "sig",
|
||||
"kid": kid,
|
||||
"n": base64.RawURLEncoding.EncodeToString(nBytes),
|
||||
"e": base64.RawURLEncoding.EncodeToString(eBytes),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -145,14 +239,27 @@ func TestOIDCCallbackHandler(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestInitOIDC(t *testing.T) {
|
||||
setupMockOIDC(t)
|
||||
// Create a test server that serves the discovery document
|
||||
var server *httptest.Server
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
ExpectNoError(t, json.NewEncoder(w).Encode(discoveryDocument(t, server)))
|
||||
})
|
||||
server = httptest.NewServer(mux)
|
||||
t.Cleanup(server.Close)
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
issuerURL string
|
||||
clientID string
|
||||
clientSecret string
|
||||
redirectURL string
|
||||
allowedUsers []string
|
||||
wantErr bool
|
||||
name string
|
||||
issuerURL string
|
||||
clientID string
|
||||
clientSecret string
|
||||
redirectURL string
|
||||
allowedUsers []string
|
||||
allowedGroups []string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "Fail - Empty configuration",
|
||||
|
@ -163,33 +270,179 @@ func TestInitOIDC(t *testing.T) {
|
|||
allowedUsers: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
// {
|
||||
// name: "Success - Valid configuration",
|
||||
// issuerURL: "https://example.com",
|
||||
// clientID: "client_id",
|
||||
// clientSecret: "client_secret",
|
||||
// redirectURL: "https://example.com/callback",
|
||||
// allowedUsers: []string{"user1", "user2"},
|
||||
// wantErr: false,
|
||||
// },
|
||||
{
|
||||
name: "Fail - No allowed users",
|
||||
name: "Success - Valid configuration with users",
|
||||
issuerURL: server.URL,
|
||||
clientID: "client_id",
|
||||
clientSecret: "client_secret",
|
||||
redirectURL: "https://example.com/callback",
|
||||
allowedUsers: []string{"user1", "user2"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Success - Valid configuration with groups",
|
||||
issuerURL: server.URL,
|
||||
clientID: "client_id",
|
||||
clientSecret: "client_secret",
|
||||
redirectURL: "https://example.com/callback",
|
||||
allowedGroups: []string{"group1", "group2"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Fail - No allowed users or allowed groups",
|
||||
issuerURL: "https://example.com",
|
||||
clientID: "client_id",
|
||||
clientSecret: "client_secret",
|
||||
redirectURL: "https://example.com/callback",
|
||||
allowedUsers: []string{},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Cleanup(cleanup)
|
||||
_, err := NewOIDCProvider(tt.issuerURL, tt.clientID, tt.clientSecret, tt.redirectURL, tt.allowedUsers)
|
||||
_, err := NewOIDCProvider(tt.issuerURL, tt.clientID, tt.clientSecret, tt.redirectURL, tt.allowedUsers, tt.allowedGroups)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("InitOIDC() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckToken(t *testing.T) {
|
||||
provider := setupProvider(t)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
allowedUsers []string
|
||||
allowedGroups []string
|
||||
claims jwt.Claims
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "Success - Valid token with allowed user",
|
||||
allowedUsers: []string{"user1"},
|
||||
claims: jwt.MapClaims{
|
||||
"iss": provider.ts.URL,
|
||||
"aud": clientID,
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"preferred_username": "user1",
|
||||
"groups": []string{"group1"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Success - Valid token with allowed group",
|
||||
allowedGroups: []string{"group1"},
|
||||
claims: jwt.MapClaims{
|
||||
"iss": provider.ts.URL,
|
||||
"aud": clientID,
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"preferred_username": "user1",
|
||||
"groups": []string{"group1"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Success - Server omits groups, but user is allowed",
|
||||
allowedUsers: []string{"user1"},
|
||||
claims: jwt.MapClaims{
|
||||
"iss": provider.ts.URL,
|
||||
"aud": clientID,
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"preferred_username": "user1",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Success - Server omits preferred_username, but group is allowed",
|
||||
allowedGroups: []string{"group1"},
|
||||
claims: jwt.MapClaims{
|
||||
"iss": provider.ts.URL,
|
||||
"aud": clientID,
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"groups": []string{"group1"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Success - Valid token with allowed user and group",
|
||||
allowedUsers: []string{"user1"},
|
||||
allowedGroups: []string{"group1"},
|
||||
claims: jwt.MapClaims{
|
||||
"iss": provider.ts.URL,
|
||||
"aud": clientID,
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"preferred_username": "user1",
|
||||
"groups": []string{"group1"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Error - User not allowed",
|
||||
allowedUsers: []string{"user2", "user3"},
|
||||
allowedGroups: []string{"group2", "group3"},
|
||||
claims: jwt.MapClaims{
|
||||
"iss": provider.ts.URL,
|
||||
"aud": clientID,
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"preferred_username": "user1",
|
||||
"groups": []string{"group1"},
|
||||
},
|
||||
wantErr: ErrUserNotAllowed,
|
||||
},
|
||||
{
|
||||
name: "Error - Server returns incorrect issuer",
|
||||
claims: jwt.MapClaims{
|
||||
"iss": "https://example.com",
|
||||
"aud": clientID,
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"preferred_username": "user1",
|
||||
"groups": []string{"group1"},
|
||||
},
|
||||
wantErr: ErrInvalidToken,
|
||||
},
|
||||
{
|
||||
name: "Error - Server returns incorrect audience",
|
||||
claims: jwt.MapClaims{
|
||||
"iss": provider.ts.URL,
|
||||
"aud": "some-other-audience",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"preferred_username": "user1",
|
||||
"groups": []string{"group1"},
|
||||
},
|
||||
wantErr: ErrInvalidToken,
|
||||
},
|
||||
{
|
||||
name: "Error - Server returns expired token",
|
||||
claims: jwt.MapClaims{
|
||||
"iss": provider.ts.URL,
|
||||
"aud": clientID,
|
||||
"exp": time.Now().Add(-time.Hour).Unix(),
|
||||
"preferred_username": "user1",
|
||||
"groups": []string{"group1"},
|
||||
},
|
||||
wantErr: ErrInvalidToken,
|
||||
},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Create the Auth Provider.
|
||||
auth := &OIDCProvider{
|
||||
oidcVerifier: provider.verifier,
|
||||
allowedUsers: tc.allowedUsers,
|
||||
allowedGroups: tc.allowedGroups,
|
||||
}
|
||||
// Sign the claims to create a token.
|
||||
signedToken := provider.SignClaims(t, tc.claims)
|
||||
// Craft a test HTTP request that includes the token as a cookie.
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.AddCookie(&http.Cookie{
|
||||
Name: auth.TokenCookieName(),
|
||||
Value: signedToken,
|
||||
})
|
||||
|
||||
// Call CheckToken and verify the result.
|
||||
err := auth.CheckToken(req)
|
||||
if tc.wantErr == nil {
|
||||
ExpectNoError(t, err)
|
||||
} else {
|
||||
ExpectError(t, tc.wantErr, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -47,12 +47,13 @@ var (
|
|||
APIPassword = GetEnvString("API_PASSWORD", "password")
|
||||
|
||||
// OIDC Configuration.
|
||||
OIDCIssuerURL = GetEnvString("OIDC_ISSUER_URL", "")
|
||||
OIDCClientID = GetEnvString("OIDC_CLIENT_ID", "")
|
||||
OIDCClientSecret = GetEnvString("OIDC_CLIENT_SECRET", "")
|
||||
OIDCRedirectURL = GetEnvString("OIDC_REDIRECT_URL", "")
|
||||
OIDCScopes = GetEnvString("OIDC_SCOPES", "openid, profile, email")
|
||||
OIDCAllowedUsers = GetCommaSepEnv("OIDC_ALLOWED_USERS", "")
|
||||
OIDCIssuerURL = GetEnvString("OIDC_ISSUER_URL", "")
|
||||
OIDCClientID = GetEnvString("OIDC_CLIENT_ID", "")
|
||||
OIDCClientSecret = GetEnvString("OIDC_CLIENT_SECRET", "")
|
||||
OIDCRedirectURL = GetEnvString("OIDC_REDIRECT_URL", "")
|
||||
OIDCScopes = GetEnvString("OIDC_SCOPES", "openid, profile, email")
|
||||
OIDCAllowedUsers = GetCommaSepEnv("OIDC_ALLOWED_USERS", "")
|
||||
OIDCAllowedGroups = GetCommaSepEnv("OIDC_ALLOWED_GROUPS", "")
|
||||
)
|
||||
|
||||
func GetEnv[T any](key string, defaultValue T, parser func(string) (T, error)) T {
|
||||
|
|
|
@ -8,7 +8,8 @@ import (
|
|||
)
|
||||
|
||||
type oidcMiddleware struct {
|
||||
AllowedUsers []string `json:"allowed_users"`
|
||||
AllowedUsers []string `json:"allowed_users"`
|
||||
AllowedGroups []string `json:"allowed_groups"`
|
||||
|
||||
auth auth.Provider
|
||||
authMux *http.ServeMux
|
||||
|
@ -30,6 +31,9 @@ func (amw *oidcMiddleware) finalize() error {
|
|||
if len(amw.AllowedUsers) > 0 {
|
||||
authProvider.SetAllowedUsers(amw.AllowedUsers)
|
||||
}
|
||||
if len(amw.AllowedGroups) > 0 {
|
||||
authProvider.SetAllowedGroups(amw.AllowedGroups)
|
||||
}
|
||||
|
||||
amw.authMux = http.NewServeMux()
|
||||
amw.authMux.HandleFunc(auth.OIDCMiddlewareCallbackPath, authProvider.LoginCallbackHandler)
|
||||
|
|
20
internal/utils/slices.go
Normal file
20
internal/utils/slices.go
Normal file
|
@ -0,0 +1,20 @@
|
|||
package utils
|
||||
|
||||
// Intersect returns a new slice containing the elements that are present in both input slices.
|
||||
// This provides a more efficient solution than using two nested loops.
|
||||
func Intersect[T comparable, Slice ~[]T](slice1 Slice, slice2 Slice) Slice {
|
||||
var result Slice
|
||||
seen := map[T]struct{}{}
|
||||
|
||||
for i := range slice1 {
|
||||
seen[slice1[i]] = struct{}{}
|
||||
}
|
||||
|
||||
for i := range slice2 {
|
||||
if _, ok := seen[slice2[i]]; ok {
|
||||
result = append(result, slice2[i])
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
96
internal/utils/slices_test.go
Normal file
96
internal/utils/slices_test.go
Normal file
|
@ -0,0 +1,96 @@
|
|||
package utils
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
utils "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
||||
func TestIntersect(t *testing.T) {
|
||||
t.Run("strings", func(t *testing.T) {
|
||||
t.Run("no intersection", func(t *testing.T) {
|
||||
var (
|
||||
slice1 = []string{"a", "b", "c"}
|
||||
slice2 = []string{"d", "e", "f"}
|
||||
want []string
|
||||
)
|
||||
result := Intersect(slice1, slice2)
|
||||
slices.Sort(result)
|
||||
slices.Sort(want)
|
||||
utils.ExpectDeepEqual(t, result, want)
|
||||
})
|
||||
t.Run("intersection", func(t *testing.T) {
|
||||
var (
|
||||
slice1 = []string{"a", "b", "c"}
|
||||
slice2 = []string{"b", "c", "d"}
|
||||
want = []string{"b", "c"}
|
||||
)
|
||||
result := Intersect(slice1, slice2)
|
||||
slices.Sort(result)
|
||||
slices.Sort(want)
|
||||
utils.ExpectDeepEqual(t, result, want)
|
||||
})
|
||||
})
|
||||
t.Run("ints", func(t *testing.T) {
|
||||
t.Run("no intersection", func(t *testing.T) {
|
||||
var (
|
||||
slice1 = []int{1, 2, 3}
|
||||
slice2 = []int{4, 5, 6}
|
||||
want []int
|
||||
)
|
||||
result := Intersect(slice1, slice2)
|
||||
slices.Sort(result)
|
||||
slices.Sort(want)
|
||||
utils.ExpectDeepEqual(t, result, want)
|
||||
})
|
||||
t.Run("intersection", func(t *testing.T) {
|
||||
var (
|
||||
slice1 = []int{1, 2, 3}
|
||||
slice2 = []int{2, 3, 4}
|
||||
want = []int{2, 3}
|
||||
)
|
||||
result := Intersect(slice1, slice2)
|
||||
slices.Sort(result)
|
||||
slices.Sort(want)
|
||||
utils.ExpectDeepEqual(t, result, want)
|
||||
})
|
||||
})
|
||||
t.Run("complex", func(t *testing.T) {
|
||||
type T struct {
|
||||
A string
|
||||
B int
|
||||
}
|
||||
t.Run("no intersection", func(t *testing.T) {
|
||||
var (
|
||||
slice1 = []T{{"a", 1}, {"b", 2}, {"c", 3}}
|
||||
slice2 = []T{{"d", 4}, {"e", 5}, {"f", 6}}
|
||||
want []T
|
||||
)
|
||||
result := Intersect(slice1, slice2)
|
||||
slices.SortFunc(result, func(i T, j T) int {
|
||||
return strings.Compare(i.A, j.A)
|
||||
})
|
||||
slices.SortFunc(want, func(i T, j T) int {
|
||||
return strings.Compare(i.A, j.A)
|
||||
})
|
||||
utils.ExpectDeepEqual(t, result, want)
|
||||
})
|
||||
t.Run("intersection", func(t *testing.T) {
|
||||
var (
|
||||
slice1 = []T{{"a", 1}, {"b", 2}, {"c", 3}}
|
||||
slice2 = []T{{"b", 2}, {"c", 3}, {"d", 4}}
|
||||
want = []T{{"b", 2}, {"c", 3}}
|
||||
)
|
||||
result := Intersect(slice1, slice2)
|
||||
slices.SortFunc(result, func(i T, j T) int {
|
||||
return strings.Compare(i.A, j.A)
|
||||
})
|
||||
slices.SortFunc(want, func(i T, j T) int {
|
||||
return strings.Compare(i.A, j.A)
|
||||
})
|
||||
utils.ExpectDeepEqual(t, result, want)
|
||||
})
|
||||
})
|
||||
}
|
Loading…
Add table
Reference in a new issue