package auth

import (
	"context"
	"crypto/rand"
	"encoding/base64"
	"errors"
	"fmt"
	"net/http"
	"net/url"
	"slices"
	"time"

	"github.com/coreos/go-oidc/v3/oidc"
	"github.com/yusing/go-proxy/internal/common"
	"github.com/yusing/go-proxy/internal/logging"
	"github.com/yusing/go-proxy/internal/net/gphttp"
	"github.com/yusing/go-proxy/internal/utils"
	"github.com/yusing/go-proxy/internal/utils/strutils"
	"golang.org/x/oauth2"
)

type (
	OIDCProvider struct {
		oauthConfig   *oauth2.Config
		oidcProvider  *oidc.Provider
		oidcVerifier  *oidc.IDTokenVerifier
		endSessionURL *url.URL
		allowedUsers  []string
		allowedGroups []string
	}

	IDTokenClaims struct {
		Username string   `json:"preferred_username"`
		Groups   []string `json:"groups"`
	}
)

const (
	CookieOauthState        = "godoxy_oidc_state"
	CookieOauthSessionID    = "godoxy_session_id"
	CookieOauthToken        = "godoxy_oauth_token"
	CookieOauthSessionToken = "godoxy_session_token"
)

const (
	OIDCAuthInitPath = "/auth/init"
	OIDCPostAuthPath = "/auth/postauth"
	OIDCLogoutPath   = "/auth/logout"
)

var errMissingIDToken = errors.New("missing id_token field from oauth token")

// generateState generates a random string for OIDC state.
const oidcStateLength = 32

func generateState() string {
	b := make([]byte, oidcStateLength)
	_, _ = rand.Read(b)
	return base64.URLEncoding.EncodeToString(b)[:oidcStateLength]
}

func NewOIDCProvider(issuerURL, clientID, clientSecret string, allowedUsers, allowedGroups []string) (*OIDCProvider, error) {
	if len(allowedUsers)+len(allowedGroups) == 0 {
		return nil, errors.New("OIDC users, groups, or both must not be empty")
	}
	provider, err := oidc.NewProvider(context.Background(), issuerURL)
	if err != nil {
		return nil, fmt.Errorf("failed to initialize OIDC provider: %w", err)
	}

	endSessionURL, err := url.Parse(provider.EndSessionEndpoint())
	if err != nil && provider.EndSessionEndpoint() != "" {
		// non critical, just warn
		logging.Warn().
			Str("issuer", issuerURL).
			Err(err).
			Msg("failed to parse end session URL")
	}

	return &OIDCProvider{
		oauthConfig: &oauth2.Config{
			ClientID:     clientID,
			ClientSecret: clientSecret,
			RedirectURL:  "",
			Endpoint:     provider.Endpoint(),
			Scopes:       strutils.CommaSeperatedList(common.OIDCScopes),
		},
		oidcProvider: provider,
		oidcVerifier: provider.Verifier(&oidc.Config{
			ClientID: clientID,
		}),
		endSessionURL: endSessionURL,
		allowedUsers:  allowedUsers,
		allowedGroups: allowedGroups,
	}, nil
}

// NewOIDCProviderFromEnv creates a new OIDCProvider from environment variables.
func NewOIDCProviderFromEnv() (*OIDCProvider, error) {
	return NewOIDCProvider(
		common.OIDCIssuerURL,
		common.OIDCClientID,
		common.OIDCClientSecret,
		common.OIDCAllowedUsers,
		common.OIDCAllowedGroups,
	)
}

func (auth *OIDCProvider) SetAllowedUsers(users []string) {
	auth.allowedUsers = users
}

func (auth *OIDCProvider) SetAllowedGroups(groups []string) {
	auth.allowedGroups = groups
}

// optRedirectPostAuth returns an oauth2 option that sets the "redirect_uri"
// parameter of the authorization URL to the post auth path of the current
// request host.
func optRedirectPostAuth(r *http.Request) oauth2.AuthCodeOption {
	return oauth2.SetAuthURLParam("redirect_uri", "https://"+requestHost(r)+OIDCPostAuthPath)
}

func (auth *OIDCProvider) getIdToken(ctx context.Context, oauthToken *oauth2.Token) (string, *oidc.IDToken, error) {
	idTokenJWT, ok := oauthToken.Extra("id_token").(string)
	if !ok {
		return "", nil, errMissingIDToken
	}
	idToken, err := auth.oidcVerifier.Verify(ctx, idTokenJWT)
	if err != nil {
		return "", nil, fmt.Errorf("failed to verify ID token: %w", err)
	}
	return idTokenJWT, idToken, nil
}

func (auth *OIDCProvider) HandleAuth(w http.ResponseWriter, r *http.Request) {
	switch r.URL.Path {
	case OIDCAuthInitPath:
		auth.LoginHandler(w, r)
	case OIDCPostAuthPath:
		auth.PostAuthCallbackHandler(w, r)
	case OIDCLogoutPath:
		auth.LogoutHandler(w, r)
	default:
		http.Redirect(w, r, OIDCAuthInitPath, http.StatusFound)
	}
}

func (auth *OIDCProvider) LoginHandler(w http.ResponseWriter, r *http.Request) {
	// check for session token
	sessionToken, err := r.Cookie(CookieOauthSessionToken)
	if err == nil {
		err = auth.TryRefreshToken(w, r, sessionToken.Value)
		if err != nil {
			logging.Debug().Err(err).Msg("failed to refresh token")
			auth.clearCookie(w, r)
		}
		http.Redirect(w, r, "/", http.StatusFound)
		return
	}

	state := generateState()
	setTokenCookie(w, r, CookieOauthState, state, 300*time.Second)
	// redirect user to Idp
	http.Redirect(w, r, auth.oauthConfig.AuthCodeURL(state, optRedirectPostAuth(r)), http.StatusFound)
}

func parseClaims(idToken *oidc.IDToken) (*IDTokenClaims, error) {
	var claim IDTokenClaims
	if err := idToken.Claims(&claim); err != nil {
		return nil, fmt.Errorf("failed to parse claims: %w", err)
	}
	if claim.Username == "" {
		return nil, fmt.Errorf("missing username in ID token")
	}
	return &claim, nil
}

func (auth *OIDCProvider) checkAllowed(user string, groups []string) bool {
	userAllowed := slices.Contains(auth.allowedUsers, user)
	if !userAllowed {
		return false
	}
	if len(auth.allowedGroups) == 0 {
		return true
	}
	return len(utils.Intersect(groups, auth.allowedGroups)) > 0
}

func (auth *OIDCProvider) CheckToken(r *http.Request) error {
	tokenCookie, err := r.Cookie(CookieOauthToken)
	if err != nil {
		return ErrMissingOAuthToken
	}

	idToken, err := auth.oidcVerifier.Verify(r.Context(), tokenCookie.Value)
	if err != nil {
		return fmt.Errorf("%w: %w", ErrInvalidOAuthToken, err)
	}

	claims, err := parseClaims(idToken)
	if err != nil {
		return fmt.Errorf("%w: %w", ErrInvalidOAuthToken, err)
	}

	if !auth.checkAllowed(claims.Username, claims.Groups) {
		return ErrUserNotAllowed
	}
	return nil
}

func (auth *OIDCProvider) PostAuthCallbackHandler(w http.ResponseWriter, r *http.Request) {
	// For testing purposes, skip provider verification
	if common.IsTest {
		auth.handleTestCallback(w, r)
		return
	}

	// verify state
	state, err := r.Cookie(CookieOauthState)
	if err != nil {
		gphttp.BadRequest(w, "missing state cookie")
		return
	}
	if r.URL.Query().Get("state") != state.Value {
		gphttp.BadRequest(w, "invalid oauth state")
		return
	}

	code := r.URL.Query().Get("code")
	oauth2Token, err := auth.oauthConfig.Exchange(r.Context(), code, optRedirectPostAuth(r))
	if err != nil {
		gphttp.ServerError(w, r, fmt.Errorf("failed to exchange token: %w", err))
		return
	}

	idTokenJWT, idToken, err := auth.getIdToken(r.Context(), oauth2Token)
	if err != nil {
		gphttp.ServerError(w, r, err)
		return
	}

	if oauth2Token.RefreshToken != "" {
		claims, err := parseClaims(idToken)
		if err != nil {
			gphttp.ServerError(w, r, err)
			return
		}
		session := newSession(claims.Username, claims.Groups)
		storeOAuthRefreshToken(session.SessionID, claims.Username, oauth2Token.RefreshToken)
		auth.setSessionTokenCookie(w, r, session)
	}
	auth.setIDTokenCookie(w, r, idTokenJWT, time.Until(idToken.Expiry))

	// Redirect to home page
	http.Redirect(w, r, "/", http.StatusFound)
}

func (auth *OIDCProvider) LogoutHandler(w http.ResponseWriter, r *http.Request) {
	oauthToken, _ := r.Cookie(CookieOauthToken)
	sessionToken, _ := r.Cookie(CookieOauthSessionToken)
	auth.clearCookie(w, r)

	if sessionToken != nil {
		claims, _, err := auth.parseSessionJWT(sessionToken.Value)
		if err == nil {
			invalidateOAuthRefreshToken(claims.SessionID)
		}
	}

	url := "/"
	if auth.endSessionURL != nil && oauthToken != nil {
		query := auth.endSessionURL.Query()
		query.Set("id_token_hint", oauthToken.Value)
		query.Set("post_logout_redirect_uri", "https://"+requestHost(r))

		clone := *auth.endSessionURL
		clone.RawQuery = query.Encode()
		url = clone.String()
	} else if auth.endSessionURL != nil {
		url = auth.endSessionURL.String()
	}

	http.Redirect(w, r, url, http.StatusFound)
}

func (auth *OIDCProvider) setIDTokenCookie(w http.ResponseWriter, r *http.Request, jwt string, ttl time.Duration) {
	setTokenCookie(w, r, CookieOauthToken, jwt, ttl)
}

func (auth *OIDCProvider) clearCookie(w http.ResponseWriter, r *http.Request) {
	clearTokenCookie(w, r, CookieOauthToken)
	clearTokenCookie(w, r, CookieOauthSessionToken)
}

// handleTestCallback handles OIDC callback in test environment.
func (auth *OIDCProvider) handleTestCallback(w http.ResponseWriter, r *http.Request) {
	state, err := r.Cookie(CookieOauthState)
	if err != nil {
		gphttp.BadRequest(w, "missing state cookie")
		return
	}

	if r.URL.Query().Get("state") != state.Value {
		gphttp.BadRequest(w, "invalid oauth state")
		return
	}

	// Create test JWT token
	setTokenCookie(w, r, CookieOauthToken, "test", time.Hour)

	http.Redirect(w, r, "/", http.StatusFound)
}