refactored some stuff, added healthcheck support, fixed 'include file' reload not showing in log

This commit is contained in:
yusing 2024-10-12 13:56:38 +08:00
parent 64e30f59e8
commit d47b672aa5
41 changed files with 783 additions and 421 deletions

View file

@ -1,13 +1,11 @@
package v1
import (
"fmt"
"net/http"
"strings"
. "github.com/yusing/go-proxy/internal/api/v1/utils"
"github.com/yusing/go-proxy/internal/config"
R "github.com/yusing/go-proxy/internal/route"
"github.com/yusing/go-proxy/internal/watcher/health"
)
func CheckHealth(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
@ -17,26 +15,14 @@ func CheckHealth(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
return
}
var ok bool
route := cfg.FindRoute(target)
switch {
case route == nil:
isHealthy, ok := health.IsHealthy(target)
if !ok {
HandleErr(w, r, ErrNotFound("target", target), http.StatusNotFound)
return
case route.Type() == R.RouteTypeReverseProxy:
ok = IsSiteHealthy(route.URL().String())
case route.Type() == R.RouteTypeStream:
entry := route.Entry()
ok = IsStreamHealthy(
strings.Split(entry.Scheme, ":")[1], // target scheme
fmt.Sprintf("%s:%v", entry.Host, strings.Split(entry.Port, ":")[1]),
)
}
if ok {
if isHealthy {
w.WriteHeader(http.StatusOK)
} else {
w.WriteHeader(http.StatusRequestTimeout)
w.WriteHeader(http.StatusServiceUnavailable)
}
}

View file

@ -1,34 +0,0 @@
package v1
import (
"net"
"net/http"
U "github.com/yusing/go-proxy/internal/api/v1/utils"
"github.com/yusing/go-proxy/internal/common"
)
func IsSiteHealthy(url string) bool {
// try HEAD first
// if HEAD is not allowed, try GET
resp, err := U.Head(url)
if resp != nil {
resp.Body.Close()
}
if err != nil && resp != nil && resp.StatusCode == http.StatusMethodNotAllowed {
_, err = U.Get(url)
}
if resp != nil {
resp.Body.Close()
}
return err == nil
}
func IsStreamHealthy(scheme, address string) bool {
conn, err := net.DialTimeout(scheme, address, common.DialTimeout)
if err != nil {
return false
}
conn.Close()
return true
}

View file

@ -233,7 +233,7 @@ func (p *Provider) certState() CertState {
sort.Strings(certDomains)
if !reflect.DeepEqual(certDomains, wantedDomains) {
logger.Debugf("cert domains mismatch: %v != %v", certDomains, p.cfg.Domains)
logger.Infof("cert domains mismatch: %v != %v", certDomains, p.cfg.Domains)
return CertStateMismatch
}

View file

@ -19,20 +19,14 @@ const (
ConfigPath = ConfigBasePath + "/" + ConfigFileName
MiddlewareComposeBasePath = ConfigBasePath + "/middlewares"
)
const (
SchemaBasePath = "schema"
ConfigSchemaPath = SchemaBasePath + "/config.schema.json"
FileProviderSchemaPath = SchemaBasePath + "/providers.schema.json"
)
const (
ComposeFileName = "compose.yml"
ComposeExampleFileName = "compose.example.yml"
)
const (
ErrorPagesBasePath = "error_pages"
)
@ -46,6 +40,9 @@ var RequiredDirectories = []string{
const DockerHostFromEnv = "$DOCKER_HOST"
const (
HealthCheckIntervalDefault = 5 * time.Second
HealthCheckTimeoutDefault = 5 * time.Second
IdleTimeoutDefault = "0"
WakeTimeoutDefault = "30s"
StopTimeoutDefault = "10s"

View file

@ -103,7 +103,7 @@ func (cfg *Config) WatchChanges() {
case <-cfg.watcherCtx.Done():
return
case <-cfg.reloadReq:
if err := cfg.Reload(); err.HasError() {
if err := cfg.Reload(); err != nil {
cfg.l.Error(err)
}
}
@ -130,9 +130,9 @@ func (cfg *Config) WatchChanges() {
}()
}
func (cfg *Config) forEachRoute(do func(alias string, r R.Route, p *PR.Provider)) {
func (cfg *Config) forEachRoute(do func(alias string, r *R.Route, p *PR.Provider)) {
cfg.proxyProviders.RangeAll(func(_ string, p *PR.Provider) {
p.RangeRoutes(func(a string, r R.Route) {
p.RangeRoutes(func(a string, r *R.Route) {
do(a, r, p)
})
})
@ -146,20 +146,20 @@ func (cfg *Config) load() (res E.NestedError) {
defer cfg.l.Debug("loaded config")
data, err := E.Check(os.ReadFile(common.ConfigPath))
if err.HasError() {
if err != nil {
b.Add(E.FailWith("read config", err))
logrus.Fatal(b.Build())
}
if !common.NoSchemaValidation {
if err = Validate(data); err.HasError() {
if err = Validate(data); err != nil {
b.Add(E.FailWith("schema validation", err))
logrus.Fatal(b.Build())
}
}
model := types.DefaultConfig()
if err := E.From(yaml.Unmarshal(data, model)); err.HasError() {
if err := E.From(yaml.Unmarshal(data, model)); err != nil {
b.Add(E.FailWith("parse config", err))
logrus.Fatal(b.Build())
}
@ -182,7 +182,7 @@ func (cfg *Config) initAutoCert(autocertCfg *types.AutoCertConfig) (err E.Nested
defer cfg.l.Debug("initialized autocert")
cfg.autocertProvider, err = autocert.NewConfig(autocertCfg).GetProvider()
if err.HasError() {
if err != nil {
err = E.FailWith("autocert provider", err)
}
return
@ -220,12 +220,12 @@ func (cfg *Config) controlProviders(action string, do func(*PR.Provider) E.Neste
errors := E.NewBuilder("errors in %s these providers", action)
cfg.proxyProviders.RangeAllParallel(func(name string, p *PR.Provider) {
if err := do(p); err.HasError() {
if err := do(p); err != nil {
errors.Add(err.Subject(p))
}
})
if err := errors.Build(); err.HasError() {
if err := errors.Build(); err != nil {
cfg.l.Error(err)
}
}

View file

@ -5,7 +5,7 @@ import (
"strings"
"github.com/yusing/go-proxy/internal/common"
H "github.com/yusing/go-proxy/internal/homepage"
"github.com/yusing/go-proxy/internal/homepage"
PR "github.com/yusing/go-proxy/internal/proxy/provider"
R "github.com/yusing/go-proxy/internal/route"
"github.com/yusing/go-proxy/internal/types"
@ -15,8 +15,8 @@ import (
func (cfg *Config) DumpEntries() map[string]*types.RawEntry {
entries := make(map[string]*types.RawEntry)
cfg.forEachRoute(func(alias string, r R.Route, p *PR.Provider) {
entries[alias] = r.Entry()
cfg.forEachRoute(func(alias string, r *R.Route, p *PR.Provider) {
entries[alias] = r.Entry
})
return entries
}
@ -29,7 +29,7 @@ func (cfg *Config) DumpProviders() map[string]*PR.Provider {
return entries
}
func (cfg *Config) HomepageConfig() H.HomePageConfig {
func (cfg *Config) HomepageConfig() homepage.Config {
var proto, port string
domains := cfg.value.MatchDomains
cert, _ := cfg.autocertProvider.GetCert(nil)
@ -41,16 +41,16 @@ func (cfg *Config) HomepageConfig() H.HomePageConfig {
port = common.ProxyHTTPPort
}
hpCfg := H.NewHomePageConfig()
cfg.forEachRoute(func(alias string, r R.Route, p *PR.Provider) {
hpCfg := homepage.NewHomePageConfig()
cfg.forEachRoute(func(alias string, r *R.Route, p *PR.Provider) {
if !r.Started() {
return
}
entry := r.Entry()
entry := r.Entry
if entry.Homepage == nil {
entry.Homepage = &H.HomePageItem{
Show: r.Entry().IsExplicit || !p.IsExplicitOnly(),
entry.Homepage = &homepage.Item{
Show: r.Entry.IsExplicit || !p.IsExplicitOnly(),
}
}
@ -60,7 +60,7 @@ func (cfg *Config) HomepageConfig() H.HomePageConfig {
item.Show = true
}
if !item.Show || r.Type() != R.RouteTypeReverseProxy {
if !item.Show || r.Type != R.RouteTypeReverseProxy {
return
}
@ -99,19 +99,19 @@ func (cfg *Config) HomepageConfig() H.HomePageConfig {
func (cfg *Config) RoutesByAlias() map[string]U.SerializedObject {
routes := make(map[string]U.SerializedObject)
cfg.forEachRoute(func(alias string, r R.Route, p *PR.Provider) {
cfg.forEachRoute(func(alias string, r *R.Route, p *PR.Provider) {
if !r.Started() {
return
}
obj, err := U.Serialize(r)
if err.HasError() {
if err != nil {
cfg.l.Error(err)
return
}
obj["provider"] = p.GetName()
obj["type"] = string(r.Type())
obj["type"] = string(r.Type)
obj["started"] = r.Started()
obj["raw"] = r.Entry()
obj["raw"] = r.Entry
routes[alias] = obj
})
return routes
@ -138,9 +138,9 @@ func (cfg *Config) Statistics() map[string]any {
}
}
func (cfg *Config) FindRoute(alias string) R.Route {
func (cfg *Config) FindRoute(alias string) *R.Route {
return F.MapFind(cfg.proxyProviders,
func(p *PR.Provider) (R.Route, bool) {
func(p *PR.Provider) (*R.Route, bool) {
if route, ok := p.GetRoute(alias); ok {
return route, true
}

View file

@ -105,7 +105,7 @@ func ParseLabel(label string, value string) (*Label, E.NestedError) {
default:
l.Attribute = parts[2]
nestedLabel, err := ParseLabel(strings.Join(parts[3:], "."), value)
if err.HasError() {
if err != nil {
return nil, err
}
l.Value = nestedLabel

View file

@ -40,6 +40,8 @@ func (b Builder) Addf(format string, args ...any) Builder {
}
func (b Builder) AddRangeE(errs ...error) Builder {
b.Lock()
defer b.Unlock()
for _, err := range errs {
b.AddE(err)
}

View file

@ -42,7 +42,6 @@ func TestErrorNestedIs(t *testing.T) {
func TestIsNil(t *testing.T) {
var err NestedError
ExpectTrue(t, err.Is(nil))
ExpectFalse(t, err.HasError())
ExpectTrue(t, err == nil)
ExpectTrue(t, err.NoError())

View file

@ -1,24 +1,24 @@
package homepage
type (
HomePageConfig map[string]HomePageCategory
HomePageCategory []*HomePageItem
Config map[string]Category
Category []*Item
HomePageItem struct {
Show bool `yaml:"show" json:"show"`
Name string `yaml:"name" json:"name"`
Icon string `yaml:"icon" json:"icon"`
URL string `yaml:"url" json:"url"` // alias + domain
Category string `yaml:"category" json:"category"`
Description string `yaml:"description" json:"description"`
WidgetConfig map[string]any `yaml:",flow" json:"widget_config"`
Item struct {
Show bool `json:"show" yaml:"show"`
Name string `json:"name" yaml:"name"`
Icon string `json:"icon" yaml:"icon"`
URL string `json:"url" yaml:"url"` // alias + domain
Category string `json:"category" yaml:"category"`
Description string `json:"description" yaml:"description"`
WidgetConfig map[string]any `json:"widget_config" yaml:",flow"`
SourceType string `yaml:"-" json:"source_type"`
AltURL string `yaml:"-" json:"alt_url"` // original proxy target
SourceType string `json:"source_type" yaml:"-"`
AltURL string `json:"alt_url" yaml:"-"` // original proxy target
}
)
func (item *HomePageItem) IsEmpty() bool {
func (item *Item) IsEmpty() bool {
return item == nil || (item.Name == "" &&
item.Icon == "" &&
item.URL == "" &&
@ -27,17 +27,17 @@ func (item *HomePageItem) IsEmpty() bool {
len(item.WidgetConfig) == 0)
}
func NewHomePageConfig() HomePageConfig {
return HomePageConfig(make(map[string]HomePageCategory))
func NewHomePageConfig() Config {
return Config(make(map[string]Category))
}
func (c *HomePageConfig) Clear() {
*c = make(HomePageConfig)
func (c *Config) Clear() {
*c = make(Config)
}
func (c HomePageConfig) Add(item *HomePageItem) {
func (c Config) Add(item *Item) {
if c[item.Category] == nil {
c[item.Category] = make(HomePageCategory, 0)
c[item.Category] = make(Category, 0)
}
c[item.Category] = append(c[item.Category], item)
}

View file

@ -4,15 +4,40 @@ import (
"hash/fnv"
"net"
"net/http"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/net/http/middleware"
)
type ipHash struct{ *LoadBalancer }
type ipHash struct {
*LoadBalancer
realIP *middleware.Middleware
}
func (lb *LoadBalancer) newIPHash() impl { return &ipHash{lb} }
func (lb *LoadBalancer) newIPHash() impl {
impl := &ipHash{LoadBalancer: lb}
if len(lb.Options) == 0 {
return impl
}
var err E.NestedError
impl.realIP, err = middleware.NewRealIP(lb.Options)
if err != nil {
logger.Errorf("loadbalancer %s invalid real_ip options: %s, ignoring", lb.Link, err)
}
return impl
}
func (ipHash) OnAddServer(srv *Server) {}
func (ipHash) OnRemoveServer(srv *Server) {}
func (impl ipHash) ServeHTTP(_ servers, rw http.ResponseWriter, r *http.Request) {
if impl.realIP != nil {
impl.realIP.ModifyRequest(impl.serveHTTP, rw, r)
} else {
impl.serveHTTP(rw, r)
}
}
func (impl ipHash) serveHTTP(rw http.ResponseWriter, r *http.Request) {
ip, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
http.Error(rw, "Internal error", http.StatusInternalServerError)
@ -20,7 +45,7 @@ func (impl ipHash) ServeHTTP(_ servers, rw http.ResponseWriter, r *http.Request)
return
}
idx := hashIP(ip) % uint32(len(impl.pool))
if !impl.pool[idx].available.Load() {
if !impl.pool[idx].IsHealthy() {
http.Error(rw, "Service unavailable", http.StatusServiceUnavailable)
}
impl.pool[idx].handler.ServeHTTP(rw, r)

View file

@ -1,13 +1,12 @@
package loadbalancer
import (
"context"
"net/http"
"sync"
"time"
"github.com/go-acme/lego/v4/log"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/net/http/middleware"
)
// TODO: stats of each server.
@ -19,20 +18,17 @@ type (
OnRemoveServer(srv *Server)
}
Config struct {
Link string
Mode Mode
Weight weightType
Link string `json:"link" yaml:"link"`
Mode Mode `json:"mode" yaml:"mode"`
Weight weightType `json:"weight" yaml:"weight"`
Options middleware.OptionsRaw `json:"options,omitempty" yaml:"options,omitempty"`
}
LoadBalancer struct {
impl
Config
pool servers
poolMu sync.RWMutex
ctx context.Context
cancel context.CancelFunc
done chan struct{}
poolMu sync.Mutex
sumWeight weightType
}
@ -73,8 +69,8 @@ func (lb *LoadBalancer) AddServer(srv *Server) {
}
func (lb *LoadBalancer) RemoveServer(srv *Server) {
lb.poolMu.RLock()
defer lb.poolMu.RUnlock()
lb.poolMu.Lock()
defer lb.poolMu.Unlock()
lb.impl.OnRemoveServer(srv)
@ -85,7 +81,7 @@ func (lb *LoadBalancer) RemoveServer(srv *Server) {
}
}
if lb.IsEmpty() {
lb.Stop()
lb.pool = nil
return
}
@ -171,54 +167,12 @@ func (lb *LoadBalancer) Start() {
if lb.sumWeight != 0 {
log.Warnf("weighted mode not supported yet")
}
lb.done = make(chan struct{}, 1)
lb.ctx, lb.cancel = context.WithCancel(context.Background())
updateAll := func() {
lb.poolMu.Lock()
defer lb.poolMu.Unlock()
var wg sync.WaitGroup
wg.Add(len(lb.pool))
for _, s := range lb.pool {
go func(s *Server) {
defer wg.Done()
s.checkUpdateAvail(lb.ctx)
}(s)
}
wg.Wait()
}
logger.Debugf("loadbalancer %s started", lb.Link)
go func() {
defer lb.cancel()
defer close(lb.done)
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
updateAll()
for {
select {
case <-lb.ctx.Done():
return
case <-ticker.C:
updateAll()
}
}
}()
}
func (lb *LoadBalancer) Stop() {
if lb.cancel == nil {
return
}
lb.cancel()
<-lb.done
lb.poolMu.Lock()
defer lb.poolMu.Unlock()
lb.pool = nil
logger.Debugf("loadbalancer %s stopped", lb.Link)
@ -228,9 +182,9 @@ func (lb *LoadBalancer) availServers() servers {
lb.poolMu.Lock()
defer lb.poolMu.Unlock()
avail := servers{}
avail := make(servers, 0, len(lb.pool))
for _, s := range lb.pool {
if s.available.Load() {
if s.IsHealthy() {
avail = append(avail, s)
}
}

View file

@ -1,67 +1,42 @@
package loadbalancer
import (
"context"
"net/http"
"sync/atomic"
"time"
"github.com/yusing/go-proxy/internal/net/types"
U "github.com/yusing/go-proxy/internal/utils"
"github.com/yusing/go-proxy/internal/watcher/health"
)
type (
Server struct {
Name string
URL types.URL
Weight weightType
handler http.Handler
_ U.NoCopy
pinger *http.Client
available atomic.Bool
Name string
URL types.URL
Weight weightType
handler http.Handler
healthMon health.HealthMonitor
}
servers []*Server
)
func NewServer(name string, url types.URL, weight weightType, handler http.Handler) *Server {
func NewServer(name string, url types.URL, weight weightType, handler http.Handler, healthMon health.HealthMonitor) *Server {
srv := &Server{
Name: name,
URL: url,
Weight: weight,
handler: handler,
pinger: &http.Client{Timeout: 3 * time.Second},
Name: name,
URL: url,
Weight: weight,
handler: handler,
healthMon: healthMon,
}
srv.available.Store(true)
return srv
}
func (srv *Server) checkUpdateAvail(ctx context.Context) {
req, err := http.NewRequestWithContext(
ctx,
http.MethodHead,
srv.URL.String(),
nil,
)
if err != nil {
logger.Error("failed to create request: ", err)
srv.available.Store(false)
}
resp, err := srv.pinger.Do(req)
if err == nil && resp.StatusCode != http.StatusServiceUnavailable {
if !srv.available.Swap(true) {
logger.Infof("server %s is up", srv.Name)
}
} else if err != nil {
if srv.available.Swap(false) {
logger.Warnf("server %s is down: %s", srv.Name, err)
}
} else {
if srv.available.Swap(false) {
logger.Warnf("server %s is down: status %s", srv.Name, resp.Status)
}
}
}
func (srv *Server) String() string {
return srv.Name
}
func (srv *Server) IsHealthy() bool {
return srv.healthMon.IsHealthy()
}

View file

@ -30,6 +30,8 @@ type (
Options any
Middleware struct {
_ U.NoCopy
name string
before BeforeFunc // runs before ReverseProxy.ServeHTTP
@ -77,30 +79,37 @@ func (m *Middleware) MarshalJSON() ([]byte, error) {
func (m *Middleware) WithOptionsClone(optsRaw OptionsRaw) (*Middleware, E.NestedError) {
if len(optsRaw) != 0 && m.withOptions != nil {
if mWithOpt, err := m.withOptions(optsRaw); err != nil {
return nil, err
} else {
return mWithOpt, nil
}
return m.withOptions(optsRaw)
}
// WithOptionsClone is called only once
// set withOptions and labelParser will not be used after that
return &Middleware{
m.name,
m.before,
m.modifyResponse,
nil,
m.impl,
m.parent,
m.children,
false,
name: m.name,
before: m.before,
modifyResponse: m.modifyResponse,
impl: m.impl,
parent: m.parent,
children: m.children,
}, nil
}
// TODO: check conflict or duplicates
func PatchReverseProxy(rpName string, rp *ReverseProxy, middlewaresMap map[string]OptionsRaw) (res E.NestedError) {
middlewares := make([]*Middleware, 0, len(middlewaresMap))
func (m *Middleware) ModifyRequest(next http.HandlerFunc, w ResponseWriter, r *Request) {
if m.before != nil {
m.before(next, w, r)
}
}
func (m *Middleware) ModifyResponse(resp *Response) error {
if m.modifyResponse != nil {
return m.modifyResponse(resp)
}
return nil
}
// TODO: check conflict or duplicates.
func createMiddlewares(middlewaresMap map[string]OptionsRaw) (middlewares []*Middleware, res E.NestedError) {
middlewares = make([]*Middleware, 0, len(middlewaresMap))
invalidM := E.NewBuilder("invalid middlewares")
invalidOpts := E.NewBuilder("invalid options")
@ -124,10 +133,15 @@ func PatchReverseProxy(rpName string, rp *ReverseProxy, middlewaresMap map[strin
middlewares = append(middlewares, m)
}
if invalidM.HasError() {
return
}
func PatchReverseProxy(rpName string, rp *ReverseProxy, middlewaresMap map[string]OptionsRaw) (err E.NestedError) {
var middlewares []*Middleware
middlewares, err = createMiddlewares(middlewaresMap)
if err != nil {
return
}
patchReverseProxy(rpName, rp, middlewares)
return
}

View file

@ -114,7 +114,7 @@ func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.N
} else {
proxyURL, _ = url.Parse("https://" + testHost) // dummy url, no actual effect
}
rp := gphttp.NewReverseProxy(types.NewURL(proxyURL), &rr)
rp := gphttp.NewReverseProxy("test", types.NewURL(proxyURL), &rr)
mid, setOptErr := middleware.WithOptionsClone(args.middlewareOpt)
if setOptErr != nil {
return nil, setOptErr

View file

@ -86,7 +86,8 @@ type ReverseProxy struct {
ServeHTTP http.HandlerFunc
TargetURL types.URL
TargetName string
TargetURL types.URL
}
func singleJoiningSlash(a, b string) string {
@ -144,11 +145,11 @@ func joinURLPath(a, b *url.URL) (path, rawpath string) {
// }
//
func NewReverseProxy(target types.URL, transport http.RoundTripper) *ReverseProxy {
func NewReverseProxy(name string, target types.URL, transport http.RoundTripper) *ReverseProxy {
if transport == nil {
panic("nil transport")
}
rp := &ReverseProxy{Transport: transport, TargetURL: target}
rp := &ReverseProxy{Transport: transport, TargetName: name, TargetURL: target}
rp.ServeHTTP = rp.serveHTTP
return rp
}
@ -194,9 +195,9 @@ func (p *ReverseProxy) errorHandler(rw http.ResponseWriter, r *http.Request, err
switch {
case errors.Is(err, context.Canceled),
errors.Is(err, io.EOF):
logger.Debugf("http proxy to %s error: %s", r.URL.String(), err)
logger.Debugf("http proxy to %s(%s) error: %s", p.TargetName, r.URL.String(), err)
default:
logger.Errorf("http proxy to %s error: %s", r.URL.String(), err)
logger.Errorf("http proxy to %s(%s) error: %s", p.TargetName, r.URL.String(), err)
}
if writeHeader {
rw.WriteHeader(http.StatusBadGateway)

View file

@ -9,20 +9,21 @@ import (
type CIDR net.IPNet
func (*CIDR) ConvertFrom(val any) (any, E.NestedError) {
cidr, ok := val.(string)
func (cidr *CIDR) ConvertFrom(val any) E.NestedError {
cidrStr, ok := val.(string)
if !ok {
return nil, E.TypeMismatch[string](val)
return E.TypeMismatch[string](val)
}
if !strings.Contains(cidr, "/") {
cidr += "/32" // single IP
if !strings.Contains(cidrStr, "/") {
cidrStr += "/32" // single IP
}
_, ipnet, err := net.ParseCIDR(cidr)
_, ipnet, err := net.ParseCIDR(cidrStr)
if err != nil {
return nil, E.Invalid("CIDR", cidr)
return E.Invalid("CIDR", cidr)
}
return (*CIDR)(ipnet), nil
*cidr = CIDR(*ipnet)
return nil
}
func (cidr *CIDR) Contains(ip net.IP) bool {

View file

@ -1,10 +1,22 @@
package types
import "net/url"
import (
urlPkg "net/url"
)
type URL struct{ *url.URL }
type URL struct {
*urlPkg.URL
}
func NewURL(url *url.URL) URL {
func ParseURL(url string) (URL, error) {
u, err := urlPkg.Parse(url)
if err != nil {
return URL{}, err
}
return URL{URL: u}, nil
}
func NewURL(url *urlPkg.URL) URL {
return URL{url}
}
@ -19,6 +31,10 @@ func (u URL) MarshalText() (text []byte, err error) {
return []byte(u.String()), nil
}
func (u URL) Equals(other URL) bool {
func (u URL) Equals(other *URL) bool {
return u.URL == other.URL || u.String() == other.String()
}
func (u URL) JoinPath(path string) URL {
return URL{u.URL.JoinPath(path)}
}

View file

@ -11,17 +11,19 @@ import (
net "github.com/yusing/go-proxy/internal/net/types"
T "github.com/yusing/go-proxy/internal/proxy/fields"
"github.com/yusing/go-proxy/internal/types"
"github.com/yusing/go-proxy/internal/watcher/health"
)
type (
ReverseProxyEntry struct { // real model after validation
Alias T.Alias `json:"alias"`
Scheme T.Scheme `json:"scheme"`
URL net.URL `json:"url"`
NoTLSVerify bool `json:"no_tls_verify"`
PathPatterns T.PathPatterns `json:"path_patterns"`
LoadBalance loadbalancer.Config `json:"load_balance"`
Middlewares D.NestedLabelMap `json:"middlewares"`
Alias T.Alias `json:"alias"`
Scheme T.Scheme `json:"scheme"`
URL net.URL `json:"url"`
NoTLSVerify bool `json:"no_tls_verify"`
PathPatterns T.PathPatterns `json:"path_patterns"`
HealthCheck health.HealthCheckConfig `json:"healthcheck"`
LoadBalance loadbalancer.Config `json:"load_balance"`
Middlewares D.NestedLabelMap `json:"middlewares"`
/* Docker only */
IdleTimeout time.Duration `json:"idle_timeout"`
@ -35,10 +37,11 @@ type (
ContainerRunning bool `json:"container_running"`
}
StreamEntry struct {
Alias T.Alias `json:"alias"`
Scheme T.StreamScheme `json:"scheme"`
Host T.Host `json:"host"`
Port T.StreamPort `json:"port"`
Alias T.Alias `json:"alias"`
Scheme T.StreamScheme `json:"scheme"`
Host T.Host `json:"host"`
Port T.StreamPort `json:"port"`
Healthcheck health.HealthCheckConfig `json:"healthcheck"`
}
)
@ -58,7 +61,7 @@ func ValidateEntry(m *types.RawEntry) (any, E.NestedError) {
m.FillMissingFields()
scheme, err := T.NewScheme(m.Scheme)
if err.HasError() {
if err != nil {
return nil, err
}
@ -69,7 +72,7 @@ func ValidateEntry(m *types.RawEntry) (any, E.NestedError) {
} else {
entry = validateRPEntry(m, scheme, e)
}
if err := e.Build(); err.HasError() {
if err := e.Build(); err != nil {
return nil, err
}
return entry, nil
@ -107,7 +110,7 @@ func validateRPEntry(m *types.RawEntry, s T.Scheme, b E.Builder) *ReverseProxyEn
stopSignal, err := T.ValidateSignal(m.StopSignal)
b.Add(err)
if err.HasError() {
if err != nil {
return nil
}
@ -117,6 +120,7 @@ func validateRPEntry(m *types.RawEntry, s T.Scheme, b E.Builder) *ReverseProxyEn
URL: net.NewURL(url),
NoTLSVerify: m.NoTLSVerify,
PathPatterns: pathPatterns,
HealthCheck: m.HealthCheck,
LoadBalance: m.LoadBalance,
Middlewares: m.Middlewares,
IdleTimeout: idleTimeout,
@ -146,9 +150,10 @@ func validateStreamEntry(m *types.RawEntry, b E.Builder) *StreamEntry {
}
return &StreamEntry{
Alias: T.NewAlias(m.Alias),
Scheme: *scheme,
Host: host,
Port: port,
Alias: T.NewAlias(m.Alias),
Scheme: *scheme,
Host: host,
Port: port,
Healthcheck: m.HealthCheck,
}
}

View file

@ -1,24 +0,0 @@
package fields
import (
E "github.com/yusing/go-proxy/internal/error"
)
type PathMode string
func NewPathMode(pm string) (PathMode, E.NestedError) {
switch pm {
case "", "forward":
return PathMode(pm), nil
default:
return "", E.Invalid("path mode", pm)
}
}
func (p PathMode) IsRemove() bool {
return p == ""
}
func (p PathMode) IsForward() bool {
return p == "forward"
}

View file

@ -13,7 +13,7 @@ type (
var pathPattern = regexp.MustCompile(`^(/[-\w./]*({\$\})?|((GET|POST|DELETE|PUT|HEAD|OPTION) /[-\w./]*({\$\})?))$`)
func NewPathPattern(s string) (PathPattern, E.NestedError) {
func ValidatePathPattern(s string) (PathPattern, E.NestedError) {
if len(s) == 0 {
return "", E.Invalid("path", "must not be empty")
}
@ -29,7 +29,7 @@ func ValidatePathPatterns(s []string) (PathPatterns, E.NestedError) {
}
pp := make(PathPatterns, len(s))
for i, v := range s {
pattern, err := NewPathPattern(v)
pattern, err := ValidatePathPattern(v)
if err != nil {
return nil, err
}

View file

@ -37,11 +37,11 @@ var invalidPatterns = []string{
func TestPathPatternRegex(t *testing.T) {
for _, pattern := range validPatterns {
_, err := NewPathPattern(pattern)
_, err := ValidatePathPattern(pattern)
U.ExpectNoError(t, err.Error())
}
for _, pattern := range invalidPatterns {
_, err := NewPathPattern(pattern)
_, err := ValidatePathPattern(pattern)
U.ExpectError2(t, pattern, E.ErrInvalid, err.Error())
}
}

View file

@ -46,7 +46,7 @@ func (p *DockerProvider) LoadRoutesImpl() (routes R.Routes, err E.NestedError) {
entries := types.NewProxyEntries()
info, err := D.GetClientInfo(p.dockerHost, true)
if err.HasError() {
if err != nil {
return routes, E.FailWith("connect to docker", err)
}
@ -59,7 +59,7 @@ func (p *DockerProvider) LoadRoutesImpl() (routes R.Routes, err E.NestedError) {
}
newEntries, err := p.entriesFromContainerLabels(container)
if err.HasError() {
if err != nil {
errors.Add(err)
}
// although err is not nil
@ -98,9 +98,9 @@ func (p *DockerProvider) OnEvent(event W.Event, routes R.Routes) (res EventResul
b := E.NewBuilder("event %s error", event)
defer b.To(&res.err)
routes.RangeAll(func(k string, v R.Route) {
if v.Entry().ContainerID == event.ActorID ||
v.Entry().ContainerName == event.ActorName {
routes.RangeAll(func(k string, v *R.Route) {
if v.Entry.ContainerID == event.ActorID ||
v.Entry.ContainerName == event.ActorName {
b.Add(v.Stop())
routes.Delete(k)
res.nRemoved++
@ -115,7 +115,7 @@ func (p *DockerProvider) OnEvent(event W.Event, routes R.Routes) (res EventResul
b.Add(E.FailWith("rescan routes", err))
return
}
routesNew.Range(func(k string, v R.Route) bool {
routesNew.Range(func(k string, v *R.Route) bool {
if !routesOld.Has(k) {
routesOld.Store(k, v)
b.Add(v.Start())
@ -124,7 +124,7 @@ func (p *DockerProvider) OnEvent(event W.Event, routes R.Routes) (res EventResul
}
return true
})
routesOld.Range(func(k string, v R.Route) bool {
routesOld.Range(func(k string, v *R.Route) bool {
if !routesNew.Has(k) {
b.Add(v.Stop())
routesOld.Delete(k)
@ -137,13 +137,13 @@ func (p *DockerProvider) OnEvent(event W.Event, routes R.Routes) (res EventResul
}
client, err := D.ConnectClient(p.dockerHost)
if err.HasError() {
if err != nil {
b.Add(E.FailWith("connect to docker", err))
return
}
defer client.Close()
cont, err := client.Inspect(event.ActorID)
if err.HasError() {
if err != nil {
b.Add(E.FailWith("inspect container", err))
return
}
@ -159,7 +159,7 @@ func (p *DockerProvider) OnEvent(event W.Event, routes R.Routes) (res EventResul
if routes.Has(alias) {
b.Add(E.Duplicated("alias", alias))
} else {
if route, err := R.NewRoute(entry); err.HasError() {
if route, err := R.NewRoute(entry); err != nil {
b.Add(err)
} else {
routes.Store(alias, route)
@ -221,7 +221,7 @@ func (p *DockerProvider) applyLabel(container *D.Container, entries types.RawEnt
}
lbl, err := D.ParseLabel(key, val)
if err.HasError() {
if err != nil {
b.Add(err.Subject(key))
}
if lbl.Namespace != D.NSProxy {
@ -230,7 +230,7 @@ func (p *DockerProvider) applyLabel(container *D.Container, entries types.RawEnt
if lbl.Target == D.WildcardAlias {
// apply label for all aliases
entries.RangeAll(func(a string, e *types.RawEntry) {
if err = D.ApplyLabel(e, lbl); err.HasError() {
if err = D.ApplyLabel(e, lbl); err != nil {
b.Add(err)
}
})
@ -249,7 +249,7 @@ func (p *DockerProvider) applyLabel(container *D.Container, entries types.RawEnt
b.Add(E.NotExist("alias", lbl.Target))
return
}
if err = D.ApplyLabel(config, lbl); err.HasError() {
if err = D.ApplyLabel(config, lbl); err != nil {
b.Add(err)
}
}

View file

@ -15,8 +15,10 @@ import (
. "github.com/yusing/go-proxy/internal/utils/testing"
)
var dummyNames = []string{"/a"}
var p DockerProvider
var (
dummyNames = []string{"/a"}
p DockerProvider
)
func TestApplyLabelWildcard(t *testing.T) {
pathPatterns := `

View file

@ -47,19 +47,21 @@ func (p FileProvider) OnEvent(event W.Event, routes R.Routes) (res EventResult)
defer b.To(&res.err)
newRoutes, err := p.LoadRoutesImpl()
if err.HasError() {
if err != nil {
b.Add(err)
return
}
routes.RangeAllParallel(func(_ string, v R.Route) {
res.nRemoved = newRoutes.Size()
routes.RangeAllParallel(func(_ string, v *R.Route) {
b.Add(v.Stop())
})
routes.Clear()
newRoutes.RangeAllParallel(func(_ string, v R.Route) {
newRoutes.RangeAllParallel(func(_ string, v *R.Route) {
b.Add(v.Start())
})
res.nAdded = newRoutes.Size()
routes.MergeFrom(newRoutes)
return
@ -74,12 +76,12 @@ func (p *FileProvider) LoadRoutesImpl() (routes R.Routes, res E.NestedError) {
entries := types.NewProxyEntries()
data, err := E.Check(os.ReadFile(p.path))
if err.HasError() {
if err != nil {
b.Add(E.FailWith("read file", err))
return
}
if err = entries.UnmarshalFromYAML(data); err.HasError() {
if err = entries.UnmarshalFromYAML(data); err != nil {
b.Add(err)
return
}

View file

@ -111,7 +111,7 @@ func (p *Provider) StartAllRoutes() (res E.NestedError) {
// start watcher no matter load success or not
go p.watchEvents()
p.routes.RangeAllParallel(func(alias string, r R.Route) {
p.routes.RangeAllParallel(func(alias string, r *R.Route) {
errors.Add(r.Start().Subject(r))
})
return
@ -126,17 +126,17 @@ func (p *Provider) StopAllRoutes() (res E.NestedError) {
errors := E.NewBuilder("errors stopping routes")
defer errors.To(&res)
p.routes.RangeAllParallel(func(alias string, r R.Route) {
p.routes.RangeAllParallel(func(alias string, r *R.Route) {
errors.Add(r.Stop().Subject(r))
})
return
}
func (p *Provider) RangeRoutes(do func(string, R.Route)) {
func (p *Provider) RangeRoutes(do func(string, *R.Route)) {
p.routes.RangeAll(do)
}
func (p *Provider) GetRoute(alias string) (R.Route, bool) {
func (p *Provider) GetRoute(alias string) (*R.Route, bool) {
return p.routes.Load(alias)
}
@ -156,11 +156,11 @@ func (p *Provider) LoadRoutes() E.NestedError {
func (p *Provider) Statistics() ProviderStats {
numRPs := 0
numStreams := 0
p.routes.RangeAll(func(_ string, r R.Route) {
p.routes.RangeAll(func(_ string, r *R.Route) {
if !r.Started() {
return
}
switch r.Type() {
switch r.Type {
case R.RouteTypeReverseProxy:
numRPs++
case R.RouteTypeStream:
@ -187,9 +187,17 @@ func (p *Provider) watchEvents() {
res := p.OnEvent(event, p.routes)
l.Infof("%s event %q", event.Type, event)
if res.nAdded > 0 || res.nRemoved > 0 {
l.Infof("%d route added, %d routes removed", res.nAdded, res.nRemoved)
n := res.nAdded - res.nRemoved
switch {
case n == 0:
l.Infof("%d route(s) reloaded", res.nAdded)
case n > 0:
l.Infof("%d route(s) added", n)
default:
l.Infof("%d route(s) removed", -n)
}
}
if res.err.HasError() {
if res.err != nil {
l.Error(res.err)
}
case err := <-errs:

View file

@ -1,6 +1,7 @@
package route
import (
"context"
"errors"
"fmt"
"net/http"
@ -14,9 +15,11 @@ import (
gphttp "github.com/yusing/go-proxy/internal/net/http"
"github.com/yusing/go-proxy/internal/net/http/loadbalancer"
"github.com/yusing/go-proxy/internal/net/http/middleware"
url "github.com/yusing/go-proxy/internal/net/types"
P "github.com/yusing/go-proxy/internal/proxy"
PT "github.com/yusing/go-proxy/internal/proxy/fields"
F "github.com/yusing/go-proxy/internal/utils/functional"
"github.com/yusing/go-proxy/internal/watcher/health"
)
type (
@ -24,9 +27,10 @@ type (
*P.ReverseProxyEntry
LoadBalancer *loadbalancer.LoadBalancer `json:"load_balancer"`
server *loadbalancer.Server
handler http.Handler
rp *gphttp.ReverseProxy
healthMon health.HealthMonitor
server *loadbalancer.Server
handler http.Handler
rp *gphttp.ReverseProxy
}
SubdomainKey = PT.Alias
@ -65,7 +69,7 @@ func NewHTTPRoute(entry *P.ReverseProxyEntry) (*HTTPRoute, E.NestedError) {
trans = gphttp.DefaultTransport.Clone()
}
rp := gphttp.NewReverseProxy(entry.URL, trans)
rp := gphttp.NewReverseProxy(string(entry.Alias), entry.URL, trans)
if len(entry.Middlewares) > 0 {
err := middleware.PatchReverseProxy(string(entry.Alias), rp, entry.Middlewares)
@ -81,6 +85,18 @@ func NewHTTPRoute(entry *P.ReverseProxyEntry) (*HTTPRoute, E.NestedError) {
ReverseProxyEntry: entry,
rp: rp,
}
if entry.LoadBalance.Link != "" && entry.HealthCheck.Disabled {
logrus.Warnf("%s.healthCheck.disabled cannot be false when loadbalancer is enabled", entry.Alias)
entry.HealthCheck.Disabled = true
}
if !entry.HealthCheck.Disabled {
r.healthMon = health.NewHTTPHealthMonitor(
context.Background(),
string(entry.Alias),
entry.URL,
entry.HealthCheck,
)
}
return r, nil
}
@ -88,6 +104,10 @@ func (r *HTTPRoute) String() string {
return string(r.Alias)
}
func (r *HTTPRoute) URL() url.URL {
return r.ReverseProxyEntry.URL
}
func (r *HTTPRoute) Start() E.NestedError {
if r.handler != nil {
return nil
@ -118,24 +138,13 @@ func (r *HTTPRoute) Start() E.NestedError {
if r.LoadBalance.Link == "" {
httpRoutes.Store(string(r.Alias), r)
return nil
} else {
r.addToLoadBalancer()
}
var lb *loadbalancer.LoadBalancer
linked, ok := httpRoutes.Load(r.LoadBalance.Link)
if ok {
lb = linked.LoadBalancer
} else {
lb = loadbalancer.New(r.LoadBalance)
lb.Start()
linked = &HTTPRoute{
LoadBalancer: lb,
handler: lb,
}
httpRoutes.Store(r.LoadBalance.Link, linked)
if r.healthMon != nil {
r.healthMon.Start()
}
r.server = loadbalancer.NewServer(string(r.Alias), r.rp.TargetURL, r.LoadBalance.Weight, r.handler)
lb.AddServer(r.server)
return nil
}
@ -164,6 +173,10 @@ func (r *HTTPRoute) Stop() (_ E.NestedError) {
httpRoutes.Delete(string(r.Alias))
}
if r.healthMon != nil {
r.healthMon.Stop()
}
r.handler = nil
return
@ -173,8 +186,30 @@ func (r *HTTPRoute) Started() bool {
return r.handler != nil
}
func (r *HTTPRoute) addToLoadBalancer() {
var lb *loadbalancer.LoadBalancer
linked, ok := httpRoutes.Load(r.LoadBalance.Link)
if ok {
lb = linked.LoadBalancer
} else {
lb = loadbalancer.New(r.LoadBalance)
lb.Start()
linked = &HTTPRoute{
LoadBalancer: lb,
handler: lb,
}
httpRoutes.Store(r.LoadBalance.Link, linked)
}
r.server = loadbalancer.NewServer(string(r.Alias), r.rp.TargetURL, r.LoadBalance.Weight, r.handler, r.healthMon)
lb.AddServer(r.server)
}
func ProxyHandler(w http.ResponseWriter, r *http.Request) {
mux, err := findMuxFunc(r.Host)
// Why use StatusNotFound instead of StatusBadRequest or StatusBadGateway?
// On nginx, when route for domain does not exist, it returns StatusBadGateway.
// Then scraper / scanners will know the subdomain is invalid.
// With StatusNotFound, they won't know whether it's the path, or the subdomain that is invalid.
if err != nil {
if !middleware.ServeStaticErrorPageFile(w, r) {
logrus.Error(E.Failure("request").

View file

@ -1,35 +1,30 @@
package route
import (
"fmt"
"net/url"
E "github.com/yusing/go-proxy/internal/error"
url "github.com/yusing/go-proxy/internal/net/types"
P "github.com/yusing/go-proxy/internal/proxy"
"github.com/yusing/go-proxy/internal/types"
U "github.com/yusing/go-proxy/internal/utils"
F "github.com/yusing/go-proxy/internal/utils/functional"
)
type (
Route interface {
RouteImpl
Entry() *types.RawEntry
Type() RouteType
URL() *url.URL
RouteType string
Route struct {
_ U.NoCopy
impl
Type RouteType
Entry *types.RawEntry
}
Routes = F.Map[string, Route]
Routes = F.Map[string, *Route]
RouteImpl interface {
impl interface {
Start() E.NestedError
Stop() E.NestedError
Started() bool
String() string
}
RouteType string
route struct {
RouteImpl
type_ RouteType
entry *types.RawEntry
URL() url.URL
}
)
@ -38,44 +33,36 @@ const (
RouteTypeReverseProxy RouteType = "reverse_proxy"
)
// function alias
var NewRoutes = F.NewMapOf[string, Route]
// function alias.
var NewRoutes = F.NewMap[Routes]
func NewRoute(en *types.RawEntry) (Route, E.NestedError) {
func NewRoute(en *types.RawEntry) (*Route, E.NestedError) {
entry, err := P.ValidateEntry(en)
if err != nil {
return nil, err
}
var t RouteType
var rt RouteImpl
var rt impl
switch e := entry.(type) {
case *P.StreamEntry:
rt, err = NewStreamRoute(e)
t = RouteTypeStream
rt, err = NewStreamRoute(e)
case *P.ReverseProxyEntry:
rt, err = NewHTTPRoute(e)
t = RouteTypeReverseProxy
rt, err = NewHTTPRoute(e)
default:
panic("bug: should not reach here")
}
if err != nil {
return nil, err
}
return &route{RouteImpl: rt, entry: en, type_: t}, nil
}
func (rt *route) Entry() *types.RawEntry {
return rt.entry
}
func (rt *route) Type() RouteType {
return rt.type_
}
func (rt *route) URL() *url.URL {
url, _ := url.Parse(fmt.Sprintf("%s://%s:%s", rt.entry.Scheme, rt.entry.Host, rt.entry.Port))
return url
return &Route{
impl: rt,
Type: t,
Entry: en,
}, nil
}
func FromEntries(entries types.RawEntries) (Routes, E.NestedError) {
@ -85,7 +72,7 @@ func FromEntries(entries types.RawEntries) (Routes, E.NestedError) {
entries.RangeAll(func(alias string, entry *types.RawEntry) {
entry.Alias = alias
r, err := NewRoute(entry)
if err.HasError() {
if err != nil {
b.Add(err.Subject(alias))
} else {
routes.Store(alias, r)

View file

@ -10,14 +10,19 @@ import (
"github.com/sirupsen/logrus"
E "github.com/yusing/go-proxy/internal/error"
url "github.com/yusing/go-proxy/internal/net/types"
P "github.com/yusing/go-proxy/internal/proxy"
PT "github.com/yusing/go-proxy/internal/proxy/fields"
"github.com/yusing/go-proxy/internal/watcher/health"
)
type StreamRoute struct {
*P.StreamEntry
StreamImpl `json:"-"`
url url.URL
healthMon health.HealthMonitor
wg sync.WaitGroup
ctx context.Context
cancel context.CancelFunc
@ -40,8 +45,14 @@ func NewStreamRoute(entry *P.StreamEntry) (*StreamRoute, E.NestedError) {
if !entry.Scheme.IsCoherent() {
return nil, E.Unsupported("scheme", fmt.Sprintf("%v -> %v", entry.Scheme.ListeningScheme, entry.Scheme.ProxyScheme))
}
url, err := url.ParseURL(fmt.Sprintf("%s://%s:%d", entry.Scheme.ProxyScheme, entry.Host, entry.Port.ProxyPort))
if err != nil {
// !! should not happen
panic(err)
}
base := &StreamRoute{
StreamEntry: entry,
url: url,
connCh: make(chan any, 100),
}
if entry.Scheme.ListeningScheme.IsTCP() {
@ -49,6 +60,9 @@ func NewStreamRoute(entry *P.StreamEntry) (*StreamRoute, E.NestedError) {
} else {
base.StreamImpl = NewUDPRoute(base)
}
if !entry.Healthcheck.Disabled {
base.healthMon = health.NewRawHealthMonitor(base.ctx, string(entry.Alias), url, entry.Healthcheck)
}
base.l = logrus.WithField("route", base.StreamImpl)
return base, nil
}
@ -57,6 +71,10 @@ func (r *StreamRoute) String() string {
return fmt.Sprintf("%s stream: %s", r.Scheme, r.Alias)
}
func (r *StreamRoute) URL() url.URL {
return r.url
}
func (r *StreamRoute) Start() E.NestedError {
if r.Port.ProxyPort == PT.NoPort || r.started.Load() {
return nil
@ -71,6 +89,9 @@ func (r *StreamRoute) Start() E.NestedError {
r.wg.Add(2)
go r.grAcceptConnections()
go r.grHandleConnections()
if r.healthMon != nil {
r.healthMon.Start()
}
return nil
}
@ -78,7 +99,12 @@ func (r *StreamRoute) Stop() E.NestedError {
if !r.started.Load() {
return nil
}
l := r.l
r.started.Store(false)
if r.healthMon != nil {
r.healthMon.Stop()
}
r.cancel()
r.CloseListeners()
@ -92,7 +118,7 @@ func (r *StreamRoute) Stop() E.NestedError {
for {
select {
case <-done:
l.Debug("stopped listening")
r.l.Debug("stopped listening")
return nil
case <-timeout:
return E.FailedWhy("stop", "timed out")

View file

@ -27,7 +27,7 @@ type (
UDPConnMap = F.Map[string, *UDPConn]
)
var NewUDPConnMap = F.NewMapOf[string, *UDPConn]
var NewUDPConnMap = F.NewMap[UDPConnMap]
func NewUDPRoute(base *StreamRoute) StreamImpl {
return &UDPRoute{

View file

@ -5,11 +5,12 @@ import (
"strings"
"github.com/yusing/go-proxy/internal/common"
D "github.com/yusing/go-proxy/internal/docker"
H "github.com/yusing/go-proxy/internal/homepage"
"github.com/yusing/go-proxy/internal/docker"
"github.com/yusing/go-proxy/internal/homepage"
"github.com/yusing/go-proxy/internal/net/http/loadbalancer"
U "github.com/yusing/go-proxy/internal/utils"
F "github.com/yusing/go-proxy/internal/utils/functional"
"github.com/yusing/go-proxy/internal/watcher/health"
)
type (
@ -18,18 +19,19 @@ type (
// raw entry object before validation
// loaded from docker labels or yaml file
Alias string `json:"-" yaml:"-"`
Scheme string `json:"scheme" yaml:"scheme"`
Host string `json:"host" yaml:"host"`
Port string `json:"port" yaml:"port"`
NoTLSVerify bool `json:"no_tls_verify,omitempty" yaml:"no_tls_verify"` // https proxy only
PathPatterns []string `json:"path_patterns,omitempty" yaml:"path_patterns"` // http(s) proxy only
LoadBalance loadbalancer.Config `json:"load_balance" yaml:"load_balance"`
Middlewares D.NestedLabelMap `json:"middlewares,omitempty" yaml:"middlewares"`
Homepage *H.HomePageItem `json:"homepage,omitempty" yaml:"homepage"`
Alias string `json:"-" yaml:"-"`
Scheme string `json:"scheme" yaml:"scheme"`
Host string `json:"host" yaml:"host"`
Port string `json:"port" yaml:"port"`
NoTLSVerify bool `json:"no_tls_verify,omitempty" yaml:"no_tls_verify"` // https proxy only
PathPatterns []string `json:"path_patterns,omitempty" yaml:"path_patterns"` // http(s) proxy only
HealthCheck health.HealthCheckConfig `json:"healthcheck,omitempty" yaml:"healthcheck"`
LoadBalance loadbalancer.Config `json:"load_balance,omitempty" yaml:"load_balance"`
Middlewares docker.NestedLabelMap `json:"middlewares,omitempty" yaml:"middlewares"`
Homepage *homepage.Item `json:"homepage,omitempty" yaml:"homepage"`
/* Docker only */
*D.Container `json:"container" yaml:"-"`
*docker.Container `json:"container" yaml:"-"`
}
RawEntries = F.Map[string, *RawEntry]
@ -40,7 +42,7 @@ var NewProxyEntries = F.NewMapOf[string, *RawEntry]
func (e *RawEntry) FillMissingFields() {
isDocker := e.Container != nil
if !isDocker {
e.Container = &D.Container{}
e.Container = &docker.Container{}
}
if e.Host == "" {
@ -113,6 +115,9 @@ func (e *RawEntry) FillMissingFields() {
}
}
if e.HealthCheck.Interval == 0 {
e.HealthCheck.Interval = common.HealthCheckIntervalDefault
}
if e.IdleTimeout == "" {
e.IdleTimeout = common.IdleTimeoutDefault
}

View file

@ -24,6 +24,10 @@ func NewMapFrom[KT comparable, VT any](m map[KT]VT) (res Map[KT, VT]) {
return
}
func NewMap[MapType Map[KT, VT], KT comparable, VT any]() Map[KT, VT] {
return NewMapOf[KT, VT]()
}
// MapFind iterates over the map and returns the first value
// that satisfies the given criteria. The iteration is stopped
// once a value is found. If no value satisfies the criteria,
@ -161,7 +165,7 @@ func (m Map[KT, VT]) UnmarshalFromYAML(data []byte) E.NestedError {
return E.FailedWhy("unmarshal from yaml", "map is not empty")
}
tmp := make(map[KT]VT)
if err := E.From(yaml.Unmarshal(data, tmp)); err.HasError() {
if err := E.From(yaml.Unmarshal(data, tmp)); err != nil {
return err
}
for k, v := range tmp {

View file

@ -8,6 +8,7 @@ import (
"reflect"
"strconv"
"strings"
"time"
"unicode"
"github.com/santhosh-tekuri/jsonschema"
@ -18,7 +19,7 @@ import (
type (
SerializedObject = map[string]any
Converter interface {
ConvertFrom(value any) (any, E.NestedError)
ConvertFrom(value any) E.NestedError
}
)
@ -264,23 +265,10 @@ func Convert(src reflect.Value, dst reflect.Value) E.NestedError {
var ok bool
// check if (*T).Convertor is implemented
if converter, ok = dst.Addr().Interface().(Converter); !ok {
// check if (T).Convertor is implemented
converter, ok = dst.Interface().(Converter)
if !ok {
return E.TypeError("conversion", srcT, dstT)
}
return E.TypeError("conversion", srcT, dstT)
}
converted, err := converter.ConvertFrom(src.Interface())
if err != nil {
return err
}
c := reflect.ValueOf(converted)
if c.Kind() == reflect.Ptr {
c = c.Elem()
}
dst.Set(c)
return nil
return converter.ConvertFrom(src.Interface())
}
func ConvertString(src string, dst reflect.Value) (convertible bool, convErr E.NestedError) {
@ -295,6 +283,20 @@ func ConvertString(src string, dst reflect.Value) (convertible bool, convErr E.N
dst.SetString(src)
return
}
switch dst.Type() {
case reflect.TypeFor[time.Duration]():
if src == "" {
dst.Set(reflect.Zero(dst.Type()))
return
}
d, err := time.ParseDuration(src)
if err != nil {
convErr = E.Invalid("duration", src)
return
}
dst.Set(reflect.ValueOf(d))
return
}
// primitive types / simple types
switch dst.Kind() {
case reflect.Bool:

View file

@ -4,6 +4,7 @@ import (
"reflect"
"testing"
E "github.com/yusing/go-proxy/internal/error"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
@ -102,3 +103,48 @@ func TestStringIntConvert(t *testing.T) {
ExpectNoError(t, err.Error())
ExpectEqual(t, test.u64, uint64(127))
}
type testModel struct {
Test testType
}
type testType struct {
foo int
bar string
}
func (c *testType) ConvertFrom(v any) E.NestedError {
switch v := v.(type) {
case string:
c.bar = v
return nil
case int:
c.foo = v
return nil
default:
return E.Invalid("input type", v)
}
}
func TestConvertor(t *testing.T) {
t.Run("string", func(t *testing.T) {
m := new(testModel)
ExpectNoError(t, Deserialize(map[string]any{"Test": "bar"}, m).Error())
ExpectEqual(t, m.Test.foo, 0)
ExpectEqual(t, m.Test.bar, "bar")
})
t.Run("int", func(t *testing.T) {
m := new(testModel)
ExpectNoError(t, Deserialize(map[string]any{"Test": 123}, m).Error())
ExpectEqual(t, m.Test.foo, 123)
ExpectEqual(t, m.Test.bar, "")
})
t.Run("invalid", func(t *testing.T) {
m := new(testModel)
ExpectError(t, E.ErrInvalid, Deserialize(map[string]any{"Test": 123.456}, m).Error())
})
}

View file

@ -26,6 +26,14 @@ type DirWatcher struct {
ctx context.Context
}
// NewDirectoryWatcher returns a DirWatcher instance.
//
// The DirWatcher watches the given directory for file system events.
// Currently, only events on files directly in the given directory are watched, not
// recursively.
//
// Note that the returned DirWatcher is not ready to use until the goroutine
// started by NewDirectoryWatcher has finished.
func NewDirectoryWatcher(ctx context.Context, dirPath string) *DirWatcher {
//! subdirectories are not watched
w, err := fsnotify.NewWatcher()
@ -70,16 +78,8 @@ func (h *DirWatcher) Add(relPath string) Watcher {
close(s.eventCh)
close(s.errCh)
}()
for {
select {
case <-h.ctx.Done():
return
case _, ok := <-h.eventCh:
if !ok { // directory watcher closed
return
}
}
}
<-h.ctx.Done()
logrus.Debugf("file watcher %s stopped", relPath)
}()
h.fwMap.Store(relPath, s)
return s
@ -88,6 +88,7 @@ func (h *DirWatcher) Add(relPath string) Watcher {
func (h *DirWatcher) start() {
defer close(h.eventCh)
defer h.w.Close()
defer logrus.Debugf("directory watcher %s stopped", h.dir)
for {
select {
@ -121,7 +122,9 @@ func (h *DirWatcher) start() {
// send event to directory watcher
select {
case h.eventCh <- msg:
logrus.Debugf("sent event to directory watcher %s", h.dir)
default:
logrus.Debugf("failed to send event to directory watcher %s", h.dir)
}
// send event to file watcher too
@ -129,8 +132,12 @@ func (h *DirWatcher) start() {
if ok {
select {
case w.eventCh <- msg:
logrus.Debugf("sent event to file watcher %s", relPath)
default:
logrus.Debugf("failed to send event to file watcher %s", relPath)
}
} else {
logrus.Debugf("file watcher not found: %s", relPath)
}
case err := <-h.w.Errors:
if errors.Is(err, fsnotify.ErrClosed) {

View file

@ -0,0 +1,22 @@
package health
import (
"time"
"github.com/yusing/go-proxy/internal/common"
)
type HealthCheckConfig struct {
Disabled bool `json:"disabled" yaml:"disabled"`
Path string `json:"path" yaml:"path"`
UseGet bool `json:"use_get" yaml:"use_get"`
Interval time.Duration `json:"interval" yaml:"interval"`
Timeout time.Duration `json:"timeout" yaml:"timeout"`
}
func DefaultHealthCheckConfig() HealthCheckConfig {
return HealthCheckConfig{
Interval: common.HealthCheckIntervalDefault,
Timeout: common.HealthCheckTimeoutDefault,
}
}

View file

@ -0,0 +1,63 @@
package health
import (
"context"
"crypto/tls"
"errors"
"net/http"
"github.com/yusing/go-proxy/internal/net/types"
)
type HTTPHealthMonitor struct {
*monitor
method string
pinger *http.Client
}
func NewHTTPHealthMonitor(ctx context.Context, name string, url types.URL, config HealthCheckConfig) HealthMonitor {
mon := new(HTTPHealthMonitor)
mon.monitor = newMonitor(ctx, name, url, &config, mon.checkHealth)
mon.pinger = &http.Client{Timeout: config.Timeout}
if config.UseGet {
mon.method = http.MethodGet
} else {
mon.method = http.MethodHead
}
return mon
}
func (mon *HTTPHealthMonitor) checkHealth() (healthy bool, detail string, err error) {
req, reqErr := http.NewRequestWithContext(
mon.ctx,
mon.method,
mon.URL.String(),
nil,
)
if reqErr != nil {
err = reqErr
return
}
req.Header.Set("Connection", "close")
resp, respErr := mon.pinger.Do(req)
if respErr == nil {
resp.Body.Close()
}
switch {
case respErr != nil:
// treat tls error as healthy
var tlsErr *tls.CertificateVerificationError
if ok := errors.As(respErr, &tlsErr); !ok {
detail = respErr.Error()
return
}
case resp.StatusCode == http.StatusServiceUnavailable:
detail = resp.Status
return
}
healthy = true
return
}

View file

@ -0,0 +1,5 @@
package health
import "github.com/sirupsen/logrus"
var logger = logrus.WithField("module", "health_mon")

View file

@ -0,0 +1,139 @@
package health
import (
"context"
"errors"
"sync"
"sync/atomic"
"time"
"github.com/yusing/go-proxy/internal/net/types"
F "github.com/yusing/go-proxy/internal/utils/functional"
)
type (
HealthMonitor interface {
Start()
Stop()
IsHealthy() bool
String() string
}
HealthCheckFunc func() (healthy bool, detail string, err error)
monitor struct {
Name string
URL types.URL
Interval time.Duration
healthy atomic.Bool
checkHealth HealthCheckFunc
ctx context.Context
cancel context.CancelFunc
done chan struct{}
mu sync.Mutex
}
)
var monMap = F.NewMapOf[string, HealthMonitor]()
func newMonitor(parentCtx context.Context, name string, url types.URL, config *HealthCheckConfig, healthCheckFunc HealthCheckFunc) *monitor {
if parentCtx == nil {
parentCtx = context.Background()
}
ctx, cancel := context.WithCancel(parentCtx)
mon := &monitor{
Name: name,
URL: url.JoinPath(config.Path),
Interval: config.Interval,
checkHealth: healthCheckFunc,
ctx: ctx,
cancel: cancel,
done: make(chan struct{}),
}
mon.healthy.Store(true)
monMap.Store(name, mon)
return mon
}
func IsHealthy(name string) (healthy bool, ok bool) {
mon, ok := monMap.Load(name)
if !ok {
return
}
return mon.IsHealthy(), true
}
func (mon *monitor) Start() {
go func() {
defer close(mon.done)
ok := mon.checkUpdateHealth()
if !ok {
return
}
ticker := time.NewTicker(mon.Interval)
defer ticker.Stop()
for {
select {
case <-mon.ctx.Done():
return
case <-ticker.C:
ok = mon.checkUpdateHealth()
if !ok {
return
}
}
}
}()
logger.Debugf("health monitor %q started", mon)
}
func (mon *monitor) Stop() {
defer logger.Debugf("health monitor %q stopped", mon)
monMap.Delete(mon.Name)
mon.mu.Lock()
defer mon.mu.Unlock()
if mon.cancel == nil {
return
}
mon.cancel()
<-mon.done
mon.cancel = nil
}
func (mon *monitor) IsHealthy() bool {
return mon.healthy.Load()
}
func (mon *monitor) String() string {
return mon.Name
}
func (mon *monitor) checkUpdateHealth() (hasError bool) {
healthy, detail, err := mon.checkHealth()
if err != nil {
mon.healthy.Store(false)
if !errors.Is(err, context.Canceled) {
logger.Errorf("server %q failed to check health: %s", mon, err)
}
mon.Stop()
return false
}
if healthy != mon.healthy.Swap(healthy) {
if healthy {
logger.Infof("server %q is up", mon)
} else {
logger.Warnf("server %q is down: %s", mon, detail)
}
}
return true
}

View file

@ -0,0 +1,37 @@
package health
import (
"context"
"net"
"github.com/yusing/go-proxy/internal/net/types"
)
type (
RawHealthMonitor struct {
*monitor
dialer *net.Dialer
}
)
func NewRawHealthMonitor(ctx context.Context, name string, url types.URL, config HealthCheckConfig) HealthMonitor {
mon := new(RawHealthMonitor)
mon.monitor = newMonitor(ctx, name, url, &config, mon.checkAvail)
mon.dialer = &net.Dialer{
Timeout: config.Timeout,
FallbackDelay: -1,
}
return mon
}
func (mon *RawHealthMonitor) checkAvail() (avail bool, detail string, err error) {
conn, dialErr := mon.dialer.DialContext(mon.ctx, mon.URL.Scheme, mon.URL.Host)
if dialErr != nil {
detail = dialErr.Error()
/* trunk-ignore(golangci-lint/nilerr) */
return
}
conn.Close()
avail = true
return
}

View file

@ -116,6 +116,61 @@
"type": "object"
}
}
},
"load_balance": {
"type": "object",
"properties": {
"link": {
"type": "string",
"description": "Name and subdomain of load-balancer",
"format": "uri"
},
"mode": {
"enum": [
"round_robin",
"least_conn",
"ip_hash"
],
"description": "Load-balance mode",
"default": "roundrobin"
},
"weight": {
"type": "integer",
"description": "Reserved for future use",
"minimum": 0,
"maximum": 100
},
"options": {
"type": "object",
"description": "load-balance mode specific options"
}
}
},
"healthcheck": {
"type": "object",
"properties": {
"disabled": {
"type": "boolean",
"default": false
},
"path": {
"type": "string",
"description": "Healthcheck path",
"default": "/",
"format": "uri"
},
"use_get": {
"type": "boolean",
"description": "Use GET instead of HEAD",
"default": false
},
"interval": {
"type": "string",
"description": "Interval for healthcheck (e.g. 5s, 1h25m30s)",
"pattern": "^([0-9]+(ms|s|m|h))+$",
"default": "5s"
}
}
}
},
"additionalProperties": false,