mirror of
https://github.com/yusing/godoxy.git
synced 2025-05-19 20:32:35 +02:00
style: coed cleanup and fix styling
This commit is contained in:
parent
a06787593c
commit
c05059765d
24 changed files with 161 additions and 218 deletions
|
@ -9,6 +9,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type Matcher func(*maxmind.IPInfo) bool
|
type Matcher func(*maxmind.IPInfo) bool
|
||||||
|
|
||||||
type Matchers []Matcher
|
type Matchers []Matcher
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -26,6 +27,7 @@ var errMatcherFormat = gperr.Multiline().AddLines(
|
||||||
"tz:Asia/Shanghai",
|
"tz:Asia/Shanghai",
|
||||||
"country:GB",
|
"country:GB",
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
errSyntax = gperr.New("syntax error")
|
errSyntax = gperr.New("syntax error")
|
||||||
errInvalidIP = gperr.New("invalid IP")
|
errInvalidIP = gperr.New("invalid IP")
|
||||||
|
|
|
@ -1,5 +0,0 @@
|
||||||
package dockerapi
|
|
||||||
|
|
||||||
import "time"
|
|
||||||
|
|
||||||
const reqTimeout = 10 * time.Second
|
|
|
@ -18,7 +18,7 @@ type Container struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func Containers(w http.ResponseWriter, r *http.Request) {
|
func Containers(w http.ResponseWriter, r *http.Request) {
|
||||||
serveHTTP[Container, []Container](w, r, GetContainers)
|
serveHTTP[Container](w, r, GetContainers)
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetContainers(ctx context.Context, dockerClients DockerClients) ([]Container, gperr.Error) {
|
func GetContainers(ctx context.Context, dockerClients DockerClients) ([]Container, gperr.Error) {
|
||||||
|
|
|
@ -22,7 +22,7 @@ func Logs(w http.ResponseWriter, r *http.Request) {
|
||||||
until := query.Get("to")
|
until := query.Get("to")
|
||||||
levels := query.Get("levels") // TODO: implement levels
|
levels := query.Get("levels") // TODO: implement levels
|
||||||
|
|
||||||
dockerClient, found, err := getDockerClient(w, server)
|
dockerClient, found, err := getDockerClient(server)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
gphttp.BadRequest(w, err.Error())
|
gphttp.BadRequest(w, err.Error())
|
||||||
return
|
return
|
||||||
|
|
|
@ -56,7 +56,7 @@ func getDockerClients() (DockerClients, gperr.Error) {
|
||||||
return dockerClients, connErrs.Error()
|
return dockerClients, connErrs.Error()
|
||||||
}
|
}
|
||||||
|
|
||||||
func getDockerClient(w http.ResponseWriter, server string) (*docker.SharedClient, bool, error) {
|
func getDockerClient(server string) (*docker.SharedClient, bool, error) {
|
||||||
cfg := config.GetInstance()
|
cfg := config.GetInstance()
|
||||||
var host string
|
var host string
|
||||||
for name, h := range cfg.Value().Providers.Docker {
|
for name, h := range cfg.Value().Providers.Docker {
|
||||||
|
@ -98,7 +98,7 @@ func handleResult[V any, T ResultType[V]](w http.ResponseWriter, errs error, res
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
json.NewEncoder(w).Encode(result)
|
json.NewEncoder(w).Encode(result) //nolint
|
||||||
}
|
}
|
||||||
|
|
||||||
func serveHTTP[V any, T ResultType[V]](w http.ResponseWriter, r *http.Request, getResult func(ctx context.Context, dockerClients DockerClients) (T, gperr.Error)) {
|
func serveHTTP[V any, T ResultType[V]](w http.ResponseWriter, r *http.Request, getResult func(ctx context.Context, dockerClients DockerClients) (T, gperr.Error)) {
|
||||||
|
@ -119,6 +119,6 @@ func serveHTTP[V any, T ResultType[V]](w http.ResponseWriter, r *http.Request, g
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
result, err := getResult(r.Context(), dockerClients)
|
result, err := getResult(r.Context(), dockerClients)
|
||||||
handleResult[V, T](w, err, result)
|
handleResult[V](w, err, result)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,7 +22,7 @@ type oauthRefreshToken struct {
|
||||||
RefreshToken string `json:"refresh_token"`
|
RefreshToken string `json:"refresh_token"`
|
||||||
Expiry time.Time `json:"expiry"`
|
Expiry time.Time `json:"expiry"`
|
||||||
|
|
||||||
result *refreshResult
|
result *RefreshResult
|
||||||
err error
|
err error
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
}
|
}
|
||||||
|
@ -33,7 +33,7 @@ type Session struct {
|
||||||
Groups []string `json:"groups"`
|
Groups []string `json:"groups"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type refreshResult struct {
|
type RefreshResult struct {
|
||||||
newSession Session
|
newSession Session
|
||||||
jwt string
|
jwt string
|
||||||
jwtExpiry time.Time
|
jwtExpiry time.Time
|
||||||
|
@ -50,7 +50,6 @@ var oauthRefreshTokens jsonstore.MapStore[*oauthRefreshToken]
|
||||||
|
|
||||||
var (
|
var (
|
||||||
defaultRefreshTokenExpiry = 30 * 24 * time.Hour // 1 month
|
defaultRefreshTokenExpiry = 30 * 24 * time.Hour // 1 month
|
||||||
refreshBefore = 30 * time.Second
|
|
||||||
sessionInvalidateDelay = 3 * time.Second
|
sessionInvalidateDelay = 3 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -148,7 +147,7 @@ func (auth *OIDCProvider) parseSessionJWT(sessionJWT string) (claims *sessionCla
|
||||||
return claims, sessionToken.Valid && claims.Issuer == sessionTokenIssuer, nil
|
return claims, sessionToken.Valid && claims.Issuer == sessionTokenIssuer, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *OIDCProvider) TryRefreshToken(ctx context.Context, sessionJWT string) (*refreshResult, error) {
|
func (auth *OIDCProvider) TryRefreshToken(ctx context.Context, sessionJWT string) (*RefreshResult, error) {
|
||||||
// verify the session cookie
|
// verify the session cookie
|
||||||
claims, valid, err := auth.parseSessionJWT(sessionJWT)
|
claims, valid, err := auth.parseSessionJWT(sessionJWT)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -171,7 +170,7 @@ func (auth *OIDCProvider) TryRefreshToken(ctx context.Context, sessionJWT string
|
||||||
return auth.doRefreshToken(ctx, refreshToken, &claims.Session)
|
return auth.doRefreshToken(ctx, refreshToken, &claims.Session)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *OIDCProvider) doRefreshToken(ctx context.Context, refreshToken *oauthRefreshToken, claims *Session) (*refreshResult, error) {
|
func (auth *OIDCProvider) doRefreshToken(ctx context.Context, refreshToken *oauthRefreshToken, claims *Session) (*RefreshResult, error) {
|
||||||
refreshToken.mu.Lock()
|
refreshToken.mu.Lock()
|
||||||
defer refreshToken.mu.Unlock()
|
defer refreshToken.mu.Unlock()
|
||||||
|
|
||||||
|
@ -209,7 +208,7 @@ func (auth *OIDCProvider) doRefreshToken(ctx context.Context, refreshToken *oaut
|
||||||
logging.Debug().Str("username", claims.Username).Time("expiry", newToken.Expiry).Msg("refreshed token")
|
logging.Debug().Str("username", claims.Username).Time("expiry", newToken.Expiry).Msg("refreshed token")
|
||||||
storeOAuthRefreshToken(sessionID, claims.Username, newToken.RefreshToken)
|
storeOAuthRefreshToken(sessionID, claims.Username, newToken.RefreshToken)
|
||||||
|
|
||||||
refreshToken.result = &refreshResult{
|
refreshToken.result = &RefreshResult{
|
||||||
newSession: Session{
|
newSession: Session{
|
||||||
SessionID: sessionID,
|
SessionID: sessionID,
|
||||||
Username: claims.Username,
|
Username: claims.Username,
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
package auth
|
package auth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/rsa"
|
"crypto/rsa"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
|
@ -24,7 +23,7 @@ import (
|
||||||
func setupMockOIDC(t *testing.T) {
|
func setupMockOIDC(t *testing.T) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
provider := (&oidc.ProviderConfig{}).NewProvider(context.TODO())
|
provider := (&oidc.ProviderConfig{}).NewProvider(t.Context())
|
||||||
defaultAuth = &OIDCProvider{
|
defaultAuth = &OIDCProvider{
|
||||||
oauthConfig: &oauth2.Config{
|
oauthConfig: &oauth2.Config{
|
||||||
ClientID: "test-client",
|
ClientID: "test-client",
|
||||||
|
@ -104,7 +103,7 @@ func setupProvider(t *testing.T) *provider {
|
||||||
t.Cleanup(ts.Close)
|
t.Cleanup(ts.Close)
|
||||||
|
|
||||||
// Create a test OIDCProvider.
|
// Create a test OIDCProvider.
|
||||||
providerCtx := oidc.ClientContext(context.Background(), ts.Client())
|
providerCtx := oidc.ClientContext(t.Context(), ts.Client())
|
||||||
keySet := oidc.NewRemoteKeySet(providerCtx, ts.URL+"/.well-known/jwks.json")
|
keySet := oidc.NewRemoteKeySet(providerCtx, ts.URL+"/.well-known/jwks.json")
|
||||||
|
|
||||||
return &provider{
|
return &provider{
|
||||||
|
|
|
@ -9,7 +9,6 @@ import (
|
||||||
"path"
|
"path"
|
||||||
"reflect"
|
"reflect"
|
||||||
"sort"
|
"sort"
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/go-acme/lego/v4/certificate"
|
"github.com/go-acme/lego/v4/certificate"
|
||||||
|
@ -33,8 +32,6 @@ type (
|
||||||
legoCert *certificate.Resource
|
legoCert *certificate.Resource
|
||||||
tlsCert *tls.Certificate
|
tlsCert *tls.Certificate
|
||||||
certExpiries CertExpiries
|
certExpiries CertExpiries
|
||||||
|
|
||||||
obtainMu sync.Mutex
|
|
||||||
}
|
}
|
||||||
|
|
||||||
CertExpiries map[string]time.Time
|
CertExpiries map[string]time.Time
|
||||||
|
|
|
@ -46,5 +46,5 @@ oauth2_config:
|
||||||
opt := make(map[string]any)
|
opt := make(map[string]any)
|
||||||
require.NoError(t, yaml.Unmarshal([]byte(testYaml), &opt))
|
require.NoError(t, yaml.Unmarshal([]byte(testYaml), &opt))
|
||||||
require.NoError(t, utils.MapUnmarshalValidate(opt, cfg))
|
require.NoError(t, utils.MapUnmarshalValidate(opt, cfg))
|
||||||
require.Equal(t, cfg, cfgExpected)
|
require.Equal(t, cfgExpected, cfg)
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,13 +14,12 @@ import (
|
||||||
idlewatcher "github.com/yusing/go-proxy/internal/idlewatcher/types"
|
idlewatcher "github.com/yusing/go-proxy/internal/idlewatcher/types"
|
||||||
"github.com/yusing/go-proxy/internal/logging"
|
"github.com/yusing/go-proxy/internal/logging"
|
||||||
"github.com/yusing/go-proxy/internal/utils"
|
"github.com/yusing/go-proxy/internal/utils"
|
||||||
U "github.com/yusing/go-proxy/internal/utils"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type (
|
type (
|
||||||
PortMapping = map[int]container.Port
|
PortMapping = map[int]container.Port
|
||||||
Container struct {
|
Container struct {
|
||||||
_ U.NoCopy
|
_ utils.NoCopy
|
||||||
|
|
||||||
DockerHost string `json:"docker_host"`
|
DockerHost string `json:"docker_host"`
|
||||||
Image *ContainerImage `json:"image"`
|
Image *ContainerImage `json:"image"`
|
||||||
|
@ -104,6 +103,33 @@ func (c *Container) IsBlacklisted() bool {
|
||||||
return c.Image.IsBlacklisted() || c.isDatabase()
|
return c.Image.IsBlacklisted() || c.isDatabase()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Container) UpdatePorts() error {
|
||||||
|
client, err := NewClient(c.DockerHost)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer client.Close()
|
||||||
|
|
||||||
|
inspect, err := client.ContainerInspect(context.Background(), c.ContainerID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for port := range inspect.Config.ExposedPorts {
|
||||||
|
proto, portStr := nat.SplitProtoPort(string(port))
|
||||||
|
portInt, _ := nat.ParsePort(portStr)
|
||||||
|
if portInt == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
c.PublicPortMapping[portInt] = container.Port{
|
||||||
|
PublicPort: uint16(portInt),
|
||||||
|
PrivatePort: uint16(portInt),
|
||||||
|
Type: proto,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
var databaseMPs = map[string]struct{}{
|
var databaseMPs = map[string]struct{}{
|
||||||
"/var/lib/postgresql/data": {},
|
"/var/lib/postgresql/data": {},
|
||||||
"/var/lib/mysql": {},
|
"/var/lib/mysql": {},
|
||||||
|
@ -205,30 +231,3 @@ func (c *Container) loadDeleteIdlewatcherLabels(helper containerHelper) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Container) UpdatePorts() error {
|
|
||||||
client, err := NewClient(c.DockerHost)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer client.Close()
|
|
||||||
|
|
||||||
inspect, err := client.ContainerInspect(context.Background(), c.ContainerID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
for port := range inspect.Config.ExposedPorts {
|
|
||||||
proto, portStr := nat.SplitProtoPort(string(port))
|
|
||||||
portInt, _ := nat.ParsePort(portStr)
|
|
||||||
if portInt == 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
c.PublicPortMapping[portInt] = container.Port{
|
|
||||||
PublicPort: uint16(portInt),
|
|
||||||
PrivatePort: uint16(portInt),
|
|
||||||
Type: proto,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
|
@ -24,10 +24,6 @@ type Builder struct {
|
||||||
rwLock
|
rwLock
|
||||||
}
|
}
|
||||||
|
|
||||||
type multiline struct {
|
|
||||||
*Builder
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewBuilder creates a new Builder.
|
// NewBuilder creates a new Builder.
|
||||||
//
|
//
|
||||||
// If about is not provided, the Builder will not have a subject
|
// If about is not provided, the Builder will not have a subject
|
||||||
|
@ -88,23 +84,6 @@ func (b *Builder) Add(err error) {
|
||||||
b.add(err)
|
b.add(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Builder) add(err error) {
|
|
||||||
switch err := err.(type) {
|
|
||||||
case *baseError:
|
|
||||||
b.errs = append(b.errs, err.Err)
|
|
||||||
case *nestedError:
|
|
||||||
if err.Err == nil {
|
|
||||||
b.errs = append(b.errs, err.Extras...)
|
|
||||||
} else {
|
|
||||||
b.errs = append(b.errs, err)
|
|
||||||
}
|
|
||||||
case *MultilineError:
|
|
||||||
b.add(&err.nestedError)
|
|
||||||
default:
|
|
||||||
b.errs = append(b.errs, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *Builder) Adds(err string) {
|
func (b *Builder) Adds(err string) {
|
||||||
b.Lock()
|
b.Lock()
|
||||||
defer b.Unlock()
|
defer b.Unlock()
|
||||||
|
@ -160,3 +139,20 @@ func (b *Builder) ForEach(fn func(error)) {
|
||||||
fn(err)
|
fn(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (b *Builder) add(err error) {
|
||||||
|
switch err := err.(type) { //nolint:errorlint
|
||||||
|
case *baseError:
|
||||||
|
b.errs = append(b.errs, err.Err)
|
||||||
|
case *nestedError:
|
||||||
|
if err.Err == nil {
|
||||||
|
b.errs = append(b.errs, err.Extras...)
|
||||||
|
} else {
|
||||||
|
b.errs = append(b.errs, err)
|
||||||
|
}
|
||||||
|
case *MultilineError:
|
||||||
|
b.add(&err.nestedError)
|
||||||
|
default:
|
||||||
|
b.errs = append(b.errs, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -60,7 +60,7 @@ func fetchIconAbsolute(ctx context.Context, url string) *FetchResult {
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) {
|
if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) {
|
||||||
return &FetchResult{StatusCode: http.StatusBadGateway, ErrMsg: "request timeout"}
|
return &FetchResult{StatusCode: http.StatusBadGateway, ErrMsg: "request timeout"}
|
||||||
|
@ -161,7 +161,7 @@ func findIconSlow(ctx context.Context, r httpRoute, uri string, stack []string)
|
||||||
ctx, cancel := context.WithTimeoutCause(ctx, faviconFetchTimeout, errors.New("favicon request timeout"))
|
ctx, cancel := context.WithTimeoutCause(ctx, faviconFetchTimeout, errors.New("favicon request timeout"))
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
newReq, err := http.NewRequestWithContext(ctx, "GET", r.TargetURL().String(), nil)
|
newReq, err := http.NewRequestWithContext(ctx, http.MethodGet, r.TargetURL().String(), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &FetchResult{StatusCode: http.StatusInternalServerError, ErrMsg: "cannot create request"}
|
return &FetchResult{StatusCode: http.StatusInternalServerError, ErrMsg: "cannot create request"}
|
||||||
}
|
}
|
||||||
|
|
|
@ -80,7 +80,7 @@ func (c *Config) validateProvider() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Config) validateTimeouts() error {
|
func (c *Config) validateTimeouts() error { //nolint:unparam
|
||||||
if c.WakeTimeout == 0 {
|
if c.WakeTimeout == 0 {
|
||||||
c.WakeTimeout = WakeTimeoutDefault
|
c.WakeTimeout = WakeTimeoutDefault
|
||||||
}
|
}
|
||||||
|
|
|
@ -139,12 +139,12 @@ func logEntry() []byte {
|
||||||
Format: FormatJSON,
|
Format: FormatJSON,
|
||||||
})
|
})
|
||||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Write([]byte("hello"))
|
_, _ = w.Write([]byte("hello"))
|
||||||
}))
|
}))
|
||||||
srv.URL = "http://localhost:8080"
|
srv.URL = "http://localhost:8080"
|
||||||
defer srv.Close()
|
defer srv.Close()
|
||||||
// make a request to the server
|
// make a request to the server
|
||||||
req, _ := http.NewRequest("GET", srv.URL, nil)
|
req, _ := http.NewRequest(http.MethodGet, srv.URL, nil)
|
||||||
res := httptest.NewRecorder()
|
res := httptest.NewRecorder()
|
||||||
// server the request
|
// server the request
|
||||||
srv.Config.Handler.ServeHTTP(res, req)
|
srv.Config.Handler.ServeHTTP(res, req)
|
||||||
|
@ -179,7 +179,10 @@ func TestReset(t *testing.T) {
|
||||||
t.Errorf("scanner error: %v", err)
|
t.Errorf("scanner error: %v", err)
|
||||||
}
|
}
|
||||||
expect.Equal(t, linesRead, nLines)
|
expect.Equal(t, linesRead, nLines)
|
||||||
s.Reset()
|
err = s.Reset()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to reset scanner: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
linesRead = 0
|
linesRead = 0
|
||||||
for s.Scan() {
|
for s.Scan() {
|
||||||
|
@ -191,7 +194,7 @@ func TestReset(t *testing.T) {
|
||||||
expect.Equal(t, linesRead, nLines)
|
expect.Equal(t, linesRead, nLines)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 100000 log entries
|
// 100000 log entries.
|
||||||
func BenchmarkBackScanner(b *testing.B) {
|
func BenchmarkBackScanner(b *testing.B) {
|
||||||
mockFile := NewMockFile()
|
mockFile := NewMockFile()
|
||||||
line := logEntry()
|
line := logEntry()
|
||||||
|
|
|
@ -94,32 +94,6 @@ func (f *CombinedFormatter) AppendRequestLog(line []byte, req *http.Request, res
|
||||||
return line
|
return line
|
||||||
}
|
}
|
||||||
|
|
||||||
type zeroLogStringStringMapMarshaler struct {
|
|
||||||
values map[string]string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (z *zeroLogStringStringMapMarshaler) MarshalZerologObject(e *zerolog.Event) {
|
|
||||||
if len(z.values) == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
for k, v := range z.values {
|
|
||||||
e.Str(k, v)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type zeroLogStringStringSliceMapMarshaler struct {
|
|
||||||
values map[string][]string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (z *zeroLogStringStringSliceMapMarshaler) MarshalZerologObject(e *zerolog.Event) {
|
|
||||||
if len(z.values) == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
for k, v := range z.values {
|
|
||||||
e.Strs(k, v)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *JSONFormatter) AppendRequestLog(line []byte, req *http.Request, res *http.Response) []byte {
|
func (f *JSONFormatter) AppendRequestLog(line []byte, req *http.Request, res *http.Response) []byte {
|
||||||
query := f.cfg.Query.ZerologQuery(req.URL.Query())
|
query := f.cfg.Query.ZerologQuery(req.URL.Query())
|
||||||
headers := f.cfg.Headers.ZerologHeaders(req.Header)
|
headers := f.cfg.Headers.ZerologHeaders(req.Header)
|
||||||
|
|
|
@ -27,10 +27,6 @@ type memLogger struct {
|
||||||
|
|
||||||
type MemLogger io.Writer
|
type MemLogger io.Writer
|
||||||
|
|
||||||
type buffer struct {
|
|
||||||
data []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
const (
|
||||||
maxMemLogSize = 16 * 1024
|
maxMemLogSize = 16 * 1024
|
||||||
truncateSize = maxMemLogSize / 2
|
truncateSize = maxMemLogSize / 2
|
||||||
|
@ -59,64 +55,6 @@ func Events() (<-chan []byte, func()) {
|
||||||
return memLoggerInstance.events()
|
return memLoggerInstance.events()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *memLogger) truncateIfNeeded(n int) {
|
|
||||||
m.RLock()
|
|
||||||
needTruncate := m.Len()+n > maxMemLogSize
|
|
||||||
m.RUnlock()
|
|
||||||
|
|
||||||
if needTruncate {
|
|
||||||
m.Lock()
|
|
||||||
defer m.Unlock()
|
|
||||||
needTruncate = m.Len()+n > maxMemLogSize
|
|
||||||
if !needTruncate {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
m.Truncate(truncateSize)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *memLogger) notifyWS(pos, n int) {
|
|
||||||
if m.connChans.Size() == 0 && m.listeners.Size() == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
timeout := time.NewTimer(3 * time.Second)
|
|
||||||
defer timeout.Stop()
|
|
||||||
|
|
||||||
m.notifyLock.RLock()
|
|
||||||
defer m.notifyLock.RUnlock()
|
|
||||||
|
|
||||||
m.connChans.Range(func(ch chan *logEntryRange, _ struct{}) bool {
|
|
||||||
select {
|
|
||||||
case ch <- &logEntryRange{pos, pos + n}:
|
|
||||||
return true
|
|
||||||
case <-timeout.C:
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
if m.listeners.Size() > 0 {
|
|
||||||
msg := m.Buffer.Bytes()[pos : pos+n]
|
|
||||||
m.listeners.Range(func(ch chan []byte, _ struct{}) bool {
|
|
||||||
select {
|
|
||||||
case <-timeout.C:
|
|
||||||
return false
|
|
||||||
case ch <- msg:
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *memLogger) writeBuf(b []byte) (pos int, err error) {
|
|
||||||
m.Lock()
|
|
||||||
defer m.Unlock()
|
|
||||||
pos = m.Len()
|
|
||||||
_, err = m.Buffer.Write(b)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Write implements io.Writer.
|
// Write implements io.Writer.
|
||||||
func (m *memLogger) Write(p []byte) (n int, err error) {
|
func (m *memLogger) Write(p []byte) (n int, err error) {
|
||||||
n = len(p)
|
n = len(p)
|
||||||
|
@ -159,6 +97,64 @@ func (m *memLogger) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
m.wsStreamLog(r.Context(), conn, logCh)
|
m.wsStreamLog(r.Context(), conn, logCh)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *memLogger) truncateIfNeeded(n int) {
|
||||||
|
m.RLock()
|
||||||
|
needTruncate := m.Len()+n > maxMemLogSize
|
||||||
|
m.RUnlock()
|
||||||
|
|
||||||
|
if needTruncate {
|
||||||
|
m.Lock()
|
||||||
|
defer m.Unlock()
|
||||||
|
needTruncate = m.Len()+n > maxMemLogSize
|
||||||
|
if !needTruncate {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
m.Truncate(truncateSize)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *memLogger) notifyWS(pos, n int) {
|
||||||
|
if m.connChans.Size() == 0 && m.listeners.Size() == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
timeout := time.NewTimer(3 * time.Second)
|
||||||
|
defer timeout.Stop()
|
||||||
|
|
||||||
|
m.notifyLock.RLock()
|
||||||
|
defer m.notifyLock.RUnlock()
|
||||||
|
|
||||||
|
m.connChans.Range(func(ch chan *logEntryRange, _ struct{}) bool {
|
||||||
|
select {
|
||||||
|
case ch <- &logEntryRange{pos, pos + n}:
|
||||||
|
return true
|
||||||
|
case <-timeout.C:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
if m.listeners.Size() > 0 {
|
||||||
|
msg := m.Bytes()[pos : pos+n]
|
||||||
|
m.listeners.Range(func(ch chan []byte, _ struct{}) bool {
|
||||||
|
select {
|
||||||
|
case <-timeout.C:
|
||||||
|
return false
|
||||||
|
case ch <- msg:
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *memLogger) writeBuf(b []byte) (pos int, err error) {
|
||||||
|
m.Lock()
|
||||||
|
defer m.Unlock()
|
||||||
|
pos = m.Len()
|
||||||
|
_, err = m.Buffer.Write(b)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
func (m *memLogger) events() (logs <-chan []byte, cancel func()) {
|
func (m *memLogger) events() (logs <-chan []byte, cancel func()) {
|
||||||
ch := make(chan []byte)
|
ch := make(chan []byte)
|
||||||
m.notifyLock.Lock()
|
m.notifyLock.Lock()
|
||||||
|
@ -181,7 +177,7 @@ func (m *memLogger) wsInitial(ctx context.Context, conn *websocket.Conn) error {
|
||||||
m.Lock()
|
m.Lock()
|
||||||
defer m.Unlock()
|
defer m.Unlock()
|
||||||
|
|
||||||
return m.writeBytes(ctx, conn, m.Buffer.Bytes())
|
return m.writeBytes(ctx, conn, m.Bytes())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *memLogger) wsStreamLog(ctx context.Context, conn *websocket.Conn, ch <-chan *logEntryRange) {
|
func (m *memLogger) wsStreamLog(ctx context.Context, conn *websocket.Conn, ch <-chan *logEntryRange) {
|
||||||
|
@ -191,7 +187,7 @@ func (m *memLogger) wsStreamLog(ctx context.Context, conn *websocket.Conn, ch <-
|
||||||
return
|
return
|
||||||
case logRange := <-ch:
|
case logRange := <-ch:
|
||||||
m.RLock()
|
m.RLock()
|
||||||
msg := m.Buffer.Bytes()[logRange.Start:logRange.End]
|
msg := m.Bytes()[logRange.Start:logRange.End]
|
||||||
err := m.writeBytes(ctx, conn, msg)
|
err := m.writeBytes(ctx, conn, msg)
|
||||||
m.RUnlock()
|
m.RUnlock()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -9,18 +9,10 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/oschwald/maxminddb-golang"
|
"github.com/oschwald/maxminddb-golang"
|
||||||
"github.com/rs/zerolog"
|
|
||||||
maxmind "github.com/yusing/go-proxy/internal/maxmind/types"
|
maxmind "github.com/yusing/go-proxy/internal/maxmind/types"
|
||||||
"github.com/yusing/go-proxy/internal/task"
|
"github.com/yusing/go-proxy/internal/task"
|
||||||
)
|
)
|
||||||
|
|
||||||
// --- Helper for MaxMindConfig ---
|
|
||||||
type testLogger struct{ zerolog.Logger }
|
|
||||||
|
|
||||||
func (testLogger) Info() *zerolog.Event { return &zerolog.Event{} }
|
|
||||||
func (testLogger) Warn() *zerolog.Event { return &zerolog.Event{} }
|
|
||||||
func (testLogger) Err(_ error) *zerolog.Event { return &zerolog.Event{} }
|
|
||||||
|
|
||||||
func testCfg() *MaxMind {
|
func testCfg() *MaxMind {
|
||||||
return &MaxMind{
|
return &MaxMind{
|
||||||
Config: &Config{
|
Config: &Config{
|
||||||
|
@ -41,16 +33,17 @@ func testDoReq(cfg *MaxMind, w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("Last-Modified", testLastMod.Format(http.TimeFormat))
|
w.Header().Set("Last-Modified", testLastMod.Format(http.TimeFormat))
|
||||||
gz := gzip.NewWriter(w)
|
gz := gzip.NewWriter(w)
|
||||||
t := tar.NewWriter(gz)
|
t := tar.NewWriter(gz)
|
||||||
t.WriteHeader(&tar.Header{
|
_ = t.WriteHeader(&tar.Header{
|
||||||
Name: cfg.dbFilename(),
|
Name: cfg.dbFilename(),
|
||||||
})
|
})
|
||||||
t.Write([]byte("1234"))
|
_, _ = t.Write([]byte("1234"))
|
||||||
t.Close()
|
_ = t.Close()
|
||||||
gz.Close()
|
_ = gz.Close()
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
}
|
}
|
||||||
|
|
||||||
func mockDoReq(cfg *MaxMind, t *testing.T) {
|
func mockDoReq(t *testing.T, cfg *MaxMind) {
|
||||||
|
t.Helper()
|
||||||
rw := httptest.NewRecorder()
|
rw := httptest.NewRecorder()
|
||||||
oldDoReq := doReq
|
oldDoReq := doReq
|
||||||
doReq = func(req *http.Request) (*http.Response, error) {
|
doReq = func(req *http.Request) (*http.Response, error) {
|
||||||
|
@ -61,12 +54,14 @@ func mockDoReq(cfg *MaxMind, t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func mockDataDir(t *testing.T) {
|
func mockDataDir(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
oldDataDir := dataDir
|
oldDataDir := dataDir
|
||||||
dataDir = t.TempDir()
|
dataDir = t.TempDir()
|
||||||
t.Cleanup(func() { dataDir = oldDataDir })
|
t.Cleanup(func() { dataDir = oldDataDir })
|
||||||
}
|
}
|
||||||
|
|
||||||
func mockMaxMindDBOpen(t *testing.T) {
|
func mockMaxMindDBOpen(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
oldMaxMindDBOpen := maxmindDBOpen
|
oldMaxMindDBOpen := maxmindDBOpen
|
||||||
maxmindDBOpen = func(path string) (*maxminddb.Reader, error) {
|
maxmindDBOpen = func(path string) (*maxminddb.Reader, error) {
|
||||||
return &maxminddb.Reader{}, nil
|
return &maxminddb.Reader{}, nil
|
||||||
|
@ -76,7 +71,7 @@ func mockMaxMindDBOpen(t *testing.T) {
|
||||||
|
|
||||||
func Test_MaxMindConfig_doReq(t *testing.T) {
|
func Test_MaxMindConfig_doReq(t *testing.T) {
|
||||||
cfg := testCfg()
|
cfg := testCfg()
|
||||||
mockDoReq(cfg, t)
|
mockDoReq(t, cfg)
|
||||||
resp, err := cfg.doReq(http.MethodGet)
|
resp, err := cfg.doReq(http.MethodGet)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("newReq() error = %v", err)
|
t.Fatalf("newReq() error = %v", err)
|
||||||
|
@ -88,7 +83,7 @@ func Test_MaxMindConfig_doReq(t *testing.T) {
|
||||||
|
|
||||||
func Test_MaxMindConfig_checkLatest(t *testing.T) {
|
func Test_MaxMindConfig_checkLatest(t *testing.T) {
|
||||||
cfg := testCfg()
|
cfg := testCfg()
|
||||||
mockDoReq(cfg, t)
|
mockDoReq(t, cfg)
|
||||||
|
|
||||||
latest, err := cfg.checkLastest()
|
latest, err := cfg.checkLastest()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -103,7 +98,7 @@ func Test_MaxMindConfig_download(t *testing.T) {
|
||||||
cfg := testCfg()
|
cfg := testCfg()
|
||||||
mockDataDir(t)
|
mockDataDir(t)
|
||||||
mockMaxMindDBOpen(t)
|
mockMaxMindDBOpen(t)
|
||||||
mockDoReq(cfg, t)
|
mockDoReq(t, cfg)
|
||||||
|
|
||||||
err := cfg.download()
|
err := cfg.download()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -21,13 +21,6 @@ const (
|
||||||
CacheKeyBasicAuth = "basic_auth"
|
CacheKeyBasicAuth = "basic_auth"
|
||||||
)
|
)
|
||||||
|
|
||||||
var cacheKeys = []string{
|
|
||||||
CacheKeyQueries,
|
|
||||||
CacheKeyCookies,
|
|
||||||
CacheKeyRemoteIP,
|
|
||||||
CacheKeyBasicAuth,
|
|
||||||
}
|
|
||||||
|
|
||||||
var cachePool = &sync.Pool{
|
var cachePool = &sync.Pool{
|
||||||
New: func() any {
|
New: func() any {
|
||||||
return make(Cache)
|
return make(Cache)
|
||||||
|
|
|
@ -60,7 +60,7 @@ func (r *StreamRoute) Start(parent task.Parent) gperr.Error {
|
||||||
r.HealthMon = monitor.NewMonitor(r)
|
r.HealthMon = monitor.NewMonitor(r)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := r.Stream.Setup(); err != nil {
|
if err := r.Setup(); err != nil {
|
||||||
r.task.Finish(err)
|
r.task.Finish(err)
|
||||||
return gperr.Wrap(err)
|
return gperr.Wrap(err)
|
||||||
}
|
}
|
||||||
|
@ -104,7 +104,7 @@ func (r *StreamRoute) acceptConnections() {
|
||||||
case <-r.task.Context().Done():
|
case <-r.task.Context().Done():
|
||||||
return
|
return
|
||||||
default:
|
default:
|
||||||
conn, err := r.Stream.Accept()
|
conn, err := r.Accept()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
select {
|
select {
|
||||||
case <-r.task.Context().Done():
|
case <-r.task.Context().Done():
|
||||||
|
@ -118,7 +118,7 @@ func (r *StreamRoute) acceptConnections() {
|
||||||
panic("connection is nil")
|
panic("connection is nil")
|
||||||
}
|
}
|
||||||
go func() {
|
go func() {
|
||||||
err := r.Stream.Handle(conn)
|
err := r.Handle(conn)
|
||||||
if err != nil && !errors.Is(err, context.Canceled) {
|
if err != nil && !errors.Is(err, context.Canceled) {
|
||||||
gperr.LogError("handle connection error", err, &r.l)
|
gperr.LogError("handle connection error", err, &r.l)
|
||||||
}
|
}
|
||||||
|
|
|
@ -26,11 +26,11 @@ func AppendDuration(d time.Duration, buf []byte) []byte {
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
case d < time.Millisecond:
|
case d < time.Millisecond:
|
||||||
buf = strconv.AppendInt(buf, int64(d.Nanoseconds()), 10)
|
buf = strconv.AppendInt(buf, d.Nanoseconds(), 10)
|
||||||
buf = append(buf, []byte(" ns")...)
|
buf = append(buf, []byte(" ns")...)
|
||||||
return buf
|
return buf
|
||||||
case d < time.Second:
|
case d < time.Second:
|
||||||
buf = strconv.AppendInt(buf, int64(d.Milliseconds()), 10)
|
buf = strconv.AppendInt(buf, d.Milliseconds(), 10)
|
||||||
buf = append(buf, []byte(" ms")...)
|
buf = append(buf, []byte(" ms")...)
|
||||||
return buf
|
return buf
|
||||||
}
|
}
|
||||||
|
|
|
@ -93,7 +93,7 @@ func TestFormatTime(t *testing.T) {
|
||||||
result := FormatTimeWithReference(tt.time, now)
|
result := FormatTimeWithReference(tt.time, now)
|
||||||
|
|
||||||
if tt.expectedLength > 0 {
|
if tt.expectedLength > 0 {
|
||||||
require.Equal(t, tt.expectedLength, len(result), result)
|
require.Len(t, result, tt.expectedLength)
|
||||||
} else {
|
} else {
|
||||||
require.Equal(t, tt.expected, result)
|
require.Equal(t, tt.expected, result)
|
||||||
}
|
}
|
||||||
|
@ -213,11 +213,8 @@ func TestFormatLastSeen(t *testing.T) {
|
||||||
|
|
||||||
if tt.name == "zero time" {
|
if tt.name == "zero time" {
|
||||||
require.Equal(t, tt.expected, result)
|
require.Equal(t, tt.expected, result)
|
||||||
} else {
|
} else if result == "never" { // Just make sure it's not "never", the actual formatting is tested in TestFormatTime
|
||||||
// Just make sure it's not "never", the actual formatting is tested in TestFormatTime
|
t.Errorf("Expected non-zero time to not return 'never', got %s", result)
|
||||||
if result == "never" {
|
|
||||||
t.Errorf("Expected non-zero time to not return 'never', got %s", result)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -37,6 +37,6 @@ func (p *Pool[T]) Get() []T {
|
||||||
|
|
||||||
func (p *Pool[T]) Put(b []T) {
|
func (p *Pool[T]) Put(b []T) {
|
||||||
if cap(b) <= p.maxSize {
|
if cap(b) <= p.maxSize {
|
||||||
p.pool.Put(b[:0])
|
p.pool.Put(b[:0]) //nolint:staticcheck
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,10 +5,12 @@ import (
|
||||||
"slices"
|
"slices"
|
||||||
)
|
)
|
||||||
|
|
||||||
type YieldFunc = func(part string, value any) bool
|
type (
|
||||||
type YieldKeyFunc = func(key string) bool
|
YieldFunc = func(part string, value any) bool
|
||||||
type Iterator = func(YieldFunc)
|
YieldKeyFunc = func(key string) bool
|
||||||
type KeyIterator = func(YieldKeyFunc)
|
Iterator = func(YieldFunc)
|
||||||
|
KeyIterator = func(YieldKeyFunc)
|
||||||
|
)
|
||||||
|
|
||||||
// WalkAll walks all nodes in the trie, yields full key and series
|
// WalkAll walks all nodes in the trie, yields full key and series
|
||||||
func (node *Node) Walk(yield YieldFunc) {
|
func (node *Node) Walk(yield YieldFunc) {
|
||||||
|
@ -17,10 +19,7 @@ func (node *Node) Walk(yield YieldFunc) {
|
||||||
|
|
||||||
func (node *Node) walkAll(yield YieldFunc) bool {
|
func (node *Node) walkAll(yield YieldFunc) bool {
|
||||||
if !node.value.IsNil() {
|
if !node.value.IsNil() {
|
||||||
if !yield(node.key, node.value.Load()) {
|
return yield(node.key, node.value.Load())
|
||||||
return false
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
}
|
||||||
for _, v := range node.children.Range {
|
for _, v := range node.children.Range {
|
||||||
if !v.walkAll(yield) {
|
if !v.walkAll(yield) {
|
||||||
|
@ -57,10 +56,9 @@ func (node *Node) Map() map[string]any {
|
||||||
func (tree Root) Query(key *Key) Iterator {
|
func (tree Root) Query(key *Key) Iterator {
|
||||||
if !key.hasWildcard {
|
if !key.hasWildcard {
|
||||||
return func(yield YieldFunc) {
|
return func(yield YieldFunc) {
|
||||||
if v, ok := tree.Node.Get(key); ok {
|
if v, ok := tree.Get(key); ok {
|
||||||
yield(key.full, v)
|
yield(key.full, v)
|
||||||
}
|
}
|
||||||
return
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return func(yield YieldFunc) {
|
return func(yield YieldFunc) {
|
||||||
|
|
|
@ -55,11 +55,11 @@ func (e *EventQueue) Start(eventCh <-chan Event, errCh <-chan gperr.Error) {
|
||||||
// recover panic in onFlush when in production mode
|
// recover panic in onFlush when in production mode
|
||||||
e.onFlush = func(events []Event) {
|
e.onFlush = func(events []Event) {
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := recover(); err != nil {
|
if errV := recover(); errV != nil {
|
||||||
if err, ok := err.(error); ok {
|
if err, ok := errV.(error); ok {
|
||||||
e.onError(gperr.Wrap(err).Subject(e.task.Name()))
|
e.onError(gperr.Wrap(err).Subject(e.task.Name()))
|
||||||
} else {
|
} else {
|
||||||
e.onError(gperr.New("recovered panic in onFlush").Withf("%v", err).Subject(e.task.Name()))
|
e.onError(gperr.New("recovered panic in onFlush").Withf("%v", errV).Subject(e.task.Name()))
|
||||||
}
|
}
|
||||||
if common.IsDebug {
|
if common.IsDebug {
|
||||||
panic(string(debug.Stack()))
|
panic(string(debug.Stack()))
|
||||||
|
|
Loading…
Add table
Reference in a new issue