From 81177926ff64037bceaa42fa09b33b19c1d9b5cc Mon Sep 17 00:00:00 2001 From: yusing Date: Wed, 30 Oct 2024 06:25:32 +0800 Subject: [PATCH] implemented login and jwt auth --- Makefile | 12 ++-- go.mod | 1 + go.sum | 2 + internal/api/handler.go | 20 +++--- internal/api/v1/auth/auth.go | 126 +++++++++++++++++++++++++++++++++ internal/api/v1/utils/error.go | 11 ++- internal/common/constants.go | 4 ++ internal/common/crypto.go | 31 ++++++++ internal/common/env.go | 12 +++- internal/error/subject.go | 6 +- internal/notif/dispatcher.go | 7 +- internal/route/stream.go | 1 - 12 files changed, 206 insertions(+), 27 deletions(-) create mode 100644 internal/api/v1/auth/auth.go create mode 100644 internal/common/crypto.go diff --git a/Makefile b/Makefile index cba758c..3561f95 100755 --- a/Makefile +++ b/Makefile @@ -28,20 +28,20 @@ get: go get -u ./cmd && go mod tidy debug: - make build && sudo GOPROXY_DEBUG=1 bin/go-proxy + make build + GOPROXY_DEBUG=1 sudo bin/go-proxy debug-trace: - make build && sudo GOPROXY_DEBUG=1 GOPROXY_TRACE=1 bin/go-proxy + make build + GOPROXY_DEBUG=1 GOPROXY_TRACE=1 sudo bin/go-proxy profile: - GODEBUG=gctrace=1 make build && sudo GOPROXY_DEBUG=1 bin/go-proxy + GODEBUG=gctrace=1 make build + GOPROXY_DEBUG=1 sudo bin/go-proxy mtrace: bin/go-proxy debug-ls-mtrace > mtrace.json -run-test: - make build && sudo GOPROXY_TEST=1 bin/go-proxy - run: make build && sudo bin/go-proxy diff --git a/go.mod b/go.mod index 3370a40..306f62a 100644 --- a/go.mod +++ b/go.mod @@ -8,6 +8,7 @@ require ( github.com/docker/docker v27.3.1+incompatible github.com/fsnotify/fsnotify v1.7.0 github.com/go-acme/lego/v4 v4.19.2 + github.com/golang-jwt/jwt/v5 v5.2.1 github.com/gotify/server/v2 v2.5.0 github.com/puzpuzpuz/xsync/v3 v3.4.0 github.com/rs/zerolog v1.33.0 diff --git a/go.sum b/go.sum index 62f4b45..94cd447 100644 --- a/go.sum +++ b/go.sum @@ -44,6 +44,8 @@ github.com/goccy/go-json v0.10.3/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PU github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= +github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk= +github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= diff --git a/internal/api/handler.go b/internal/api/handler.go index 5ee66a0..12333de 100644 --- a/internal/api/handler.go +++ b/internal/api/handler.go @@ -6,6 +6,7 @@ import ( "net/http" v1 "github.com/yusing/go-proxy/internal/api/v1" + "github.com/yusing/go-proxy/internal/api/v1/auth" . "github.com/yusing/go-proxy/internal/api/v1/utils" "github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/config" @@ -25,16 +26,17 @@ func NewHandler() http.Handler { mux := NewServeMux() mux.HandleFunc("GET", "/v1", v1.Index) mux.HandleFunc("GET", "/v1/version", v1.GetVersion) - mux.HandleFunc("GET", "/v1/checkhealth", v1.CheckHealth) - mux.HandleFunc("HEAD", "/v1/checkhealth", v1.CheckHealth) + // mux.HandleFunc("GET", "/v1/checkhealth", v1.CheckHealth) + // mux.HandleFunc("HEAD", "/v1/checkhealth", v1.CheckHealth) + mux.HandleFunc("POST", "/v1/login", auth.LoginHandler) mux.HandleFunc("POST", "/v1/reload", v1.Reload) - mux.HandleFunc("GET", "/v1/list", v1.List) - mux.HandleFunc("GET", "/v1/list/{what}", v1.List) - mux.HandleFunc("GET", "/v1/list/{what}/{which}", v1.List) - mux.HandleFunc("GET", "/v1/file", v1.GetFileContent) - mux.HandleFunc("GET", "/v1/file/{filename...}", v1.GetFileContent) - mux.HandleFunc("POST", "/v1/file/{filename...}", v1.SetFileContent) - mux.HandleFunc("PUT", "/v1/file/{filename...}", v1.SetFileContent) + mux.HandleFunc("GET", "/v1/list", auth.RequireAuth(v1.List)) + mux.HandleFunc("GET", "/v1/list/{what}", auth.RequireAuth(v1.List)) + mux.HandleFunc("GET", "/v1/list/{what}/{which}", auth.RequireAuth(v1.List)) + mux.HandleFunc("GET", "/v1/file", auth.RequireAuth(v1.GetFileContent)) + mux.HandleFunc("GET", "/v1/file/{filename...}", auth.RequireAuth(v1.GetFileContent)) + mux.HandleFunc("POST", "/v1/file/{filename...}", auth.RequireAuth(v1.SetFileContent)) + mux.HandleFunc("PUT", "/v1/file/{filename...}", auth.RequireAuth(v1.SetFileContent)) mux.HandleFunc("GET", "/v1/stats", v1.Stats) mux.HandleFunc("GET", "/v1/stats/ws", v1.StatsWS) return mux diff --git a/internal/api/v1/auth/auth.go b/internal/api/v1/auth/auth.go new file mode 100644 index 0000000..b70ce67 --- /dev/null +++ b/internal/api/v1/auth/auth.go @@ -0,0 +1,126 @@ +package auth + +import ( + "bytes" + "encoding/json" + "fmt" + "net/http" + "time" + + "github.com/golang-jwt/jwt/v5" + U "github.com/yusing/go-proxy/internal/api/v1/utils" + "github.com/yusing/go-proxy/internal/common" + E "github.com/yusing/go-proxy/internal/error" + "github.com/yusing/go-proxy/internal/utils/strutils" +) + +type ( + Credentials struct { + Username string `json:"username"` + Password string `json:"password"` + } + Claims struct { + Username string `json:"username"` + jwt.RegisteredClaims + } +) + +var ( + ErrInvalidUsername = E.New("invalid username") + ErrInvalidPassword = E.New("invalid password") +) + +const tokenExpiration = 24 * time.Hour + +const jwtClaimKeyUsername = "username" + +func validatePassword(cred *Credentials) error { + if cred.Username != common.APIUser { + return ErrInvalidUsername.Subject(cred.Username) + } + if !bytes.Equal(common.HashPassword(cred.Password), common.APIPasswordHash) { + return ErrInvalidPassword.Subject(cred.Password) + } + return nil +} + +func LoginHandler(w http.ResponseWriter, r *http.Request) { + var creds Credentials + err := json.NewDecoder(r.Body).Decode(&creds) + if err != nil { + U.HandleErr(w, r, err, http.StatusBadRequest) + return + } + if err := validatePassword(&creds); err != nil { + U.HandleErr(w, r, err, http.StatusUnauthorized) + return + } + + expiresAt := time.Now().Add(tokenExpiration) + claim := &Claims{ + Username: creds.Username, + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(expiresAt), + }, + } + token := jwt.NewWithClaims(jwt.SigningMethodES512, claim) + tokenStr, err := token.SignedString(common.APIJWTSecret) + if err != nil { + U.HandleErr(w, r, err) + return + } + http.SetCookie(w, &http.Cookie{ + Name: "token", + Value: tokenStr, + Expires: expiresAt, + HttpOnly: true, + SameSite: http.SameSiteStrictMode, + Path: "/", + }) + w.WriteHeader(http.StatusOK) +} + +func RequireAuth(next http.HandlerFunc) http.HandlerFunc { + if common.IsDebugSkipAuth { + return next + } + + return func(w http.ResponseWriter, r *http.Request) { + if checkToken(w, r) { + next(w, r) + } + } +} + +func checkToken(w http.ResponseWriter, r *http.Request) (ok bool) { + tokenCookie, err := r.Cookie("token") + if err != nil { + U.HandleErr(w, r, E.PrependSubject("token", err), http.StatusUnauthorized) + return false + } + var claims Claims + token, err := jwt.ParseWithClaims(tokenCookie.Value, &claims, func(t *jwt.Token) (interface{}, error) { + if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("Unexpected signing method: %v", t.Header["alg"]) + } + return common.APIJWTSecret, nil + }) + + switch { + case err != nil: + break + case !token.Valid: + err = E.New("invalid token") + case claims.Username != common.APIUser: + err = E.New("username mismatch").Subject(claims.Username) + case claims.ExpiresAt.Before(time.Now()): + err = E.Errorf("token expired on %s", strutils.FormatTime(claims.ExpiresAt.Time)) + } + + if err != nil { + U.HandleErr(w, r, err, http.StatusForbidden) + return false + } + + return true +} diff --git a/internal/api/v1/utils/error.go b/internal/api/v1/utils/error.go index 3c55d95..e2775df 100644 --- a/internal/api/v1/utils/error.go +++ b/internal/api/v1/utils/error.go @@ -6,16 +6,21 @@ import ( E "github.com/yusing/go-proxy/internal/error" ) +// HandleErr logs the error and returns an HTTP error response to the client. +// If code is specified, it will be used as the HTTP status code; otherwise, +// http.StatusInternalServerError is used. +// +// The error is only logged but not returned to the client. func HandleErr(w http.ResponseWriter, r *http.Request, origErr error, code ...int) { if origErr == nil { return } LogError(r).Msg(origErr.Error()) + statusCode := http.StatusInternalServerError if len(code) > 0 { - http.Error(w, origErr.Error(), code[0]) - return + statusCode = code[0] } - http.Error(w, origErr.Error(), http.StatusInternalServerError) + http.Error(w, http.StatusText(statusCode), statusCode) } func ErrMissingKey(k string) error { diff --git a/internal/common/constants.go b/internal/common/constants.go index e9ab187..0a2c2bb 100644 --- a/internal/common/constants.go +++ b/internal/common/constants.go @@ -13,11 +13,15 @@ const ( // file, folder structure const ( + DotEnvPath = ".env" + ConfigBasePath = "config" ConfigFileName = "config.yml" ConfigExampleFileName = "config.example.yml" ConfigPath = ConfigBasePath + "/" + ConfigFileName + JWTKeyPath = ConfigBasePath + "/jwt.key" + MiddlewareComposeBasePath = ConfigBasePath + "/middlewares" SchemaBasePath = "schema" diff --git a/internal/common/crypto.go b/internal/common/crypto.go new file mode 100644 index 0000000..751e5d8 --- /dev/null +++ b/internal/common/crypto.go @@ -0,0 +1,31 @@ +package common + +import ( + "crypto/rand" + "crypto/sha512" + "encoding/base64" + + "github.com/rs/zerolog/log" +) + +func HashPassword(pwd string) []byte { + h := sha512.New() + h.Write([]byte(pwd)) + return h.Sum(nil) +} + +func generateJWTKey(size int) string { + bytes := make([]byte, size) + if _, err := rand.Read(bytes); err != nil { + log.Panic().Err(err).Msg("failed to generate jwt key") + } + return base64.URLEncoding.EncodeToString(bytes) +} + +func decodeJWTKey(key string) []byte { + bytes, err := base64.URLEncoding.DecodeString(key) + if err != nil { + log.Panic().Err(err).Msg("failed to decode jwt key") + } + return bytes +} diff --git a/internal/common/env.go b/internal/common/env.go index a8ad9b7..d58880b 100644 --- a/internal/common/env.go +++ b/internal/common/env.go @@ -2,17 +2,19 @@ package common import ( "fmt" - "log" "net" "os" "strconv" "strings" + + "github.com/rs/zerolog/log" ) var ( NoSchemaValidation = GetEnvBool("GOPROXY_NO_SCHEMA_VALIDATION", true) IsTest = GetEnvBool("GOPROXY_TEST", false) || strings.HasSuffix(os.Args[0], ".test") IsDebug = GetEnvBool("GOPROXY_DEBUG", IsTest) + IsDebugSkipAuth = GetEnvBool("GOPROXY_DEBUG_SKIP_AUTH", false) IsTrace = GetEnvBool("GOPROXY_TRACE", false) && IsDebug ProxyHTTPAddr, @@ -29,6 +31,10 @@ var ( APIHTTPHost, APIHTTPPort, APIHTTPURL = GetAddrEnv("GOPROXY_API_ADDR", "127.0.0.1:8888", "http") + + APIJWTSecret = decodeJWTKey(GetEnv("GOPROXY_API_JWT_SECRET", generateJWTKey(32))) + APIUser = GetEnv("GOPROXY_API_USER", "admin") + APIPasswordHash = HashPassword(GetEnv("GOPROXY_API_PASSWORD", "password")) ) func GetEnvBool(key string, defaultValue bool) bool { @@ -38,7 +44,7 @@ func GetEnvBool(key string, defaultValue bool) bool { } b, err := strconv.ParseBool(value) if err != nil { - log.Fatalf("env %s: invalid boolean value: %s", key, value) + log.Fatal().Msgf("env %s: invalid boolean value: %s", key, value) } return b } @@ -55,7 +61,7 @@ func GetAddrEnv(key, defaultValue, scheme string) (addr, host, port, fullURL str addr = GetEnv(key, defaultValue) host, port, err := net.SplitHostPort(addr) if err != nil { - log.Fatalf("env %s: invalid address: %s", key, addr) + log.Fatal().Msgf("env %s: invalid address: %s", key, addr) } if host == "" { host = "localhost" diff --git a/internal/error/subject.go b/internal/error/subject.go index c78ef2d..0c5ffef 100644 --- a/internal/error/subject.go +++ b/internal/error/subject.go @@ -17,12 +17,12 @@ func highlight(subject string) string { return ansi.HighlightRed + subject + ansi.Reset } -func PrependSubject(subject string, err error) *withSubject { +func PrependSubject(subject string, err error) error { switch err := err.(type) { case *withSubject: return err.Prepend(subject) - case *baseError: - return PrependSubject(subject, err.Err) + case Error: + return err.Subject(subject) default: return &withSubject{subject, err} } diff --git a/internal/notif/dispatcher.go b/internal/notif/dispatcher.go index 2b74f42..fed2c9a 100644 --- a/internal/notif/dispatcher.go +++ b/internal/notif/dispatcher.go @@ -3,6 +3,7 @@ package notif import ( "github.com/rs/zerolog" E "github.com/yusing/go-proxy/internal/error" + "github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/task" "github.com/yusing/go-proxy/internal/utils" F "github.com/yusing/go-proxy/internal/utils/functional" @@ -70,8 +71,8 @@ func (disp *Dispatcher) start() { select { case <-disp.task.Context().Done(): return - case entry := <-disp.logCh: - go disp.dispatch(entry) + case msg := <-disp.logCh: + go disp.dispatch(msg) } } } @@ -88,6 +89,8 @@ func (disp *Dispatcher) dispatch(msg *LogMessage) { }) if errs.HasError() { E.LogError(errs.About(), errs.Error()) + } else { + logging.Debug().Msgf("dispatched notif: %s", msg.Message) } } diff --git a/internal/route/stream.go b/internal/route/stream.go index c94d75b..71d89eb 100755 --- a/internal/route/stream.go +++ b/internal/route/stream.go @@ -100,7 +100,6 @@ func (r *StreamRoute) Start(providerSubtask task.Task) E.Error { }) r.l.Info(). - Str("proto", string(r.Scheme.ListeningScheme)). Int("port", int(r.Port.ListeningPort)). Msg("listening")