mirror of
https://github.com/yusing/godoxy.git
synced 2025-05-19 20:32:35 +02:00
style: coed cleanup and fix styling
This commit is contained in:
parent
a06787593c
commit
c05059765d
24 changed files with 161 additions and 218 deletions
|
@ -9,6 +9,7 @@ import (
|
|||
)
|
||||
|
||||
type Matcher func(*maxmind.IPInfo) bool
|
||||
|
||||
type Matchers []Matcher
|
||||
|
||||
const (
|
||||
|
@ -26,6 +27,7 @@ var errMatcherFormat = gperr.Multiline().AddLines(
|
|||
"tz:Asia/Shanghai",
|
||||
"country:GB",
|
||||
)
|
||||
|
||||
var (
|
||||
errSyntax = gperr.New("syntax error")
|
||||
errInvalidIP = gperr.New("invalid IP")
|
||||
|
|
|
@ -1,5 +0,0 @@
|
|||
package dockerapi
|
||||
|
||||
import "time"
|
||||
|
||||
const reqTimeout = 10 * time.Second
|
|
@ -18,7 +18,7 @@ type Container struct {
|
|||
}
|
||||
|
||||
func Containers(w http.ResponseWriter, r *http.Request) {
|
||||
serveHTTP[Container, []Container](w, r, GetContainers)
|
||||
serveHTTP[Container](w, r, GetContainers)
|
||||
}
|
||||
|
||||
func GetContainers(ctx context.Context, dockerClients DockerClients) ([]Container, gperr.Error) {
|
||||
|
|
|
@ -22,7 +22,7 @@ func Logs(w http.ResponseWriter, r *http.Request) {
|
|||
until := query.Get("to")
|
||||
levels := query.Get("levels") // TODO: implement levels
|
||||
|
||||
dockerClient, found, err := getDockerClient(w, server)
|
||||
dockerClient, found, err := getDockerClient(server)
|
||||
if err != nil {
|
||||
gphttp.BadRequest(w, err.Error())
|
||||
return
|
||||
|
|
|
@ -56,7 +56,7 @@ func getDockerClients() (DockerClients, gperr.Error) {
|
|||
return dockerClients, connErrs.Error()
|
||||
}
|
||||
|
||||
func getDockerClient(w http.ResponseWriter, server string) (*docker.SharedClient, bool, error) {
|
||||
func getDockerClient(server string) (*docker.SharedClient, bool, error) {
|
||||
cfg := config.GetInstance()
|
||||
var host string
|
||||
for name, h := range cfg.Value().Providers.Docker {
|
||||
|
@ -98,7 +98,7 @@ func handleResult[V any, T ResultType[V]](w http.ResponseWriter, errs error, res
|
|||
return
|
||||
}
|
||||
}
|
||||
json.NewEncoder(w).Encode(result)
|
||||
json.NewEncoder(w).Encode(result) //nolint
|
||||
}
|
||||
|
||||
func serveHTTP[V any, T ResultType[V]](w http.ResponseWriter, r *http.Request, getResult func(ctx context.Context, dockerClients DockerClients) (T, gperr.Error)) {
|
||||
|
@ -119,6 +119,6 @@ func serveHTTP[V any, T ResultType[V]](w http.ResponseWriter, r *http.Request, g
|
|||
})
|
||||
} else {
|
||||
result, err := getResult(r.Context(), dockerClients)
|
||||
handleResult[V, T](w, err, result)
|
||||
handleResult[V](w, err, result)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -22,7 +22,7 @@ type oauthRefreshToken struct {
|
|||
RefreshToken string `json:"refresh_token"`
|
||||
Expiry time.Time `json:"expiry"`
|
||||
|
||||
result *refreshResult
|
||||
result *RefreshResult
|
||||
err error
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
@ -33,7 +33,7 @@ type Session struct {
|
|||
Groups []string `json:"groups"`
|
||||
}
|
||||
|
||||
type refreshResult struct {
|
||||
type RefreshResult struct {
|
||||
newSession Session
|
||||
jwt string
|
||||
jwtExpiry time.Time
|
||||
|
@ -50,7 +50,6 @@ var oauthRefreshTokens jsonstore.MapStore[*oauthRefreshToken]
|
|||
|
||||
var (
|
||||
defaultRefreshTokenExpiry = 30 * 24 * time.Hour // 1 month
|
||||
refreshBefore = 30 * time.Second
|
||||
sessionInvalidateDelay = 3 * time.Second
|
||||
)
|
||||
|
||||
|
@ -148,7 +147,7 @@ func (auth *OIDCProvider) parseSessionJWT(sessionJWT string) (claims *sessionCla
|
|||
return claims, sessionToken.Valid && claims.Issuer == sessionTokenIssuer, nil
|
||||
}
|
||||
|
||||
func (auth *OIDCProvider) TryRefreshToken(ctx context.Context, sessionJWT string) (*refreshResult, error) {
|
||||
func (auth *OIDCProvider) TryRefreshToken(ctx context.Context, sessionJWT string) (*RefreshResult, error) {
|
||||
// verify the session cookie
|
||||
claims, valid, err := auth.parseSessionJWT(sessionJWT)
|
||||
if err != nil {
|
||||
|
@ -171,7 +170,7 @@ func (auth *OIDCProvider) TryRefreshToken(ctx context.Context, sessionJWT string
|
|||
return auth.doRefreshToken(ctx, refreshToken, &claims.Session)
|
||||
}
|
||||
|
||||
func (auth *OIDCProvider) doRefreshToken(ctx context.Context, refreshToken *oauthRefreshToken, claims *Session) (*refreshResult, error) {
|
||||
func (auth *OIDCProvider) doRefreshToken(ctx context.Context, refreshToken *oauthRefreshToken, claims *Session) (*RefreshResult, error) {
|
||||
refreshToken.mu.Lock()
|
||||
defer refreshToken.mu.Unlock()
|
||||
|
||||
|
@ -209,7 +208,7 @@ func (auth *OIDCProvider) doRefreshToken(ctx context.Context, refreshToken *oaut
|
|||
logging.Debug().Str("username", claims.Username).Time("expiry", newToken.Expiry).Msg("refreshed token")
|
||||
storeOAuthRefreshToken(sessionID, claims.Username, newToken.RefreshToken)
|
||||
|
||||
refreshToken.result = &refreshResult{
|
||||
refreshToken.result = &RefreshResult{
|
||||
newSession: Session{
|
||||
SessionID: sessionID,
|
||||
Username: claims.Username,
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"encoding/base64"
|
||||
|
@ -24,7 +23,7 @@ import (
|
|||
func setupMockOIDC(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
provider := (&oidc.ProviderConfig{}).NewProvider(context.TODO())
|
||||
provider := (&oidc.ProviderConfig{}).NewProvider(t.Context())
|
||||
defaultAuth = &OIDCProvider{
|
||||
oauthConfig: &oauth2.Config{
|
||||
ClientID: "test-client",
|
||||
|
@ -104,7 +103,7 @@ func setupProvider(t *testing.T) *provider {
|
|||
t.Cleanup(ts.Close)
|
||||
|
||||
// Create a test OIDCProvider.
|
||||
providerCtx := oidc.ClientContext(context.Background(), ts.Client())
|
||||
providerCtx := oidc.ClientContext(t.Context(), ts.Client())
|
||||
keySet := oidc.NewRemoteKeySet(providerCtx, ts.URL+"/.well-known/jwks.json")
|
||||
|
||||
return &provider{
|
||||
|
|
|
@ -9,7 +9,6 @@ import (
|
|||
"path"
|
||||
"reflect"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/go-acme/lego/v4/certificate"
|
||||
|
@ -33,8 +32,6 @@ type (
|
|||
legoCert *certificate.Resource
|
||||
tlsCert *tls.Certificate
|
||||
certExpiries CertExpiries
|
||||
|
||||
obtainMu sync.Mutex
|
||||
}
|
||||
|
||||
CertExpiries map[string]time.Time
|
||||
|
|
|
@ -46,5 +46,5 @@ oauth2_config:
|
|||
opt := make(map[string]any)
|
||||
require.NoError(t, yaml.Unmarshal([]byte(testYaml), &opt))
|
||||
require.NoError(t, utils.MapUnmarshalValidate(opt, cfg))
|
||||
require.Equal(t, cfg, cfgExpected)
|
||||
require.Equal(t, cfgExpected, cfg)
|
||||
}
|
||||
|
|
|
@ -14,13 +14,12 @@ import (
|
|||
idlewatcher "github.com/yusing/go-proxy/internal/idlewatcher/types"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/yusing/go-proxy/internal/utils"
|
||||
U "github.com/yusing/go-proxy/internal/utils"
|
||||
)
|
||||
|
||||
type (
|
||||
PortMapping = map[int]container.Port
|
||||
Container struct {
|
||||
_ U.NoCopy
|
||||
_ utils.NoCopy
|
||||
|
||||
DockerHost string `json:"docker_host"`
|
||||
Image *ContainerImage `json:"image"`
|
||||
|
@ -104,6 +103,33 @@ func (c *Container) IsBlacklisted() bool {
|
|||
return c.Image.IsBlacklisted() || c.isDatabase()
|
||||
}
|
||||
|
||||
func (c *Container) UpdatePorts() error {
|
||||
client, err := NewClient(c.DockerHost)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
inspect, err := client.ContainerInspect(context.Background(), c.ContainerID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for port := range inspect.Config.ExposedPorts {
|
||||
proto, portStr := nat.SplitProtoPort(string(port))
|
||||
portInt, _ := nat.ParsePort(portStr)
|
||||
if portInt == 0 {
|
||||
continue
|
||||
}
|
||||
c.PublicPortMapping[portInt] = container.Port{
|
||||
PublicPort: uint16(portInt),
|
||||
PrivatePort: uint16(portInt),
|
||||
Type: proto,
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var databaseMPs = map[string]struct{}{
|
||||
"/var/lib/postgresql/data": {},
|
||||
"/var/lib/mysql": {},
|
||||
|
@ -205,30 +231,3 @@ func (c *Container) loadDeleteIdlewatcherLabels(helper containerHelper) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Container) UpdatePorts() error {
|
||||
client, err := NewClient(c.DockerHost)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
inspect, err := client.ContainerInspect(context.Background(), c.ContainerID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for port := range inspect.Config.ExposedPorts {
|
||||
proto, portStr := nat.SplitProtoPort(string(port))
|
||||
portInt, _ := nat.ParsePort(portStr)
|
||||
if portInt == 0 {
|
||||
continue
|
||||
}
|
||||
c.PublicPortMapping[portInt] = container.Port{
|
||||
PublicPort: uint16(portInt),
|
||||
PrivatePort: uint16(portInt),
|
||||
Type: proto,
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -24,10 +24,6 @@ type Builder struct {
|
|||
rwLock
|
||||
}
|
||||
|
||||
type multiline struct {
|
||||
*Builder
|
||||
}
|
||||
|
||||
// NewBuilder creates a new Builder.
|
||||
//
|
||||
// If about is not provided, the Builder will not have a subject
|
||||
|
@ -88,23 +84,6 @@ func (b *Builder) Add(err error) {
|
|||
b.add(err)
|
||||
}
|
||||
|
||||
func (b *Builder) add(err error) {
|
||||
switch err := err.(type) {
|
||||
case *baseError:
|
||||
b.errs = append(b.errs, err.Err)
|
||||
case *nestedError:
|
||||
if err.Err == nil {
|
||||
b.errs = append(b.errs, err.Extras...)
|
||||
} else {
|
||||
b.errs = append(b.errs, err)
|
||||
}
|
||||
case *MultilineError:
|
||||
b.add(&err.nestedError)
|
||||
default:
|
||||
b.errs = append(b.errs, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Builder) Adds(err string) {
|
||||
b.Lock()
|
||||
defer b.Unlock()
|
||||
|
@ -160,3 +139,20 @@ func (b *Builder) ForEach(fn func(error)) {
|
|||
fn(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Builder) add(err error) {
|
||||
switch err := err.(type) { //nolint:errorlint
|
||||
case *baseError:
|
||||
b.errs = append(b.errs, err.Err)
|
||||
case *nestedError:
|
||||
if err.Err == nil {
|
||||
b.errs = append(b.errs, err.Extras...)
|
||||
} else {
|
||||
b.errs = append(b.errs, err)
|
||||
}
|
||||
case *MultilineError:
|
||||
b.add(&err.nestedError)
|
||||
default:
|
||||
b.errs = append(b.errs, err)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -60,7 +60,7 @@ func fetchIconAbsolute(ctx context.Context, url string) *FetchResult {
|
|||
return result
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) {
|
||||
return &FetchResult{StatusCode: http.StatusBadGateway, ErrMsg: "request timeout"}
|
||||
|
@ -161,7 +161,7 @@ func findIconSlow(ctx context.Context, r httpRoute, uri string, stack []string)
|
|||
ctx, cancel := context.WithTimeoutCause(ctx, faviconFetchTimeout, errors.New("favicon request timeout"))
|
||||
defer cancel()
|
||||
|
||||
newReq, err := http.NewRequestWithContext(ctx, "GET", r.TargetURL().String(), nil)
|
||||
newReq, err := http.NewRequestWithContext(ctx, http.MethodGet, r.TargetURL().String(), nil)
|
||||
if err != nil {
|
||||
return &FetchResult{StatusCode: http.StatusInternalServerError, ErrMsg: "cannot create request"}
|
||||
}
|
||||
|
|
|
@ -80,7 +80,7 @@ func (c *Config) validateProvider() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (c *Config) validateTimeouts() error {
|
||||
func (c *Config) validateTimeouts() error { //nolint:unparam
|
||||
if c.WakeTimeout == 0 {
|
||||
c.WakeTimeout = WakeTimeoutDefault
|
||||
}
|
||||
|
|
|
@ -139,12 +139,12 @@ func logEntry() []byte {
|
|||
Format: FormatJSON,
|
||||
})
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("hello"))
|
||||
_, _ = w.Write([]byte("hello"))
|
||||
}))
|
||||
srv.URL = "http://localhost:8080"
|
||||
defer srv.Close()
|
||||
// make a request to the server
|
||||
req, _ := http.NewRequest("GET", srv.URL, nil)
|
||||
req, _ := http.NewRequest(http.MethodGet, srv.URL, nil)
|
||||
res := httptest.NewRecorder()
|
||||
// server the request
|
||||
srv.Config.Handler.ServeHTTP(res, req)
|
||||
|
@ -179,7 +179,10 @@ func TestReset(t *testing.T) {
|
|||
t.Errorf("scanner error: %v", err)
|
||||
}
|
||||
expect.Equal(t, linesRead, nLines)
|
||||
s.Reset()
|
||||
err = s.Reset()
|
||||
if err != nil {
|
||||
t.Errorf("failed to reset scanner: %v", err)
|
||||
}
|
||||
|
||||
linesRead = 0
|
||||
for s.Scan() {
|
||||
|
@ -191,7 +194,7 @@ func TestReset(t *testing.T) {
|
|||
expect.Equal(t, linesRead, nLines)
|
||||
}
|
||||
|
||||
// 100000 log entries
|
||||
// 100000 log entries.
|
||||
func BenchmarkBackScanner(b *testing.B) {
|
||||
mockFile := NewMockFile()
|
||||
line := logEntry()
|
||||
|
|
|
@ -94,32 +94,6 @@ func (f *CombinedFormatter) AppendRequestLog(line []byte, req *http.Request, res
|
|||
return line
|
||||
}
|
||||
|
||||
type zeroLogStringStringMapMarshaler struct {
|
||||
values map[string]string
|
||||
}
|
||||
|
||||
func (z *zeroLogStringStringMapMarshaler) MarshalZerologObject(e *zerolog.Event) {
|
||||
if len(z.values) == 0 {
|
||||
return
|
||||
}
|
||||
for k, v := range z.values {
|
||||
e.Str(k, v)
|
||||
}
|
||||
}
|
||||
|
||||
type zeroLogStringStringSliceMapMarshaler struct {
|
||||
values map[string][]string
|
||||
}
|
||||
|
||||
func (z *zeroLogStringStringSliceMapMarshaler) MarshalZerologObject(e *zerolog.Event) {
|
||||
if len(z.values) == 0 {
|
||||
return
|
||||
}
|
||||
for k, v := range z.values {
|
||||
e.Strs(k, v)
|
||||
}
|
||||
}
|
||||
|
||||
func (f *JSONFormatter) AppendRequestLog(line []byte, req *http.Request, res *http.Response) []byte {
|
||||
query := f.cfg.Query.ZerologQuery(req.URL.Query())
|
||||
headers := f.cfg.Headers.ZerologHeaders(req.Header)
|
||||
|
|
|
@ -27,10 +27,6 @@ type memLogger struct {
|
|||
|
||||
type MemLogger io.Writer
|
||||
|
||||
type buffer struct {
|
||||
data []byte
|
||||
}
|
||||
|
||||
const (
|
||||
maxMemLogSize = 16 * 1024
|
||||
truncateSize = maxMemLogSize / 2
|
||||
|
@ -59,64 +55,6 @@ func Events() (<-chan []byte, func()) {
|
|||
return memLoggerInstance.events()
|
||||
}
|
||||
|
||||
func (m *memLogger) truncateIfNeeded(n int) {
|
||||
m.RLock()
|
||||
needTruncate := m.Len()+n > maxMemLogSize
|
||||
m.RUnlock()
|
||||
|
||||
if needTruncate {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
needTruncate = m.Len()+n > maxMemLogSize
|
||||
if !needTruncate {
|
||||
return
|
||||
}
|
||||
|
||||
m.Truncate(truncateSize)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *memLogger) notifyWS(pos, n int) {
|
||||
if m.connChans.Size() == 0 && m.listeners.Size() == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
timeout := time.NewTimer(3 * time.Second)
|
||||
defer timeout.Stop()
|
||||
|
||||
m.notifyLock.RLock()
|
||||
defer m.notifyLock.RUnlock()
|
||||
|
||||
m.connChans.Range(func(ch chan *logEntryRange, _ struct{}) bool {
|
||||
select {
|
||||
case ch <- &logEntryRange{pos, pos + n}:
|
||||
return true
|
||||
case <-timeout.C:
|
||||
return false
|
||||
}
|
||||
})
|
||||
|
||||
if m.listeners.Size() > 0 {
|
||||
msg := m.Buffer.Bytes()[pos : pos+n]
|
||||
m.listeners.Range(func(ch chan []byte, _ struct{}) bool {
|
||||
select {
|
||||
case <-timeout.C:
|
||||
return false
|
||||
case ch <- msg:
|
||||
return true
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (m *memLogger) writeBuf(b []byte) (pos int, err error) {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
pos = m.Len()
|
||||
_, err = m.Buffer.Write(b)
|
||||
return
|
||||
}
|
||||
|
||||
// Write implements io.Writer.
|
||||
func (m *memLogger) Write(p []byte) (n int, err error) {
|
||||
n = len(p)
|
||||
|
@ -159,6 +97,64 @@ func (m *memLogger) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||
m.wsStreamLog(r.Context(), conn, logCh)
|
||||
}
|
||||
|
||||
func (m *memLogger) truncateIfNeeded(n int) {
|
||||
m.RLock()
|
||||
needTruncate := m.Len()+n > maxMemLogSize
|
||||
m.RUnlock()
|
||||
|
||||
if needTruncate {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
needTruncate = m.Len()+n > maxMemLogSize
|
||||
if !needTruncate {
|
||||
return
|
||||
}
|
||||
|
||||
m.Truncate(truncateSize)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *memLogger) notifyWS(pos, n int) {
|
||||
if m.connChans.Size() == 0 && m.listeners.Size() == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
timeout := time.NewTimer(3 * time.Second)
|
||||
defer timeout.Stop()
|
||||
|
||||
m.notifyLock.RLock()
|
||||
defer m.notifyLock.RUnlock()
|
||||
|
||||
m.connChans.Range(func(ch chan *logEntryRange, _ struct{}) bool {
|
||||
select {
|
||||
case ch <- &logEntryRange{pos, pos + n}:
|
||||
return true
|
||||
case <-timeout.C:
|
||||
return false
|
||||
}
|
||||
})
|
||||
|
||||
if m.listeners.Size() > 0 {
|
||||
msg := m.Bytes()[pos : pos+n]
|
||||
m.listeners.Range(func(ch chan []byte, _ struct{}) bool {
|
||||
select {
|
||||
case <-timeout.C:
|
||||
return false
|
||||
case ch <- msg:
|
||||
return true
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (m *memLogger) writeBuf(b []byte) (pos int, err error) {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
pos = m.Len()
|
||||
_, err = m.Buffer.Write(b)
|
||||
return
|
||||
}
|
||||
|
||||
func (m *memLogger) events() (logs <-chan []byte, cancel func()) {
|
||||
ch := make(chan []byte)
|
||||
m.notifyLock.Lock()
|
||||
|
@ -181,7 +177,7 @@ func (m *memLogger) wsInitial(ctx context.Context, conn *websocket.Conn) error {
|
|||
m.Lock()
|
||||
defer m.Unlock()
|
||||
|
||||
return m.writeBytes(ctx, conn, m.Buffer.Bytes())
|
||||
return m.writeBytes(ctx, conn, m.Bytes())
|
||||
}
|
||||
|
||||
func (m *memLogger) wsStreamLog(ctx context.Context, conn *websocket.Conn, ch <-chan *logEntryRange) {
|
||||
|
@ -191,7 +187,7 @@ func (m *memLogger) wsStreamLog(ctx context.Context, conn *websocket.Conn, ch <-
|
|||
return
|
||||
case logRange := <-ch:
|
||||
m.RLock()
|
||||
msg := m.Buffer.Bytes()[logRange.Start:logRange.End]
|
||||
msg := m.Bytes()[logRange.Start:logRange.End]
|
||||
err := m.writeBytes(ctx, conn, msg)
|
||||
m.RUnlock()
|
||||
if err != nil {
|
||||
|
|
|
@ -9,18 +9,10 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/oschwald/maxminddb-golang"
|
||||
"github.com/rs/zerolog"
|
||||
maxmind "github.com/yusing/go-proxy/internal/maxmind/types"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
)
|
||||
|
||||
// --- Helper for MaxMindConfig ---
|
||||
type testLogger struct{ zerolog.Logger }
|
||||
|
||||
func (testLogger) Info() *zerolog.Event { return &zerolog.Event{} }
|
||||
func (testLogger) Warn() *zerolog.Event { return &zerolog.Event{} }
|
||||
func (testLogger) Err(_ error) *zerolog.Event { return &zerolog.Event{} }
|
||||
|
||||
func testCfg() *MaxMind {
|
||||
return &MaxMind{
|
||||
Config: &Config{
|
||||
|
@ -41,16 +33,17 @@ func testDoReq(cfg *MaxMind, w http.ResponseWriter, r *http.Request) {
|
|||
w.Header().Set("Last-Modified", testLastMod.Format(http.TimeFormat))
|
||||
gz := gzip.NewWriter(w)
|
||||
t := tar.NewWriter(gz)
|
||||
t.WriteHeader(&tar.Header{
|
||||
_ = t.WriteHeader(&tar.Header{
|
||||
Name: cfg.dbFilename(),
|
||||
})
|
||||
t.Write([]byte("1234"))
|
||||
t.Close()
|
||||
gz.Close()
|
||||
_, _ = t.Write([]byte("1234"))
|
||||
_ = t.Close()
|
||||
_ = gz.Close()
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
func mockDoReq(cfg *MaxMind, t *testing.T) {
|
||||
func mockDoReq(t *testing.T, cfg *MaxMind) {
|
||||
t.Helper()
|
||||
rw := httptest.NewRecorder()
|
||||
oldDoReq := doReq
|
||||
doReq = func(req *http.Request) (*http.Response, error) {
|
||||
|
@ -61,12 +54,14 @@ func mockDoReq(cfg *MaxMind, t *testing.T) {
|
|||
}
|
||||
|
||||
func mockDataDir(t *testing.T) {
|
||||
t.Helper()
|
||||
oldDataDir := dataDir
|
||||
dataDir = t.TempDir()
|
||||
t.Cleanup(func() { dataDir = oldDataDir })
|
||||
}
|
||||
|
||||
func mockMaxMindDBOpen(t *testing.T) {
|
||||
t.Helper()
|
||||
oldMaxMindDBOpen := maxmindDBOpen
|
||||
maxmindDBOpen = func(path string) (*maxminddb.Reader, error) {
|
||||
return &maxminddb.Reader{}, nil
|
||||
|
@ -76,7 +71,7 @@ func mockMaxMindDBOpen(t *testing.T) {
|
|||
|
||||
func Test_MaxMindConfig_doReq(t *testing.T) {
|
||||
cfg := testCfg()
|
||||
mockDoReq(cfg, t)
|
||||
mockDoReq(t, cfg)
|
||||
resp, err := cfg.doReq(http.MethodGet)
|
||||
if err != nil {
|
||||
t.Fatalf("newReq() error = %v", err)
|
||||
|
@ -88,7 +83,7 @@ func Test_MaxMindConfig_doReq(t *testing.T) {
|
|||
|
||||
func Test_MaxMindConfig_checkLatest(t *testing.T) {
|
||||
cfg := testCfg()
|
||||
mockDoReq(cfg, t)
|
||||
mockDoReq(t, cfg)
|
||||
|
||||
latest, err := cfg.checkLastest()
|
||||
if err != nil {
|
||||
|
@ -103,7 +98,7 @@ func Test_MaxMindConfig_download(t *testing.T) {
|
|||
cfg := testCfg()
|
||||
mockDataDir(t)
|
||||
mockMaxMindDBOpen(t)
|
||||
mockDoReq(cfg, t)
|
||||
mockDoReq(t, cfg)
|
||||
|
||||
err := cfg.download()
|
||||
if err != nil {
|
||||
|
|
|
@ -21,13 +21,6 @@ const (
|
|||
CacheKeyBasicAuth = "basic_auth"
|
||||
)
|
||||
|
||||
var cacheKeys = []string{
|
||||
CacheKeyQueries,
|
||||
CacheKeyCookies,
|
||||
CacheKeyRemoteIP,
|
||||
CacheKeyBasicAuth,
|
||||
}
|
||||
|
||||
var cachePool = &sync.Pool{
|
||||
New: func() any {
|
||||
return make(Cache)
|
||||
|
|
|
@ -60,7 +60,7 @@ func (r *StreamRoute) Start(parent task.Parent) gperr.Error {
|
|||
r.HealthMon = monitor.NewMonitor(r)
|
||||
}
|
||||
|
||||
if err := r.Stream.Setup(); err != nil {
|
||||
if err := r.Setup(); err != nil {
|
||||
r.task.Finish(err)
|
||||
return gperr.Wrap(err)
|
||||
}
|
||||
|
@ -104,7 +104,7 @@ func (r *StreamRoute) acceptConnections() {
|
|||
case <-r.task.Context().Done():
|
||||
return
|
||||
default:
|
||||
conn, err := r.Stream.Accept()
|
||||
conn, err := r.Accept()
|
||||
if err != nil {
|
||||
select {
|
||||
case <-r.task.Context().Done():
|
||||
|
@ -118,7 +118,7 @@ func (r *StreamRoute) acceptConnections() {
|
|||
panic("connection is nil")
|
||||
}
|
||||
go func() {
|
||||
err := r.Stream.Handle(conn)
|
||||
err := r.Handle(conn)
|
||||
if err != nil && !errors.Is(err, context.Canceled) {
|
||||
gperr.LogError("handle connection error", err, &r.l)
|
||||
}
|
||||
|
|
|
@ -26,11 +26,11 @@ func AppendDuration(d time.Duration, buf []byte) []byte {
|
|||
|
||||
switch {
|
||||
case d < time.Millisecond:
|
||||
buf = strconv.AppendInt(buf, int64(d.Nanoseconds()), 10)
|
||||
buf = strconv.AppendInt(buf, d.Nanoseconds(), 10)
|
||||
buf = append(buf, []byte(" ns")...)
|
||||
return buf
|
||||
case d < time.Second:
|
||||
buf = strconv.AppendInt(buf, int64(d.Milliseconds()), 10)
|
||||
buf = strconv.AppendInt(buf, d.Milliseconds(), 10)
|
||||
buf = append(buf, []byte(" ms")...)
|
||||
return buf
|
||||
}
|
||||
|
|
|
@ -93,7 +93,7 @@ func TestFormatTime(t *testing.T) {
|
|||
result := FormatTimeWithReference(tt.time, now)
|
||||
|
||||
if tt.expectedLength > 0 {
|
||||
require.Equal(t, tt.expectedLength, len(result), result)
|
||||
require.Len(t, result, tt.expectedLength)
|
||||
} else {
|
||||
require.Equal(t, tt.expected, result)
|
||||
}
|
||||
|
@ -213,12 +213,9 @@ func TestFormatLastSeen(t *testing.T) {
|
|||
|
||||
if tt.name == "zero time" {
|
||||
require.Equal(t, tt.expected, result)
|
||||
} else {
|
||||
// Just make sure it's not "never", the actual formatting is tested in TestFormatTime
|
||||
if result == "never" {
|
||||
} else if result == "never" { // Just make sure it's not "never", the actual formatting is tested in TestFormatTime
|
||||
t.Errorf("Expected non-zero time to not return 'never', got %s", result)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -37,6 +37,6 @@ func (p *Pool[T]) Get() []T {
|
|||
|
||||
func (p *Pool[T]) Put(b []T) {
|
||||
if cap(b) <= p.maxSize {
|
||||
p.pool.Put(b[:0])
|
||||
p.pool.Put(b[:0]) //nolint:staticcheck
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,10 +5,12 @@ import (
|
|||
"slices"
|
||||
)
|
||||
|
||||
type YieldFunc = func(part string, value any) bool
|
||||
type YieldKeyFunc = func(key string) bool
|
||||
type Iterator = func(YieldFunc)
|
||||
type KeyIterator = func(YieldKeyFunc)
|
||||
type (
|
||||
YieldFunc = func(part string, value any) bool
|
||||
YieldKeyFunc = func(key string) bool
|
||||
Iterator = func(YieldFunc)
|
||||
KeyIterator = func(YieldKeyFunc)
|
||||
)
|
||||
|
||||
// WalkAll walks all nodes in the trie, yields full key and series
|
||||
func (node *Node) Walk(yield YieldFunc) {
|
||||
|
@ -17,10 +19,7 @@ func (node *Node) Walk(yield YieldFunc) {
|
|||
|
||||
func (node *Node) walkAll(yield YieldFunc) bool {
|
||||
if !node.value.IsNil() {
|
||||
if !yield(node.key, node.value.Load()) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
return yield(node.key, node.value.Load())
|
||||
}
|
||||
for _, v := range node.children.Range {
|
||||
if !v.walkAll(yield) {
|
||||
|
@ -57,10 +56,9 @@ func (node *Node) Map() map[string]any {
|
|||
func (tree Root) Query(key *Key) Iterator {
|
||||
if !key.hasWildcard {
|
||||
return func(yield YieldFunc) {
|
||||
if v, ok := tree.Node.Get(key); ok {
|
||||
if v, ok := tree.Get(key); ok {
|
||||
yield(key.full, v)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
return func(yield YieldFunc) {
|
||||
|
|
|
@ -55,11 +55,11 @@ func (e *EventQueue) Start(eventCh <-chan Event, errCh <-chan gperr.Error) {
|
|||
// recover panic in onFlush when in production mode
|
||||
e.onFlush = func(events []Event) {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
if err, ok := err.(error); ok {
|
||||
if errV := recover(); errV != nil {
|
||||
if err, ok := errV.(error); ok {
|
||||
e.onError(gperr.Wrap(err).Subject(e.task.Name()))
|
||||
} else {
|
||||
e.onError(gperr.New("recovered panic in onFlush").Withf("%v", err).Subject(e.task.Name()))
|
||||
e.onError(gperr.New("recovered panic in onFlush").Withf("%v", errV).Subject(e.task.Name()))
|
||||
}
|
||||
if common.IsDebug {
|
||||
panic(string(debug.Stack()))
|
||||
|
|
Loading…
Add table
Reference in a new issue