package auth

import (
	"context"
	"crypto/rand"
	"encoding/base64"
	"errors"
	"fmt"
	"net/http"
	"net/url"
	"slices"
	"time"

	"github.com/coreos/go-oidc/v3/oidc"
	"github.com/yusing/go-proxy/internal/common"
	"github.com/yusing/go-proxy/internal/net/gphttp"
	CE "github.com/yusing/go-proxy/internal/utils"
	"github.com/yusing/go-proxy/internal/utils/strutils"
	"golang.org/x/oauth2"
)

type OIDCProvider struct {
	oauthConfig   *oauth2.Config
	oidcProvider  *oidc.Provider
	oidcVerifier  *oidc.IDTokenVerifier
	oidcLogoutURL *url.URL
	allowedUsers  []string
	allowedGroups []string
	isMiddleware  bool
}

const CookieOauthState = "godoxy_oidc_state"

const (
	OIDCMiddlewareCallbackPath = "/auth/callback"
	OIDCLogoutPath             = "/auth/logout"
)

func NewOIDCProvider(issuerURL, clientID, clientSecret, redirectURL, logoutURL string, allowedUsers, allowedGroups []string) (*OIDCProvider, error) {
	if len(allowedUsers)+len(allowedGroups) == 0 {
		return nil, errors.New("OIDC users, groups, or both must not be empty")
	}

	var logout *url.URL
	var err error
	if logoutURL != "" {
		logout, err = url.Parse(logoutURL)
		if err != nil {
			return nil, fmt.Errorf("failed to parse logout URL: %w", err)
		}
	}

	provider, err := oidc.NewProvider(context.Background(), issuerURL)
	if err != nil {
		return nil, fmt.Errorf("failed to initialize OIDC provider: %w", err)
	}

	return &OIDCProvider{
		oauthConfig: &oauth2.Config{
			ClientID:     clientID,
			ClientSecret: clientSecret,
			RedirectURL:  redirectURL,
			Endpoint:     provider.Endpoint(),
			Scopes:       strutils.CommaSeperatedList(common.OIDCScopes),
		},
		oidcProvider: provider,
		oidcVerifier: provider.Verifier(&oidc.Config{
			ClientID: clientID,
		}),
		oidcLogoutURL: logout,
		allowedUsers:  allowedUsers,
		allowedGroups: allowedGroups,
	}, nil
}

// NewOIDCProviderFromEnv creates a new OIDCProvider from environment variables.
func NewOIDCProviderFromEnv() (*OIDCProvider, error) {
	return NewOIDCProvider(
		common.OIDCIssuerURL,
		common.OIDCClientID,
		common.OIDCClientSecret,
		common.OIDCRedirectURL,
		common.OIDCLogoutURL,
		common.OIDCAllowedUsers,
		common.OIDCAllowedGroups,
	)
}

func (auth *OIDCProvider) TokenCookieName() string {
	return "godoxy_oidc_token"
}

func (auth *OIDCProvider) SetIsMiddleware(enabled bool) {
	auth.isMiddleware = enabled
	auth.oauthConfig.RedirectURL = ""
}

func (auth *OIDCProvider) SetAllowedUsers(users []string) {
	auth.allowedUsers = users
}

func (auth *OIDCProvider) SetAllowedGroups(groups []string) {
	auth.allowedGroups = groups
}

func (auth *OIDCProvider) CheckToken(r *http.Request) error {
	token, err := r.Cookie(auth.TokenCookieName())
	if err != nil {
		return ErrMissingToken
	}

	// checks for Expiry, Audience == ClientID, Issuer, etc.
	idToken, err := auth.oidcVerifier.Verify(r.Context(), token.Value)
	if err != nil {
		return fmt.Errorf("failed to verify ID token: %w: %w", ErrInvalidToken, err)
	}

	if len(idToken.Audience) == 0 {
		return ErrInvalidToken
	}

	var claims struct {
		Email    string   `json:"email"`
		Username string   `json:"preferred_username"`
		Groups   []string `json:"groups"`
	}
	if err := idToken.Claims(&claims); err != nil {
		return fmt.Errorf("failed to parse claims: %w", err)
	}

	// Logical AND between allowed users and groups.
	allowedUser := slices.Contains(auth.allowedUsers, claims.Username)
	allowedGroup := len(CE.Intersect(claims.Groups, auth.allowedGroups)) > 0
	if !allowedUser && !allowedGroup {
		return ErrUserNotAllowed.Subject(claims.Username)
	}
	return nil
}

// generateState generates a random string for OIDC state.
const oidcStateLength = 32

func generateState() (string, error) {
	b := make([]byte, oidcStateLength)
	_, err := rand.Read(b)
	if err != nil {
		return "", err
	}
	return base64.URLEncoding.EncodeToString(b)[:oidcStateLength], nil
}

// RedirectOIDC initiates the OIDC login flow.
func (auth *OIDCProvider) RedirectLoginPage(w http.ResponseWriter, r *http.Request) {
	state, err := generateState()
	if err != nil {
		gphttp.ServerError(w, r, err)
		return
	}
	http.SetCookie(w, &http.Cookie{
		Name:     CookieOauthState,
		Value:    state,
		MaxAge:   300,
		HttpOnly: true,
		SameSite: http.SameSiteLaxMode,
		Secure:   r.TLS != nil,
		Path:     "/",
	})

	redirURL := auth.oauthConfig.AuthCodeURL(state)
	if auth.isMiddleware {
		u, err := r.URL.Parse(redirURL)
		if err != nil {
			gphttp.ServerError(w, r, err)
			return
		}
		q := u.Query()
		q.Set("redirect_uri", "https://"+r.Host+OIDCMiddlewareCallbackPath+q.Get("redirect_uri"))
		u.RawQuery = q.Encode()
		redirURL = u.String()
	}
	http.Redirect(w, r, redirURL, http.StatusTemporaryRedirect)
}

func (auth *OIDCProvider) exchange(r *http.Request) (*oauth2.Token, error) {
	if auth.isMiddleware {
		cfg := *auth.oauthConfig
		cfg.RedirectURL = "https://" + r.Host + OIDCMiddlewareCallbackPath
		return cfg.Exchange(r.Context(), r.URL.Query().Get("code"))
	}
	return auth.oauthConfig.Exchange(r.Context(), r.URL.Query().Get("code"))
}

// OIDCCallbackHandler handles the OIDC callback.
func (auth *OIDCProvider) LoginCallbackHandler(w http.ResponseWriter, r *http.Request) {
	// For testing purposes, skip provider verification
	if common.IsTest {
		auth.handleTestCallback(w, r)
		return
	}

	state, err := r.Cookie(CookieOauthState)
	if err != nil {
		gphttp.BadRequest(w, "missing state cookie")
		return
	}

	query := r.URL.Query()
	if query.Get("state") != state.Value {
		gphttp.BadRequest(w, "invalid oauth state")
		return
	}

	oauth2Token, err := auth.exchange(r)
	if err != nil {
		gphttp.ServerError(w, r, fmt.Errorf("failed to exchange token: %w", err))
		return
	}

	rawIDToken, ok := oauth2Token.Extra("id_token").(string)
	if !ok {
		gphttp.BadRequest(w, "missing id_token")
		return
	}

	idToken, err := auth.oidcVerifier.Verify(r.Context(), rawIDToken)
	if err != nil {
		gphttp.ServerError(w, r, fmt.Errorf("failed to verify ID token: %w", err))
		return
	}

	setTokenCookie(w, r, auth.TokenCookieName(), rawIDToken, time.Until(idToken.Expiry))

	// Redirect to home page
	http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
}

func (auth *OIDCProvider) LogoutCallbackHandler(w http.ResponseWriter, r *http.Request) {
	if auth.oidcLogoutURL == nil {
		DefaultLogoutCallbackHandler(auth, w, r)
		return
	}

	token, err := r.Cookie(auth.TokenCookieName())
	if err != nil {
		gphttp.BadRequest(w, "missing token cookie")
		return
	}
	clearTokenCookie(w, r, auth.TokenCookieName())

	logoutURL := *auth.oidcLogoutURL
	logoutURL.Query().Add("id_token_hint", token.Value)

	http.Redirect(w, r, logoutURL.String(), http.StatusFound)
}

// handleTestCallback handles OIDC callback in test environment.
func (auth *OIDCProvider) handleTestCallback(w http.ResponseWriter, r *http.Request) {
	state, err := r.Cookie(CookieOauthState)
	if err != nil {
		gphttp.BadRequest(w, "missing state cookie")
		return
	}

	if r.URL.Query().Get("state") != state.Value {
		gphttp.BadRequest(w, "invalid oauth state")
		return
	}

	// Create test JWT token
	setTokenCookie(w, r, auth.TokenCookieName(), "test", time.Hour)

	http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
}