refactor: code refactor and improved context and error handling
Some checks are pending
Docker Image CI (nightly) / build-nightly (push) Waiting to run
Docker Image CI (nightly) / build-nightly-agent (push) Waiting to run

This commit is contained in:
yusing 2025-05-24 10:02:24 +08:00
parent 1f1ae38e4d
commit 5b7c392297
31 changed files with 116 additions and 98 deletions

View file

@ -2,15 +2,16 @@ version: "2"
linters: linters:
default: all default: all
disable: disable:
- bodyclose # - bodyclose
- containedctx - containedctx
- contextcheck # - contextcheck
- cyclop - cyclop
- depguard - depguard
- dupl # - dupl
- err113 - err113
- exhaustive - exhaustive
- exhaustruct - exhaustruct
- funcorder
- forcetypeassert - forcetypeassert
- gochecknoglobals - gochecknoglobals
- gochecknoinits - gochecknoinits
@ -18,7 +19,6 @@ linters:
- goconst - goconst
- gocyclo - gocyclo
- gomoddirectives - gomoddirectives
- gosec
- gosmopolitan - gosmopolitan
- ireturn - ireturn
- lll - lll
@ -27,12 +27,10 @@ linters:
- mnd - mnd
- nakedret - nakedret
- nestif - nestif
- nilnil
- nlreturn - nlreturn
- noctx
- nonamedreturns - nonamedreturns
- paralleltest - paralleltest
- prealloc - revive
- rowserrcheck - rowserrcheck
- sqlclosecheck - sqlclosecheck
- tagliatelle - tagliatelle

View file

@ -21,7 +21,7 @@ lint:
- markdownlint - markdownlint
- yamllint - yamllint
enabled: enabled:
- checkov@3.2.416 - checkov@3.2.432
- golangci-lint2@2.1.6 - golangci-lint2@2.1.6
- hadolint@2.12.1-beta - hadolint@2.12.1-beta
- actionlint@1.7.7 - actionlint@1.7.7
@ -32,7 +32,7 @@ lint:
- prettier@3.5.3 - prettier@3.5.3
- shellcheck@0.10.0 - shellcheck@0.10.0
- shfmt@3.6.0 - shfmt@3.6.0
- trufflehog@3.88.29 - trufflehog@3.88.33
actions: actions:
disabled: disabled:
- trunk-announce - trunk-announce

View file

@ -45,7 +45,7 @@ func (c *checkCache) Expired() bool {
return c.created.Add(cacheTTL).Before(utils.TimeNow()) return c.created.Add(cacheTTL).Before(utils.TimeNow())
} }
//TODO: add stats // TODO: add stats
const ( const (
ACLAllow = "allow" ACLAllow = "allow"

View file

@ -6,7 +6,7 @@ import (
"testing" "testing"
maxmind "github.com/yusing/go-proxy/internal/maxmind/types" maxmind "github.com/yusing/go-proxy/internal/maxmind/types"
"github.com/yusing/go-proxy/internal/utils" "github.com/yusing/go-proxy/internal/serialization"
) )
func TestMatchers(t *testing.T) { func TestMatchers(t *testing.T) {
@ -16,7 +16,7 @@ func TestMatchers(t *testing.T) {
} }
var mathers Matchers var mathers Matchers
err := utils.Convert(reflect.ValueOf(strMatchers), reflect.ValueOf(&mathers), false) err := serialization.Convert(reflect.ValueOf(strMatchers), reflect.ValueOf(&mathers), false)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View file

@ -22,12 +22,12 @@ func (noConn) SetDeadline(t time.Time) error { return nil }
func (noConn) SetReadDeadline(t time.Time) error { return nil } func (noConn) SetReadDeadline(t time.Time) error { return nil }
func (noConn) SetWriteDeadline(t time.Time) error { return nil } func (noConn) SetWriteDeadline(t time.Time) error { return nil }
func (cfg *Config) WrapTCP(lis net.Listener) net.Listener { func (c *Config) WrapTCP(lis net.Listener) net.Listener {
if cfg == nil { if c == nil {
return lis return lis
} }
return &TCPListener{ return &TCPListener{
acl: cfg, acl: c,
lis: lis, lis: lis,
} }
} }

View file

@ -190,7 +190,7 @@ func (auth *OIDCProvider) doRefreshToken(ctx context.Context, refreshToken *oaut
return nil, refreshToken.err return nil, refreshToken.err
} }
idTokenJWT, idToken, err := auth.getIdToken(ctx, newToken) idTokenJWT, idToken, err := auth.getIDToken(ctx, newToken)
if err != nil { if err != nil {
refreshToken.err = fmt.Errorf("session: %s - %w: %w", claims.SessionID, ErrRefreshTokenFailure, err) refreshToken.err = fmt.Errorf("session: %s - %w: %w", claims.SessionID, ErrRefreshTokenFailure, err)
return nil, refreshToken.err return nil, refreshToken.err

View file

@ -38,8 +38,8 @@ type (
const ( const (
CookieOauthState = "godoxy_oidc_state" CookieOauthState = "godoxy_oidc_state"
CookieOauthToken = "godoxy_oauth_token" CookieOauthToken = "godoxy_oauth_token" //nolint:gosec
CookieOauthSessionToken = "godoxy_session_token" CookieOauthSessionToken = "godoxy_session_token" //nolint:gosec
) )
const ( const (
@ -129,7 +129,7 @@ func optRedirectPostAuth(r *http.Request) oauth2.AuthCodeOption {
return oauth2.SetAuthURLParam("redirect_uri", "https://"+requestHost(r)+OIDCPostAuthPath) return oauth2.SetAuthURLParam("redirect_uri", "https://"+requestHost(r)+OIDCPostAuthPath)
} }
func (auth *OIDCProvider) getIdToken(ctx context.Context, oauthToken *oauth2.Token) (string, *oidc.IDToken, error) { func (auth *OIDCProvider) getIDToken(ctx context.Context, oauthToken *oauth2.Token) (string, *oidc.IDToken, error) {
idTokenJWT, ok := oauthToken.Extra("id_token").(string) idTokenJWT, ok := oauthToken.Extra("id_token").(string)
if !ok { if !ok {
return "", nil, errMissingIDToken return "", nil, errMissingIDToken
@ -257,7 +257,7 @@ func (auth *OIDCProvider) PostAuthCallbackHandler(w http.ResponseWriter, r *http
return return
} }
idTokenJWT, idToken, err := auth.getIdToken(r.Context(), oauth2Token) idTokenJWT, idToken, err := auth.getIDToken(r.Context(), oauth2Token)
if err != nil { if err != nil {
gphttp.ServerError(w, r, err) gphttp.ServerError(w, r, err)
return return

View file

@ -212,7 +212,7 @@ func (s *testACMEServer) httpClient() *http.Client {
TLSHandshakeTimeout: 30 * time.Second, TLSHandshakeTimeout: 30 * time.Second,
ResponseHeaderTimeout: 30 * time.Second, ResponseHeaderTimeout: 30 * time.Second,
TLSClientConfig: &tls.Config{ TLSClientConfig: &tls.Config{
InsecureSkipVerify: true, InsecureSkipVerify: true, //nolint:gosec
}, },
}, },
} }

View file

@ -6,7 +6,7 @@ import (
"github.com/go-acme/lego/v4/providers/dns/ovh" "github.com/go-acme/lego/v4/providers/dns/ovh"
"github.com/goccy/go-yaml" "github.com/goccy/go-yaml"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/yusing/go-proxy/internal/utils" "github.com/yusing/go-proxy/internal/serialization"
) )
// type Config struct { // type Config struct {
@ -45,6 +45,6 @@ oauth2_config:
testYaml = testYaml[1:] // remove first \n testYaml = testYaml[1:] // remove first \n
opt := make(map[string]any) opt := make(map[string]any)
require.NoError(t, yaml.Unmarshal([]byte(testYaml), &opt)) require.NoError(t, yaml.Unmarshal([]byte(testYaml), &opt))
require.NoError(t, utils.MapUnmarshalValidate(opt, cfg)) require.NoError(t, serialization.MapUnmarshalValidate(opt, cfg))
require.Equal(t, cfgExpected, cfg) require.Equal(t, cfgExpected, cfg)
} }

View file

@ -190,7 +190,7 @@ func NewClient(host string) (*SharedClient, error) {
c.dial = client.Dialer() c.dial = client.Dialer()
} }
if c.addr == "" { if c.addr == "" {
c.addr = c.Client.DaemonHost() c.addr = c.DaemonHost()
} }
defer log.Debug().Str("host", host).Msg("docker client initialized") defer log.Debug().Str("host", host).Msg("docker client initialized")

View file

@ -1,6 +1,7 @@
package homepage package homepage
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
@ -46,30 +47,30 @@ type (
func (icon *IconMeta) Filenames(ref string) []string { func (icon *IconMeta) Filenames(ref string) []string {
filenames := make([]string, 0) filenames := make([]string, 0)
if icon.SVG { if icon.SVG {
filenames = append(filenames, fmt.Sprintf("%s.svg", ref)) filenames = append(filenames, ref+".svg")
if icon.Light { if icon.Light {
filenames = append(filenames, fmt.Sprintf("%s-light.svg", ref)) filenames = append(filenames, ref+"-light.svg")
} }
if icon.Dark { if icon.Dark {
filenames = append(filenames, fmt.Sprintf("%s-dark.svg", ref)) filenames = append(filenames, ref+"-dark.svg")
} }
} }
if icon.PNG { if icon.PNG {
filenames = append(filenames, fmt.Sprintf("%s.png", ref)) filenames = append(filenames, ref+".png")
if icon.Light { if icon.Light {
filenames = append(filenames, fmt.Sprintf("%s-light.png", ref)) filenames = append(filenames, ref+"-light.png")
} }
if icon.Dark { if icon.Dark {
filenames = append(filenames, fmt.Sprintf("%s-dark.png", ref)) filenames = append(filenames, ref+"-dark.png")
} }
} }
if icon.WebP { if icon.WebP {
filenames = append(filenames, fmt.Sprintf("%s.webp", ref)) filenames = append(filenames, ref+".webp")
if icon.Light { if icon.Light {
filenames = append(filenames, fmt.Sprintf("%s-light.webp", ref)) filenames = append(filenames, ref+"-light.webp")
} }
if icon.Dark { if icon.Dark {
filenames = append(filenames, fmt.Sprintf("%s-dark.webp", ref)) filenames = append(filenames, ref+"-dark.webp")
} }
} }
return filenames return filenames
@ -113,7 +114,7 @@ func InitIconListCache() {
} }
task.OnProgramExit("save_icons_cache", func() { task.OnProgramExit("save_icons_cache", func() {
serialization.SaveJSON(common.IconListCachePath, iconsCache, 0o644) _ = serialization.SaveJSON(common.IconListCachePath, iconsCache, 0o644)
}) })
} }
@ -230,14 +231,17 @@ func updateIcons() error {
var httpGet = httpGetImpl var httpGet = httpGetImpl
func MockHttpGet(body []byte) { func MockHTTPGet(body []byte) {
httpGet = func(_ string) ([]byte, error) { httpGet = func(_ string) ([]byte, error) {
return body, nil return body, nil
} }
} }
func httpGetImpl(url string) ([]byte, error) { func httpGetImpl(url string) ([]byte, error) {
req, err := http.NewRequest(http.MethodGet, url, nil) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -347,7 +351,7 @@ func UpdateSelfhstIcons() error {
} }
data := make([]SelfhStIcon, 0) data := make([]SelfhStIcon, 0)
err = json.Unmarshal(body, &data) err = json.Unmarshal(body, &data) //nolint:musttag
if err != nil { if err != nil {
return err return err
} }

View file

@ -68,6 +68,8 @@ type testCases struct {
} }
func runTests(t *testing.T, iconsCache *Cache, test []testCases) { func runTests(t *testing.T, iconsCache *Cache, test []testCases) {
t.Helper()
for _, item := range test { for _, item := range test {
icon, ok := iconsCache.Icons[item.Key] icon, ok := iconsCache.Icons[item.Key]
if !ok { if !ok {
@ -89,7 +91,7 @@ func runTests(t *testing.T, iconsCache *Cache, test []testCases) {
} }
func TestListWalkxCodeIcons(t *testing.T) { func TestListWalkxCodeIcons(t *testing.T) {
MockHttpGet([]byte(walkxcodeIcons)) MockHTTPGet([]byte(walkxcodeIcons))
if err := UpdateWalkxCodeIcons(); err != nil { if err := UpdateWalkxCodeIcons(); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -122,7 +124,7 @@ func TestListWalkxCodeIcons(t *testing.T) {
} }
func TestListSelfhstIcons(t *testing.T) { func TestListSelfhstIcons(t *testing.T) {
MockHttpGet([]byte(selfhstIcons)) MockHTTPGet([]byte(selfhstIcons))
if err := UpdateSelfhstIcons(); err != nil { if err := UpdateSelfhstIcons(); err != nil {
t.Fatal(err) t.Fatal(err)
} }

View file

@ -33,17 +33,18 @@ var widgetProviders = map[string]struct{}{
var ErrInvalidProvider = gperr.New("invalid provider") var ErrInvalidProvider = gperr.New("invalid provider")
func (cfg *Config) UnmarshalMap(m map[string]any) error { func (cfg *Config) UnmarshalMap(m map[string]any) error {
cfg.Provider = m["provider"].(string) var ok bool
cfg.Provider, ok = m["provider"].(string)
if !ok {
return ErrInvalidProvider.Withf("non string")
}
if _, ok := widgetProviders[cfg.Provider]; !ok { if _, ok := widgetProviders[cfg.Provider]; !ok {
return ErrInvalidProvider.Subject(cfg.Provider) return ErrInvalidProvider.Subject(cfg.Provider)
} }
delete(m, "provider") delete(m, "provider")
m, ok := m["config"].(map[string]any) m, ok = m["config"].(map[string]any)
if !ok { if !ok {
return gperr.New("invalid config") return gperr.New("invalid config")
} }
if err := serialization.MapUnmarshalValidate(m, &cfg.Config); err != nil { return serialization.MapUnmarshalValidate(m, &cfg.Config)
return err
}
return nil
} }

View file

@ -73,13 +73,13 @@ var dummyHealthCheckConfig = &health.HealthCheckConfig{
} }
var ( var (
causeReload = gperr.New("reloaded") causeReload = gperr.New("reloaded") //nolint:errname
causeContainerDestroy = gperr.New("container destroyed") causeContainerDestroy = gperr.New("container destroyed") //nolint:errname
) )
const reqTimeout = 3 * time.Second const reqTimeout = 3 * time.Second
// TODO: fix stream type // TODO: fix stream type.
func NewWatcher(parent task.Parent, r routes.Route) (*Watcher, error) { func NewWatcher(parent task.Parent, r routes.Route) (*Watcher, error) {
cfg := r.IdlewatcherConfig() cfg := r.IdlewatcherConfig()
key := cfg.Key() key := cfg.Key()

View file

@ -2,12 +2,11 @@ package jsonstore
import ( import (
"encoding/json" "encoding/json"
"maps"
"os" "os"
"path/filepath" "path/filepath"
"reflect" "reflect"
"maps"
"github.com/puzpuzpuz/xsync/v4" "github.com/puzpuzpuz/xsync/v4"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/common"
@ -36,8 +35,10 @@ type store interface {
json.Unmarshaler json.Unmarshaler
} }
var stores = make(map[namespace]store) var (
var storesPath = common.DataDir stores = make(map[namespace]store)
storesPath = common.DataDir
)
func init() { func init() {
task.OnProgramExit("save_stores", func() { task.OnProgramExit("save_stores", func() {
@ -117,7 +118,7 @@ func (s *MapStore[VT]) UnmarshalJSON(data []byte) error {
} }
s.Map = xsync.NewMap[string, VT](xsync.WithPresize(len(tmp))) s.Map = xsync.NewMap[string, VT](xsync.WithPresize(len(tmp)))
for k, v := range tmp { for k, v := range tmp {
s.Map.Store(k, v) s.Store(k, v)
} }
return nil return nil
} }

View file

@ -83,6 +83,9 @@ func NewAccessLogger(parent task.Parent, cfg AnyConfig) (*AccessLogger, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
if io == nil {
return nil, nil //nolint:nilnil
}
return NewAccessLoggerWithIO(parent, io, cfg), nil return NewAccessLoggerWithIO(parent, io, cfg), nil
} }
@ -181,7 +184,7 @@ func (l *AccessLogger) LogError(req *http.Request, err error) {
func (l *AccessLogger) LogACL(info *maxmind.IPInfo, blocked bool) { func (l *AccessLogger) LogACL(info *maxmind.IPInfo, blocked bool) {
line := l.lineBufPool.Get() line := l.lineBufPool.Get()
defer l.lineBufPool.Put(line) defer l.lineBufPool.Put(line)
line = l.ACLFormatter.AppendACLLog(line, info, blocked) line = l.AppendACLLog(line, info, blocked)
if line[len(line)-1] != '\n' { if line[len(line)-1] != '\n' {
line = append(line, '\n') line = append(line, '\n')
} }
@ -194,7 +197,7 @@ func (l *AccessLogger) ShouldRotate() bool {
func (l *AccessLogger) Rotate() (result *RotateResult, err error) { func (l *AccessLogger) Rotate() (result *RotateResult, err error) {
if !l.ShouldRotate() { if !l.ShouldRotate() {
return nil, nil return nil, nil //nolint:nilnil
} }
l.writer.Flush() l.writer.Flush()

View file

@ -5,7 +5,7 @@ import (
"github.com/yusing/go-proxy/internal/docker" "github.com/yusing/go-proxy/internal/docker"
. "github.com/yusing/go-proxy/internal/logging/accesslog" . "github.com/yusing/go-proxy/internal/logging/accesslog"
"github.com/yusing/go-proxy/internal/utils" "github.com/yusing/go-proxy/internal/serialization"
expect "github.com/yusing/go-proxy/internal/utils/testing" expect "github.com/yusing/go-proxy/internal/utils/testing"
) )
@ -29,7 +29,7 @@ func TestNewConfig(t *testing.T) {
expect.NoError(t, err) expect.NoError(t, err)
var config RequestLoggerConfig var config RequestLoggerConfig
err = utils.MapUnmarshalValidate(parsed, &config) err = serialization.MapUnmarshalValidate(parsed, &config)
expect.NoError(t, err) expect.NoError(t, err)
expect.Equal(t, config.Format, FormatCombined) expect.Equal(t, config.Format, FormatCombined)

View file

@ -35,20 +35,19 @@ func newFileIO(path string) (SupportRotate, error) {
if opened, ok := openedFiles[path]; ok { if opened, ok := openedFiles[path]; ok {
opened.refCount.Add() opened.refCount.Add()
return opened, nil return opened, nil
} else {
// cannot open as O_APPEND as we need Seek and WriteAt
f, err := os.OpenFile(path, os.O_CREATE|os.O_RDWR, 0o644)
if err != nil {
return nil, fmt.Errorf("access log open error: %w", err)
}
if _, err := f.Seek(0, io.SeekEnd); err != nil {
return nil, fmt.Errorf("access log seek error: %w", err)
}
file = &File{f: f, path: path, refCount: utils.NewRefCounter()}
openedFiles[path] = file
go file.closeOnZero()
} }
// cannot open as O_APPEND as we need Seek and WriteAt
f, err := os.OpenFile(path, os.O_CREATE|os.O_RDWR, 0o644)
if err != nil {
return nil, fmt.Errorf("access log open error: %w", err)
}
if _, err := f.Seek(0, io.SeekEnd); err != nil {
return nil, fmt.Errorf("access log seek error: %w", err)
}
file = &File{f: f, path: path, refCount: utils.NewRefCounter()}
openedFiles[path] = file
go file.closeOnZero()
return file, nil return file, nil
} }

View file

@ -1,4 +1,3 @@
//nolint:zerologlint
package logging package logging
import ( import (

View file

@ -65,7 +65,7 @@ func NewPoller[T any, AggregateT json.Marshaler](
} }
func (p *Poller[T, AggregateT]) savePath() string { func (p *Poller[T, AggregateT]) savePath() string {
return filepath.Join(saveBaseDir, fmt.Sprintf("%s.json", p.name)) return filepath.Join(saveBaseDir, p.name+".json")
} }
func (p *Poller[T, AggregateT]) load() error { func (p *Poller[T, AggregateT]) load() error {
@ -135,13 +135,14 @@ func (p *Poller[T, AggregateT]) pollWithTimeout(ctx context.Context) {
func (p *Poller[T, AggregateT]) Start() { func (p *Poller[T, AggregateT]) Start() {
t := task.RootTask("poller." + p.name) t := task.RootTask("poller." + p.name)
l := log.With().Str("name", p.name).Logger()
err := p.load() err := p.load()
if err != nil { if err != nil {
if !os.IsNotExist(err) { if !os.IsNotExist(err) {
log.Error().Err(err).Msgf("failed to load last metrics data for %s", p.name) l.Err(err).Msg("failed to load last metrics data")
} }
} else { } else {
log.Debug().Msgf("Loaded last metrics data for %s, %d entries", p.name, p.period.Total()) l.Debug().Int("entries", p.period.Total()).Msgf("Loaded last metrics data")
} }
go func() { go func() {
@ -154,11 +155,13 @@ func (p *Poller[T, AggregateT]) Start() {
gatherErrsTicker.Stop() gatherErrsTicker.Stop()
saveTicker.Stop() saveTicker.Stop()
p.save() if err := p.save(); err != nil {
l.Err(err).Msg("failed to save metrics data")
}
t.Finish(nil) t.Finish(nil)
}() }()
log.Debug().Msgf("Starting poller %s with interval %s", p.name, pollInterval) l.Debug().Dur("interval", pollInterval).Msg("Starting poller")
p.pollWithTimeout(t.Context()) p.pollWithTimeout(t.Context())

View file

@ -1,3 +1,3 @@
package types package types
type Weight uint16 type Weight int

View file

@ -8,10 +8,10 @@ import (
) )
func reqLogger(r *http.Request, level zerolog.Level) *zerolog.Event { func reqLogger(r *http.Request, level zerolog.Level) *zerolog.Event {
return log.WithLevel(level). return log.WithLevel(level). //nolint:zerologlint
Str("remote", r.RemoteAddr). Str("remote", r.RemoteAddr).
Str("host", r.Host). Str("host", r.Host).
Str("uri", r.Method+" "+r.RequestURI) Str("uri", r.Method+" "+r.RequestURI)
} }
func LogError(r *http.Request) *zerolog.Event { return reqLogger(r, zerolog.ErrorLevel) } func LogError(r *http.Request) *zerolog.Event { return reqLogger(r, zerolog.ErrorLevel) }

View file

@ -60,7 +60,7 @@ func (wl *cidrWhitelist) checkIP(w http.ResponseWriter, r *http.Request) bool {
ipStr = r.RemoteAddr ipStr = r.RemoteAddr
} }
ip := net.ParseIP(ipStr) ip := net.ParseIP(ipStr)
for _, cidr := range wl.CIDRWhitelistOpts.Allow { for _, cidr := range wl.Allow {
if cidr.Contains(ip) { if cidr.Contains(ip) {
wl.cachedAddr.Store(r.RemoteAddr, true) wl.cachedAddr.Store(r.RemoteAddr, true)
allow = true allow = true
@ -70,7 +70,7 @@ func (wl *cidrWhitelist) checkIP(w http.ResponseWriter, r *http.Request) bool {
} }
if !allow { if !allow {
wl.cachedAddr.Store(r.RemoteAddr, false) wl.cachedAddr.Store(r.RemoteAddr, false)
wl.AddTracef("client %s is forbidden", ipStr).With("allowed CIDRs", wl.CIDRWhitelistOpts.Allow) wl.AddTracef("client %s is forbidden", ipStr).With("allowed CIDRs", wl.Allow)
} }
} }
if !allow { if !allow {

View file

@ -8,7 +8,7 @@ import (
"testing" "testing"
"github.com/yusing/go-proxy/internal/gperr" "github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/utils" "github.com/yusing/go-proxy/internal/serialization"
. "github.com/yusing/go-proxy/internal/utils/testing" . "github.com/yusing/go-proxy/internal/utils/testing"
) )
@ -41,7 +41,7 @@ func TestCIDRWhitelistValidation(t *testing.T) {
_, err := CIDRWhiteList.New(OptionsRaw{ _, err := CIDRWhiteList.New(OptionsRaw{
"message": testMessage, "message": testMessage,
}) })
ExpectError(t, utils.ErrValidationError, err) ExpectError(t, serialization.ErrValidationError, err)
}) })
t.Run("invalid cidr", func(t *testing.T) { t.Run("invalid cidr", func(t *testing.T) {
_, err := CIDRWhiteList.New(OptionsRaw{ _, err := CIDRWhiteList.New(OptionsRaw{
@ -56,7 +56,7 @@ func TestCIDRWhitelistValidation(t *testing.T) {
"status_code": 600, "status_code": 600,
"message": testMessage, "message": testMessage,
}) })
ExpectError(t, utils.ErrValidationError, err) ExpectError(t, serialization.ErrValidationError, err)
}) })
} }

View file

@ -1,6 +1,7 @@
package middleware package middleware
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@ -103,7 +104,15 @@ func tryFetchCFCIDR() (cfCIDRs []*types.CIDR) {
} }
func fetchUpdateCFIPRange(endpoint string, cfCIDRs *[]*types.CIDR) error { func fetchUpdateCFIPRange(endpoint string, cfCIDRs *[]*types.CIDR) error {
resp, err := http.Get(endpoint) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return err
}
resp, err := http.DefaultClient.Do(req) //nolint:gosec
if err != nil { if err != nil {
return err return err
} }

View file

@ -220,7 +220,6 @@ func (p *ReverseProxy) handler(rw http.ResponseWriter, req *http.Request) {
transport := p.Transport transport := p.Transport
ctx := req.Context() ctx := req.Context()
/* trunk-ignore(golangci-lint/revive) */
if ctx.Done() != nil { if ctx.Done() != nil {
// CloseNotifier predates context.Context, and has been // CloseNotifier predates context.Context, and has been
// entirely superseded by it. If the request contains // entirely superseded by it. If the request contains
@ -352,7 +351,7 @@ func (p *ReverseProxy) handler(rw http.ResponseWriter, req *http.Request) {
return nil return nil
}, },
} }
outreq = outreq.WithContext(httptrace.WithClientTrace(outreq.Context(), trace)) outreq = outreq.WithContext(httptrace.WithClientTrace(outreq.Context(), trace)) //nolint:contextcheck
res, err := transport.RoundTrip(outreq) res, err := transport.RoundTrip(outreq)
@ -507,18 +506,18 @@ func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.R
res.Header = rw.Header() res.Header = rw.Header()
res.Body = nil // so res.Write only writes the headers; we have res.Body in backConn above res.Body = nil // so res.Write only writes the headers; we have res.Body in backConn above
if err := res.Write(brw); err != nil { if err := res.Write(brw); err != nil {
/* trunk-ignore(golangci-lint/errorlint) */ //nolint:errorlint
p.errorHandler(rw, req, fmt.Errorf("response write: %s", err), true) p.errorHandler(rw, req, fmt.Errorf("response write: %s", err), true)
return return
} }
if err := brw.Flush(); err != nil { if err := brw.Flush(); err != nil {
/* trunk-ignore(golangci-lint/errorlint) */ //nolint:errorlint
p.errorHandler(rw, req, fmt.Errorf("response flush: %s", err), true) p.errorHandler(rw, req, fmt.Errorf("response flush: %s", err), true)
return return
} }
bdp := U.NewBidirectionalPipe(req.Context(), conn, backConn) bdp := U.NewBidirectionalPipe(req.Context(), conn, backConn)
/* trunk-ignore(golangci-lint/errcheck) */ //nolint:errcheck
bdp.Start() bdp.Start()
} }

View file

@ -16,7 +16,7 @@ import (
) )
type CertProvider interface { type CertProvider interface {
GetCert(*tls.ClientHelloInfo) (*tls.Certificate, error) GetCert(_ *tls.ClientHelloInfo) (*tls.Certificate, error)
} }
type Server struct { type Server struct {

View file

@ -106,7 +106,7 @@ func (p *DockerProvider) loadRoutesImpl() (route.Routes, gperr.Error) {
// Always non-nil. // Always non-nil.
func (p *DockerProvider) routesFromContainerLabels(container *docker.Container) (route.Routes, gperr.Error) { func (p *DockerProvider) routesFromContainerLabels(container *docker.Container) (route.Routes, gperr.Error) {
if !container.IsExplicit && p.IsExplicitOnly() { if !container.IsExplicit && p.IsExplicitOnly() {
return nil, nil return make(route.Routes, 0), nil
} }
routes := make(route.Routes, len(container.Aliases)) routes := make(route.Routes, len(container.Aliases))

View file

@ -3,7 +3,7 @@ package rules
import ( import (
"testing" "testing"
"github.com/yusing/go-proxy/internal/utils" "github.com/yusing/go-proxy/internal/serialization"
. "github.com/yusing/go-proxy/internal/utils/testing" . "github.com/yusing/go-proxy/internal/utils/testing"
) )
@ -28,7 +28,7 @@ func TestParseRule(t *testing.T) {
var rules struct { var rules struct {
Rules Rules Rules Rules
} }
err := utils.MapUnmarshalValidate(utils.SerializedObject{"rules": test}, &rules) err := serialization.MapUnmarshalValidate(serialization.SerializedObject{"rules": test}, &rules)
ExpectNoError(t, err) ExpectNoError(t, err)
ExpectEqual(t, len(rules.Rules), len(test)) ExpectEqual(t, len(rules.Rules), len(test))
ExpectEqual(t, rules.Rules[0].Name, "test") ExpectEqual(t, rules.Rules[0].Name, "test")

View file

@ -6,7 +6,7 @@ import (
. "github.com/yusing/go-proxy/internal/route" . "github.com/yusing/go-proxy/internal/route"
route "github.com/yusing/go-proxy/internal/route/types" route "github.com/yusing/go-proxy/internal/route/types"
"github.com/yusing/go-proxy/internal/utils" "github.com/yusing/go-proxy/internal/serialization"
expect "github.com/yusing/go-proxy/internal/utils/testing" expect "github.com/yusing/go-proxy/internal/utils/testing"
) )
@ -40,7 +40,7 @@ func TestHTTPConfigDeserialize(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
cfg := Route{} cfg := Route{}
tt.input["host"] = "internal" tt.input["host"] = "internal"
err := utils.MapUnmarshalValidate(tt.input, &cfg) err := serialization.MapUnmarshalValidate(tt.input, &cfg)
if err != nil { if err != nil {
expect.NoError(t, err) expect.NoError(t, err)
} }

View file

@ -18,7 +18,7 @@ func GetLastVersion() Version {
func GetVersionHTTPHandler() http.HandlerFunc { func GetVersionHTTPHandler() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(GetVersion().String())) fmt.Fprint(w, GetVersion().String())
} }
} }