style: coed cleanup and fix styling

This commit is contained in:
yusing 2025-05-10 10:46:31 +08:00
parent a06787593c
commit c05059765d
24 changed files with 161 additions and 218 deletions

View file

@ -9,6 +9,7 @@ import (
) )
type Matcher func(*maxmind.IPInfo) bool type Matcher func(*maxmind.IPInfo) bool
type Matchers []Matcher type Matchers []Matcher
const ( const (
@ -26,6 +27,7 @@ var errMatcherFormat = gperr.Multiline().AddLines(
"tz:Asia/Shanghai", "tz:Asia/Shanghai",
"country:GB", "country:GB",
) )
var ( var (
errSyntax = gperr.New("syntax error") errSyntax = gperr.New("syntax error")
errInvalidIP = gperr.New("invalid IP") errInvalidIP = gperr.New("invalid IP")

View file

@ -1,5 +0,0 @@
package dockerapi
import "time"
const reqTimeout = 10 * time.Second

View file

@ -18,7 +18,7 @@ type Container struct {
} }
func Containers(w http.ResponseWriter, r *http.Request) { func Containers(w http.ResponseWriter, r *http.Request) {
serveHTTP[Container, []Container](w, r, GetContainers) serveHTTP[Container](w, r, GetContainers)
} }
func GetContainers(ctx context.Context, dockerClients DockerClients) ([]Container, gperr.Error) { func GetContainers(ctx context.Context, dockerClients DockerClients) ([]Container, gperr.Error) {

View file

@ -22,7 +22,7 @@ func Logs(w http.ResponseWriter, r *http.Request) {
until := query.Get("to") until := query.Get("to")
levels := query.Get("levels") // TODO: implement levels levels := query.Get("levels") // TODO: implement levels
dockerClient, found, err := getDockerClient(w, server) dockerClient, found, err := getDockerClient(server)
if err != nil { if err != nil {
gphttp.BadRequest(w, err.Error()) gphttp.BadRequest(w, err.Error())
return return

View file

@ -56,7 +56,7 @@ func getDockerClients() (DockerClients, gperr.Error) {
return dockerClients, connErrs.Error() return dockerClients, connErrs.Error()
} }
func getDockerClient(w http.ResponseWriter, server string) (*docker.SharedClient, bool, error) { func getDockerClient(server string) (*docker.SharedClient, bool, error) {
cfg := config.GetInstance() cfg := config.GetInstance()
var host string var host string
for name, h := range cfg.Value().Providers.Docker { for name, h := range cfg.Value().Providers.Docker {
@ -98,7 +98,7 @@ func handleResult[V any, T ResultType[V]](w http.ResponseWriter, errs error, res
return return
} }
} }
json.NewEncoder(w).Encode(result) json.NewEncoder(w).Encode(result) //nolint
} }
func serveHTTP[V any, T ResultType[V]](w http.ResponseWriter, r *http.Request, getResult func(ctx context.Context, dockerClients DockerClients) (T, gperr.Error)) { func serveHTTP[V any, T ResultType[V]](w http.ResponseWriter, r *http.Request, getResult func(ctx context.Context, dockerClients DockerClients) (T, gperr.Error)) {
@ -119,6 +119,6 @@ func serveHTTP[V any, T ResultType[V]](w http.ResponseWriter, r *http.Request, g
}) })
} else { } else {
result, err := getResult(r.Context(), dockerClients) result, err := getResult(r.Context(), dockerClients)
handleResult[V, T](w, err, result) handleResult[V](w, err, result)
} }
} }

View file

@ -22,7 +22,7 @@ type oauthRefreshToken struct {
RefreshToken string `json:"refresh_token"` RefreshToken string `json:"refresh_token"`
Expiry time.Time `json:"expiry"` Expiry time.Time `json:"expiry"`
result *refreshResult result *RefreshResult
err error err error
mu sync.Mutex mu sync.Mutex
} }
@ -33,7 +33,7 @@ type Session struct {
Groups []string `json:"groups"` Groups []string `json:"groups"`
} }
type refreshResult struct { type RefreshResult struct {
newSession Session newSession Session
jwt string jwt string
jwtExpiry time.Time jwtExpiry time.Time
@ -50,7 +50,6 @@ var oauthRefreshTokens jsonstore.MapStore[*oauthRefreshToken]
var ( var (
defaultRefreshTokenExpiry = 30 * 24 * time.Hour // 1 month defaultRefreshTokenExpiry = 30 * 24 * time.Hour // 1 month
refreshBefore = 30 * time.Second
sessionInvalidateDelay = 3 * time.Second sessionInvalidateDelay = 3 * time.Second
) )
@ -148,7 +147,7 @@ func (auth *OIDCProvider) parseSessionJWT(sessionJWT string) (claims *sessionCla
return claims, sessionToken.Valid && claims.Issuer == sessionTokenIssuer, nil return claims, sessionToken.Valid && claims.Issuer == sessionTokenIssuer, nil
} }
func (auth *OIDCProvider) TryRefreshToken(ctx context.Context, sessionJWT string) (*refreshResult, error) { func (auth *OIDCProvider) TryRefreshToken(ctx context.Context, sessionJWT string) (*RefreshResult, error) {
// verify the session cookie // verify the session cookie
claims, valid, err := auth.parseSessionJWT(sessionJWT) claims, valid, err := auth.parseSessionJWT(sessionJWT)
if err != nil { if err != nil {
@ -171,7 +170,7 @@ func (auth *OIDCProvider) TryRefreshToken(ctx context.Context, sessionJWT string
return auth.doRefreshToken(ctx, refreshToken, &claims.Session) return auth.doRefreshToken(ctx, refreshToken, &claims.Session)
} }
func (auth *OIDCProvider) doRefreshToken(ctx context.Context, refreshToken *oauthRefreshToken, claims *Session) (*refreshResult, error) { func (auth *OIDCProvider) doRefreshToken(ctx context.Context, refreshToken *oauthRefreshToken, claims *Session) (*RefreshResult, error) {
refreshToken.mu.Lock() refreshToken.mu.Lock()
defer refreshToken.mu.Unlock() defer refreshToken.mu.Unlock()
@ -209,7 +208,7 @@ func (auth *OIDCProvider) doRefreshToken(ctx context.Context, refreshToken *oaut
logging.Debug().Str("username", claims.Username).Time("expiry", newToken.Expiry).Msg("refreshed token") logging.Debug().Str("username", claims.Username).Time("expiry", newToken.Expiry).Msg("refreshed token")
storeOAuthRefreshToken(sessionID, claims.Username, newToken.RefreshToken) storeOAuthRefreshToken(sessionID, claims.Username, newToken.RefreshToken)
refreshToken.result = &refreshResult{ refreshToken.result = &RefreshResult{
newSession: Session{ newSession: Session{
SessionID: sessionID, SessionID: sessionID,
Username: claims.Username, Username: claims.Username,

View file

@ -1,7 +1,6 @@
package auth package auth
import ( import (
"context"
"crypto/rand" "crypto/rand"
"crypto/rsa" "crypto/rsa"
"encoding/base64" "encoding/base64"
@ -24,7 +23,7 @@ import (
func setupMockOIDC(t *testing.T) { func setupMockOIDC(t *testing.T) {
t.Helper() t.Helper()
provider := (&oidc.ProviderConfig{}).NewProvider(context.TODO()) provider := (&oidc.ProviderConfig{}).NewProvider(t.Context())
defaultAuth = &OIDCProvider{ defaultAuth = &OIDCProvider{
oauthConfig: &oauth2.Config{ oauthConfig: &oauth2.Config{
ClientID: "test-client", ClientID: "test-client",
@ -104,7 +103,7 @@ func setupProvider(t *testing.T) *provider {
t.Cleanup(ts.Close) t.Cleanup(ts.Close)
// Create a test OIDCProvider. // Create a test OIDCProvider.
providerCtx := oidc.ClientContext(context.Background(), ts.Client()) providerCtx := oidc.ClientContext(t.Context(), ts.Client())
keySet := oidc.NewRemoteKeySet(providerCtx, ts.URL+"/.well-known/jwks.json") keySet := oidc.NewRemoteKeySet(providerCtx, ts.URL+"/.well-known/jwks.json")
return &provider{ return &provider{

View file

@ -9,7 +9,6 @@ import (
"path" "path"
"reflect" "reflect"
"sort" "sort"
"sync"
"time" "time"
"github.com/go-acme/lego/v4/certificate" "github.com/go-acme/lego/v4/certificate"
@ -33,8 +32,6 @@ type (
legoCert *certificate.Resource legoCert *certificate.Resource
tlsCert *tls.Certificate tlsCert *tls.Certificate
certExpiries CertExpiries certExpiries CertExpiries
obtainMu sync.Mutex
} }
CertExpiries map[string]time.Time CertExpiries map[string]time.Time

View file

@ -46,5 +46,5 @@ oauth2_config:
opt := make(map[string]any) opt := make(map[string]any)
require.NoError(t, yaml.Unmarshal([]byte(testYaml), &opt)) require.NoError(t, yaml.Unmarshal([]byte(testYaml), &opt))
require.NoError(t, utils.MapUnmarshalValidate(opt, cfg)) require.NoError(t, utils.MapUnmarshalValidate(opt, cfg))
require.Equal(t, cfg, cfgExpected) require.Equal(t, cfgExpected, cfg)
} }

View file

@ -14,13 +14,12 @@ import (
idlewatcher "github.com/yusing/go-proxy/internal/idlewatcher/types" idlewatcher "github.com/yusing/go-proxy/internal/idlewatcher/types"
"github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/utils" "github.com/yusing/go-proxy/internal/utils"
U "github.com/yusing/go-proxy/internal/utils"
) )
type ( type (
PortMapping = map[int]container.Port PortMapping = map[int]container.Port
Container struct { Container struct {
_ U.NoCopy _ utils.NoCopy
DockerHost string `json:"docker_host"` DockerHost string `json:"docker_host"`
Image *ContainerImage `json:"image"` Image *ContainerImage `json:"image"`
@ -104,6 +103,33 @@ func (c *Container) IsBlacklisted() bool {
return c.Image.IsBlacklisted() || c.isDatabase() return c.Image.IsBlacklisted() || c.isDatabase()
} }
func (c *Container) UpdatePorts() error {
client, err := NewClient(c.DockerHost)
if err != nil {
return err
}
defer client.Close()
inspect, err := client.ContainerInspect(context.Background(), c.ContainerID)
if err != nil {
return err
}
for port := range inspect.Config.ExposedPorts {
proto, portStr := nat.SplitProtoPort(string(port))
portInt, _ := nat.ParsePort(portStr)
if portInt == 0 {
continue
}
c.PublicPortMapping[portInt] = container.Port{
PublicPort: uint16(portInt),
PrivatePort: uint16(portInt),
Type: proto,
}
}
return nil
}
var databaseMPs = map[string]struct{}{ var databaseMPs = map[string]struct{}{
"/var/lib/postgresql/data": {}, "/var/lib/postgresql/data": {},
"/var/lib/mysql": {}, "/var/lib/mysql": {},
@ -205,30 +231,3 @@ func (c *Container) loadDeleteIdlewatcherLabels(helper containerHelper) {
} }
} }
} }
func (c *Container) UpdatePorts() error {
client, err := NewClient(c.DockerHost)
if err != nil {
return err
}
defer client.Close()
inspect, err := client.ContainerInspect(context.Background(), c.ContainerID)
if err != nil {
return err
}
for port := range inspect.Config.ExposedPorts {
proto, portStr := nat.SplitProtoPort(string(port))
portInt, _ := nat.ParsePort(portStr)
if portInt == 0 {
continue
}
c.PublicPortMapping[portInt] = container.Port{
PublicPort: uint16(portInt),
PrivatePort: uint16(portInt),
Type: proto,
}
}
return nil
}

View file

@ -24,10 +24,6 @@ type Builder struct {
rwLock rwLock
} }
type multiline struct {
*Builder
}
// NewBuilder creates a new Builder. // NewBuilder creates a new Builder.
// //
// If about is not provided, the Builder will not have a subject // If about is not provided, the Builder will not have a subject
@ -88,23 +84,6 @@ func (b *Builder) Add(err error) {
b.add(err) b.add(err)
} }
func (b *Builder) add(err error) {
switch err := err.(type) {
case *baseError:
b.errs = append(b.errs, err.Err)
case *nestedError:
if err.Err == nil {
b.errs = append(b.errs, err.Extras...)
} else {
b.errs = append(b.errs, err)
}
case *MultilineError:
b.add(&err.nestedError)
default:
b.errs = append(b.errs, err)
}
}
func (b *Builder) Adds(err string) { func (b *Builder) Adds(err string) {
b.Lock() b.Lock()
defer b.Unlock() defer b.Unlock()
@ -160,3 +139,20 @@ func (b *Builder) ForEach(fn func(error)) {
fn(err) fn(err)
} }
} }
func (b *Builder) add(err error) {
switch err := err.(type) { //nolint:errorlint
case *baseError:
b.errs = append(b.errs, err.Err)
case *nestedError:
if err.Err == nil {
b.errs = append(b.errs, err.Extras...)
} else {
b.errs = append(b.errs, err)
}
case *MultilineError:
b.add(&err.nestedError)
default:
b.errs = append(b.errs, err)
}
}

View file

@ -60,7 +60,7 @@ func fetchIconAbsolute(ctx context.Context, url string) *FetchResult {
return result return result
} }
req, err := http.NewRequestWithContext(ctx, "GET", url, nil) req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil { if err != nil {
if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) { if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) {
return &FetchResult{StatusCode: http.StatusBadGateway, ErrMsg: "request timeout"} return &FetchResult{StatusCode: http.StatusBadGateway, ErrMsg: "request timeout"}
@ -161,7 +161,7 @@ func findIconSlow(ctx context.Context, r httpRoute, uri string, stack []string)
ctx, cancel := context.WithTimeoutCause(ctx, faviconFetchTimeout, errors.New("favicon request timeout")) ctx, cancel := context.WithTimeoutCause(ctx, faviconFetchTimeout, errors.New("favicon request timeout"))
defer cancel() defer cancel()
newReq, err := http.NewRequestWithContext(ctx, "GET", r.TargetURL().String(), nil) newReq, err := http.NewRequestWithContext(ctx, http.MethodGet, r.TargetURL().String(), nil)
if err != nil { if err != nil {
return &FetchResult{StatusCode: http.StatusInternalServerError, ErrMsg: "cannot create request"} return &FetchResult{StatusCode: http.StatusInternalServerError, ErrMsg: "cannot create request"}
} }

View file

@ -80,7 +80,7 @@ func (c *Config) validateProvider() error {
return nil return nil
} }
func (c *Config) validateTimeouts() error { func (c *Config) validateTimeouts() error { //nolint:unparam
if c.WakeTimeout == 0 { if c.WakeTimeout == 0 {
c.WakeTimeout = WakeTimeoutDefault c.WakeTimeout = WakeTimeoutDefault
} }

View file

@ -139,12 +139,12 @@ func logEntry() []byte {
Format: FormatJSON, Format: FormatJSON,
}) })
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("hello")) _, _ = w.Write([]byte("hello"))
})) }))
srv.URL = "http://localhost:8080" srv.URL = "http://localhost:8080"
defer srv.Close() defer srv.Close()
// make a request to the server // make a request to the server
req, _ := http.NewRequest("GET", srv.URL, nil) req, _ := http.NewRequest(http.MethodGet, srv.URL, nil)
res := httptest.NewRecorder() res := httptest.NewRecorder()
// server the request // server the request
srv.Config.Handler.ServeHTTP(res, req) srv.Config.Handler.ServeHTTP(res, req)
@ -179,7 +179,10 @@ func TestReset(t *testing.T) {
t.Errorf("scanner error: %v", err) t.Errorf("scanner error: %v", err)
} }
expect.Equal(t, linesRead, nLines) expect.Equal(t, linesRead, nLines)
s.Reset() err = s.Reset()
if err != nil {
t.Errorf("failed to reset scanner: %v", err)
}
linesRead = 0 linesRead = 0
for s.Scan() { for s.Scan() {
@ -191,7 +194,7 @@ func TestReset(t *testing.T) {
expect.Equal(t, linesRead, nLines) expect.Equal(t, linesRead, nLines)
} }
// 100000 log entries // 100000 log entries.
func BenchmarkBackScanner(b *testing.B) { func BenchmarkBackScanner(b *testing.B) {
mockFile := NewMockFile() mockFile := NewMockFile()
line := logEntry() line := logEntry()

View file

@ -94,32 +94,6 @@ func (f *CombinedFormatter) AppendRequestLog(line []byte, req *http.Request, res
return line return line
} }
type zeroLogStringStringMapMarshaler struct {
values map[string]string
}
func (z *zeroLogStringStringMapMarshaler) MarshalZerologObject(e *zerolog.Event) {
if len(z.values) == 0 {
return
}
for k, v := range z.values {
e.Str(k, v)
}
}
type zeroLogStringStringSliceMapMarshaler struct {
values map[string][]string
}
func (z *zeroLogStringStringSliceMapMarshaler) MarshalZerologObject(e *zerolog.Event) {
if len(z.values) == 0 {
return
}
for k, v := range z.values {
e.Strs(k, v)
}
}
func (f *JSONFormatter) AppendRequestLog(line []byte, req *http.Request, res *http.Response) []byte { func (f *JSONFormatter) AppendRequestLog(line []byte, req *http.Request, res *http.Response) []byte {
query := f.cfg.Query.ZerologQuery(req.URL.Query()) query := f.cfg.Query.ZerologQuery(req.URL.Query())
headers := f.cfg.Headers.ZerologHeaders(req.Header) headers := f.cfg.Headers.ZerologHeaders(req.Header)

View file

@ -27,10 +27,6 @@ type memLogger struct {
type MemLogger io.Writer type MemLogger io.Writer
type buffer struct {
data []byte
}
const ( const (
maxMemLogSize = 16 * 1024 maxMemLogSize = 16 * 1024
truncateSize = maxMemLogSize / 2 truncateSize = maxMemLogSize / 2
@ -59,64 +55,6 @@ func Events() (<-chan []byte, func()) {
return memLoggerInstance.events() return memLoggerInstance.events()
} }
func (m *memLogger) truncateIfNeeded(n int) {
m.RLock()
needTruncate := m.Len()+n > maxMemLogSize
m.RUnlock()
if needTruncate {
m.Lock()
defer m.Unlock()
needTruncate = m.Len()+n > maxMemLogSize
if !needTruncate {
return
}
m.Truncate(truncateSize)
}
}
func (m *memLogger) notifyWS(pos, n int) {
if m.connChans.Size() == 0 && m.listeners.Size() == 0 {
return
}
timeout := time.NewTimer(3 * time.Second)
defer timeout.Stop()
m.notifyLock.RLock()
defer m.notifyLock.RUnlock()
m.connChans.Range(func(ch chan *logEntryRange, _ struct{}) bool {
select {
case ch <- &logEntryRange{pos, pos + n}:
return true
case <-timeout.C:
return false
}
})
if m.listeners.Size() > 0 {
msg := m.Buffer.Bytes()[pos : pos+n]
m.listeners.Range(func(ch chan []byte, _ struct{}) bool {
select {
case <-timeout.C:
return false
case ch <- msg:
return true
}
})
}
}
func (m *memLogger) writeBuf(b []byte) (pos int, err error) {
m.Lock()
defer m.Unlock()
pos = m.Len()
_, err = m.Buffer.Write(b)
return
}
// Write implements io.Writer. // Write implements io.Writer.
func (m *memLogger) Write(p []byte) (n int, err error) { func (m *memLogger) Write(p []byte) (n int, err error) {
n = len(p) n = len(p)
@ -159,6 +97,64 @@ func (m *memLogger) ServeHTTP(w http.ResponseWriter, r *http.Request) {
m.wsStreamLog(r.Context(), conn, logCh) m.wsStreamLog(r.Context(), conn, logCh)
} }
func (m *memLogger) truncateIfNeeded(n int) {
m.RLock()
needTruncate := m.Len()+n > maxMemLogSize
m.RUnlock()
if needTruncate {
m.Lock()
defer m.Unlock()
needTruncate = m.Len()+n > maxMemLogSize
if !needTruncate {
return
}
m.Truncate(truncateSize)
}
}
func (m *memLogger) notifyWS(pos, n int) {
if m.connChans.Size() == 0 && m.listeners.Size() == 0 {
return
}
timeout := time.NewTimer(3 * time.Second)
defer timeout.Stop()
m.notifyLock.RLock()
defer m.notifyLock.RUnlock()
m.connChans.Range(func(ch chan *logEntryRange, _ struct{}) bool {
select {
case ch <- &logEntryRange{pos, pos + n}:
return true
case <-timeout.C:
return false
}
})
if m.listeners.Size() > 0 {
msg := m.Bytes()[pos : pos+n]
m.listeners.Range(func(ch chan []byte, _ struct{}) bool {
select {
case <-timeout.C:
return false
case ch <- msg:
return true
}
})
}
}
func (m *memLogger) writeBuf(b []byte) (pos int, err error) {
m.Lock()
defer m.Unlock()
pos = m.Len()
_, err = m.Buffer.Write(b)
return
}
func (m *memLogger) events() (logs <-chan []byte, cancel func()) { func (m *memLogger) events() (logs <-chan []byte, cancel func()) {
ch := make(chan []byte) ch := make(chan []byte)
m.notifyLock.Lock() m.notifyLock.Lock()
@ -181,7 +177,7 @@ func (m *memLogger) wsInitial(ctx context.Context, conn *websocket.Conn) error {
m.Lock() m.Lock()
defer m.Unlock() defer m.Unlock()
return m.writeBytes(ctx, conn, m.Buffer.Bytes()) return m.writeBytes(ctx, conn, m.Bytes())
} }
func (m *memLogger) wsStreamLog(ctx context.Context, conn *websocket.Conn, ch <-chan *logEntryRange) { func (m *memLogger) wsStreamLog(ctx context.Context, conn *websocket.Conn, ch <-chan *logEntryRange) {
@ -191,7 +187,7 @@ func (m *memLogger) wsStreamLog(ctx context.Context, conn *websocket.Conn, ch <-
return return
case logRange := <-ch: case logRange := <-ch:
m.RLock() m.RLock()
msg := m.Buffer.Bytes()[logRange.Start:logRange.End] msg := m.Bytes()[logRange.Start:logRange.End]
err := m.writeBytes(ctx, conn, msg) err := m.writeBytes(ctx, conn, msg)
m.RUnlock() m.RUnlock()
if err != nil { if err != nil {

View file

@ -9,18 +9,10 @@ import (
"time" "time"
"github.com/oschwald/maxminddb-golang" "github.com/oschwald/maxminddb-golang"
"github.com/rs/zerolog"
maxmind "github.com/yusing/go-proxy/internal/maxmind/types" maxmind "github.com/yusing/go-proxy/internal/maxmind/types"
"github.com/yusing/go-proxy/internal/task" "github.com/yusing/go-proxy/internal/task"
) )
// --- Helper for MaxMindConfig ---
type testLogger struct{ zerolog.Logger }
func (testLogger) Info() *zerolog.Event { return &zerolog.Event{} }
func (testLogger) Warn() *zerolog.Event { return &zerolog.Event{} }
func (testLogger) Err(_ error) *zerolog.Event { return &zerolog.Event{} }
func testCfg() *MaxMind { func testCfg() *MaxMind {
return &MaxMind{ return &MaxMind{
Config: &Config{ Config: &Config{
@ -41,16 +33,17 @@ func testDoReq(cfg *MaxMind, w http.ResponseWriter, r *http.Request) {
w.Header().Set("Last-Modified", testLastMod.Format(http.TimeFormat)) w.Header().Set("Last-Modified", testLastMod.Format(http.TimeFormat))
gz := gzip.NewWriter(w) gz := gzip.NewWriter(w)
t := tar.NewWriter(gz) t := tar.NewWriter(gz)
t.WriteHeader(&tar.Header{ _ = t.WriteHeader(&tar.Header{
Name: cfg.dbFilename(), Name: cfg.dbFilename(),
}) })
t.Write([]byte("1234")) _, _ = t.Write([]byte("1234"))
t.Close() _ = t.Close()
gz.Close() _ = gz.Close()
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
} }
func mockDoReq(cfg *MaxMind, t *testing.T) { func mockDoReq(t *testing.T, cfg *MaxMind) {
t.Helper()
rw := httptest.NewRecorder() rw := httptest.NewRecorder()
oldDoReq := doReq oldDoReq := doReq
doReq = func(req *http.Request) (*http.Response, error) { doReq = func(req *http.Request) (*http.Response, error) {
@ -61,12 +54,14 @@ func mockDoReq(cfg *MaxMind, t *testing.T) {
} }
func mockDataDir(t *testing.T) { func mockDataDir(t *testing.T) {
t.Helper()
oldDataDir := dataDir oldDataDir := dataDir
dataDir = t.TempDir() dataDir = t.TempDir()
t.Cleanup(func() { dataDir = oldDataDir }) t.Cleanup(func() { dataDir = oldDataDir })
} }
func mockMaxMindDBOpen(t *testing.T) { func mockMaxMindDBOpen(t *testing.T) {
t.Helper()
oldMaxMindDBOpen := maxmindDBOpen oldMaxMindDBOpen := maxmindDBOpen
maxmindDBOpen = func(path string) (*maxminddb.Reader, error) { maxmindDBOpen = func(path string) (*maxminddb.Reader, error) {
return &maxminddb.Reader{}, nil return &maxminddb.Reader{}, nil
@ -76,7 +71,7 @@ func mockMaxMindDBOpen(t *testing.T) {
func Test_MaxMindConfig_doReq(t *testing.T) { func Test_MaxMindConfig_doReq(t *testing.T) {
cfg := testCfg() cfg := testCfg()
mockDoReq(cfg, t) mockDoReq(t, cfg)
resp, err := cfg.doReq(http.MethodGet) resp, err := cfg.doReq(http.MethodGet)
if err != nil { if err != nil {
t.Fatalf("newReq() error = %v", err) t.Fatalf("newReq() error = %v", err)
@ -88,7 +83,7 @@ func Test_MaxMindConfig_doReq(t *testing.T) {
func Test_MaxMindConfig_checkLatest(t *testing.T) { func Test_MaxMindConfig_checkLatest(t *testing.T) {
cfg := testCfg() cfg := testCfg()
mockDoReq(cfg, t) mockDoReq(t, cfg)
latest, err := cfg.checkLastest() latest, err := cfg.checkLastest()
if err != nil { if err != nil {
@ -103,7 +98,7 @@ func Test_MaxMindConfig_download(t *testing.T) {
cfg := testCfg() cfg := testCfg()
mockDataDir(t) mockDataDir(t)
mockMaxMindDBOpen(t) mockMaxMindDBOpen(t)
mockDoReq(cfg, t) mockDoReq(t, cfg)
err := cfg.download() err := cfg.download()
if err != nil { if err != nil {

View file

@ -21,13 +21,6 @@ const (
CacheKeyBasicAuth = "basic_auth" CacheKeyBasicAuth = "basic_auth"
) )
var cacheKeys = []string{
CacheKeyQueries,
CacheKeyCookies,
CacheKeyRemoteIP,
CacheKeyBasicAuth,
}
var cachePool = &sync.Pool{ var cachePool = &sync.Pool{
New: func() any { New: func() any {
return make(Cache) return make(Cache)

View file

@ -60,7 +60,7 @@ func (r *StreamRoute) Start(parent task.Parent) gperr.Error {
r.HealthMon = monitor.NewMonitor(r) r.HealthMon = monitor.NewMonitor(r)
} }
if err := r.Stream.Setup(); err != nil { if err := r.Setup(); err != nil {
r.task.Finish(err) r.task.Finish(err)
return gperr.Wrap(err) return gperr.Wrap(err)
} }
@ -104,7 +104,7 @@ func (r *StreamRoute) acceptConnections() {
case <-r.task.Context().Done(): case <-r.task.Context().Done():
return return
default: default:
conn, err := r.Stream.Accept() conn, err := r.Accept()
if err != nil { if err != nil {
select { select {
case <-r.task.Context().Done(): case <-r.task.Context().Done():
@ -118,7 +118,7 @@ func (r *StreamRoute) acceptConnections() {
panic("connection is nil") panic("connection is nil")
} }
go func() { go func() {
err := r.Stream.Handle(conn) err := r.Handle(conn)
if err != nil && !errors.Is(err, context.Canceled) { if err != nil && !errors.Is(err, context.Canceled) {
gperr.LogError("handle connection error", err, &r.l) gperr.LogError("handle connection error", err, &r.l)
} }

View file

@ -26,11 +26,11 @@ func AppendDuration(d time.Duration, buf []byte) []byte {
switch { switch {
case d < time.Millisecond: case d < time.Millisecond:
buf = strconv.AppendInt(buf, int64(d.Nanoseconds()), 10) buf = strconv.AppendInt(buf, d.Nanoseconds(), 10)
buf = append(buf, []byte(" ns")...) buf = append(buf, []byte(" ns")...)
return buf return buf
case d < time.Second: case d < time.Second:
buf = strconv.AppendInt(buf, int64(d.Milliseconds()), 10) buf = strconv.AppendInt(buf, d.Milliseconds(), 10)
buf = append(buf, []byte(" ms")...) buf = append(buf, []byte(" ms")...)
return buf return buf
} }

View file

@ -93,7 +93,7 @@ func TestFormatTime(t *testing.T) {
result := FormatTimeWithReference(tt.time, now) result := FormatTimeWithReference(tt.time, now)
if tt.expectedLength > 0 { if tt.expectedLength > 0 {
require.Equal(t, tt.expectedLength, len(result), result) require.Len(t, result, tt.expectedLength)
} else { } else {
require.Equal(t, tt.expected, result) require.Equal(t, tt.expected, result)
} }
@ -213,11 +213,8 @@ func TestFormatLastSeen(t *testing.T) {
if tt.name == "zero time" { if tt.name == "zero time" {
require.Equal(t, tt.expected, result) require.Equal(t, tt.expected, result)
} else { } else if result == "never" { // Just make sure it's not "never", the actual formatting is tested in TestFormatTime
// Just make sure it's not "never", the actual formatting is tested in TestFormatTime t.Errorf("Expected non-zero time to not return 'never', got %s", result)
if result == "never" {
t.Errorf("Expected non-zero time to not return 'never', got %s", result)
}
} }
}) })
} }

View file

@ -37,6 +37,6 @@ func (p *Pool[T]) Get() []T {
func (p *Pool[T]) Put(b []T) { func (p *Pool[T]) Put(b []T) {
if cap(b) <= p.maxSize { if cap(b) <= p.maxSize {
p.pool.Put(b[:0]) p.pool.Put(b[:0]) //nolint:staticcheck
} }
} }

View file

@ -5,10 +5,12 @@ import (
"slices" "slices"
) )
type YieldFunc = func(part string, value any) bool type (
type YieldKeyFunc = func(key string) bool YieldFunc = func(part string, value any) bool
type Iterator = func(YieldFunc) YieldKeyFunc = func(key string) bool
type KeyIterator = func(YieldKeyFunc) Iterator = func(YieldFunc)
KeyIterator = func(YieldKeyFunc)
)
// WalkAll walks all nodes in the trie, yields full key and series // WalkAll walks all nodes in the trie, yields full key and series
func (node *Node) Walk(yield YieldFunc) { func (node *Node) Walk(yield YieldFunc) {
@ -17,10 +19,7 @@ func (node *Node) Walk(yield YieldFunc) {
func (node *Node) walkAll(yield YieldFunc) bool { func (node *Node) walkAll(yield YieldFunc) bool {
if !node.value.IsNil() { if !node.value.IsNil() {
if !yield(node.key, node.value.Load()) { return yield(node.key, node.value.Load())
return false
}
return true
} }
for _, v := range node.children.Range { for _, v := range node.children.Range {
if !v.walkAll(yield) { if !v.walkAll(yield) {
@ -57,10 +56,9 @@ func (node *Node) Map() map[string]any {
func (tree Root) Query(key *Key) Iterator { func (tree Root) Query(key *Key) Iterator {
if !key.hasWildcard { if !key.hasWildcard {
return func(yield YieldFunc) { return func(yield YieldFunc) {
if v, ok := tree.Node.Get(key); ok { if v, ok := tree.Get(key); ok {
yield(key.full, v) yield(key.full, v)
} }
return
} }
} }
return func(yield YieldFunc) { return func(yield YieldFunc) {

View file

@ -55,11 +55,11 @@ func (e *EventQueue) Start(eventCh <-chan Event, errCh <-chan gperr.Error) {
// recover panic in onFlush when in production mode // recover panic in onFlush when in production mode
e.onFlush = func(events []Event) { e.onFlush = func(events []Event) {
defer func() { defer func() {
if err := recover(); err != nil { if errV := recover(); errV != nil {
if err, ok := err.(error); ok { if err, ok := errV.(error); ok {
e.onError(gperr.Wrap(err).Subject(e.task.Name())) e.onError(gperr.Wrap(err).Subject(e.task.Name()))
} else { } else {
e.onError(gperr.New("recovered panic in onFlush").Withf("%v", err).Subject(e.task.Name())) e.onError(gperr.New("recovered panic in onFlush").Withf("%v", errV).Subject(e.task.Name()))
} }
if common.IsDebug { if common.IsDebug {
panic(string(debug.Stack())) panic(string(debug.Stack()))