mirror of
https://github.com/yusing/godoxy.git
synced 2025-05-19 20:32:35 +02:00
cleanup code, redirect to auth page when need
This commit is contained in:
parent
ef277ef57f
commit
76fe5345d8
11 changed files with 113 additions and 109 deletions
14
cmd/main.go
14
cmd/main.go
|
@ -109,16 +109,16 @@ func main() {
|
|||
return
|
||||
}
|
||||
|
||||
if common.APIJWTSecret == nil {
|
||||
logging.Warn().Msg("API JWT secret is empty, authentication is disabled")
|
||||
}
|
||||
|
||||
cfg.Start()
|
||||
config.WatchChanges()
|
||||
|
||||
// Initialize authentication providers
|
||||
if err := auth.Initialize(); err != nil {
|
||||
logging.Warn().Err(err).Msg("Failed to initialize authentication providers")
|
||||
if !auth.IsEnabled() {
|
||||
logging.Warn().Msg("authentication is disabled, please set API_JWT_SECRET or OIDC_* to enable authentication")
|
||||
} else {
|
||||
// Initialize authentication providers
|
||||
if err := auth.Initialize(); err != nil {
|
||||
logging.Fatal().Err(err).Msg("Failed to initialize authentication providers")
|
||||
}
|
||||
}
|
||||
|
||||
sig := make(chan os.Signal, 1)
|
||||
|
|
|
@ -22,9 +22,8 @@ func NewHandler(cfg config.ConfigInstance) http.Handler {
|
|||
mux := ServeMux{http.NewServeMux()}
|
||||
mux.HandleFunc("GET", "/v1", v1.Index)
|
||||
mux.HandleFunc("GET", "/v1/version", v1.GetVersion)
|
||||
mux.HandleFunc("POST", "/v1/login", auth.LoginHandler)
|
||||
mux.HandleFunc("GET", "/v1/login/method", auth.AuthMethodHandler)
|
||||
mux.HandleFunc("GET", "/v1/login/oidc", auth.OIDCLoginHandler)
|
||||
mux.HandleFunc("POST", "/v1/login", auth.UserPassLoginHandler)
|
||||
mux.HandleFunc("GET", "/v1/auth/redirect", auth.AuthRedirectHandler)
|
||||
mux.HandleFunc("GET", "/v1/auth/callback", auth.OIDCCallbackHandler)
|
||||
mux.HandleFunc("GET", "/v1/logout", auth.LogoutHandler)
|
||||
mux.HandleFunc("POST", "/v1/logout", auth.LogoutHandler)
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
@ -25,51 +23,37 @@ type (
|
|||
}
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidUsername = E.New("invalid username")
|
||||
ErrInvalidPassword = E.New("invalid password")
|
||||
)
|
||||
|
||||
func validatePassword(cred *Credentials) error {
|
||||
if cred.Username != common.APIUser {
|
||||
return ErrInvalidUsername.Subject(cred.Username)
|
||||
}
|
||||
if !bytes.Equal(common.HashPassword(cred.Password), common.APIPasswordHash) {
|
||||
return ErrInvalidPassword.Subject(cred.Password)
|
||||
// Initialize sets up authentication providers.
|
||||
func Initialize() error {
|
||||
// Initialize OIDC if configured.
|
||||
if common.OIDCIssuerURL != "" {
|
||||
return InitOIDC(
|
||||
common.OIDCIssuerURL,
|
||||
common.OIDCClientID,
|
||||
common.OIDCClientSecret,
|
||||
common.OIDCRedirectURL,
|
||||
)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func LoginHandler(w http.ResponseWriter, r *http.Request) {
|
||||
var creds Credentials
|
||||
err := json.NewDecoder(r.Body).Decode(&creds)
|
||||
if err != nil {
|
||||
U.HandleErr(w, r, err, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if err := validatePassword(&creds); err != nil {
|
||||
U.HandleErr(w, r, err, http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
if err := setAuthenticatedCookie(w, creds.Username); err != nil {
|
||||
U.HandleErr(w, r, err, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
func IsEnabled() bool {
|
||||
return common.APIJWTSecret != nil || common.OIDCIssuerURL != ""
|
||||
}
|
||||
|
||||
func AuthMethodHandler(w http.ResponseWriter, r *http.Request) {
|
||||
// AuthRedirectHandler handles redirect to login page or OIDC login base on configuration.
|
||||
func AuthRedirectHandler(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case oauthConfig != nil:
|
||||
RedirectOIDC(w, r)
|
||||
return
|
||||
case common.APIJWTSecret == nil:
|
||||
U.WriteBody(w, []byte("skip"))
|
||||
case common.OIDCIssuerURL != "":
|
||||
U.WriteBody(w, []byte("oidc"))
|
||||
case common.APIPasswordHash != nil:
|
||||
U.WriteBody(w, []byte("password"))
|
||||
http.Redirect(w, r, "/login", http.StatusTemporaryRedirect)
|
||||
return
|
||||
default:
|
||||
U.WriteBody(w, []byte("skip"))
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
func setAuthenticatedCookie(w http.ResponseWriter, username string) error {
|
||||
|
@ -86,57 +70,44 @@ func setAuthenticatedCookie(w http.ResponseWriter, username string) error {
|
|||
return err
|
||||
}
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "token",
|
||||
Name: CookieToken,
|
||||
Value: tokenStr,
|
||||
Expires: expiresAt,
|
||||
HttpOnly: true,
|
||||
Secure: true,
|
||||
SameSite: http.SameSiteStrictMode,
|
||||
Path: "/",
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// LogoutHandler clear authentication cookie and redirect to login page.
|
||||
func LogoutHandler(w http.ResponseWriter, r *http.Request) {
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "token",
|
||||
Name: CookieToken,
|
||||
Value: "",
|
||||
Expires: time.Unix(0, 0),
|
||||
HttpOnly: true,
|
||||
Secure: true,
|
||||
SameSite: http.SameSiteStrictMode,
|
||||
Path: "/",
|
||||
})
|
||||
w.Header().Set("location", "/login")
|
||||
w.WriteHeader(http.StatusTemporaryRedirect)
|
||||
}
|
||||
|
||||
// Initialize sets up authentication providers.
|
||||
func Initialize() error {
|
||||
// Initialize OIDC if configured.
|
||||
if common.OIDCIssuerURL != "" {
|
||||
return InitOIDC(
|
||||
common.OIDCIssuerURL,
|
||||
common.OIDCClientID,
|
||||
common.OIDCClientSecret,
|
||||
common.OIDCRedirectURL,
|
||||
)
|
||||
}
|
||||
return nil
|
||||
AuthRedirectHandler(w, r)
|
||||
}
|
||||
|
||||
func RequireAuth(next http.HandlerFunc) http.HandlerFunc {
|
||||
if common.IsDebugSkipAuth || common.APIJWTSecret == nil {
|
||||
return next
|
||||
}
|
||||
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if checkToken(w, r) {
|
||||
next(w, r)
|
||||
if IsEnabled() {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if checkToken(w, r) {
|
||||
next(w, r)
|
||||
}
|
||||
}
|
||||
}
|
||||
return next
|
||||
}
|
||||
|
||||
func checkToken(w http.ResponseWriter, r *http.Request) (ok bool) {
|
||||
tokenCookie, err := r.Cookie("token")
|
||||
tokenCookie, err := r.Cookie(CookieToken)
|
||||
if err != nil {
|
||||
U.RespondError(w, E.New("missing token"), http.StatusUnauthorized)
|
||||
return false
|
||||
|
|
6
internal/api/v1/auth/cookies.go
Normal file
6
internal/api/v1/auth/cookies.go
Normal file
|
@ -0,0 +1,6 @@
|
|||
package auth
|
||||
|
||||
const (
|
||||
CookieToken = "token"
|
||||
CookieOauthState = "oauth_state"
|
||||
)
|
|
@ -4,10 +4,8 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
U "github.com/yusing/go-proxy/internal/api/v1/utils"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
|
@ -47,8 +45,8 @@ func InitOIDC(issuerURL, clientID, clientSecret, redirectURL string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// OIDCLoginHandler initiates the OIDC login flow.
|
||||
func OIDCLoginHandler(w http.ResponseWriter, r *http.Request) {
|
||||
// RedirectOIDC initiates the OIDC login flow.
|
||||
func RedirectOIDC(w http.ResponseWriter, r *http.Request) {
|
||||
if oauthConfig == nil {
|
||||
U.HandleErr(w, r, E.New("OIDC not configured"), http.StatusNotImplemented)
|
||||
return
|
||||
|
@ -56,7 +54,7 @@ func OIDCLoginHandler(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
state := common.GenerateRandomString(32)
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "oauth_state",
|
||||
Name: CookieOauthState,
|
||||
Value: state,
|
||||
MaxAge: 300,
|
||||
HttpOnly: true,
|
||||
|
@ -87,7 +85,7 @@ func OIDCCallbackHandler(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
state, err := r.Cookie("oauth_state")
|
||||
state, err := r.Cookie(CookieOauthState)
|
||||
if err != nil {
|
||||
U.HandleErr(w, r, E.New("missing state cookie"), http.StatusBadRequest)
|
||||
return
|
||||
|
@ -137,7 +135,7 @@ func OIDCCallbackHandler(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
// handleTestCallback handles OIDC callback in test environment.
|
||||
func handleTestCallback(w http.ResponseWriter, r *http.Request) {
|
||||
state, err := r.Cookie("oauth_state")
|
||||
state, err := r.Cookie(CookieOauthState)
|
||||
if err != nil {
|
||||
U.HandleErr(w, r, E.New("missing state cookie"), http.StatusBadRequest)
|
||||
return
|
||||
|
@ -149,29 +147,10 @@ func handleTestCallback(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// Create test JWT token
|
||||
expiresAt := time.Now().Add(common.APIJWTTokenTTL)
|
||||
jwtClaims := &Claims{
|
||||
Username: "test-user",
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(expiresAt),
|
||||
},
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS512, jwtClaims)
|
||||
tokenStr, err := token.SignedString(common.APIJWTSecret)
|
||||
if err != nil {
|
||||
if err := setAuthenticatedCookie(w, "test-user"); err != nil {
|
||||
U.HandleErr(w, r, err, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "token",
|
||||
Value: tokenStr,
|
||||
Expires: expiresAt,
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteStrictMode,
|
||||
Path: "/",
|
||||
})
|
||||
|
||||
http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
|
||||
}
|
||||
|
|
|
@ -68,10 +68,10 @@ func TestOIDCLoginHandler(t *testing.T) {
|
|||
oauthConfig = nil
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/login/oidc", nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "/auth/redirect", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
OIDCLoginHandler(w, req)
|
||||
RedirectOIDC(w, req)
|
||||
|
||||
if got := w.Code; got != tt.wantStatus {
|
||||
t.Errorf("OIDCLoginHandler() status = %v, want %v", got, tt.wantStatus)
|
||||
|
|
45
internal/api/v1/auth/userpass.go
Normal file
45
internal/api/v1/auth/userpass.go
Normal file
|
@ -0,0 +1,45 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
U "github.com/yusing/go-proxy/internal/api/v1/utils"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidUsername = E.New("invalid username")
|
||||
ErrInvalidPassword = E.New("invalid password")
|
||||
)
|
||||
|
||||
func validatePassword(cred *Credentials) error {
|
||||
if cred.Username != common.APIUser {
|
||||
return ErrInvalidUsername.Subject(cred.Username)
|
||||
}
|
||||
if !bytes.Equal(common.HashPassword(cred.Password), common.APIPasswordHash) {
|
||||
return ErrInvalidPassword.Subject(cred.Password)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UserPassLoginHandler handles user login.
|
||||
func UserPassLoginHandler(w http.ResponseWriter, r *http.Request) {
|
||||
var creds Credentials
|
||||
err := json.NewDecoder(r.Body).Decode(&creds)
|
||||
if err != nil {
|
||||
U.HandleErr(w, r, err, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if err := validatePassword(&creds); err != nil {
|
||||
U.HandleErr(w, r, err, http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
if err := setAuthenticatedCookie(w, creds.Username); err != nil {
|
||||
U.HandleErr(w, r, err, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
|
@ -7,7 +7,7 @@ import (
|
|||
"github.com/yusing/go-proxy/internal/utils/strutils/ansi"
|
||||
)
|
||||
|
||||
// HandleErr logs the error and returns an HTTP error response to the client.
|
||||
// HandleErr logs the error and returns an error code to the client.
|
||||
// If code is specified, it will be used as the HTTP status code; otherwise,
|
||||
// http.StatusInternalServerError is used.
|
||||
//
|
||||
|
@ -23,10 +23,14 @@ func HandleErr(w http.ResponseWriter, r *http.Request, err error, code ...int) {
|
|||
http.Error(w, http.StatusText(code[0]), code[0])
|
||||
}
|
||||
|
||||
// RespondError returns error details to the client.
|
||||
// If code is specified, it will be used as the HTTP status code; otherwise,
|
||||
// http.StatusBadRequest is used.
|
||||
func RespondError(w http.ResponseWriter, err error, code ...int) {
|
||||
if len(code) == 0 {
|
||||
code = []int{http.StatusBadRequest}
|
||||
}
|
||||
// strip ANSI color codes added from Error.WithSubject
|
||||
http.Error(w, ansi.StripANSI(err.Error()), code[0])
|
||||
}
|
||||
|
||||
|
|
|
@ -11,7 +11,7 @@ import (
|
|||
|
||||
func WriteBody(w http.ResponseWriter, body []byte) {
|
||||
if _, err := w.Write(body); err != nil {
|
||||
HandleErr(w, nil, err)
|
||||
logging.Err(err).Msg("failed to write body")
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -14,11 +14,10 @@ import (
|
|||
var (
|
||||
prefixes = []string{"GODOXY_", "GOPROXY_", ""}
|
||||
|
||||
IsTest = GetEnvBool("TEST", false) || strings.HasSuffix(os.Args[0], ".test")
|
||||
IsDebug = GetEnvBool("DEBUG", IsTest)
|
||||
IsDebugSkipAuth = GetEnvBool("DEBUG_SKIP_AUTH", false)
|
||||
IsTrace = GetEnvBool("TRACE", false) && IsDebug
|
||||
IsProduction = !IsTest && !IsDebug
|
||||
IsTest = GetEnvBool("TEST", false) || strings.HasSuffix(os.Args[0], ".test")
|
||||
IsDebug = GetEnvBool("DEBUG", IsTest)
|
||||
IsTrace = GetEnvBool("TRACE", false) && IsDebug
|
||||
IsProduction = !IsTest && !IsDebug
|
||||
|
||||
ProxyHTTPAddr,
|
||||
ProxyHTTPHost,
|
||||
|
@ -46,7 +45,7 @@ var (
|
|||
APIUser = GetEnvString("API_USER", "admin")
|
||||
APIPasswordHash = HashPassword(GetEnvString("API_PASSWORD", "password"))
|
||||
|
||||
// OIDC Configuration
|
||||
// OIDC Configuration.
|
||||
OIDCIssuerURL = GetEnvString("OIDC_ISSUER_URL", "")
|
||||
OIDCClientID = GetEnvString("OIDC_CLIENT_ID", "")
|
||||
OIDCClientSecret = GetEnvString("OIDC_CLIENT_SECRET", "")
|
||||
|
|
|
@ -31,6 +31,7 @@ func makeEntries(cont *types.Container, dockerHostIP ...string) route.RawEntries
|
|||
} else {
|
||||
host = client.DefaultDockerHost
|
||||
}
|
||||
p.name = "test"
|
||||
entries := E.Must(p.entriesFromContainerLabels(D.FromDocker(cont, host)))
|
||||
entries.RangeAll(func(k string, v *route.RawEntry) {
|
||||
v.Finalize()
|
||||
|
|
Loading…
Add table
Reference in a new issue