refactor and typo fixes

This commit is contained in:
yusing 2024-11-02 03:14:47 +08:00
parent 76454df5e6
commit a86d316d07
34 changed files with 160 additions and 128 deletions

View file

@ -108,6 +108,7 @@ linters:
- prealloc # Too many false-positive. - prealloc # Too many false-positive.
- makezero # Not relevant - makezero # Not relevant
- dupl # Too strict - dupl # Too strict
- gci # I don't care
- gosec # Too strict - gosec # Too strict
- gochecknoinits - gochecknoinits
- gochecknoglobals - gochecknoglobals

View file

@ -14,7 +14,7 @@ import (
) )
func ReloadServer() E.Error { func ReloadServer() E.Error {
resp, err := U.Post(fmt.Sprintf("%s/v1/reload", common.APIHTTPURL), "", nil) resp, err := U.Post(common.APIHTTPURL+"/v1/reload", "", nil)
if err != nil { if err != nil {
return E.From(err) return E.From(err)
} }

View file

@ -69,17 +69,18 @@ func HomepageConfig() homepage.Config {
) )
} }
if entry.IsDocker(r) { switch {
case entry.IsDocker(r):
if item.Category == "" { if item.Category == "" {
item.Category = "Docker" item.Category = "Docker"
} }
item.SourceType = string(proxy.ProviderTypeDocker) item.SourceType = string(proxy.ProviderTypeDocker)
} else if entry.UseLoadBalance(r) { case entry.UseLoadBalance(r):
if item.Category == "" { if item.Category == "" {
item.Category = "Load-balanced" item.Category = "Load-balanced"
} }
item.SourceType = "loadbalancer" item.SourceType = "loadbalancer"
} else { default:
if item.Category == "" { if item.Category == "" {
item.Category = "Others" item.Category = "Others"
} }

View file

@ -52,13 +52,10 @@ func (c *SharedClient) Connected() bool {
} }
// if the client is still referenced, this is no-op. // if the client is still referenced, this is no-op.
func (c *SharedClient) Close() error { func (c *SharedClient) Close() {
if !c.Connected() { if c.Connected() {
return nil c.refCount.Sub()
} }
c.refCount.Sub()
return nil
} }
// ConnectClient creates a new Docker client connection to the specified host. // ConnectClient creates a new Docker client connection to the specified host.
@ -115,7 +112,6 @@ func ConnectClient(host string) (Client, error) {
} }
client, err := client.NewClientWithOpts(opt...) client, err := client.NewClientWithOpts(opt...)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -3,7 +3,6 @@ package idlewatcher
import ( import (
"bytes" "bytes"
_ "embed" _ "embed"
"fmt"
"strings" "strings"
"text/template" "text/template"
@ -21,7 +20,7 @@ var loadingPage []byte
var loadingPageTmpl = template.Must(template.New("loading_page").Parse(string(loadingPage))) var loadingPageTmpl = template.Must(template.New("loading_page").Parse(string(loadingPage)))
func (w *Watcher) makeLoadingPageBody() []byte { func (w *Watcher) makeLoadingPageBody() []byte {
msg := fmt.Sprintf("%s is starting...", w.ContainerName) msg := w.ContainerName + " is starting..."
data := new(templateData) data := new(templateData)
data.CheckRedirectHeader = common.HeaderCheckRedirect data.CheckRedirectHeader = common.HeaderCheckRedirect

View file

@ -45,17 +45,18 @@ func newWaker(providerSubTask task.Task, entry entry.Entry, rp *gphttp.ReversePr
return nil, E.Errorf("register watcher: %w", err) return nil, E.Errorf("register watcher: %w", err)
} }
if rp != nil { switch {
case rp != nil:
waker.hc = health.NewHTTPHealthChecker(entry.TargetURL(), hcCfg, rp.Transport) waker.hc = health.NewHTTPHealthChecker(entry.TargetURL(), hcCfg, rp.Transport)
} else if stream != nil { case stream != nil:
waker.hc = health.NewRawHealthChecker(entry.TargetURL(), hcCfg) waker.hc = health.NewRawHealthChecker(entry.TargetURL(), hcCfg)
} else { default:
panic("both nil") panic("both nil")
} }
return watcher, nil return watcher, nil
} }
// lifetime should follow route provider // lifetime should follow route provider.
func NewHTTPWaker(providerSubTask task.Task, entry entry.Entry, rp *gphttp.ReverseProxy) (Waker, E.Error) { func NewHTTPWaker(providerSubTask task.Task, entry entry.Entry, rp *gphttp.ReverseProxy) (Waker, E.Error) {
return newWaker(providerSubTask, entry, rp, nil) return newWaker(providerSubTask, entry, rp, nil)
} }

View file

@ -28,7 +28,7 @@ func (w *Watcher) Accept() (conn types.StreamConn, err error) {
return return
} }
if wakeErr := w.wakeFromStream(); wakeErr != nil { if wakeErr := w.wakeFromStream(); wakeErr != nil {
w.WakeError(wakeErr).Msg("error waking from stream") w.WakeError(wakeErr)
} }
return return
} }
@ -58,7 +58,7 @@ func (w *Watcher) wakeFromStream() error {
wakeErr := w.wakeIfStopped() wakeErr := w.wakeIfStopped()
if wakeErr != nil { if wakeErr != nil {
wakeErr = fmt.Errorf("%s failed: %w", w.String(), wakeErr) wakeErr = fmt.Errorf("%s failed: %w", w.String(), wakeErr)
w.WakeError(wakeErr).Msg("wake failed") w.WakeError(wakeErr)
return wakeErr return wakeErr
} }

View file

@ -17,7 +17,6 @@ import (
U "github.com/yusing/go-proxy/internal/utils" U "github.com/yusing/go-proxy/internal/utils"
F "github.com/yusing/go-proxy/internal/utils/functional" F "github.com/yusing/go-proxy/internal/utils/functional"
"github.com/yusing/go-proxy/internal/watcher" "github.com/yusing/go-proxy/internal/watcher"
W "github.com/yusing/go-proxy/internal/watcher"
"github.com/yusing/go-proxy/internal/watcher/events" "github.com/yusing/go-proxy/internal/watcher/events"
) )
@ -108,8 +107,8 @@ func (w *Watcher) WakeTrace() *zerolog.Event {
return w.Trace().Str("action", "wake") return w.Trace().Str("action", "wake")
} }
func (w *Watcher) WakeError(err error) *zerolog.Event { func (w *Watcher) WakeError(err error) {
return w.Err(err).Str("action", "wake") w.Err(err).Str("action", "wake").Msg("error")
} }
func (w *Watcher) LogReason(action, reason string) { func (w *Watcher) LogReason(action, reason string) {
@ -204,17 +203,17 @@ func (w *Watcher) resetIdleTimer() {
func (w *Watcher) getEventCh(dockerWatcher watcher.DockerWatcher) (eventTask task.Task, eventCh <-chan events.Event, errCh <-chan E.Error) { func (w *Watcher) getEventCh(dockerWatcher watcher.DockerWatcher) (eventTask task.Task, eventCh <-chan events.Event, errCh <-chan E.Error) {
eventTask = w.task.Subtask("docker event watcher") eventTask = w.task.Subtask("docker event watcher")
eventCh, errCh = dockerWatcher.EventsWithOptions(eventTask.Context(), W.DockerListOptions{ eventCh, errCh = dockerWatcher.EventsWithOptions(eventTask.Context(), watcher.DockerListOptions{
Filters: W.NewDockerFilter( Filters: watcher.NewDockerFilter(
W.DockerFilterContainer, watcher.DockerFilterContainer,
W.DockerFilterContainerNameID(w.ContainerID), watcher.DockerFilterContainerNameID(w.ContainerID),
W.DockerFilterStart, watcher.DockerFilterStart,
W.DockerFilterStop, watcher.DockerFilterStop,
W.DockerFilterDie, watcher.DockerFilterDie,
W.DockerFilterKill, watcher.DockerFilterKill,
W.DockerFilterDestroy, watcher.DockerFilterDestroy,
W.DockerFilterPause, watcher.DockerFilterPause,
W.DockerFilterUnpause, watcher.DockerFilterUnpause,
), ),
}) })
return return
@ -230,9 +229,9 @@ func (w *Watcher) getEventCh(dockerWatcher watcher.DockerWatcher) (eventTask tas
// stop method. // stop method.
// //
// it exits only if the context is canceled, the container is destroyed, // it exits only if the context is canceled, the container is destroyed,
// errors occured on docker client, or route provider died (mainly caused by config reload). // errors occurred on docker client, or route provider died (mainly caused by config reload).
func (w *Watcher) watchUntilDestroy() (returnCause error) { func (w *Watcher) watchUntilDestroy() (returnCause error) {
dockerWatcher := W.NewDockerWatcherWithClient(w.client) dockerWatcher := watcher.NewDockerWatcherWithClient(w.client)
eventTask, dockerEventCh, dockerEventErrCh := w.getEventCh(dockerWatcher) eventTask, dockerEventCh, dockerEventErrCh := w.getEventCh(dockerWatcher)
defer eventTask.Finish("stopped") defer eventTask.Finish("stopped")
@ -279,9 +278,13 @@ func (w *Watcher) watchUntilDestroy() (returnCause error) {
case <-w.ticker.C: case <-w.ticker.C:
w.ticker.Stop() w.ticker.Stop()
if w.ContainerRunning { if w.ContainerRunning {
if err := w.stopByMethod(); err != nil && !errors.Is(err, context.Canceled) { err := w.stopByMethod()
switch {
case errors.Is(err, context.Canceled):
continue
case err != nil:
w.Err(err).Msgf("container stop with method %q failed", w.StopMethod) w.Err(err).Msgf("container stop with method %q failed", w.StopMethod)
} else { default:
w.LogReason("container stopped", "idle timeout") w.LogReason("container stopped", "idle timeout")
} }
} }

View file

@ -27,26 +27,31 @@ func (b *Builder) HasError() bool {
return len(b.errs) > 0 return len(b.errs) > 0
} }
func (b *Builder) Error() Error { func (b *Builder) error() Error {
if !b.HasError() { if !b.HasError() {
return nil return nil
} }
if len(b.errs) == 1 {
return From(b.errs[0])
}
return &nestedError{Err: New(b.about), Extras: b.errs} return &nestedError{Err: New(b.about), Extras: b.errs}
} }
func (b *Builder) Error() Error {
if len(b.errs) == 1 {
return From(b.errs[0])
}
return b.error()
}
func (b *Builder) String() string { func (b *Builder) String() string {
if !b.HasError() { err := b.error()
if err == nil {
return "" return ""
} }
return (&nestedError{Err: New(b.about), Extras: b.errs}).Error() return err.Error()
} }
// Add adds an error to the Builder. // Add adds an error to the Builder.
// //
// adding nil is no-op, // adding nil is no-op.
func (b *Builder) Add(err error) *Builder { func (b *Builder) Add(err error) *Builder {
if err == nil { if err == nil {
return b return b
@ -90,6 +95,21 @@ func (b *Builder) Addf(format string, args ...any) *Builder {
return b return b
} }
func (b *Builder) AddFrom(other *Builder, flatten bool) *Builder {
if other == nil || !other.HasError() {
return b
}
b.Lock()
defer b.Unlock()
if flatten {
b.errs = append(b.errs, other.errs...)
} else {
b.errs = append(b.errs, other.error())
}
return b
}
func (b *Builder) AddRange(errs ...error) *Builder { func (b *Builder) AddRange(errs ...error) *Builder {
b.Lock() b.Lock()
defer b.Unlock() defer b.Unlock()

View file

@ -22,7 +22,7 @@ type Error interface {
Subjectf(format string, args ...any) Error Subjectf(format string, args ...any) Error
} }
// this makes JSON marshalling work, // this makes JSON marshaling work,
// as the builtin one doesn't. // as the builtin one doesn't.
type errStr string type errStr string

View file

@ -15,13 +15,14 @@ func init() {
var level zerolog.Level var level zerolog.Level
var exclude []string var exclude []string
if common.IsTrace { switch {
case common.IsTrace:
timeFmt = "04:05" timeFmt = "04:05"
level = zerolog.TraceLevel level = zerolog.TraceLevel
} else if common.IsDebug { case common.IsDebug:
timeFmt = "01-02 15:04" timeFmt = "01-02 15:04"
level = zerolog.DebugLevel level = zerolog.DebugLevel
} else { default:
timeFmt = "01-02 15:04" timeFmt = "01-02 15:04"
level = zerolog.InfoLevel level = zerolog.InfoLevel
exclude = []string{"module"} exclude = []string{"module"}

View file

@ -5,8 +5,10 @@ import (
"net/http" "net/http"
) )
type ContentType string type (
type AcceptContentType []ContentType ContentType string
AcceptContentType []ContentType
)
func GetContentType(h http.Header) ContentType { func GetContentType(h http.Header) ContentType {
ct := h.Get("Content-Type") ct := h.Get("Content-Type")

View file

@ -55,7 +55,6 @@ func New(cfg *Config) *LoadBalancer {
Logger: logger.With().Str("name", cfg.Link).Logger(), Logger: logger.With().Str("name", cfg.Link).Logger(),
Config: new(Config), Config: new(Config),
pool: newPool(), pool: newPool(),
task: task.DummyTask(),
} }
lb.UpdateConfigIfNeeded(cfg) lb.UpdateConfigIfNeeded(cfg)
return lb return lb

View file

@ -1,7 +1,7 @@
package loadbalancer package loadbalancer
import ( import (
U "github.com/yusing/go-proxy/internal/utils" "github.com/yusing/go-proxy/internal/utils/strutils"
) )
type Mode string type Mode string
@ -14,7 +14,7 @@ const (
) )
func (mode *Mode) ValidateUpdate() bool { func (mode *Mode) ValidateUpdate() bool {
switch U.ToLowerNoSnake(string(*mode)) { switch strutils.ToLowerNoSnake(string(*mode)) {
case "": case "":
return true return true
case string(RoundRobin): case string(RoundRobin):

View file

@ -109,9 +109,9 @@ func fetchUpdateCFIPRange(endpoint string, cfCIDRs []*types.CIDR) error {
_, cidr, err := net.ParseCIDR(line) _, cidr, err := net.ParseCIDR(line)
if err != nil { if err != nil {
return fmt.Errorf("cloudflare responeded an invalid CIDR: %s", line) return fmt.Errorf("cloudflare responeded an invalid CIDR: %s", line)
} else {
cfCIDRs = append(cfCIDRs, (*types.CIDR)(cidr))
} }
cfCIDRs = append(cfCIDRs, (*types.CIDR)(cidr))
} }
return nil return nil

View file

@ -19,24 +19,33 @@ import (
const errPagesBasePath = common.ErrorPagesBasePath const errPagesBasePath = common.ErrorPagesBasePath
var ( var (
setupMu sync.Mutex
dirWatcher W.Watcher dirWatcher W.Watcher
fileContentMap = F.NewMapOf[string, []byte]() fileContentMap = F.NewMapOf[string, []byte]()
) )
var setup = sync.OnceFunc(func() { func setup() {
setupMu.Lock()
defer setupMu.Unlock()
if dirWatcher != nil {
return
}
task := task.GlobalTask("error page") task := task.GlobalTask("error page")
dirWatcher = W.NewDirectoryWatcher(task.Subtask("dir watcher"), errPagesBasePath) dirWatcher = W.NewDirectoryWatcher(task.Subtask("dir watcher"), errPagesBasePath)
loadContent() loadContent()
go watchDir(task) go watchDir(task)
}) }
func GetStaticFile(filename string) ([]byte, bool) { func GetStaticFile(filename string) ([]byte, bool) {
setup()
return fileContentMap.Load(filename) return fileContentMap.Load(filename)
} }
// try <statusCode>.html -> 404.html -> not ok. // try <statusCode>.html -> 404.html -> not ok.
func GetErrorPageByStatus(statusCode int) (content []byte, ok bool) { func GetErrorPageByStatus(statusCode int) (content []byte, ok bool) {
content, ok = fileContentMap.Load(fmt.Sprintf("%d.html", statusCode)) content, ok = GetStaticFile(fmt.Sprintf("%d.html", statusCode))
if !ok && statusCode != 404 { if !ok && statusCode != 404 {
return fileContentMap.Load("404.html") return fileContentMap.Load("404.html")
} }

View file

@ -9,7 +9,6 @@ import (
"github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/utils" "github.com/yusing/go-proxy/internal/utils"
U "github.com/yusing/go-proxy/internal/utils"
"github.com/yusing/go-proxy/internal/utils/strutils" "github.com/yusing/go-proxy/internal/utils/strutils"
) )
@ -21,7 +20,7 @@ var (
) )
func Get(name string) (*Middleware, Error) { func Get(name string) (*Middleware, Error) {
middleware, ok := allMiddlewares[U.ToLowerNoSnake(name)] middleware, ok := allMiddlewares[strutils.ToLowerNoSnake(name)]
if !ok { if !ok {
return nil, ErrUnknownMiddleware. return nil, ErrUnknownMiddleware.
Subject(name). Subject(name).
@ -34,7 +33,7 @@ func All() map[string]*Middleware {
return allMiddlewares return allMiddlewares
} }
// initialize middleware names and label parsers // initialize middleware names and label parsers.
func init() { func init() {
allMiddlewares = map[string]*Middleware{ allMiddlewares = map[string]*Middleware{
"setxforwarded": SetXForwarded, "setxforwarded": SetXForwarded,
@ -67,7 +66,7 @@ func init() {
func LoadComposeFiles() { func LoadComposeFiles() {
errs := E.NewBuilder("middleware compile errors") errs := E.NewBuilder("middleware compile errors")
middlewareDefs, err := U.ListFiles(common.MiddlewareComposeBasePath, 0) middlewareDefs, err := utils.ListFiles(common.MiddlewareComposeBasePath, 0)
if err != nil { if err != nil {
logger.Err(err).Msg("failed to list middleware definitions") logger.Err(err).Msg("failed to list middleware definitions")
return return
@ -82,7 +81,7 @@ func LoadComposeFiles() {
errs.Add(ErrDuplicatedMiddleware.Subject(name)) errs.Add(ErrDuplicatedMiddleware.Subject(name))
continue continue
} }
allMiddlewares[U.ToLowerNoSnake(name)] = m allMiddlewares[strutils.ToLowerNoSnake(name)] = m
logger.Info(). logger.Info().
Str("name", name). Str("name", name).
Str("src", path.Base(defFile)). Str("src", path.Base(defFile)).

View file

@ -22,8 +22,10 @@ type Trace struct {
type Traces []*Trace type Traces []*Trace
var traces = Traces{} var (
var tracesMu sync.Mutex traces = make(Traces, 0)
tracesMu sync.Mutex
)
const MaxTraceNum = 100 const MaxTraceNum = 100

View file

@ -10,18 +10,20 @@ import (
"net/http" "net/http"
) )
type ModifyResponseFunc func(*http.Response) error type (
type ModifyResponseWriter struct { ModifyResponseFunc func(*http.Response) error
w http.ResponseWriter ModifyResponseWriter struct {
r *http.Request w http.ResponseWriter
r *http.Request
headerSent bool headerSent bool
code int code int
modifier ModifyResponseFunc modifier ModifyResponseFunc
modified bool modified bool
modifierErr error modifierErr error
} }
)
func NewModifyResponseWriter(w http.ResponseWriter, r *http.Request, f ModifyResponseFunc) *ModifyResponseWriter { func NewModifyResponseWriter(w http.ResponseWriter, r *http.Request, f ModifyResponseFunc) *ModifyResponseWriter {
return &ModifyResponseWriter{ return &ModifyResponseWriter{

View file

@ -53,15 +53,15 @@ func RegisterProvider(configSubTask task.Task, cfg ProviderConfig) (Provider, er
Subject(name). Subject(name).
Withf(strutils.DoYouMean(utils.NearestField(name, Providers))) Withf(strutils.DoYouMean(utils.NearestField(name, Providers)))
} }
if provider, err := createFunc(cfg); err != nil {
return nil, err provider, err := createFunc(cfg)
} else { if err == nil {
dispatcher.providers.Add(provider) dispatcher.providers.Add(provider)
configSubTask.OnCancel("remove provider", func() { configSubTask.OnCancel("remove provider", func() {
dispatcher.providers.Remove(provider) dispatcher.providers.Remove(provider)
}) })
return provider, nil
} }
return provider, err
} }
func (disp *Dispatcher) start() { func (disp *Dispatcher) start() {

View file

@ -68,7 +68,7 @@ func validateRPEntry(m *RawEntry, s fields.Scheme, errs *E.Builder) *ReverseProx
port := E.Collect(errs, fields.ValidatePort, m.Port) port := E.Collect(errs, fields.ValidatePort, m.Port)
pathPats := E.Collect(errs, fields.ValidatePathPatterns, m.PathPatterns) pathPats := E.Collect(errs, fields.ValidatePathPatterns, m.PathPatterns)
url := E.Collect(errs, url.Parse, fmt.Sprintf("%s://%s:%d", s, host, port)) url := E.Collect(errs, url.Parse, fmt.Sprintf("%s://%s:%d", s, host, port))
iwCfg := E.Collect(errs, idlewatcher.ValidateConfig, m.Container) iwCfg := E.Collect(errs, idlewatcher.ValidateConfig, cont)
if errs.HasError() { if errs.HasError() {
return nil return nil

View file

@ -61,7 +61,7 @@ func validateStreamEntry(m *RawEntry, errs *E.Builder) *StreamEntry {
port := E.Collect(errs, fields.ValidateStreamPort, m.Port) port := E.Collect(errs, fields.ValidateStreamPort, m.Port)
scheme := E.Collect(errs, fields.ValidateStreamScheme, m.Scheme) scheme := E.Collect(errs, fields.ValidateStreamScheme, m.Scheme)
url := E.Collect(errs, net.ParseURL, fmt.Sprintf("%s://%s:%d", scheme.ListeningScheme, host, port.ListeningPort)) url := E.Collect(errs, net.ParseURL, fmt.Sprintf("%s://%s:%d", scheme.ListeningScheme, host, port.ListeningPort))
idleWatcherCfg := E.Collect(errs, idlewatcher.ValidateConfig, m.Container) idleWatcherCfg := E.Collect(errs, idlewatcher.ValidateConfig, cont)
if errs.HasError() { if errs.HasError() {
return nil return nil

View file

@ -89,7 +89,6 @@ func NewHTTPRoute(entry *entry.ReverseProxyEntry) (impl, E.Error) {
r := &HTTPRoute{ r := &HTTPRoute{
ReverseProxyEntry: entry, ReverseProxyEntry: entry,
rp: rp, rp: rp,
task: task.DummyTask(),
l: logger.With(). l: logger.With().
Str("type", string(entry.Scheme)). Str("type", string(entry.Scheme)).
Str("name", string(entry.Alias)). Str("name", string(entry.Alias)).

View file

@ -18,9 +18,9 @@ type EventHandler struct {
updated *E.Builder updated *E.Builder
} }
func (provider *Provider) newEventHandler() *EventHandler { func (p *Provider) newEventHandler() *EventHandler {
return &EventHandler{ return &EventHandler{
provider: provider, provider: p,
errs: E.NewBuilder("event errors"), errs: E.NewBuilder("event errors"),
added: E.NewBuilder("added"), added: E.NewBuilder("added"),
removed: E.NewBuilder("removed"), removed: E.NewBuilder("removed"),
@ -60,11 +60,12 @@ func (handler *EventHandler) Handle(parent task.Task, events []watcher.Event) {
oldRoutes.RangeAll(func(k string, oldr *route.Route) { oldRoutes.RangeAll(func(k string, oldr *route.Route) {
newr, ok := newRoutes.Load(k) newr, ok := newRoutes.Load(k)
if !ok { switch {
case !ok:
handler.Remove(oldr) handler.Remove(oldr)
} else if handler.matchAny(events, newr) { case handler.matchAny(events, newr):
handler.Update(parent, oldr, newr) handler.Update(parent, oldr, newr)
} else if entry.ShouldNotServe(newr) { case entry.ShouldNotServe(newr):
handler.Remove(oldr) handler.Remove(oldr)
} }
}) })
@ -122,11 +123,11 @@ func (handler *EventHandler) Update(parent task.Task, oldRoute *route.Route, new
} }
func (handler *EventHandler) Log() { func (handler *EventHandler) Log() {
results := E.NewBuilder("event occured") results := E.NewBuilder("event occurred")
results.Add(handler.added.Error()) results.AddFrom(handler.added, false)
results.Add(handler.removed.Error()) results.AddFrom(handler.removed, false)
results.Add(handler.updated.Error()) results.AddFrom(handler.updated, false)
results.Add(handler.errs.Error()) results.AddFrom(handler.errs, false)
if result := results.String(); result != "" { if result := results.String(); result != "" {
handler.provider.Logger().Info().Msg(result) handler.provider.Logger().Info().Msg(result)
} }

View file

@ -45,9 +45,7 @@ const (
providerEventFlushInterval = 300 * time.Millisecond providerEventFlushInterval = 300 * time.Millisecond
) )
var ( var ErrEmptyProviderName = errors.New("empty provider name")
ErrEmptyProviderName = errors.New("empty provider name")
)
func newProvider(name string, t ProviderType) *Provider { func newProvider(name string, t ProviderType) *Provider {
return &Provider{ return &Provider{
@ -109,12 +107,11 @@ func (p *Provider) startRoute(parent task.Task, r *R.Route) E.Error {
p.routes.Delete(r.Entry.Alias) p.routes.Delete(r.Entry.Alias)
subtask.Finish(err) // just to ensure subtask.Finish(err) // just to ensure
return err.Subject(r.Entry.Alias) return err.Subject(r.Entry.Alias)
} else {
p.routes.Store(r.Entry.Alias, r)
subtask.OnFinished("del from provider", func() {
p.routes.Delete(r.Entry.Alias)
})
} }
p.routes.Store(r.Entry.Alias, r)
subtask.OnFinished("del from provider", func() {
p.routes.Delete(r.Entry.Alias)
})
return nil return nil
} }

View file

@ -80,11 +80,12 @@ func FromEntries(entries entry.RawEntries) (Routes, E.Error) {
entries.RangeAllParallel(func(alias string, en *entry.RawEntry) { entries.RangeAllParallel(func(alias string, en *entry.RawEntry) {
en.Alias = alias en.Alias = alias
r, err := NewRoute(en) r, err := NewRoute(en)
if err != nil { switch {
case err != nil:
b.Add(err.Subject(alias)) b.Add(err.Subject(alias))
} else if entry.ShouldNotServe(r) { case entry.ShouldNotServe(r):
return return
} else { default:
routes.Store(alias, r) routes.Store(alias, r)
} }
}) })

View file

@ -44,7 +44,6 @@ func NewStreamRoute(entry *entry.StreamEntry) (impl, E.Error) {
} }
return &StreamRoute{ return &StreamRoute{
StreamEntry: entry, StreamEntry: entry,
task: task.DummyTask(),
l: logger.With(). l: logger.With().
Str("type", string(entry.Scheme.ListeningScheme)). Str("type", string(entry.Scheme.ListeningScheme)).
Str("name", entry.TargetName()). Str("name", entry.TargetName()).

View file

@ -6,7 +6,6 @@ import (
"errors" "errors"
"io" "io"
"log" "log"
"net/http" "net/http"
"time" "time"

View file

@ -102,7 +102,7 @@ func (p BidirectionalPipe) Start() E.Error {
// Copyright 2009 The Go Authors. All rights reserved. // Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// This is a copy of io.Copy with context handling // This is a copy of io.Copy with context handling
// Author: yusing <yusing@6uo.me> // Author: yusing <yusing@6uo.me>.
func Copy(dst *ContextWriter, src *ContextReader) (err error) { func Copy(dst *ContextWriter, src *ContextReader) (err error) {
size := 32 * 1024 size := 32 * 1024
if l, ok := src.Reader.(*io.LimitedReader); ok && int64(size) > l.N { if l, ok := src.Reader.(*io.LimitedReader); ok && int64(size) > l.N {

View file

@ -18,9 +18,10 @@ func NearestField(input string, s any) string {
if t.Kind() == reflect.Ptr { if t.Kind() == reflect.Ptr {
t = t.Elem() t = t.Elem()
} }
if t.Kind() == reflect.Struct { switch t.Kind() {
case reflect.Struct:
fields = make([]string, 0) fields = make([]string, 0)
for i := 0; i < t.NumField(); i++ { for i := range t.NumField() {
jsonTag, ok := t.Field(i).Tag.Lookup("json") jsonTag, ok := t.Field(i).Tag.Lookup("json")
if ok { if ok {
fields = append(fields, jsonTag) fields = append(fields, jsonTag)
@ -28,13 +29,13 @@ func NearestField(input string, s any) string {
fields = append(fields, t.Field(i).Name) fields = append(fields, t.Field(i).Name)
} }
} }
} else if t.Kind() == reflect.Map { case reflect.Map:
keys := reflect.ValueOf(s).MapKeys() keys := reflect.ValueOf(s).MapKeys()
fields = make([]string, len(keys)) fields = make([]string, len(keys))
for i, key := range keys { for i, key := range keys {
fields[i] = key.String() fields[i] = key.String()
} }
} else { default:
panic("unsupported type: " + t.String()) panic("unsupported type: " + t.String())
} }
} }

View file

@ -27,7 +27,7 @@ var (
ErrInvalidType = E.New("invalid type") ErrInvalidType = E.New("invalid type")
ErrNilValue = E.New("nil") ErrNilValue = E.New("nil")
ErrUnsettable = E.New("unsettable") ErrUnsettable = E.New("unsettable")
ErrUnsupportedConvertion = E.New("unsupported convertion") ErrUnsupportedConversion = E.New("unsupported conversion")
ErrMapMissingColon = E.New("map missing colon") ErrMapMissingColon = E.New("map missing colon")
ErrMapTooManyColons = E.New("map too many colons") ErrMapTooManyColons = E.New("map too many colons")
ErrUnknownField = E.New("unknown field") ErrUnknownField = E.New("unknown field")
@ -176,10 +176,10 @@ func Deserialize(src SerializedObject, dst any) E.Error {
case reflect.Struct: case reflect.Struct:
mapping := make(map[string]reflect.Value) mapping := make(map[string]reflect.Value)
for _, field := range reflect.VisibleFields(dstT) { for _, field := range reflect.VisibleFields(dstT) {
mapping[ToLowerNoSnake(field.Name)] = dstV.FieldByName(field.Name) mapping[strutils.ToLowerNoSnake(field.Name)] = dstV.FieldByName(field.Name)
} }
for k, v := range src { for k, v := range src {
if field, ok := mapping[ToLowerNoSnake(k)]; ok { if field, ok := mapping[strutils.ToLowerNoSnake(k)]; ok {
err := Convert(reflect.ValueOf(v), field) err := Convert(reflect.ValueOf(v), field)
if err != nil { if err != nil {
errs.Add(err.Subject(k)) errs.Add(err.Subject(k))
@ -199,11 +199,11 @@ func Deserialize(src SerializedObject, dst any) E.Error {
if err != nil { if err != nil {
errs.Add(err.Subject(k)) errs.Add(err.Subject(k))
} }
dstV.SetMapIndex(reflect.ValueOf(ToLowerNoSnake(k)), tmp) dstV.SetMapIndex(reflect.ValueOf(strutils.ToLowerNoSnake(k)), tmp)
} }
return errs.Error() return errs.Error()
default: default:
return ErrUnsupportedConvertion.Subject("deserialize to " + dstT.String()) return ErrUnsupportedConversion.Subject("deserialize to " + dstT.String())
} }
} }
@ -250,12 +250,12 @@ func Convert(src reflect.Value, dst reflect.Value) E.Error {
case srcT.Kind() == reflect.Map: case srcT.Kind() == reflect.Map:
obj, ok := src.Interface().(SerializedObject) obj, ok := src.Interface().(SerializedObject)
if !ok { if !ok {
return ErrUnsupportedConvertion.Subject(dstT.String() + " to " + srcT.String()) return ErrUnsupportedConversion.Subject(dstT.String() + " to " + srcT.String())
} }
return Deserialize(obj, dst.Addr().Interface()) return Deserialize(obj, dst.Addr().Interface())
case srcT.Kind() == reflect.Slice: case srcT.Kind() == reflect.Slice:
if dstT.Kind() != reflect.Slice { if dstT.Kind() != reflect.Slice {
return ErrUnsupportedConvertion.Subject(dstT.String() + " to slice") return ErrUnsupportedConversion.Subject(dstT.String() + " to slice")
} }
newSlice := reflect.MakeSlice(dstT, 0, src.Len()) newSlice := reflect.MakeSlice(dstT, 0, src.Len())
i := 0 i := 0
@ -280,7 +280,7 @@ func Convert(src reflect.Value, dst reflect.Value) E.Error {
var ok bool var ok bool
// check if (*T).Convertor is implemented // check if (*T).Convertor is implemented
if converter, ok = dst.Addr().Interface().(Converter); !ok { if converter, ok = dst.Addr().Interface().(Converter); !ok {
return ErrUnsupportedConvertion.Subjectf("%s to %s", srcT, dstT) return ErrUnsupportedConversion.Subjectf("%s to %s", srcT, dstT)
} }
return converter.ConvertFrom(src.Interface()) return converter.ConvertFrom(src.Interface())
@ -310,6 +310,7 @@ func ConvertString(src string, dst reflect.Value) (convertible bool, convErr E.E
} }
dst.Set(reflect.ValueOf(d)) dst.Set(reflect.ValueOf(d))
return return
default:
} }
// primitive types / simple types // primitive types / simple types
switch dst.Kind() { switch dst.Kind() {
@ -392,7 +393,3 @@ func DeserializeJSON(j map[string]string, target any) error {
} }
return json.Unmarshal(data, target) return json.Unmarshal(data, target)
} }
func ToLowerNoSnake(s string) string {
return strings.ToLower(strings.ReplaceAll(s, "_", ""))
}

View file

@ -2,6 +2,7 @@ package strutils
import ( import (
"fmt" "fmt"
"strconv"
"strings" "strings"
"time" "time"
@ -57,6 +58,10 @@ func ParseBool(s string) bool {
} }
} }
func PortString(port uint16) string {
return strconv.FormatUint(uint64(port), 10)
}
func DoYouMean(s string) string { func DoYouMean(s string) string {
return "Did you mean " + ansi.HighlightGreen + s + ansi.Reset + "?" return "Did you mean " + ansi.HighlightGreen + s + ansi.Reset + "?"
} }

View file

@ -2,7 +2,6 @@ package strutils
import ( import (
"net/url" "net/url"
"strconv"
"strings" "strings"
"golang.org/x/text/cases" "golang.org/x/text/cases"
@ -29,8 +28,8 @@ func ExtractPort(fullURL string) (int, error) {
return Atoi(url.Port()) return Atoi(url.Port())
} }
func PortString(port uint16) string { func ToLowerNoSnake(s string) string {
return strconv.FormatUint(uint64(port), 10) return strings.ToLower(strings.ReplaceAll(s, "_", ""))
} }
func LevenshteinDistance(a, b string) int { func LevenshteinDistance(a, b string) int {
@ -60,7 +59,7 @@ func LevenshteinDistance(a, b string) int {
cost = 1 cost = 1
} }
v1[j+1] = min(v1[j]+1, v0[j+1]+1, v0[j]+cost) v1[j+1] = min3(v1[j]+1, v0[j+1]+1, v0[j]+cost)
} }
for j := 0; j <= len(b); j++ { for j := 0; j <= len(b); j++ {
@ -71,7 +70,7 @@ func LevenshteinDistance(a, b string) int {
return v1[len(b)] return v1[len(b)]
} }
func min(a, b, c int) int { func min3(a, b, c int) int {
if a < b && a < c { if a < b && a < c {
return a return a
} }

View file

@ -37,7 +37,6 @@ func newMonitor(url types.URL, config *HealthCheckConfig, healthCheckFunc Health
config: config, config: config,
checkHealth: healthCheckFunc, checkHealth: healthCheckFunc,
startTime: time.Now(), startTime: time.Now(),
task: task.DummyTask(),
} }
mon.url.Store(url) mon.url.Store(url)
mon.status.Store(StatusHealthy) mon.status.Store(StatusHealthy)