package auth

import (
	"context"
	"crypto/rand"
	"crypto/rsa"
	"encoding/base64"
	"encoding/json"
	"net/http"
	"net/http/httptest"
	"testing"
	"time"

	"github.com/coreos/go-oidc/v3/oidc"
	"github.com/golang-jwt/jwt/v5"
	"github.com/yusing/go-proxy/internal/common"
	"golang.org/x/oauth2"

	. "github.com/yusing/go-proxy/internal/utils/testing"
)

// setupMockOIDC configures mock OIDC provider for testing.
func setupMockOIDC(t *testing.T) {
	t.Helper()

	provider := (&oidc.ProviderConfig{}).NewProvider(context.TODO())
	defaultAuth = &OIDCProvider{
		oauthConfig: &oauth2.Config{
			ClientID:     "test-client",
			ClientSecret: "test-secret",
			RedirectURL:  "http://localhost/callback",
			Endpoint: oauth2.Endpoint{
				AuthURL:  "http://mock-provider/auth",
				TokenURL: "http://mock-provider/token",
			},
			Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
		},
		oidcProvider: provider,
		oidcVerifier: provider.Verifier(&oidc.Config{
			ClientID: "test-client",
		}),
		allowedUsers:  []string{"test-user"},
		allowedGroups: []string{"test-group1", "test-group2"},
	}
}

// discoveryDocument returns a mock OIDC discovery document.
func discoveryDocument(t *testing.T, server *httptest.Server) map[string]any {
	t.Helper()

	discovery := map[string]any{
		"issuer":                 server.URL,
		"authorization_endpoint": server.URL + "/auth",
		"token_endpoint":         server.URL + "/token",
	}

	return discovery
}

const (
	keyID    = "test-key-id"
	clientID = "test-client-id"
)

type provider struct {
	ts       *httptest.Server
	key      *rsa.PrivateKey
	verifier *oidc.IDTokenVerifier
}

func (j *provider) SignClaims(t *testing.T, claims jwt.Claims) string {
	t.Helper()

	token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
	token.Header["kid"] = keyID
	signed, err := token.SignedString(j.key)
	ExpectNoError(t, err)
	return signed
}

func setupProvider(t *testing.T) *provider {
	t.Helper()

	// Generate an RSA key pair for the test.
	privKey, err := rsa.GenerateKey(rand.Reader, 2048)
	ExpectNoError(t, err)

	// Build the matching public JWK that will be served by the endpoint.
	jwk := buildRSAJWK(t, &privKey.PublicKey, keyID)

	// Start a test server that serves the JWKS endpoint.
	ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		switch r.URL.Path {
		case "/.well-known/jwks.json":
			_ = json.NewEncoder(w).Encode(map[string]any{
				"keys": []any{jwk},
			})
		default:
			http.NotFound(w, r)
		}
	}))
	t.Cleanup(ts.Close)

	// Create a test OIDCProvider.
	providerCtx := oidc.ClientContext(context.Background(), ts.Client())
	keySet := oidc.NewRemoteKeySet(providerCtx, ts.URL+"/.well-known/jwks.json")

	return &provider{
		ts:  ts,
		key: privKey,
		verifier: oidc.NewVerifier(ts.URL, keySet, &oidc.Config{
			ClientID: clientID, // matches audience in the token
		}),
	}
}

// buildRSAJWK is a helper to construct a minimal JWK for the JWKS endpoint.
func buildRSAJWK(t *testing.T, pub *rsa.PublicKey, kid string) map[string]any {
	t.Helper()

	nBytes := pub.N.Bytes()
	eBytes := []byte{0x01, 0x00, 0x01} // Usually 65537

	return map[string]any{
		"kty": "RSA",
		"alg": "RS256",
		"use": "sig",
		"kid": kid,
		"n":   base64.RawURLEncoding.EncodeToString(nBytes),
		"e":   base64.RawURLEncoding.EncodeToString(eBytes),
	}
}

func cleanup() {
	defaultAuth = nil
}

func TestOIDCLoginHandler(t *testing.T) {
	// Setup
	common.APIJWTSecret = []byte("test-secret")
	t.Cleanup(cleanup)
	setupMockOIDC(t)

	tests := []struct {
		name         string
		wantStatus   int
		wantRedirect bool
	}{
		{
			name:         "Success - Redirects to provider",
			wantStatus:   http.StatusTemporaryRedirect,
			wantRedirect: true,
		},
	}

	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			req := httptest.NewRequest(http.MethodGet, "/auth/redirect", nil)
			w := httptest.NewRecorder()

			defaultAuth.RedirectLoginPage(w, req)

			if got := w.Code; got != tt.wantStatus {
				t.Errorf("OIDCLoginHandler() status = %v, want %v", got, tt.wantStatus)
			}

			if tt.wantRedirect {
				if loc := w.Header().Get("Location"); loc == "" {
					t.Error("OIDCLoginHandler() missing redirect location")
				}

				cookie := w.Header().Get("Set-Cookie")
				if cookie == "" {
					t.Error("OIDCLoginHandler() missing state cookie")
				}
			}
		})
	}
}

func TestOIDCCallbackHandler(t *testing.T) {
	// Setup
	common.APIJWTSecret = []byte("test-secret")
	t.Cleanup(cleanup)
	tests := []struct {
		name       string
		state      string
		code       string
		setupMocks bool
		wantStatus int
	}{
		{
			name:       "Success - Valid callback",
			state:      "valid-state",
			code:       "valid-code",
			setupMocks: true,
			wantStatus: http.StatusTemporaryRedirect,
		},
		{
			name:       "Failure - Missing state",
			code:       "valid-code",
			setupMocks: true,
			wantStatus: http.StatusBadRequest,
		},
	}

	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			if tt.setupMocks {
				setupMockOIDC(t)
			}

			req := httptest.NewRequest(http.MethodGet, "/auth/callback?code="+tt.code+"&state="+tt.state, nil)
			if tt.state != "" {
				req.AddCookie(&http.Cookie{
					Name:  CookieOauthState,
					Value: tt.state,
				})
			}
			w := httptest.NewRecorder()

			defaultAuth.LoginCallbackHandler(w, req)

			if got := w.Code; got != tt.wantStatus {
				t.Errorf("OIDCCallbackHandler() status = %v, want %v", got, tt.wantStatus)
			}

			if tt.wantStatus == http.StatusTemporaryRedirect {
				setCookie := Must(http.ParseSetCookie(w.Header().Get("Set-Cookie")))
				ExpectEqual(t, setCookie.Name, defaultAuth.TokenCookieName())
				ExpectTrue(t, setCookie.Value != "")
				ExpectEqual(t, setCookie.Path, "/")
				ExpectEqual(t, setCookie.SameSite, http.SameSiteLaxMode)
				ExpectEqual(t, setCookie.HttpOnly, true)
			}
		})
	}
}

func TestInitOIDC(t *testing.T) {
	setupMockOIDC(t)
	// Create a test server that serves the discovery document
	var server *httptest.Server
	mux := http.NewServeMux()
	mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
		w.Header().Set("Content-Type", "application/json")
		ExpectNoError(t, json.NewEncoder(w).Encode(discoveryDocument(t, server)))
	})
	server = httptest.NewServer(mux)
	t.Cleanup(server.Close)
	t.Cleanup(cleanup)

	tests := []struct {
		name          string
		issuerURL     string
		clientID      string
		clientSecret  string
		redirectURL   string
		logoutURL     string
		allowedUsers  []string
		allowedGroups []string
		wantErr       bool
	}{
		{
			name:    "Fail - Empty configuration",
			wantErr: true,
		},
		{
			name:         "Success - Valid configuration with users",
			issuerURL:    server.URL,
			clientID:     "client_id",
			clientSecret: "client_secret",
			redirectURL:  "https://example.com/callback",
			allowedUsers: []string{"user1", "user2"},
			wantErr:      false,
		},
		{
			name:          "Success - Valid configuration with groups",
			issuerURL:     server.URL,
			clientID:      "client_id",
			clientSecret:  "client_secret",
			redirectURL:   "https://example.com/callback",
			allowedGroups: []string{"group1", "group2"},
			wantErr:       false,
		},
		{
			name:          "Success - Valid configuration with users, groups and logout URL",
			issuerURL:     server.URL,
			clientID:      "client_id",
			clientSecret:  "client_secret",
			redirectURL:   "https://example.com/callback",
			logoutURL:     "https://example.com/logout",
			allowedUsers:  []string{"user1", "user2"},
			allowedGroups: []string{"group1", "group2"},
			wantErr:       false,
		},
		{
			name:         "Fail - No allowed users or allowed groups",
			issuerURL:    "https://example.com",
			clientID:     "client_id",
			clientSecret: "client_secret",
			redirectURL:  "https://example.com/callback",
			wantErr:      true,
		},
	}

	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			_, err := NewOIDCProvider(tt.issuerURL, tt.clientID, tt.clientSecret, tt.redirectURL, tt.logoutURL, tt.allowedUsers, tt.allowedGroups)
			if (err != nil) != tt.wantErr {
				t.Errorf("InitOIDC() error = %v, wantErr %v", err, tt.wantErr)
			}
		})
	}
}

func TestCheckToken(t *testing.T) {
	provider := setupProvider(t)

	tests := []struct {
		name          string
		allowedUsers  []string
		allowedGroups []string
		claims        jwt.Claims
		wantErr       error
	}{
		{
			name:         "Success - Valid token with allowed user",
			allowedUsers: []string{"user1"},
			claims: jwt.MapClaims{
				"iss":                provider.ts.URL,
				"aud":                clientID,
				"exp":                time.Now().Add(time.Hour).Unix(),
				"preferred_username": "user1",
				"groups":             []string{"group1"},
			},
		},
		{
			name:          "Success - Valid token with allowed group",
			allowedGroups: []string{"group1"},
			claims: jwt.MapClaims{
				"iss":                provider.ts.URL,
				"aud":                clientID,
				"exp":                time.Now().Add(time.Hour).Unix(),
				"preferred_username": "user1",
				"groups":             []string{"group1"},
			},
		},
		{
			name:         "Success - Server omits groups, but user is allowed",
			allowedUsers: []string{"user1"},
			claims: jwt.MapClaims{
				"iss":                provider.ts.URL,
				"aud":                clientID,
				"exp":                time.Now().Add(time.Hour).Unix(),
				"preferred_username": "user1",
			},
		},
		{
			name:          "Success - Server omits preferred_username, but group is allowed",
			allowedGroups: []string{"group1"},
			claims: jwt.MapClaims{
				"iss":    provider.ts.URL,
				"aud":    clientID,
				"exp":    time.Now().Add(time.Hour).Unix(),
				"groups": []string{"group1"},
			},
		},
		{
			name:          "Success - Valid token with allowed user and group",
			allowedUsers:  []string{"user1"},
			allowedGroups: []string{"group1"},
			claims: jwt.MapClaims{
				"iss":                provider.ts.URL,
				"aud":                clientID,
				"exp":                time.Now().Add(time.Hour).Unix(),
				"preferred_username": "user1",
				"groups":             []string{"group1"},
			},
		},
		{
			name:          "Error - User not allowed",
			allowedUsers:  []string{"user2", "user3"},
			allowedGroups: []string{"group2", "group3"},
			claims: jwt.MapClaims{
				"iss":                provider.ts.URL,
				"aud":                clientID,
				"exp":                time.Now().Add(time.Hour).Unix(),
				"preferred_username": "user1",
				"groups":             []string{"group1"},
			},
			wantErr: ErrUserNotAllowed,
		},
		{
			name: "Error - Server returns incorrect issuer",
			claims: jwt.MapClaims{
				"iss":                "https://example.com",
				"aud":                clientID,
				"exp":                time.Now().Add(time.Hour).Unix(),
				"preferred_username": "user1",
				"groups":             []string{"group1"},
			},
			wantErr: ErrInvalidToken,
		},
		{
			name: "Error - Server returns incorrect audience",
			claims: jwt.MapClaims{
				"iss":                provider.ts.URL,
				"aud":                "some-other-audience",
				"exp":                time.Now().Add(time.Hour).Unix(),
				"preferred_username": "user1",
				"groups":             []string{"group1"},
			},
			wantErr: ErrInvalidToken,
		},
		{
			name: "Error - Server returns expired token",
			claims: jwt.MapClaims{
				"iss":                provider.ts.URL,
				"aud":                clientID,
				"exp":                time.Now().Add(-time.Hour).Unix(),
				"preferred_username": "user1",
				"groups":             []string{"group1"},
			},
			wantErr: ErrInvalidToken,
		},
	}
	for _, tc := range tests {
		t.Run(tc.name, func(t *testing.T) {
			// Create the Auth Provider.
			auth := &OIDCProvider{
				oidcVerifier:  provider.verifier,
				allowedUsers:  tc.allowedUsers,
				allowedGroups: tc.allowedGroups,
			}
			// Sign the claims to create a token.
			signedToken := provider.SignClaims(t, tc.claims)
			// Craft a test HTTP request that includes the token as a cookie.
			req := httptest.NewRequest(http.MethodGet, "/", nil)
			req.AddCookie(&http.Cookie{
				Name:  auth.TokenCookieName(),
				Value: signedToken,
			})

			// Call CheckToken and verify the result.
			err := auth.CheckToken(req)
			if tc.wantErr == nil {
				ExpectNoError(t, err)
			} else {
				ExpectError(t, tc.wantErr, err)
			}
		})
	}
}