v0.5.0-rc5: check release

This commit is contained in:
yusing 2024-09-19 20:40:03 +08:00
parent be7a766cb2
commit 4a2d42bfa9
68 changed files with 1971 additions and 1107 deletions

View file

@ -15,6 +15,7 @@ A [lightweight](docs/benchmark_result.md), easy-to-use, and efficient reverse pr
- [go-proxy](#go-proxy)
- [Key Points](#key-points)
- [Getting Started](#getting-started)
- [Setup](#setup)
- [Commands line arguments](#commands-line-arguments)
- [Environment variables](#environment-variables)
- [Use JSON Schema in VSCode](#use-json-schema-in-vscode)
@ -27,10 +28,11 @@ A [lightweight](docs/benchmark_result.md), easy-to-use, and efficient reverse pr
- Easy to use
- Effortless configuration
- Error messages is clear and detailed
- Error messages is clear and detailed, easy troubleshooting
- Auto certificate obtaining and renewal (See [Supported DNS Challenge Providers](docs/dns_providers.md))
- Auto configuration for docker containers
- Auto hot-reload on container state / config file changes
- Stop containers on idle, wake it up on traffic _(optional)_
- Support HTTP(s), TCP and UDP
- Web UI for configuration and monitoring (See [screenshots](https://github.com/yusing/go-proxy-frontend?tab=readme-ov-file#screenshots))
- Written in **[Go](https://go.dev)**
@ -39,6 +41,8 @@ A [lightweight](docs/benchmark_result.md), easy-to-use, and efficient reverse pr
## Getting Started
### Setup
1. Setup DNS Records, e.g.
- A Record: `*.y.z` -> `10.0.10.1`
@ -60,6 +64,7 @@ A [lightweight](docs/benchmark_result.md), easy-to-use, and efficient reverse pr
| `validate` | validate config and exit | |
| `reload` | trigger a force reload of config | |
| `ls-config` | list config and exit | `go-proxy ls-config \| jq` |
| `ls-route` | list proxy entries and exit | `go-proxy ls-route \| jq` |
**run with `docker exec <container_name> /app/go-proxy <command>`**
@ -104,7 +109,7 @@ providers:
### Provider File
Fields are same as [docker labels](docs/docker.md#labels) starting from `scheme`
See [Fields](docs/docker.md#fields)
See [providers.example.yml](providers.example.yml) for examples

View file

@ -6,7 +6,7 @@ services:
network_mode: host
labels:
- proxy.aliases=gp
- proxy.gp.port=8888
- proxy.gp.port=3000
depends_on:
- app
app:

View file

@ -85,12 +85,17 @@
### Syntax
| Label | Description | Default |
| ----------------------- | -------------------------------------------------------- | ---------------- |
| `proxy.aliases` | comma separated aliases for subdomain and label matching | `container_name` |
| `proxy.exclude` | to be excluded from `go-proxy` | false |
| `proxy.<alias>.<field>` | set field for specific alias | N/A |
| `proxy.*.<field>` | set field for all aliases | N/A |
| Label | Description | Default | Accepted values |
| ----------------------- | --------------------------------------------------------------------- | -------------------- | ------------------------------------------------------------------------- |
| `proxy.aliases` | comma separated aliases for subdomain and label matching | `container_name` | any |
| `proxy.exclude` | to be excluded from `go-proxy` | false | boolean |
| `proxy.idle_timeout` | time for idle (no traffic) before put it into sleep **(http/s only)** | empty **(disabled)** | `number[unit]...`, e.g. `1m30s` |
| `proxy.wake_timeout` | time to wait for container to start before responding a loading page | empty | `number[unit]...` |
| `proxy.stop_method` | method to stop after `idle_timeout` | `stop` | `stop`, `pause`, `kill` |
| `proxy.stop_timeout` | time to wait for stop command | `10s` | `number[unit]...` |
| `proxy.stop_signal` | signal sent to container for `stop` and `kill` methods | docker's default | `SIGINT`, `SIGTERM`, `SIGHUP`, `SIGQUIT` and those without **SIG** prefix |
| `proxy.<alias>.<field>` | set field for specific alias | N/A | N/A |
| `proxy.*.<field>` | set field for all aliases | N/A | N/A |
### Fields
@ -228,12 +233,18 @@ services:
volumes:
- adg-work:/opt/adguardhome/work
- adg-conf:/opt/adguardhome/conf
ports:
- 80
- 3000
- 53
mc:
image: itzg/minecraft-server
tty: true
stdin_open: true
container_name: mc
restart: unless-stopped
ports:
- 25565
labels:
- proxy.mc.scheme=tcp
- proxy.mc.port=20001:25565
@ -246,6 +257,9 @@ services:
restart: unless-stopped
container_name: pal
stop_grace_period: 30s
ports:
- 8211
- 27015
labels:
- proxy.aliases=pal1,pal2
- proxy.*.scheme=udp
@ -261,6 +275,8 @@ services:
- nginx:/usr/share/nginx/html
ports:
- 80
labels:
proxy.idle_timeout: 1m
go-proxy:
image: ghcr.io/yusing/go-proxy:latest
container_name: go-proxy

View file

@ -3,6 +3,7 @@ package v1
import (
"fmt"
"net/http"
"strings"
U "github.com/yusing/go-proxy/api/v1/utils"
"github.com/yusing/go-proxy/config"
@ -17,17 +18,19 @@ func CheckHealth(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
}
var ok bool
route := cfg.FindRoute(target)
switch route := cfg.FindRoute(target).(type) {
case nil:
switch {
case route == nil:
U.HandleErr(w, r, U.ErrNotFound("target", target), http.StatusNotFound)
return
case *R.HTTPRoute:
ok = U.IsSiteHealthy(route.TargetURL.String())
case *R.StreamRoute:
case route.Type() == R.RouteTypeReverseProxy:
ok = U.IsSiteHealthy(route.URL().String())
case route.Type() == R.RouteTypeStream:
entry := route.Entry()
ok = U.IsStreamHealthy(
string(route.Scheme.ProxyScheme),
fmt.Sprintf("%s:%v", route.Host, route.Port.ProxyPort),
strings.Split(entry.Scheme, ":")[1], // target scheme
fmt.Sprintf("%s:%v", entry.Host, strings.Split(entry.Port, ":")[1]),
)
}

View file

@ -9,7 +9,6 @@ import (
U "github.com/yusing/go-proxy/api/v1/utils"
"github.com/yusing/go-proxy/common"
"github.com/yusing/go-proxy/config"
E "github.com/yusing/go-proxy/error"
"github.com/yusing/go-proxy/proxy/provider"
)
@ -32,25 +31,25 @@ func SetFileContent(w http.ResponseWriter, r *http.Request) {
U.HandleErr(w, r, U.ErrMissingKey("filename"), http.StatusBadRequest)
return
}
content, err := E.Check(io.ReadAll(r.Body))
if err.HasError() {
content, err := io.ReadAll(r.Body)
if err != nil {
U.HandleErr(w, r, err)
return
}
if filename == common.ConfigFileName {
err = config.Validate(content)
err = config.Validate(content).Error()
} else {
err = provider.Validate(content)
err = provider.Validate(content).Error()
}
if err.HasError() {
if err != nil {
U.HandleErr(w, r, err, http.StatusBadRequest)
return
}
err = E.From(os.WriteFile(path.Join(common.ConfigBasePath, filename), content, 0644))
if err.HasError() {
err = os.WriteFile(path.Join(common.ConfigBasePath, filename), content, 0644)
if err != nil {
U.HandleErr(w, r, err)
return
}

View file

@ -8,7 +8,7 @@ import (
)
func Reload(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
if err := cfg.Reload(); err.HasError() {
if err := cfg.Reload().Error(); err != nil {
U.HandleErr(w, r, err)
return
}

View file

@ -9,14 +9,14 @@ import (
E "github.com/yusing/go-proxy/error"
)
func HandleErr(w http.ResponseWriter, r *http.Request, err error, code ...int) {
err = E.From(err).Subjectf("%s %s", r.Method, r.URL)
func HandleErr(w http.ResponseWriter, r *http.Request, origErr error, code ...int) {
err := E.From(origErr).Subjectf("%s %s", r.Method, r.URL)
logrus.WithField("module", "api").Error(err)
if len(code) > 0 {
http.Error(w, err.Error(), code[0])
http.Error(w, err.String(), code[0])
return
}
http.Error(w, err.Error(), http.StatusInternalServerError)
http.Error(w, err.String(), http.StatusInternalServerError)
}
func ErrMissingKey(k string) error {

View file

@ -44,7 +44,7 @@ func ReloadServer() E.NestedError {
if resp.StatusCode != http.StatusOK {
return E.Failure("server reload").Subjectf("status code: %v", resp.StatusCode)
}
return E.Nil()
return nil
}
var HttpClient = &http.Client{

View file

@ -26,33 +26,35 @@ func NewConfig(cfg *M.AutoCertConfig) *Config {
return (*Config)(cfg)
}
func (cfg *Config) GetProvider() (*Provider, E.NestedError) {
errors := E.NewBuilder("cannot create autocert provider")
func (cfg *Config) GetProvider() (provider *Provider, res E.NestedError) {
b := E.NewBuilder("unable to initialize autocert")
defer b.To(&res)
if cfg.Provider != ProviderLocal {
if len(cfg.Domains) == 0 {
errors.Addf("no domains specified")
b.Addf("no domains specified")
}
if cfg.Provider == "" {
errors.Addf("no provider specified")
b.Addf("no provider specified")
}
if cfg.Email == "" {
errors.Addf("no email specified")
b.Addf("no email specified")
}
// check if provider is implemented
_, ok := providersGenMap[cfg.Provider]
if !ok {
errors.Addf("unknown provider: %q", cfg.Provider)
b.Addf("unknown provider: %q", cfg.Provider)
}
}
if err := errors.Build(); err.HasError() {
return nil, err
if b.HasError() {
return
}
privKey, err := E.Check(ecdsa.GenerateKey(elliptic.P256(), rand.Reader))
if err.HasError() {
return nil, E.Failure("generate private key").With(err)
b.Add(E.FailWith("generate private key", err))
return
}
user := &User{
@ -63,11 +65,11 @@ func (cfg *Config) GetProvider() (*Provider, E.NestedError) {
legoCfg := lego.NewConfig(user)
legoCfg.Certificate.KeyType = certcrypto.RSA2048
base := &Provider{
provider = &Provider{
cfg: cfg,
user: user,
legoCfg: legoCfg,
}
return base, E.Nil()
return
}

View file

@ -1,6 +1,8 @@
package autocert
import (
"errors"
"github.com/go-acme/lego/v4/providers/dns/clouddns"
"github.com/go-acme/lego/v4/providers/dns/cloudflare"
"github.com/go-acme/lego/v4/providers/dns/duckdns"
@ -31,4 +33,8 @@ var providersGenMap = map[string]ProviderGenerator{
ProviderOVH: providerGenerator(ovh.NewDefaultConfig, ovh.NewDNSProviderConfig),
}
var (
ErrGetCertFailure = errors.New("get certificate failed")
)
var logger = logrus.WithField("module", "autocert")

View file

@ -33,7 +33,7 @@ type CertExpiries map[string]time.Time
func (p *Provider) GetCert(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
if p.tlsCert == nil {
return nil, E.Failure("get certificate")
return nil, ErrGetCertFailure
}
return p.tlsCert, nil
}
@ -54,52 +54,60 @@ func (p *Provider) GetExpiries() CertExpiries {
return p.certExpiries
}
func (p *Provider) ObtainCert() E.NestedError {
func (p *Provider) ObtainCert() (res E.NestedError) {
b := E.NewBuilder("failed to obtain certificate")
defer b.To(&res)
if p.cfg.Provider == ProviderLocal {
return E.FailureWhy("obtain cert", "provider is set to \"local\"")
b.Addf("provider is set to %q", ProviderLocal)
return
}
if p.client == nil {
if err := p.initClient(); err.HasError() {
return E.Failure("obtain cert").With(err)
b.Add(E.FailWith("init autocert client", err))
return
}
}
ne := E.Failure("obtain certificate")
client := p.client
if p.user.Registration == nil {
if err := p.loadRegistration(); err.HasError() {
ne = ne.With(err)
if err := p.registerACME(); err.HasError() {
return ne.With(err)
b.Add(E.FailWith("register ACME", err))
return
}
}
}
client := p.client
req := certificate.ObtainRequest{
Domains: p.cfg.Domains,
Bundle: true,
}
cert, err := E.Check(client.Certificate.Obtain(req))
if err.HasError() {
return ne.With(err)
b.Add(err)
return
}
err = p.saveCert(cert)
if err.HasError() {
return ne.With(E.Failure("save certificate").With(err))
b.Add(E.FailWith("save certificate", err))
return
}
tlsCert, err := E.Check(tls.X509KeyPair(cert.Certificate, cert.PrivateKey))
if err.HasError() {
return ne.With(E.Failure("parse obtained certificate").With(err))
b.Add(E.FailWith("parse obtained certificate", err))
return
}
expiries, err := getCertExpiries(&tlsCert)
if err.HasError() {
return ne.With(E.Failure("get certificate expiry").With(err))
b.Add(E.FailWith("get certificate expiry", err))
return
}
p.tlsCert = &tlsCert
p.certExpiries = expiries
return E.Nil()
return nil
}
func (p *Provider) LoadCert() E.NestedError {
@ -152,50 +160,50 @@ func (p *Provider) ScheduleRenewal(ctx context.Context) {
func (p *Provider) initClient() E.NestedError {
legoClient, err := E.Check(lego.NewClient(p.legoCfg))
if err.HasError() {
return E.Failure("create lego client").With(err)
return E.FailWith("create lego client", err)
}
legoProvider, err := providersGenMap[p.cfg.Provider](p.cfg.Options)
if err.HasError() {
return E.Failure("create lego provider").With(err)
return E.FailWith("create lego provider", err)
}
err = E.From(legoClient.Challenge.SetDNS01Provider(legoProvider))
if err.HasError() {
return E.Failure("set challenge provider").With(err)
return E.FailWith("set challenge provider", err)
}
p.client = legoClient
return E.Nil()
return nil
}
func (p *Provider) registerACME() E.NestedError {
if p.user.Registration != nil {
return E.Nil()
return nil
}
reg, err := E.Check(p.client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true}))
if err.HasError() {
return E.Failure("register ACME").With(err)
return err
}
p.user.Registration = reg
if err := p.saveRegistration(); err.HasError() {
logger.Warn(err)
}
return E.Nil()
return nil
}
func (p *Provider) loadRegistration() E.NestedError {
if p.user.Registration != nil {
return E.Nil()
return nil
}
reg := &registration.Resource{}
err := U.LoadJson(RegistrationFile, reg)
if err.HasError() {
return E.Failure("parse registration file").With(err)
return E.FailWith("parse registration file", err)
}
p.user.Registration = reg
return E.Nil()
return nil
}
func (p *Provider) saveRegistration() E.NestedError {
@ -205,13 +213,13 @@ func (p *Provider) saveRegistration() E.NestedError {
func (p *Provider) saveCert(cert *certificate.Resource) E.NestedError {
err := os.WriteFile(p.cfg.KeyPath, cert.PrivateKey, 0o600) // -rw-------
if err != nil {
return E.Failure("write key file").With(err)
return E.FailWith("write key file", err)
}
err = os.WriteFile(p.cfg.CertPath, cert.Certificate, 0o644) // -rw-r--r--
if err != nil {
return E.Failure("write cert file").With(err)
return E.FailWith("write cert file", err)
}
return E.Nil()
return nil
}
func (p *Provider) certState() CertState {
@ -245,13 +253,13 @@ func (p *Provider) renewIfNeeded() E.NestedError {
case CertStateMismatch:
logger.Info("cert domains mismatch with config, renewing")
default:
return E.Nil()
return nil
}
if err := p.ObtainCert(); err.HasError() {
return E.Failure("renew certificate").With(err)
return E.FailWith("renew certificate", err)
}
return E.Nil()
return nil
}
func getCertExpiries(cert *tls.Certificate) (CertExpiries, E.NestedError) {
@ -259,7 +267,7 @@ func getCertExpiries(cert *tls.Certificate) (CertExpiries, E.NestedError) {
for _, cert := range cert.Certificate {
x509Cert, err := E.Check(x509.ParseCertificate(cert))
if err.HasError() {
return nil, E.Failure("parse certificate").With(err)
return nil, E.FailWith("parse certificate", err)
}
if x509Cert.IsCA {
continue
@ -269,7 +277,7 @@ func getCertExpiries(cert *tls.Certificate) (CertExpiries, E.NestedError) {
r[x509Cert.DNSNames[i]] = x509Cert.NotAfter
}
}
return r, E.Nil()
return r, nil
}
func providerGenerator[CT any, PT challenge.Provider](
@ -286,6 +294,6 @@ func providerGenerator[CT any, PT challenge.Provider](
if err.HasError() {
return nil, err
}
return p, E.Nil()
return p, nil
}
}

View file

@ -4,7 +4,8 @@ import (
"testing"
"github.com/go-acme/lego/v4/providers/dns/ovh"
. "github.com/yusing/go-proxy/utils"
U "github.com/yusing/go-proxy/utils"
. "github.com/yusing/go-proxy/utils/testing"
"gopkg.in/yaml.v3"
)
@ -44,6 +45,6 @@ oauth2_config:
testYaml = testYaml[1:] // remove first \n
opt := make(map[string]any)
ExpectNoError(t, yaml.Unmarshal([]byte(testYaml), opt))
ExpectNoError(t, Deserialize(opt, cfg))
ExpectEqual(t, cfg, cfgExpected)
ExpectTrue(t, U.Deserialize(opt, cfg).NoError())
ExpectDeepEqual(t, cfg, cfgExpected)
}

View file

@ -15,25 +15,32 @@ const (
CommandStart = ""
CommandValidate = "validate"
CommandListConfigs = "ls-config"
CommandListRoutes = "ls-routes"
CommandReload = "reload"
)
var ValidCommands = []string{CommandStart, CommandValidate, CommandListConfigs, CommandReload}
var ValidCommands = []string{
CommandStart,
CommandValidate,
CommandListConfigs,
CommandListRoutes,
CommandReload,
}
func GetArgs() Args {
var args Args
flag.Parse()
args.Command = flag.Arg(0)
if err := validateArgs(args.Command, ValidCommands); err.HasError() {
if err := validateArg(args.Command); err.HasError() {
logrus.Fatal(err)
}
return args
}
func validateArgs[T comparable](arg T, validArgs []T) E.NestedError {
for _, v := range validArgs {
func validateArg(arg string) E.NestedError {
for _, v := range ValidCommands {
if arg == v {
return E.Nil()
return nil
}
}
return E.Invalid("argument", arg)

View file

@ -41,7 +41,6 @@ const (
ProxyHTTPPort = ":80"
ProxyHTTPSPort = ":443"
APIHTTPPort = ":8888"
PanelHTTPPort = ":8080"
)
var WellKnownHTTPPorts = map[uint16]bool{
@ -53,7 +52,7 @@ var WellKnownHTTPPorts = map[uint16]bool{
}
var (
ImageNamePortMapTCP = map[string]int{
ServiceNamePortMapTCP = map[string]int{
"postgres": 5432,
"mysql": 3306,
"mariadb": 3306,
@ -62,8 +61,7 @@ var (
"memcached": 11211,
"rabbitmq": 5672,
"mongo": 27017,
}
ExtraNamePortMapTCP = map[string]int{
"dns": 53,
"ssh": 22,
"ftp": 21,
@ -71,20 +69,9 @@ var (
"pop3": 110,
"imap": 143,
}
NamePortMapTCP = func() map[string]int {
m := make(map[string]int)
for k, v := range ImageNamePortMapTCP {
m[k] = v
}
for k, v := range ExtraNamePortMapTCP {
m[k] = v
}
return m
}()
)
// docker library uses uint16, so followed here
var ImageNamePortMapHTTP = map[string]uint16{
var ImageNamePortMapHTTP = map[string]int{
"nginx": 80,
"httpd": 80,
"adguardhome": 3000,
@ -101,3 +88,10 @@ var ImageNamePortMapHTTP = map[string]uint16{
"dockge": 5001,
"nginx-proxy-manager": 81,
}
const (
IdleTimeoutDefault = "0"
WakeTimeoutDefault = "10s"
StopTimeoutDefault = "10s"
StopMethodDefault = "stop"
)

View file

@ -2,6 +2,7 @@ package config
import (
"context"
"os"
"github.com/sirupsen/logrus"
"github.com/yusing/go-proxy/autocert"
@ -17,32 +18,26 @@ import (
)
type Config struct {
value *M.Config
l logrus.FieldLogger
reader U.Reader
proxyProviders *F.Map[string, *PR.Provider]
value *M.Config
proxyProviders F.Map[string, *PR.Provider]
autocertProvider *autocert.Provider
l logrus.FieldLogger
watcher W.Watcher
watcherCtx context.Context
watcherCancel context.CancelFunc
reloadReq chan struct{}
}
func New() (*Config, E.NestedError) {
func Load() (*Config, E.NestedError) {
cfg := &Config{
l: logrus.WithField("module", "config"),
reader: U.NewFileReader(common.ConfigPath),
watcher: W.NewFileWatcher(common.ConfigFileName),
reloadReq: make(chan struct{}, 1),
proxyProviders: F.NewMapOf[string, *PR.Provider](),
l: logrus.WithField("module", "config"),
watcher: W.NewFileWatcher(common.ConfigFileName),
reloadReq: make(chan struct{}, 1),
}
if err := cfg.load(); err.HasError() {
return nil, err
}
cfg.startProviders()
cfg.watchChanges()
return cfg, E.Nil()
return cfg, cfg.load()
}
func Validate(data []byte) E.NestedError {
@ -57,11 +52,17 @@ func (cfg *Config) GetAutoCertProvider() *autocert.Provider {
return cfg.autocertProvider
}
func (cfg *Config) StartProxyProviders() {
cfg.startProviders()
cfg.watchChanges()
}
func (cfg *Config) Dispose() {
cfg.watcherCancel()
cfg.l.Debug("stopped watcher")
if cfg.watcherCancel != nil {
cfg.watcherCancel()
cfg.l.Debug("stopped watcher")
}
cfg.stopProviders()
cfg.l.Debug("stopped providers")
}
func (cfg *Config) Reload() E.NestedError {
@ -70,46 +71,31 @@ func (cfg *Config) Reload() E.NestedError {
return err
}
cfg.startProviders()
return E.Nil()
return nil
}
func (cfg *Config) FindRoute(alias string) R.Route {
r := cfg.proxyProviders.Find(
func(p *PR.Provider) (any, bool) {
rs := p.GetCurrentRoutes()
if rs.Contains(alias) {
return rs.Get(alias), true
return F.MapFind(cfg.proxyProviders,
func(p *PR.Provider) (R.Route, bool) {
if route, ok := p.GetRoute(alias); ok {
return route, true
}
return nil, false
},
)
if r == nil {
return nil
}
return r.(R.Route)
}
func (cfg *Config) RoutesByAlias() map[string]U.SerializedObject {
routes := make(map[string]U.SerializedObject)
cfg.proxyProviders.Each(func(p *PR.Provider) {
prName := p.GetName()
p.GetCurrentRoutes().EachKV(func(a string, r R.Route) {
obj, err := U.Serialize(r)
if err.HasError() {
cfg.l.Error(err)
return
}
obj["provider"] = prName
switch r.(type) {
case *R.StreamRoute:
obj["type"] = "stream"
case *R.HTTPRoute:
obj["type"] = "reverse_proxy"
default:
panic("bug: should not reach here")
}
routes[a] = obj
})
cfg.forEachRoute(func(alias string, r R.Route, p *PR.Provider) {
obj, err := U.Serialize(r)
if err.HasError() {
cfg.l.Error(err)
return
}
obj["provider"] = p.GetName()
obj["type"] = string(r.Type())
routes[alias] = obj
})
return routes
}
@ -119,26 +105,23 @@ func (cfg *Config) Statistics() map[string]any {
nTotalRPs := 0
providerStats := make(map[string]any)
cfg.proxyProviders.Each(func(p *PR.Provider) {
stats := make(map[string]any)
nStreams := 0
nRPs := 0
p.GetCurrentRoutes().EachKV(func(a string, r R.Route) {
switch r.(type) {
case *R.StreamRoute:
nStreams++
nTotalStreams++
case *R.HTTPRoute:
nRPs++
nTotalRPs++
default:
panic("bug: should not reach here")
}
})
stats["type"] = p.GetType()
stats["num_streams"] = nStreams
stats["num_reverse_proxies"] = nRPs
providerStats[p.GetName()] = stats
cfg.forEachRoute(func(alias string, r R.Route, p *PR.Provider) {
s, ok := providerStats[p.GetName()]
if !ok {
s = make(map[string]int)
}
stats := s.(map[string]int)
switch r.Type() {
case R.RouteTypeStream:
stats["num_streams"]++
nTotalStreams++
case R.RouteTypeReverseProxy:
stats["num_reverse_proxies"]++
nTotalRPs++
default:
panic("bug: should not reach here")
}
})
return map[string]any{
@ -148,6 +131,14 @@ func (cfg *Config) Statistics() map[string]any {
}
}
func (cfg *Config) forEachRoute(do func(alias string, r R.Route, p *PR.Provider)) {
cfg.proxyProviders.RangeAll(func(_ string, p *PR.Provider) {
p.RangeRoutes(func(a string, r R.Route) {
do(a, r, p)
})
})
}
func (cfg *Config) watchChanges() {
cfg.watcherCtx, cfg.watcherCancel = context.WithCancel(context.Background())
go func() {
@ -182,64 +173,82 @@ func (cfg *Config) watchChanges() {
}()
}
func (cfg *Config) load() E.NestedError {
func (cfg *Config) load() (res E.NestedError) {
b := E.NewBuilder("errors loading config")
defer b.To(&res)
cfg.l.Debug("loading config")
defer cfg.l.Debug("loaded config")
data, err := cfg.reader.Read()
data, err := E.Check(os.ReadFile(common.ConfigPath))
if err.HasError() {
return E.Failure("read config").With(err)
}
model := M.DefaultConfig()
if err := E.From(yaml.Unmarshal(data, model)); err.HasError() {
return E.Failure("parse config").With(err)
b.Add(E.FailWith("read config", err))
return
}
if !common.NoSchemaValidation {
if err = Validate(data); err.HasError() {
return err
b.Add(E.FailWith("schema validation", err))
return
}
}
warnings := E.NewBuilder("errors loading config")
model := M.DefaultConfig()
if err := E.From(yaml.Unmarshal(data, model)); err.HasError() {
b.Add(E.FailWith("parse config", err))
return
}
cfg.l.Debug("initializing autocert")
ap, err := autocert.NewConfig(&model.AutoCert).GetProvider()
if err.HasError() {
warnings.Add(E.Failure("autocert provider").With(err))
} else {
cfg.l.Debug("initialized autocert")
}
cfg.autocertProvider = ap
cfg.l.Debug("loading providers")
cfg.proxyProviders = F.NewMap[string, *PR.Provider]()
for _, filename := range model.Providers.Files {
p := PR.NewFileProvider(filename)
cfg.proxyProviders.Set(p.GetName(), p)
}
for name, dockerHost := range model.Providers.Docker {
p := PR.NewDockerProvider(name, dockerHost)
cfg.proxyProviders.Set(p.GetName(), p)
}
cfg.l.Debug("loaded providers")
// errors are non fatal below
b.WithSeverity(E.SeverityWarning)
b.Add(cfg.initAutoCert(&model.AutoCert))
b.Add(cfg.loadProviders(&model.Providers))
cfg.value = model
return
}
if err := warnings.Build(); err.HasError() {
cfg.l.Warn(err)
func (cfg *Config) initAutoCert(autocertCfg *M.AutoCertConfig) (err E.NestedError) {
if cfg.autocertProvider != nil {
return
}
cfg.l.Debug("loaded config")
return E.Nil()
cfg.l.Debug("initializing autocert")
defer cfg.l.Debug("initialized autocert")
cfg.autocertProvider, err = autocert.NewConfig(autocertCfg).GetProvider()
if err.HasError() {
err = E.FailWith("autocert provider", err)
}
return
}
func (cfg *Config) loadProviders(providers *M.ProxyProviders) (res E.NestedError) {
cfg.l.Debug("loading providers")
defer cfg.l.Debug("loaded providers")
b := E.NewBuilder("errors loading providers")
defer b.To(&res)
for _, filename := range providers.Files {
p := PR.NewFileProvider(filename)
cfg.proxyProviders.Store(p.GetName(), p)
b.Add(p.LoadRoutes())
}
for name, dockerHost := range providers.Docker {
p := PR.NewDockerProvider(name, dockerHost)
cfg.proxyProviders.Store(p.GetName(), p)
b.Add(p.LoadRoutes())
}
return
}
func (cfg *Config) controlProviders(action string, do func(*PR.Provider) E.NestedError) {
errors := E.NewBuilder("cannot %s these providers", action)
cfg.proxyProviders.EachKVParallel(func(name string, p *PR.Provider) {
cfg.proxyProviders.RangeAll(func(name string, p *PR.Provider) {
if err := do(p); err.HasError() {
errors.Add(E.From(err).Subject(p))
errors.Add(err.Subject(p))
}
})

View file

@ -3,6 +3,7 @@ package docker
import (
"net/http"
"sync"
"sync/atomic"
"github.com/docker/cli/cli/connhelper"
"github.com/docker/docker/client"
@ -11,14 +12,37 @@ import (
E "github.com/yusing/go-proxy/error"
)
type Client = *client.Client
type Client struct {
key string
refCount *atomic.Int32
*client.Client
}
func (c Client) DaemonHostname() string {
url, _ := client.ParseHostURL(c.DaemonHost())
return url.Hostname()
}
// if the client is still referenced, this is no-op
func (c Client) Close() error {
if c.refCount.Load() > 0 {
c.refCount.Add(-1)
return nil
}
clientMapMu.Lock()
defer clientMapMu.Unlock()
delete(clientMap, c.key)
return c.Client.Close()
}
// ConnectClient creates a new Docker client connection to the specified host.
//
// Returns existing client if available.
//
// Parameters:
// - host: the host to connect to (either a URL or "FROM_ENV").
// - host: the host to connect to (either a URL or common.DockerHostFromEnv).
//
// Returns:
// - Client: the Docker client connection.
@ -29,7 +53,8 @@ func ConnectClient(host string) (Client, E.NestedError) {
// check if client exists
if client, ok := clientMap[host]; ok {
return client, E.Nil()
client.refCount.Add(1)
return client, nil
}
// create client
@ -41,7 +66,7 @@ func ConnectClient(host string) (Client, E.NestedError) {
default:
helper, err := E.Check(connhelper.GetConnectionHelper(host))
if err.HasError() {
logger.Fatalf("unexpected error: %s", err)
return Client{}, E.UnexpectedError(err.Error())
}
if helper != nil {
httpClient := &http.Client{
@ -66,11 +91,16 @@ func ConnectClient(host string) (Client, E.NestedError) {
client, err := E.Check(client.NewClientWithOpts(opt...))
if err.HasError() {
return nil, err
return Client{}, err
}
clientMap[host] = client
return client, E.Nil()
clientMap[host] = Client{
Client: client,
key: host,
refCount: &atomic.Int32{},
}
clientMap[host].refCount.Add(1)
return clientMap[host], nil
}
func CloseAllClients() {
@ -83,12 +113,13 @@ func CloseAllClients() {
logger.Debug("closed all clients")
}
var clientMap map[string]Client = make(map[string]Client)
var clientMapMu sync.Mutex
var (
clientMap map[string]Client = make(map[string]Client)
clientMapMu sync.Mutex
clientOptEnvHost = []client.Opt{
client.WithHostFromEnv(),
client.WithAPIVersionNegotiation(),
}
var clientOptEnvHost = []client.Opt{
client.WithHostFromEnv(),
client.WithAPIVersionNegotiation(),
}
var logger = logrus.WithField("module", "docker")
logger = logrus.WithField("module", "docker")
)

View file

@ -12,35 +12,41 @@ import (
)
type ClientInfo struct {
Host string
Client Client
Containers []types.Container
}
func GetClientInfo(clientHost string) (*ClientInfo, E.NestedError) {
var listOptions = container.ListOptions{
// Filters: filters.NewArgs(
// filters.Arg("health", "healthy"),
// filters.Arg("health", "none"),
// filters.Arg("health", "starting"),
// ),
All: true,
}
func GetClientInfo(clientHost string, getContainer bool) (*ClientInfo, E.NestedError) {
dockerClient, err := ConnectClient(clientHost)
if err.HasError() {
return nil, E.Failure("create docker client").With(err)
return nil, E.FailWith("connect to docker", err)
}
defer dockerClient.Close()
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
containers, err := E.Check(dockerClient.ContainerList(ctx, container.ListOptions{}))
if err.HasError() {
return nil, E.Failure("list containers").With(err)
var containers []types.Container
if getContainer {
containers, err = E.Check(dockerClient.ContainerList(ctx, listOptions))
if err.HasError() {
return nil, E.FailWith("list containers", err)
}
}
// extract host from docker client url
// since the services being proxied to
// should have the same IP as the docker client
url, err := E.Check(client.ParseHostURL(dockerClient.DaemonHost()))
if err.HasError() {
return nil, E.Invalid("host url", dockerClient.DaemonHost()).With(err)
}
if url.Scheme == "unix" {
return &ClientInfo{Host: "localhost", Containers: containers}, E.Nil()
}
return &ClientInfo{Host: url.Hostname(), Containers: containers}, E.Nil()
return &ClientInfo{
Client: dockerClient,
Containers: containers,
}, nil
}
func IsErrConnectionFailed(err error) bool {

109
src/docker/container.go Normal file
View file

@ -0,0 +1,109 @@
package docker
import (
"fmt"
"strconv"
"strings"
"github.com/docker/docker/api/types"
U "github.com/yusing/go-proxy/utils"
)
type ProxyProperties struct {
DockerHost string `yaml:"docker_host" json:"docker_host"`
ContainerName string `yaml:"container_name" json:"container_name"`
ImageName string `yaml:"image_name" json:"image_name"`
Aliases []string `yaml:"aliases" json:"aliases"`
IsExcluded bool `yaml:"is_excluded" json:"is_excluded"`
FirstPort string `yaml:"first_port" json:"first_port"`
IdleTimeout string `yaml:"idle_timeout" json:"idle_timeout"`
WakeTimeout string `yaml:"wake_timeout" json:"wake_timeout"`
StopMethod string `yaml:"stop_method" json:"stop_method"`
StopTimeout string `yaml:"stop_timeout" json:"stop_timeout"` // stop_method = "stop" only
StopSignal string `yaml:"stop_signal" json:"stop_signal"` // stop_method = "stop" | "kill" only
}
type Container struct {
*types.Container
*ProxyProperties
}
func FromDocker(c *types.Container, dockerHost string) (res Container) {
res.Container = c
res.ProxyProperties = &ProxyProperties{
DockerHost: dockerHost,
ContainerName: res.getName(),
ImageName: res.getImageName(),
Aliases: res.getAliases(),
IsExcluded: U.ParseBool(res.getDeleteLabel(LableExclude)),
FirstPort: res.firstPortOrEmpty(),
IdleTimeout: res.getDeleteLabel(LabelIdleTimeout),
WakeTimeout: res.getDeleteLabel(LabelWakeTimeout),
StopMethod: res.getDeleteLabel(LabelStopMethod),
StopTimeout: res.getDeleteLabel(LabelStopTimeout),
StopSignal: res.getDeleteLabel(LabelStopSignal),
}
return
}
func FromJson(json types.ContainerJSON, dockerHost string) Container {
ports := make([]types.Port, 0)
for k, bindings := range json.NetworkSettings.Ports {
for _, v := range bindings {
pubPort, _ := strconv.Atoi(v.HostPort)
privPort, _ := strconv.Atoi(k.Port())
ports = append(ports, types.Port{
IP: v.HostIP,
PublicPort: uint16(pubPort),
PrivatePort: uint16(privPort),
})
}
}
return FromDocker(&types.Container{
ID: json.ID,
Names: []string{json.Name},
Image: json.Image,
Ports: ports,
Labels: json.Config.Labels,
State: json.State.Status,
Status: json.State.Status,
}, dockerHost)
}
func (c Container) getDeleteLabel(label string) string {
if l, ok := c.Labels[label]; ok {
delete(c.Labels, label)
return l
}
return ""
}
func (c Container) getAliases() []string {
if l := c.getDeleteLabel(LableAliases); l != "" {
return U.CommaSeperatedList(l)
} else {
return []string{c.getName()}
}
}
func (c Container) getName() string {
return strings.TrimPrefix(c.Names[0], "/")
}
func (c Container) getImageName() string {
colonSep := strings.Split(c.Image, ":")
slashSep := strings.Split(colonSep[len(colonSep)-1], "/")
return slashSep[len(slashSep)-1]
}
func (c Container) firstPortOrEmpty() string {
if len(c.Ports) == 0 {
return ""
}
for _, p := range c.Ports {
if p.PublicPort != 0 {
return fmt.Sprint(p.PublicPort)
}
}
return ""
}

View file

@ -0,0 +1,14 @@
package idlewatcher
import "net/http"
type (
roundTripper struct {
patched roundTripFunc
}
roundTripFunc func(*http.Request) (*http.Response, error)
)
func (rt roundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
return rt.patched(req)
}

View file

@ -0,0 +1,329 @@
package idlewatcher
import (
"bytes"
"context"
"io"
"net/http"
"sync"
"sync/atomic"
"time"
"github.com/docker/docker/api/types/container"
"github.com/sirupsen/logrus"
D "github.com/yusing/go-proxy/docker"
E "github.com/yusing/go-proxy/error"
P "github.com/yusing/go-proxy/proxy"
PT "github.com/yusing/go-proxy/proxy/fields"
)
type watcher struct {
*P.ReverseProxyEntry
client D.Client
refCount atomic.Int32
stopByMethod StopCallback
wakeCh chan struct{}
wakeDone chan E.NestedError
ctx context.Context
cancel context.CancelFunc
l logrus.FieldLogger
}
type (
WakeDone <-chan error
WakeFunc func() WakeDone
StopCallback func() (bool, E.NestedError)
)
func Register(entry *P.ReverseProxyEntry) (*watcher, E.NestedError) {
failure := E.Failure("idle_watcher register")
if entry.IdleTimeout == 0 {
return nil, failure.With(E.Invalid("idle_timeout", 0))
}
watcherMapMu.Lock()
defer watcherMapMu.Unlock()
if w, ok := watcherMap[entry.ContainerName]; ok {
w.refCount.Add(1)
return w, nil
}
client, err := D.ConnectClient(entry.DockerHost)
if err.HasError() {
return nil, failure.With(err)
}
w := &watcher{
ReverseProxyEntry: entry,
client: client,
wakeCh: make(chan struct{}, 1),
wakeDone: make(chan E.NestedError, 1),
l: logger.WithField("container", entry.ContainerName),
}
w.refCount.Add(1)
w.stopByMethod = w.getStopCallback()
watcherMap[w.ContainerName] = w
go func() {
newWatcherCh <- w
}()
return w, nil
}
// If the container is not registered, this is no-op
func Unregister(containerName string) {
watcherMapMu.Lock()
defer watcherMapMu.Unlock()
if w, ok := watcherMap[containerName]; ok {
if w.refCount.Load() == 0 {
w.cancel()
close(w.wakeCh)
delete(watcherMap, containerName)
} else {
w.refCount.Add(-1)
}
}
}
func Start() {
logger.Debug("started")
defer logger.Debug("stopped")
mainLoopCtx, mainLoopCancel = context.WithCancel(context.Background())
defer mainLoopWg.Wait()
for {
select {
case <-mainLoopCtx.Done():
return
case w := <-newWatcherCh:
w.l.Debug("registered")
mainLoopWg.Add(1)
go func() {
w.watch()
Unregister(w.ContainerName)
w.l.Debug("unregistered")
mainLoopWg.Done()
}()
}
}
}
func Stop() {
mainLoopCancel()
mainLoopWg.Wait()
}
func (w *watcher) PatchRoundTripper(rtp http.RoundTripper) roundTripper {
return roundTripper{patched: func(r *http.Request) (*http.Response, error) {
return w.roundTrip(rtp.RoundTrip, r)
}}
}
func (w *watcher) roundTrip(origRoundTrip roundTripFunc, req *http.Request) (*http.Response, error) {
timeout := time.After(w.WakeTimeout)
w.wakeCh <- struct{}{}
for {
select {
case err := <-w.wakeDone:
if err != nil {
return nil, err.Error()
}
return origRoundTrip(req)
case <-timeout:
resp := loadingResponse
resp.TLS = req.TLS
return &resp, nil
}
}
}
func (w *watcher) containerStop() error {
return w.client.ContainerStop(w.ctx, w.ContainerName, container.StopOptions{
Signal: string(w.StopSignal),
Timeout: &w.StopTimeout})
}
func (w *watcher) containerPause() error {
return w.client.ContainerPause(w.ctx, w.ContainerName)
}
func (w *watcher) containerKill() error {
return w.client.ContainerKill(w.ctx, w.ContainerName, string(w.StopSignal))
}
func (w *watcher) containerUnpause() error {
return w.client.ContainerUnpause(w.ctx, w.ContainerName)
}
func (w *watcher) containerStart() error {
return w.client.ContainerStart(w.ctx, w.ContainerName, container.StartOptions{})
}
func (w *watcher) containerStatus() (string, E.NestedError) {
json, err := w.client.ContainerInspect(w.ctx, w.ContainerName)
if err != nil {
return "", E.FailWith("inspect container", err)
}
return json.State.Status, nil
}
func (w *watcher) wakeIfStopped() (bool, E.NestedError) {
failure := E.Failure("wake")
status, err := w.containerStatus()
if err.HasError() {
return false, failure.With(err)
}
// "created", "running", "paused", "restarting", "removing", "exited", or "dead"
switch status {
case "exited", "dead":
err = E.From(w.containerStart())
case "paused":
err = E.From(w.containerUnpause())
case "running":
return false, nil
default:
return false, failure.With(E.Unexpected("container state", status))
}
if err.HasError() {
return false, failure.With(err)
}
status, err = w.containerStatus()
if err.HasError() {
return false, failure.With(err)
} else if status != "running" {
return false, failure.With(E.Unexpected("container state", status))
} else {
return true, nil
}
}
func (w *watcher) getStopCallback() StopCallback {
var cb func() error
switch w.StopMethod {
case PT.StopMethodPause:
cb = w.containerPause
case PT.StopMethodStop:
cb = w.containerStop
case PT.StopMethodKill:
cb = w.containerKill
default:
panic("should not reach here")
}
return func() (bool, E.NestedError) {
status, err := w.containerStatus()
if err.HasError() {
return false, E.FailWith("stop", err)
}
if status != "running" {
return false, nil
}
err = E.From(cb())
if err.HasError() {
return false, E.FailWith("stop", err)
}
return true, nil
}
}
func (w *watcher) watch() {
watcherCtx, watcherCancel := context.WithCancel(context.Background())
w.ctx = watcherCtx
w.cancel = watcherCancel
ticker := time.NewTicker(w.IdleTimeout)
defer ticker.Stop()
for {
select {
case <-mainLoopCtx.Done():
watcherCancel()
case <-watcherCtx.Done():
w.l.Debug("stopped")
return
case <-ticker.C:
w.l.Debug("timeout")
stopped, err := w.stopByMethod()
if err.HasError() {
w.l.Error(err.Extraf("stop method: %s", w.StopMethod))
} else if stopped {
w.l.Infof("%s: ok", w.StopMethod)
} else {
ticker.Stop()
}
case <-w.wakeCh:
w.l.Debug("wake received")
go func() {
started, err := w.wakeIfStopped()
if err != nil {
w.l.Error(err)
} else if started {
w.l.Infof("awaken")
ticker.Reset(w.IdleTimeout)
}
w.wakeDone <- err // this is passed to roundtrip
}()
}
}
}
var (
mainLoopCtx context.Context
mainLoopCancel context.CancelFunc
mainLoopWg sync.WaitGroup
watcherMap = make(map[string]*watcher)
watcherMapMu sync.Mutex
newWatcherCh = make(chan *watcher)
logger = logrus.WithField("module", "idle_watcher")
loadingResponse = http.Response{
StatusCode: http.StatusAccepted,
Header: http.Header{
"Content-Type": {"text/html"},
"Cache-Control": {
"no-cache",
"no-store",
"must-revalidate",
},
},
Body: io.NopCloser(bytes.NewReader((loadingPage))),
ContentLength: int64(len(loadingPage)),
}
loadingPage = []byte(`
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Loading...</title>
</head>
<body>
<script>
window.onload = function() {
setTimeout(function() {
location.reload();
}, 1000); // 1000 milliseconds = 1 second
};
</script>
<p>Container is starting... Please wait</p>
</body>
</html>
`[1:])
)

19
src/docker/inspect.go Normal file
View file

@ -0,0 +1,19 @@
package docker
import (
"context"
"time"
E "github.com/yusing/go-proxy/error"
)
func (c Client) Inspect(containerID string) (Container, E.NestedError) {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
json, err := c.ContainerInspect(ctx, containerID)
if err != nil {
return Container{}, E.From(err)
}
return FromJson(json, c.key), nil
}

View file

@ -36,7 +36,7 @@ func ParseLabel(label string, value string) (*Label, E.NestedError) {
return &Label{
Namespace: label,
Value: value,
}, E.Nil()
}, nil
}
l := &Label{
@ -54,12 +54,12 @@ func ParseLabel(label string, value string) (*Label, E.NestedError) {
// find if namespace has value parser
pm, ok := labelValueParserMap[l.Namespace]
if !ok {
return l, E.Nil()
return l, nil
}
// find if attribute has value parser
p, ok := pm[l.Attribute]
if !ok {
return l, E.Nil()
return l, nil
}
// try to parse value
v, err := p(value)
@ -67,7 +67,7 @@ func ParseLabel(label string, value string) (*Label, E.NestedError) {
return nil, err
}
l.Value = v
return l, E.Nil()
return l, nil
}
func RegisterNamespace(namespace string, pm ValueParserMap) {

View file

@ -10,7 +10,7 @@ import (
func yamlListParser(value string) (any, E.NestedError) {
value = strings.TrimSpace(value)
if value == "" {
return []string{}, E.Nil()
return []string{}, nil
}
var data []string
err := E.From(yaml.Unmarshal([]byte(value), &data))
@ -34,23 +34,15 @@ func yamlStringMappingParser(value string) (any, E.NestedError) {
h[key] = val
}
}
return h, E.Nil()
}
func commaSepParser(value string) (any, E.NestedError) {
v := strings.Split(value, ",")
for i := range v {
v[i] = strings.TrimSpace(v[i])
}
return v, E.Nil()
return h, nil
}
func boolParser(value string) (any, E.NestedError) {
switch strings.ToLower(value) {
case "true", "yes", "1":
return true, E.Nil()
return true, nil
case "false", "no", "0":
return false, E.Nil()
return false, nil
default:
return nil, E.Invalid("boolean value", value)
}

View file

@ -7,7 +7,7 @@ import (
"testing"
E "github.com/yusing/go-proxy/error"
. "github.com/yusing/go-proxy/utils"
. "github.com/yusing/go-proxy/utils/testing"
)
func makeLabel(namespace string, alias string, field string) string {
@ -19,7 +19,7 @@ func TestHomePageLabel(t *testing.T) {
field := "ip"
v := "bar"
pl, err := ParseLabel(makeLabel(NSHomePage, alias, field), v)
ExpectNoError(t, err)
ExpectNoError(t, err.Error())
if pl.Target != alias {
t.Errorf("Expected alias=%s, got %s", alias, pl.Target)
}
@ -34,8 +34,8 @@ func TestHomePageLabel(t *testing.T) {
func TestStringProxyLabel(t *testing.T) {
v := "bar"
pl, err := ParseLabel(makeLabel(NSProxy, "foo", "ip"), v)
ExpectNoError(t, err)
ExpectEqual(t, pl.Value, v)
ExpectNoError(t, err.Error())
ExpectEqual(t, pl.Value.(string), v)
}
func TestBoolProxyLabelValid(t *testing.T) {
@ -52,8 +52,8 @@ func TestBoolProxyLabelValid(t *testing.T) {
for k, v := range tests {
pl, err := ParseLabel(makeLabel(NSProxy, "foo", "no_tls_verify"), k)
ExpectNoError(t, err)
ExpectEqual(t, pl.Value, v)
ExpectNoError(t, err.Error())
ExpectEqual(t, pl.Value.(bool), v)
}
}
@ -78,7 +78,7 @@ X-Custom-Header2: boo`
}
pl, err := ParseLabel(makeLabel(NSProxy, "foo", "set_headers"), v)
ExpectNoError(t, err)
ExpectNoError(t, err.Error())
hGot := ExpectType[map[string]string](t, pl.Value)
if hGot != nil && !reflect.DeepEqual(h, hGot) {
t.Errorf("Expected %v, got %v", h, hGot)
@ -109,33 +109,32 @@ func TestHideHeadersProxyLabel(t *testing.T) {
`
v = strings.TrimPrefix(v, "\n")
pl, err := ParseLabel(makeLabel(NSProxy, "foo", "hide_headers"), v)
ExpectNoError(t, err)
ExpectNoError(t, err.Error())
sGot := ExpectType[[]string](t, pl.Value)
sWant := []string{"X-Custom-Header1", "X-Custom-Header2", "X-Custom-Header3"}
if sGot != nil {
ExpectEqual(t, sGot, sWant)
ExpectDeepEqual(t, sGot, sWant)
}
}
func TestCommaSepProxyLabelSingle(t *testing.T) {
v := "a"
pl, err := ParseLabel("proxy.aliases", v)
ExpectNoError(t, err)
sGot := ExpectType[[]string](t, pl.Value)
sWant := []string{"a"}
if sGot != nil {
ExpectEqual(t, sGot, sWant)
}
// func TestCommaSepProxyLabelSingle(t *testing.T) {
// v := "a"
// pl, err := ParseLabel("proxy.aliases", v)
// ExpectNoError(t, err)
// sGot := ExpectType[[]string](t, pl.Value)
// sWant := []string{"a"}
// if sGot != nil {
// ExpectEqual(t, sGot, sWant)
// }
// }
}
func TestCommaSepProxyLabelMulti(t *testing.T) {
v := "X-Custom-Header1, X-Custom-Header2,X-Custom-Header3"
pl, err := ParseLabel("proxy.aliases", v)
ExpectNoError(t, err)
sGot := ExpectType[[]string](t, pl.Value)
sWant := []string{"X-Custom-Header1", "X-Custom-Header2", "X-Custom-Header3"}
if sGot != nil {
ExpectEqual(t, sGot, sWant)
}
}
// func TestCommaSepProxyLabelMulti(t *testing.T) {
// v := "X-Custom-Header1, X-Custom-Header2,X-Custom-Header3"
// pl, err := ParseLabel("proxy.aliases", v)
// ExpectNoError(t, err)
// sGot := ExpectType[[]string](t, pl.Value)
// sWant := []string{"X-Custom-Header1", "X-Custom-Header2", "X-Custom-Header3"}
// if sGot != nil {
// ExpectEqual(t, sGot, sWant)
// }
// }

13
src/docker/labels.go Normal file
View file

@ -0,0 +1,13 @@
package docker
const (
WildcardAlias = "*"
LableAliases = NSProxy + ".aliases"
LableExclude = NSProxy + ".exclude"
LabelIdleTimeout = NSProxy + ".idle_timeout"
LabelWakeTimeout = NSProxy + ".wake_timeout"
LabelStopMethod = NSProxy + ".stop_method"
LabelStopTimeout = NSProxy + ".stop_timeout"
LabelStopSignal = NSProxy + ".stop_signal"
)

View file

@ -6,16 +6,23 @@ import (
)
type Builder struct {
message string
errors []error
*builder
}
type builder struct {
message string
errors []NestedError
severity Severity
sync.Mutex
}
func NewBuilder(format string, args ...any) *Builder {
return &Builder{message: fmt.Sprintf(format, args...)}
func NewBuilder(format string, args ...any) Builder {
return Builder{&builder{message: fmt.Sprintf(format, args...)}}
}
func (b *Builder) Add(err error) *Builder {
// adding nil / nil is no-op,
// you may safely pass expressions returning error to it
func (b Builder) Add(err NestedError) Builder {
if err != nil {
b.Lock()
b.errors = append(b.errors, err)
@ -24,8 +31,17 @@ func (b *Builder) Add(err error) *Builder {
return b
}
func (b *Builder) Addf(format string, args ...any) *Builder {
return b.Add(fmt.Errorf(format, args...))
func (b Builder) AddE(err error) Builder {
return b.Add(From(err))
}
func (b Builder) Addf(format string, args ...any) Builder {
return b.Add(errorf(format, args...))
}
func (b Builder) WithSeverity(s Severity) Builder {
b.severity = s
return b
}
// Build builds a NestedError based on the errors collected in the Builder.
@ -35,9 +51,21 @@ func (b *Builder) Addf(format string, args ...any) *Builder {
//
// Returns:
// - NestedError: the built NestedError.
func (b *Builder) Build() NestedError {
func (b Builder) Build() NestedError {
if len(b.errors) == 0 {
return Nil()
return nil
}
return Join(b.message, b.errors...)
return Join(b.message, b.errors...).Severity(b.severity)
}
func (b Builder) To(ptr *NestedError) {
if *ptr == nil {
*ptr = b.Build()
} else {
**ptr = *b.Build()
}
}
func (b Builder) HasError() bool {
return len(b.errors) > 0
}

View file

@ -1,13 +1,38 @@
package error
import "testing"
import (
"testing"
func TestBuilder(t *testing.T) {
. "github.com/yusing/go-proxy/utils/testing"
)
func TestBuilderEmpty(t *testing.T) {
eb := NewBuilder("qwer")
ExpectTrue(t, eb.Build() == nil)
ExpectTrue(t, eb.Build().NoError())
ExpectFalse(t, eb.HasError())
}
func TestBuilderAddNil(t *testing.T) {
eb := NewBuilder("asdf")
var err NestedError
for range 3 {
eb.Add(nil)
}
for range 3 {
eb.Add(err)
}
ExpectTrue(t, eb.Build() == nil)
ExpectTrue(t, eb.Build().NoError())
ExpectFalse(t, eb.HasError())
}
func TestBuilderNested(t *testing.T) {
eb := NewBuilder("error occurred")
eb.Add(Failure("Action 1").With(Invalid("Inner", "1")).With(Invalid("Inner", "2")))
eb.Add(Failure("Action 2").With(Invalid("Inner", "3")))
got := eb.Build().Error()
got := eb.Build().String()
expected1 :=
(`error occurred:
- Action 1 failed:

View file

@ -7,35 +7,37 @@ import (
)
type (
// NestedError is an error with an inner error
// and a list of extra nested errors.
//
// It is designed to be non nil.
//
// You can use it to join multiple errors,
// or to set a inner reason for a nested error.
//
// When a method returns both valid values and errors,
// You should return (Slice/Map, NestedError).
// Caller then should handle the nested error,
// and continue with the valid values.
NestedError struct {
subject string
err error // can be nil
extras []NestedError
NestedError = *nestedError
nestedError struct {
subject string
err error // can be nil
extras []nestedError
severity Severity
}
errorInterface struct {
*nestedError
}
Severity uint8
)
func Nil() NestedError { return NestedError{} }
const (
SeverityFatal Severity = iota
SeverityWarning
)
func (e errorInterface) Error() string {
return e.String()
}
func From(err error) NestedError {
if IsNil(err) {
return nil
}
switch err := err.(type) {
case nil:
return Nil()
case NestedError:
return err
case errorInterface:
return err.nestedError
default:
return NestedError{err: err}
return &nestedError{err: err}
}
}
@ -45,40 +47,84 @@ func Check[T any](obj T, err error) (T, NestedError) {
return obj, From(err)
}
func Join(message string, err ...error) NestedError {
extras := make([]NestedError, 0, len(err))
func Join(message string, err ...NestedError) NestedError {
extras := make([]nestedError, len(err))
nErr := 0
for _, e := range err {
if err == nil {
for i, e := range err {
if e == nil {
continue
}
extras = append(extras, From(e))
extras[i] = *e
nErr += 1
}
if nErr == 0 {
return Nil()
return nil
}
return NestedError{
return &nestedError{
err: errors.New(message),
extras: extras,
}
}
func (ne NestedError) Error() string {
func JoinE(message string, err ...error) NestedError {
b := NewBuilder(message)
for _, e := range err {
b.AddE(e)
}
return b.Build()
}
func IsNil(err error) bool {
return err == nil
}
func IsNotNil(err error) bool {
return err != nil
}
func (ne NestedError) String() string {
var buf strings.Builder
ne.writeToSB(&buf, 0, "")
return buf.String()
}
func (ne NestedError) Is(err error) bool {
return errors.Is(ne.err, err)
if ne == nil {
return err == nil
}
// return errors.Is(ne.err, err)
if errors.Is(ne.err, err) {
return true
}
for _, e := range ne.extras {
if e.Is(err) {
return true
}
}
return false
}
func (ne NestedError) IsNot(err error) bool {
return !ne.Is(err)
}
func (ne NestedError) Error() error {
if ne == nil {
return nil
}
return errorInterface{ne}
}
func (ne NestedError) With(s any) NestedError {
if ne == nil {
return ne
}
var msg string
switch ss := s.(type) {
case nil:
return ne
case *nestedError:
return ne.withError(ss.Error())
case error:
return ne.withError(ss)
case string:
@ -92,10 +138,13 @@ func (ne NestedError) With(s any) NestedError {
}
func (ne NestedError) Extraf(format string, args ...any) NestedError {
return ne.With(fmt.Errorf(format, args...))
return ne.With(errorf(format, args...))
}
func (ne NestedError) Subject(s any) NestedError {
if ne == nil {
return ne
}
switch ss := s.(type) {
case string:
ne.subject = ss
@ -108,6 +157,9 @@ func (ne NestedError) Subject(s any) NestedError {
}
func (ne NestedError) Subjectf(format string, args ...any) NestedError {
if ne == nil {
return ne
}
if strings.Contains(format, "%q") {
panic("Subjectf format should not contain %q")
}
@ -118,12 +170,36 @@ func (ne NestedError) Subjectf(format string, args ...any) NestedError {
return ne
}
func (ne NestedError) Severity(s Severity) NestedError {
if ne == nil {
return ne
}
ne.severity = s
return ne
}
func (ne NestedError) Warn() NestedError {
if ne == nil {
return ne
}
ne.severity = SeverityWarning
return ne
}
func (ne NestedError) NoError() bool {
return ne.err == nil
return ne == nil
}
func (ne NestedError) HasError() bool {
return ne.err != nil
return ne != nil
}
func (ne NestedError) IsFatal() bool {
return ne != nil && ne.severity == SeverityFatal
}
func (ne NestedError) IsWarning() bool {
return ne != nil && ne.severity == SeverityWarning
}
func errorf(format string, args ...any) NestedError {
@ -131,11 +207,13 @@ func errorf(format string, args ...any) NestedError {
}
func (ne NestedError) withError(err error) NestedError {
ne.extras = append(ne.extras, From(err))
if ne != nil && IsNotNil(err) {
ne.extras = append(ne.extras, *From(err))
}
return ne
}
func (ne *NestedError) writeToSB(sb *strings.Builder, level int, prefix string) {
func (ne NestedError) writeToSB(sb *strings.Builder, level int, prefix string) {
ne.writeIndents(sb, level)
sb.WriteString(prefix)
@ -146,7 +224,7 @@ func (ne *NestedError) writeToSB(sb *strings.Builder, level int, prefix string)
sb.WriteString(ne.err.Error())
if ne.subject != "" {
if ne.err != nil {
if IsNotNil(ne.err) {
sb.WriteString(fmt.Sprintf(" for %q", ne.subject))
} else {
sb.WriteString(fmt.Sprint(ne.subject))
@ -161,7 +239,7 @@ func (ne *NestedError) writeToSB(sb *strings.Builder, level int, prefix string)
}
}
func (ne *NestedError) writeIndents(sb *strings.Builder, level int) {
func (ne NestedError) writeIndents(sb *strings.Builder, level int) {
for i := 0; i < level; i++ {
sb.WriteString(" ")
}

View file

@ -4,7 +4,7 @@ import (
"testing"
. "github.com/yusing/go-proxy/error"
. "github.com/yusing/go-proxy/utils"
. "github.com/yusing/go-proxy/utils/testing"
)
func TestErrorIs(t *testing.T) {
@ -16,27 +16,53 @@ func TestErrorIs(t *testing.T) {
ExpectTrue(t, Invalid("foo", "bar").Is(ErrInvalid))
ExpectFalse(t, Invalid("foo", "bar").Is(ErrFailure))
ExpectTrue(t, Nil().Is(nil))
ExpectFalse(t, Nil().Is(ErrInvalid))
ExpectFalse(t, Invalid("foo", "bar").Is(nil))
}
func TestNil(t *testing.T) {
ExpectTrue(t, Nil().NoError())
ExpectFalse(t, Nil().HasError())
ExpectEqual(t, Nil().Error(), "nil")
func TestErrorNestedIs(t *testing.T) {
var err NestedError
ExpectTrue(t, err.Is(nil))
err = Failure("some reason")
ExpectTrue(t, err.Is(ErrFailure))
ExpectFalse(t, err.Is(ErrAlreadyExist))
err.With(AlreadyExist("something", ""))
ExpectTrue(t, err.Is(ErrFailure))
ExpectTrue(t, err.Is(ErrAlreadyExist))
ExpectFalse(t, err.Is(ErrInvalid))
}
func TestIsNil(t *testing.T) {
var err NestedError
ExpectTrue(t, err.Is(nil))
ExpectFalse(t, err.HasError())
ExpectTrue(t, err == nil)
ExpectTrue(t, err.NoError())
eb := NewBuilder("")
returnNil := func() error {
return eb.Build().Error()
}
ExpectTrue(t, IsNil(returnNil()))
ExpectTrue(t, returnNil() == nil)
ExpectTrue(t, (err.
Subject("any").
With("something").
Extraf("foo %s", "bar")) == nil)
}
func TestErrorSimple(t *testing.T) {
ne := Failure("foo bar")
ExpectEqual(t, ne.Error(), "foo bar failed")
ExpectEqual(t, ne.String(), "foo bar failed")
ne = ne.Subject("baz")
ExpectEqual(t, ne.Error(), "foo bar failed for \"baz\"")
ExpectEqual(t, ne.String(), "foo bar failed for \"baz\"")
}
func TestErrorWith(t *testing.T) {
ne := Failure("foo").With("bar").With("baz")
ExpectEqual(t, ne.Error(), "foo failed:\n - bar\n - baz")
ExpectEqual(t, ne.String(), "foo failed:\n - bar\n - baz")
}
func TestErrorNested(t *testing.T) {
@ -72,5 +98,5 @@ func TestErrorNested(t *testing.T) {
- inner3 failed for "action 3":
- 3
- 3`
ExpectEqual(t, ne.Error(), want)
ExpectEqual(t, ne.String(), want)
}

View file

@ -5,33 +5,48 @@ import (
)
var (
ErrFailure = stderrors.New("failed")
ErrInvalid = stderrors.New("invalid")
ErrUnsupported = stderrors.New("unsupported")
ErrNotExists = stderrors.New("does not exist")
ErrDuplicated = stderrors.New("duplicated")
ErrFailure = stderrors.New("failed")
ErrInvalid = stderrors.New("invalid")
ErrUnsupported = stderrors.New("unsupported")
ErrUnexpected = stderrors.New("unexpected")
ErrNotExists = stderrors.New("does not exist")
ErrAlreadyExist = stderrors.New("already exist")
)
const fmtSubjectWhat = "%w %v: %v"
func Failure(what string) NestedError {
return errorf("%s %w", what, ErrFailure)
}
func FailureWhy(what string, why string) NestedError {
func FailedWhy(what string, why string) NestedError {
return errorf("%s %w because %s", what, ErrFailure, why)
}
func FailWith(what string, err any) NestedError {
return Failure(what).With(err)
}
func Invalid(subject, what any) NestedError {
return errorf("%w %v - %v", ErrInvalid, subject, what)
return errorf(fmtSubjectWhat, ErrInvalid, subject, what)
}
func Unsupported(subject, what any) NestedError {
return errorf("%w %v - %v", ErrUnsupported, subject, what)
return errorf(fmtSubjectWhat, ErrUnsupported, subject, what)
}
func NotExists(subject, what any) NestedError {
return errorf("%s %v - %v", subject, ErrNotExists, what)
func Unexpected(subject, what any) NestedError {
return errorf(fmtSubjectWhat, ErrUnexpected, subject, what)
}
func Duplicated(subject, what any) NestedError {
return errorf("%w %v: %v", ErrDuplicated, subject, what)
func UnexpectedError(err error) NestedError {
return errorf("%w error: %w", ErrUnexpected, err)
}
func NotExist(subject, what any) NestedError {
return errorf("%v %w: %v", subject, ErrNotExists, what)
}
func AlreadyExist(subject, what any) NestedError {
return errorf("%v %w: %v", subject, ErrAlreadyExist, what)
}

View file

@ -7,6 +7,7 @@ require (
github.com/docker/docker v27.2.1+incompatible
github.com/fsnotify/fsnotify v1.7.0
github.com/go-acme/lego/v4 v4.18.0
github.com/puzpuzpuz/xsync/v3 v3.4.0
github.com/santhosh-tekuri/jsonschema v1.2.4
github.com/sirupsen/logrus v1.9.3
golang.org/x/net v0.29.0

View file

@ -73,6 +73,8 @@ github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/puzpuzpuz/xsync/v3 v3.4.0 h1:DuVBAdXuGFHv8adVXjWWZ63pJq+NRXOWVXlKDBZ+mJ4=
github.com/puzpuzpuz/xsync/v3 v3.4.0/go.mod h1:VjzYrABPabuM4KyBh1Ftq6u8nhwY5tBPKP9jpmh0nnA=
github.com/rogpeppe/go-internal v1.8.1 h1:geMPLpDpQOgVyCg5z5GoRwLHepNdb71NXb67XFkP+Eg=
github.com/rogpeppe/go-internal v1.8.1/go.mod h1:JeRgkft04UBgHMgCIwADu4Pn6Mtm5d4nPKWu0nJ5d+o=
github.com/santhosh-tekuri/jsonschema v1.2.4 h1:hNhW8e7t+H1vgY+1QeEQpveR6D4+OwKPXCfD2aieJis=

View file

@ -18,6 +18,7 @@ import (
"github.com/yusing/go-proxy/common"
"github.com/yusing/go-proxy/config"
"github.com/yusing/go-proxy/docker"
"github.com/yusing/go-proxy/docker/idlewatcher"
E "github.com/yusing/go-proxy/error"
R "github.com/yusing/go-proxy/route"
"github.com/yusing/go-proxy/server"
@ -53,37 +54,40 @@ func main() {
// exit if only validate config
if args.Command == common.CommandValidate {
var err E.NestedError
data, err := E.Check(os.ReadFile(common.ConfigPath))
if err.HasError() {
l.WithError(err).Fatalf("config error")
data, err := os.ReadFile(common.ConfigPath)
if err == nil {
err = config.Validate(data).Error()
}
if err = config.Validate(data); err.HasError() {
l.WithError(err).Fatalf("config error")
if err != nil {
l.Fatal("config error: ", err)
}
l.Printf("config OK")
return
}
cfg, err := config.New()
if err.HasError() {
l.Fatalf("config error: %s", err)
cfg, err := config.Load()
if err.IsFatal() {
l.Fatal(err)
}
if args.Command == common.CommandListConfigs {
yml, err := E.Check(json.Marshal(cfg.Value()))
if err.HasError() {
panic(err)
}
rawLogger := log.New(os.Stdout, "", 0)
rawLogger.Printf("%s", yml) // raw output for convenience using "jq"
printJSON(cfg.Value())
return
}
onShutdown.Add(func() {
docker.CloseAllClients()
cfg.Dispose()
})
cfg.StartProxyProviders()
if args.Command == common.CommandListRoutes {
printJSON(cfg.RoutesByAlias())
return
}
if err.HasError() {
l.Warn(err)
}
onShutdown.Add(docker.CloseAllClients)
onShutdown.Add(cfg.Dispose)
sig := make(chan os.Signal, 1)
signal.Notify(sig, syscall.SIGINT)
@ -109,8 +113,9 @@ func main() {
onShutdown.Add(certRenewalCancel)
}
for name, expiry := range autocert.GetExpiries() {
l.Infof("certificate %q: expire on %s", name, expiry)
for _, expiry := range autocert.GetExpiries() {
l.Infof("certificate expire on %s", expiry)
break
}
} else {
l.Info("autocert not configured")
@ -137,6 +142,9 @@ func main() {
onShutdown.Add(proxyServer.Stop)
onShutdown.Add(apiServer.Stop)
go idlewatcher.Start()
onShutdown.Add(idlewatcher.Stop)
// wait for signal
<-sig
@ -164,3 +172,12 @@ func main() {
logrus.Info("timeout waiting for shutdown")
}
}
func printJSON(obj any) {
j, err := E.Check(json.Marshal(obj))
if err.HasError() {
logrus.Fatal(err)
}
rawLogger := log.New(os.Stdout, "", 0)
rawLogger.Printf("%s", j) // raw output for convenience using "jq"
}

View file

@ -1,13 +1,16 @@
package model
import (
"strconv"
"strings"
. "github.com/yusing/go-proxy/common"
D "github.com/yusing/go-proxy/docker"
F "github.com/yusing/go-proxy/utils/functional"
)
type (
ProxyEntry struct {
ProxyEntry struct { // raw entry object before validation
Alias string `yaml:"-" json:"-"`
Scheme string `yaml:"scheme" json:"scheme"`
Host string `yaml:"host" json:"host"`
@ -16,35 +19,66 @@ type (
PathPatterns []string `yaml:"path_patterns" json:"path_patterns"` // http(s) proxy only
SetHeaders map[string]string `yaml:"set_headers" json:"set_headers"` // http(s) proxy only
HideHeaders []string `yaml:"hide_headers" json:"hide_headers"` // http(s) proxy only
/* Docker only */
*D.ProxyProperties `yaml:"-" json:"-"`
}
ProxyEntries = *F.Map[string, *ProxyEntry]
ProxyEntries = F.Map[string, *ProxyEntry]
)
var NewProxyEntries = F.NewMap[string, *ProxyEntry]
var NewProxyEntries = F.NewMapOf[string, *ProxyEntry]
func (e *ProxyEntry) SetDefaults() {
if e.Scheme == "" {
if strings.ContainsRune(e.Port, ':') {
switch {
case strings.ContainsRune(e.Port, ':'):
e.Scheme = "tcp"
} else {
switch e.Port {
case "443", "8443":
e.Scheme = "https"
default:
e.Scheme = "http"
case e.ProxyProperties != nil:
if _, ok := ServiceNamePortMapTCP[e.ImageName]; ok {
e.Scheme = "tcp"
}
}
}
if e.Scheme == "" {
switch e.Port {
case "443", "8443":
e.Scheme = "https"
default:
e.Scheme = "http"
}
}
if e.Host == "" {
e.Host = "localhost"
}
if e.Port == "" {
switch e.Scheme {
case "http":
e.Port = "80"
case "https":
e.Port = "443"
e.Port = e.FirstPort
}
if e.Port == "" {
if port, ok := ServiceNamePortMapTCP[e.Port]; ok {
e.Port = strconv.Itoa(port)
} else if port, ok := ImageNamePortMapHTTP[e.Port]; ok {
e.Port = strconv.Itoa(port)
} else {
switch e.Scheme {
case "http":
e.Port = "80"
case "https":
e.Port = "443"
}
}
}
if e.IdleTimeout == "" {
e.IdleTimeout = IdleTimeoutDefault
}
if e.WakeTimeout == "" {
e.WakeTimeout = WakeTimeoutDefault
}
if e.StopTimeout == "" {
e.StopTimeout = StopTimeoutDefault
}
if e.StopMethod == "" {
e.StopMethod = StopMethodDefault
}
}

View file

@ -4,6 +4,7 @@ import (
"fmt"
"net/http"
"net/url"
"time"
E "github.com/yusing/go-proxy/error"
M "github.com/yusing/go-proxy/models"
@ -11,16 +12,23 @@ import (
)
type (
Entry struct { // real model after validation
ReverseProxyEntry struct { // real model after validation
Alias T.Alias
Scheme T.Scheme
Host T.Host
Port T.Port
URL *url.URL
NoTLSVerify bool
PathPatterns T.PathPatterns
SetHeaders http.Header
HideHeaders []string
/* Docker only */
IdleTimeout time.Duration
WakeTimeout time.Duration
StopMethod T.StopMethod
StopTimeout int
StopSignal T.Signal
DockerHost string
ContainerName string
}
StreamEntry struct {
Alias T.Alias `json:"alias"`
@ -30,69 +38,105 @@ type (
}
)
func NewEntry(m *M.ProxyEntry) (any, E.NestedError) {
func (rp *ReverseProxyEntry) UseIdleWatcher() bool {
return rp.IdleTimeout > 0 && rp.DockerHost != ""
}
func ValidateEntry(m *M.ProxyEntry) (any, E.NestedError) {
m.SetDefaults()
scheme, err := T.NewScheme(m.Scheme)
if err.HasError() {
return nil, err
}
var entry any
e := E.NewBuilder("error validating proxy entry")
if scheme.IsStream() {
return validateStreamEntry(m)
entry = validateStreamEntry(m, e)
} else {
entry = validateRPEntry(m, scheme, e)
}
return validateEntry(m, scheme)
if err := e.Build(); err.HasError() {
return nil, err
}
return entry, nil
}
func validateEntry(m *M.ProxyEntry, s T.Scheme) (*Entry, E.NestedError) {
host, err := T.NewHost(m.Host)
if err.HasError() {
return nil, err
}
port, err := T.NewPort(m.Port)
if err.HasError() {
return nil, err
}
pathPatterns, err := T.NewPathPatterns(m.PathPatterns)
if err.HasError() {
return nil, err
}
setHeaders, err := T.NewHTTPHeaders(m.SetHeaders)
if err.HasError() {
return nil, err
}
func validateRPEntry(m *M.ProxyEntry, s T.Scheme, b E.Builder) *ReverseProxyEntry {
var stopTimeOut time.Duration
host, err := T.ValidateHost(m.Host)
b.Add(err)
port, err := T.ValidatePort(m.Port)
b.Add(err)
pathPatterns, err := T.ValidatePathPatterns(m.PathPatterns)
b.Add(err)
setHeaders, err := T.ValidateHTTPHeaders(m.SetHeaders)
b.Add(err)
url, err := E.Check(url.Parse(fmt.Sprintf("%s://%s:%d", s, host, port)))
if err.HasError() {
return nil, err
b.Add(err)
idleTimeout, err := T.ValidateDurationPostitive(m.IdleTimeout)
b.Add(err)
wakeTimeout, err := T.ValidateDurationPostitive(m.WakeTimeout)
b.Add(err)
stopMethod, err := T.ValidateStopMethod(m.StopMethod)
b.Add(err)
if stopMethod == T.StopMethodStop {
stopTimeOut, err = T.ValidateDurationPostitive(m.StopTimeout)
b.Add(err)
}
stopSignal, err := T.ValidateSignal(m.StopSignal)
b.Add(err)
if err.HasError() {
return nil
}
return &ReverseProxyEntry{
Alias: T.NewAlias(m.Alias),
Scheme: s,
URL: url,
NoTLSVerify: m.NoTLSVerify,
PathPatterns: pathPatterns,
SetHeaders: setHeaders,
HideHeaders: m.HideHeaders,
IdleTimeout: idleTimeout,
WakeTimeout: wakeTimeout,
StopMethod: stopMethod,
StopTimeout: int(stopTimeOut.Seconds()), // docker api takes integer seconds for timeout argument
StopSignal: stopSignal,
DockerHost: m.DockerHost,
ContainerName: m.ContainerName,
}
return &Entry{
Alias: T.NewAlias(m.Alias),
Scheme: s,
Host: host,
Port: port,
URL: url,
NoTLSVerify: m.NoTLSVerify,
PathPatterns: pathPatterns,
SetHeaders: setHeaders,
HideHeaders: m.HideHeaders,
}, E.Nil()
}
func validateStreamEntry(m *M.ProxyEntry) (*StreamEntry, E.NestedError) {
host, err := T.NewHost(m.Host)
if err.HasError() {
return nil, err
}
port, err := T.NewStreamPort(m.Port)
if err.HasError() {
return nil, err
}
scheme, err := T.NewStreamScheme(m.Scheme)
if err.HasError() {
return nil, err
func validateStreamEntry(m *M.ProxyEntry, b E.Builder) *StreamEntry {
host, err := T.ValidateHost(m.Host)
b.Add(err)
port, err := T.ValidateStreamPort(m.Port)
b.Add(err)
scheme, err := T.ValidateStreamScheme(m.Scheme)
b.Add(err)
if b.HasError() {
return nil
}
return &StreamEntry{
Alias: T.NewAlias(m.Alias),
Scheme: *scheme,
Host: host,
Port: port,
}, E.Nil()
}
}

View file

@ -1,23 +1,6 @@
package fields
import (
"strings"
F "github.com/yusing/go-proxy/utils/functional"
type (
Alias string
NewAlias = Alias
)
type Alias string
type Aliases struct{ *F.Slice[Alias] }
func NewAlias(s string) Alias {
return Alias(s)
}
func NewAliases(s string) Aliases {
split := strings.Split(s, ",")
a := Aliases{F.NewSliceN[Alias](len(split))}
for i, v := range split {
a.Set(i, NewAlias(v))
}
return a
}

View file

@ -7,7 +7,7 @@ import (
E "github.com/yusing/go-proxy/error"
)
func NewHTTPHeaders(headers map[string]string) (http.Header, E.NestedError) {
func ValidateHTTPHeaders(headers map[string]string) (http.Header, E.NestedError) {
h := make(http.Header)
for k, v := range headers {
vSplit := strings.Split(v, ",")
@ -15,5 +15,5 @@ func NewHTTPHeaders(headers map[string]string) (http.Header, E.NestedError) {
h.Add(k, strings.TrimSpace(header))
}
}
return h, E.Nil()
return h, nil
}

View file

@ -7,6 +7,6 @@ import (
type Host string
type Subdomain = Alias
func NewHost(s string) (Host, E.NestedError) {
return Host(s), E.Nil()
func ValidateHost(s string) (Host, E.NestedError) {
return Host(s), nil
}

View file

@ -9,7 +9,7 @@ type PathMode string
func NewPathMode(pm string) (PathMode, E.NestedError) {
switch pm {
case "", "forward":
return PathMode(pm), E.Nil()
return PathMode(pm), nil
default:
return "", E.Invalid("path mode", pm)
}

View file

@ -16,12 +16,12 @@ func NewPathPattern(s string) (PathPattern, E.NestedError) {
if !pathPattern.MatchString(string(s)) {
return "", E.Invalid("path pattern", s)
}
return PathPattern(s), E.Nil()
return PathPattern(s), nil
}
func NewPathPatterns(s []string) (PathPatterns, E.NestedError) {
func ValidatePathPatterns(s []string) (PathPatterns, E.NestedError) {
if len(s) == 0 {
return []PathPattern{"/"}, E.Nil()
return []PathPattern{"/"}, nil
}
pp := make(PathPatterns, len(s))
for i, v := range s {
@ -31,7 +31,7 @@ func NewPathPatterns(s []string) (PathPatterns, E.NestedError) {
pp[i] = pattern
}
}
return pp, E.Nil()
return pp, nil
}
var pathPattern = regexp.MustCompile("^((GET|POST|DELETE|PUT|PATCH|HEAD|OPTIONS|CONNECT)\\s)?(/\\w*)+/?$")

View file

@ -8,7 +8,7 @@ import (
type Port int
func NewPort(v string) (Port, E.NestedError) {
func ValidatePort(v string) (Port, E.NestedError) {
p, err := strconv.Atoi(v)
if err != nil {
return ErrPort, E.Invalid("port number", v).With(err)
@ -21,14 +21,14 @@ func NewPortInt[Int int | uint16](v Int) (Port, E.NestedError) {
if err := pp.boundCheck(); err.HasError() {
return ErrPort, err
}
return pp, E.Nil()
return pp, nil
}
func (p Port) boundCheck() E.NestedError {
if p < MinPort || p > MaxPort {
return E.Invalid("port", p)
}
return E.Nil()
return nil
}
const (

View file

@ -1,8 +1,6 @@
package fields
import (
"strings"
E "github.com/yusing/go-proxy/error"
)
@ -11,24 +9,11 @@ type Scheme string
func NewScheme(s string) (Scheme, E.NestedError) {
switch s {
case "http", "https", "tcp", "udp":
return Scheme(s), E.Nil()
return Scheme(s), nil
}
return "", E.Invalid("scheme", s)
}
func NewSchemeFromPort(p string) (Scheme, E.NestedError) {
var s string
switch {
case strings.ContainsRune(p, ':'):
s = "tcp"
case strings.HasSuffix(p, "443"):
s = "https"
default:
s = "http"
}
return Scheme(s), E.Nil()
}
func (s Scheme) IsHTTP() bool { return s == "http" }
func (s Scheme) IsHTTPS() bool { return s == "https" }
func (s Scheme) IsTCP() bool { return s == "tcp" }

View file

@ -0,0 +1,17 @@
package fields
import (
E "github.com/yusing/go-proxy/error"
)
type Signal string
func ValidateSignal(s string) (Signal, E.NestedError) {
switch s {
case "", "SIGINT", "SIGTERM", "SIGHUP", "SIGQUIT",
"INT", "TERM", "HUP", "QUIT":
return Signal(s), nil
}
return "", E.Invalid("signal", s)
}

View file

@ -0,0 +1,23 @@
package fields
import (
E "github.com/yusing/go-proxy/error"
)
type StopMethod string
const (
StopMethodPause StopMethod = "pause"
StopMethodStop StopMethod = "stop"
StopMethodKill StopMethod = "kill"
)
func ValidateStopMethod(s string) (StopMethod, E.NestedError) {
sm := StopMethod(s)
switch sm {
case StopMethodPause, StopMethodStop, StopMethodKill:
return sm, nil
default:
return "", E.Invalid("stop_method", sm)
}
}

View file

@ -12,13 +12,13 @@ type StreamPort struct {
ProxyPort Port `json:"proxy"`
}
func NewStreamPort(p string) (StreamPort, E.NestedError) {
func ValidateStreamPort(p string) (StreamPort, E.NestedError) {
split := strings.Split(p, ":")
if len(split) != 2 {
return StreamPort{}, E.Invalid("stream port", p).With("should be in 'x:y' format")
}
listeningPort, err := NewPort(split[0])
listeningPort, err := ValidatePort(split[0])
if err.HasError() {
return StreamPort{}, err
}
@ -26,7 +26,7 @@ func NewStreamPort(p string) (StreamPort, E.NestedError) {
return StreamPort{}, err
}
proxyPort, err := NewPort(split[1])
proxyPort, err := ValidatePort(split[1])
if err.HasError() {
proxyPort, err = parseNameToPort(split[1])
if err.HasError() {
@ -37,13 +37,13 @@ func NewStreamPort(p string) (StreamPort, E.NestedError) {
return StreamPort{}, err
}
return StreamPort{ListeningPort: listeningPort, ProxyPort: proxyPort}, E.Nil()
return StreamPort{ListeningPort: listeningPort, ProxyPort: proxyPort}, nil
}
func parseNameToPort(name string) (Port, E.NestedError) {
port, ok := common.NamePortMapTCP[name]
port, ok := common.ServiceNamePortMapTCP[name]
if !ok {
return -1, E.Unsupported("service", name)
}
return Port(port), E.Nil()
return Port(port), nil
}

View file

@ -12,7 +12,7 @@ type StreamScheme struct {
ProxyScheme Scheme `json:"proxy"`
}
func NewStreamScheme(s string) (ss *StreamScheme, err E.NestedError) {
func ValidateStreamScheme(s string) (ss *StreamScheme, err E.NestedError) {
ss = &StreamScheme{}
parts := strings.Split(s, ":")
if len(parts) == 1 {
@ -28,7 +28,7 @@ func NewStreamScheme(s string) (ss *StreamScheme, err E.NestedError) {
if err.HasError() {
return nil, err
}
return ss, E.Nil()
return ss, nil
}
func (s StreamScheme) String() string {

View file

@ -0,0 +1,18 @@
package fields
import (
"time"
E "github.com/yusing/go-proxy/error"
)
func ValidateDurationPostitive(value string) (time.Duration, E.NestedError) {
d, err := time.ParseDuration(value)
if err != nil {
return 0, E.Invalid("duration", value)
}
if d < 0 {
return 0, E.Invalid("duration", "negative value")
}
return d, nil
}

View file

@ -1,168 +1,160 @@
package provider
import (
"fmt"
"strings"
"github.com/docker/docker/api/types"
"github.com/sirupsen/logrus"
D "github.com/yusing/go-proxy/docker"
E "github.com/yusing/go-proxy/error"
M "github.com/yusing/go-proxy/models"
PT "github.com/yusing/go-proxy/proxy/fields"
U "github.com/yusing/go-proxy/utils"
R "github.com/yusing/go-proxy/route"
W "github.com/yusing/go-proxy/watcher"
. "github.com/yusing/go-proxy/watcher/event"
)
type DockerProvider struct {
dockerHost string
dockerHost, hostname string
}
func DockerProviderImpl(dockerHost string) ProviderImpl {
return &DockerProvider{dockerHost: dockerHost}
}
// GetProxyEntries returns proxy entries from a docker client.
//
// It retrieves the docker client information using the dockerhelper.GetClientInfo method.
// Then, it iterates over the containers in the docker client information and calls
// the getEntriesFromLabels method to get the proxy entries for each container.
// Any errors encountered during the process are added to the ne error object.
// Finally, it returns the collected proxy entries and the ne error object.
//
// Parameters:
// - p: A pointer to the DockerProvider struct.
//
// Returns:
// - P.EntryModelSlice: (non-nil) A slice of EntryModel structs representing the proxy entries.
// - error: An error object if there was an error retrieving the docker client information or parsing the labels.
func (p DockerProvider) GetProxyEntries() (M.ProxyEntries, E.NestedError) {
func (p *DockerProvider) NewWatcher() W.Watcher {
return W.NewDockerWatcher(p.dockerHost)
}
func (p *DockerProvider) LoadRoutesImpl() (routes R.Routes, err E.NestedError) {
entries := M.NewProxyEntries()
info, err := D.GetClientInfo(p.dockerHost)
info, err := D.GetClientInfo(p.dockerHost, true)
if err.HasError() {
return entries, err
return routes, E.FailWith("connect to docker", err)
}
errors := E.NewBuilder("errors when parse docker labels")
for _, container := range info.Containers {
en, err := p.getEntriesFromLabels(&container, info.Host)
for _, c := range info.Containers {
container := D.FromDocker(&c, p.dockerHost)
if container.IsExcluded {
continue
}
newEntries, err := p.entriesFromContainerLabels(container)
if err.HasError() {
errors.Add(err)
}
// although err is not nil
// there may be some valid entries in `en`
dups := entries.MergeWith(en)
dups := entries.MergeFrom(newEntries)
// add the duplicate proxy entries to the error
dups.EachKV(func(k string, v *M.ProxyEntry) {
dups.RangeAll(func(k string, v *M.ProxyEntry) {
errors.Addf("duplicate alias %s", k)
})
}
return entries, errors.Build()
entries.RangeAll(func(_ string, e *M.ProxyEntry) {
e.DockerHost = p.dockerHost
})
routes, err = R.FromEntries(entries)
errors.Add(err)
return routes, errors.Build()
}
func (p *DockerProvider) NewWatcher() W.Watcher {
return W.NewDockerWatcher(p.dockerHost)
func (p *DockerProvider) OnEvent(event Event, routes R.Routes) (res EventResult) {
b := E.NewBuilder("event %s error", event)
defer b.To(&res.err)
routes.RangeAll(func(k string, v R.Route) {
if v.Entry().ContainerName == event.ActorName {
b.Add(v.Stop())
routes.Delete(k)
res.nRemoved++
}
})
switch event.Action {
case ActionStarted, ActionCreated, ActionModified:
client, err := D.ConnectClient(p.dockerHost)
if err.HasError() {
b.Add(E.FailWith("connect to docker", err))
return
}
defer client.Close()
cont, err := client.Inspect(event.ActorID)
if err.HasError() {
b.Add(E.FailWith("inspect container", err))
return
}
entries, err := p.entriesFromContainerLabels(cont)
b.Add(err)
entries.RangeAll(func(alias string, entry *M.ProxyEntry) {
if routes.Has(alias) {
b.Add(E.AlreadyExist("alias", alias))
} else {
if route, err := R.NewRoute(entry); err.HasError() {
b.Add(err)
} else {
routes.Store(alias, route)
b.Add(route.Start())
res.nAdded++
}
}
})
}
return
}
// Returns a list of proxy entries for a container.
// Always non-nil
func (p *DockerProvider) getEntriesFromLabels(container *types.Container, clientHost string) (M.ProxyEntries, E.NestedError) {
var mainAlias string
var aliases PT.Aliases
if exclude, ok := container.Labels[D.NSProxy+".exclude"]; ok {
if U.ParseBool(exclude) {
return M.NewProxyEntries(), E.Nil()
}
}
// set mainAlias to docker compose service name if available
if serviceName, ok := container.Labels["com.docker.compose.service"]; ok {
mainAlias = serviceName
}
// if mainAlias is not set,
// or container name is different from service name
// use container name
if containerName := strings.TrimPrefix(container.Names[0], "/"); containerName != mainAlias {
mainAlias = containerName
}
if l, ok := container.Labels[D.NSProxy+".aliases"]; ok {
aliases = PT.NewAliases(l)
delete(container.Labels, D.NSProxy+"proxy.aliases")
} else {
aliases = PT.NewAliases(mainAlias)
}
func (p *DockerProvider) entriesFromContainerLabels(container D.Container) (M.ProxyEntries, E.NestedError) {
entries := M.NewProxyEntries()
// find first port, return if no port exposed
defaultPort, err := findFirstPort(container)
if err.HasError() {
logrus.Debug(mainAlias, " ", err.Error())
}
// init entries map for all aliases
aliases.ForEach(func(a PT.Alias) {
entries.Set(string(a), &M.ProxyEntry{
Alias: string(a),
Host: clientHost,
Port: defaultPort,
for _, a := range container.Aliases {
entries.Store(a, &M.ProxyEntry{
Alias: a,
Host: p.hostname,
ProxyProperties: container.ProxyProperties,
})
})
}
errors := E.NewBuilder("failed to apply label for %q", mainAlias)
errors := E.NewBuilder("failed to apply label")
for key, val := range container.Labels {
lbl, err := D.ParseLabel(key, val)
if err.HasError() {
errors.Add(E.From(err).Subject(key))
continue
}
if lbl.Namespace != D.NSProxy {
continue
}
if lbl.Target == wildcardAlias {
// apply label for all aliases
entries.EachKV(func(a string, e *M.ProxyEntry) {
if err = D.ApplyLabel(e, lbl); err.HasError() {
errors.Add(E.From(err).Subject(lbl.Target))
}
})
} else {
config, ok := entries.UnsafeGet(lbl.Target)
if !ok {
errors.Add(E.NotExists("alias", lbl.Target))
continue
}
if err = D.ApplyLabel(config, lbl); err.HasError() {
errors.Add(err.Subject(lbl.Target))
}
}
errors.Add(p.applyLabel(entries, key, val))
}
entries.EachKV(func(a string, e *M.ProxyEntry) {
if e.Port == "" {
entries.UnsafeDelete(a)
}
})
return entries, errors.Build()
return entries, errors.Build().Subject(container.ContainerName)
}
func findFirstPort(c *types.Container) (string, E.NestedError) {
if len(c.Ports) == 0 {
return "", E.FailureWhy("findFirstPort", "no port exposed")
func (p *DockerProvider) applyLabel(entries M.ProxyEntries, key, val string) (res E.NestedError) {
b := E.NewBuilder("errors in label %s", key)
defer b.To(&res)
lbl, err := D.ParseLabel(key, val)
if err.HasError() {
b.Add(err.Subject(key))
}
for _, p := range c.Ports {
if p.PublicPort != 0 {
return fmt.Sprint(p.PublicPort), E.Nil()
if lbl.Namespace != D.NSProxy {
return
}
if lbl.Target == D.WildcardAlias {
// apply label for all aliases
entries.RangeAll(func(a string, e *M.ProxyEntry) {
if err = D.ApplyLabel(e, lbl); err.HasError() {
b.Add(err.Subject(lbl.Target))
}
})
} else {
config, ok := entries.Load(lbl.Target)
if !ok {
b.Add(E.NotExist("alias", lbl.Target))
return
}
if err = D.ApplyLabel(config, lbl); err.HasError() {
b.Add(err.Subject(lbl.Target))
}
}
return "", E.Failure("findFirstPort")
return
}
const wildcardAlias = "*"

View file

@ -7,8 +7,10 @@ import (
"github.com/yusing/go-proxy/common"
E "github.com/yusing/go-proxy/error"
M "github.com/yusing/go-proxy/models"
R "github.com/yusing/go-proxy/route"
U "github.com/yusing/go-proxy/utils"
W "github.com/yusing/go-proxy/watcher"
. "github.com/yusing/go-proxy/watcher/event"
)
type FileProvider struct {
@ -27,26 +29,53 @@ func Validate(data []byte) E.NestedError {
return U.ValidateYaml(U.GetSchema(common.ProvidersSchemaPath), data)
}
func (p *FileProvider) String() string {
return p.fileName
func (p FileProvider) OnEvent(event Event, routes R.Routes) (res EventResult) {
b := E.NewBuilder("event %s error", event)
defer b.To(&res.err)
newRoutes, err := p.LoadRoutesImpl()
if err.HasError() {
b.Add(err)
return
}
routes.RangeAll(func(_ string, v R.Route) {
b.Add(v.Stop())
})
routes.Clear()
newRoutes.RangeAll(func(_ string, v R.Route) {
b.Add(v.Start())
})
routes.MergeFrom(newRoutes)
return
}
func (p *FileProvider) GetProxyEntries() (M.ProxyEntries, E.NestedError) {
func (p *FileProvider) LoadRoutesImpl() (routes R.Routes, res E.NestedError) {
b := E.NewBuilder("file %q validation failure", p.fileName)
defer b.To(&res)
entries := M.NewProxyEntries()
data, err := E.Check(os.ReadFile(p.path))
if err.HasError() {
return entries, E.Failure("read file").Subject(p).With(err)
b.Add(E.FailWith("read file", err))
return
}
ne := E.Failure("validation").Subject(p)
if !common.NoSchemaValidation {
if err = Validate(data); err.HasError() {
return entries, ne.With(err)
b.Add(err)
return
}
}
if err = entries.UnmarshalFromYAML(data); err.HasError() {
return entries, ne.With(err)
b.Add(err)
return
}
return entries, E.Nil()
return R.FromEntries(entries)
}
func (p *FileProvider) NewWatcher() W.Watcher {

View file

@ -4,38 +4,40 @@ import (
"context"
"fmt"
"path"
"time"
"github.com/sirupsen/logrus"
E "github.com/yusing/go-proxy/error"
M "github.com/yusing/go-proxy/models"
R "github.com/yusing/go-proxy/route"
W "github.com/yusing/go-proxy/watcher"
. "github.com/yusing/go-proxy/watcher/event"
)
type ProviderImpl interface {
GetProxyEntries() (M.ProxyEntries, E.NestedError)
NewWatcher() W.Watcher
}
type (
Provider struct {
ProviderImpl
type Provider struct {
ProviderImpl
name string
t ProviderType
routes R.Routes
name string
t ProviderType
routes *R.Routes
reloadReqCh chan struct{}
watcher W.Watcher
watcherCtx context.Context
watcherCancel context.CancelFunc
watcher W.Watcher
watcherCtx context.Context
watcherCancel context.CancelFunc
l *logrus.Entry
cooldownCh chan struct{}
}
type ProviderType string
l *logrus.Entry
}
ProviderImpl interface {
NewWatcher() W.Watcher
LoadRoutesImpl() (R.Routes, E.NestedError)
OnEvent(event Event, routes R.Routes) EventResult
}
ProviderType string
EventResult struct {
nRemoved int
nAdded int
err E.NestedError
}
)
const (
ProviderTypeDocker ProviderType = "docker"
@ -44,16 +46,14 @@ const (
func newProvider(name string, t ProviderType) *Provider {
p := &Provider{
name: name,
t: t,
routes: R.NewRoutes(),
reloadReqCh: make(chan struct{}, 1),
cooldownCh: make(chan struct{}, 1),
name: name,
t: t,
routes: R.NewRoutes(),
}
p.l = logrus.WithField("provider", p)
go p.processReloadRequests()
return p
}
func NewFileProvider(filename string) *Provider {
name := path.Base(filename)
p := newProvider(name, ProviderTypeFile)
@ -78,25 +78,21 @@ func (p *Provider) GetType() ProviderType {
}
func (p *Provider) String() string {
return fmt.Sprintf("%s: %s", p.t, p.name)
return fmt.Sprintf("%s-%s", p.t, p.name)
}
func (p *Provider) StartAllRoutes() E.NestedError {
err := p.loadRoutes()
func (p *Provider) StartAllRoutes() (res E.NestedError) {
errors := E.NewBuilder("errors in routes")
defer errors.To(&res)
// start watcher no matter load success or not
p.watcherCtx, p.watcherCancel = context.WithCancel(context.Background())
go p.watchEvents()
errors := E.NewBuilder("errors in routes")
nStarted := 0
nFailed := 0
if err.HasError() {
errors.Add(err)
}
p.routes.EachKVParallel(func(alias string, r R.Route) {
p.routes.RangeAll(func(alias string, r R.Route) {
if err := r.Start(); err.HasError() {
errors.Add(err.Subject(r))
nFailed++
@ -106,18 +102,21 @@ func (p *Provider) StartAllRoutes() E.NestedError {
})
p.l.Debugf("%d routes started, %d failed", nStarted, nFailed)
return errors.Build()
return
}
func (p *Provider) StopAllRoutes() E.NestedError {
func (p *Provider) StopAllRoutes() (res E.NestedError) {
if p.watcherCancel != nil {
p.watcherCancel()
p.watcherCancel = nil
}
errors := E.NewBuilder("errors stopping routes for provider %q", p.name)
defer errors.To(&res)
nStopped := 0
nFailed := 0
p.routes.EachKVParallel(func(alias string, r R.Route) {
p.routes.RangeAll(func(alias string, r R.Route) {
if err := r.Stop(); err.HasError() {
errors.Add(err.Subject(r))
nFailed++
@ -126,20 +125,24 @@ func (p *Provider) StopAllRoutes() E.NestedError {
}
})
p.l.Debugf("%d routes stopped, %d failed", nStopped, nFailed)
return errors.Build()
return
}
func (p *Provider) ReloadRoutes() {
select {
case p.reloadReqCh <- struct{}{}:
// Successfully sent reload request
default:
// Reload request already in progress, ignore this request
func (p *Provider) RangeRoutes(do func(string, R.Route)) {
p.routes.RangeAll(do)
}
func (p *Provider) GetRoute(alias string) (R.Route, bool) {
return p.routes.Load(alias)
}
func (p *Provider) LoadRoutes() E.NestedError {
routes, err := p.LoadRoutesImpl()
if err != nil {
return err
}
}
func (p *Provider) GetCurrentRoutes() *R.Routes {
return p.routes
p.routes = routes
return nil
}
func (p *Provider) watchEvents() {
@ -151,11 +154,15 @@ func (p *Provider) watchEvents() {
case <-p.watcherCtx.Done():
return
case event, ok := <-events:
if !ok {
if !ok { // channel closed
return
}
l.Info(event)
p.ReloadRoutes()
res := p.OnEvent(event, p.routes)
l.Infof("%s event %q", event.Type, event)
l.Infof("%d route added, %d routes removed", res.nAdded, res.nRemoved)
if res.err.HasError() {
l.Error(res.err)
}
case err, ok := <-errs:
if !ok {
return
@ -167,50 +174,3 @@ func (p *Provider) watchEvents() {
}
}
}
func (p *Provider) processReloadRequests() {
for range p.reloadReqCh {
// prevent busy loop caused by a container
// repeating crashing and restarting
select {
case p.cooldownCh <- struct{}{}:
p.l.Info("Starting to reload routes")
nRoutes := p.routes.Size()
p.StopAllRoutes()
p.loadRoutes()
p.StartAllRoutes()
p.l.Infof("Routes reloaded (%d -> %d)", nRoutes, p.routes.Size())
go func() {
time.Sleep(reloadCooldown)
<-p.cooldownCh
}()
default:
}
}
}
func (p *Provider) loadRoutes() E.NestedError {
entries, err := p.GetProxyEntries()
if err.HasError() {
p.l.Warn(err.Subject(p))
}
p.routes = R.NewRoutes()
errors := E.NewBuilder("errors loading routes from %s", p)
entries.EachKV(func(a string, e *M.ProxyEntry) {
e.Alias = a
r, err := R.NewRoute(e)
if err.HasError() {
errors.Add(err.Subject(a))
} else {
p.routes.Set(a, r)
}
})
return errors.Build()
}
const reloadCooldown = 50 * time.Millisecond

View file

@ -207,7 +207,7 @@ func joinURLPath(a, b *url.URL) (path, rawpath string) {
// }
//
// TODO: headers in ModifyResponse
func NewReverseProxy(target *url.URL, transport *http.Transport, entry *Entry) *ReverseProxy {
func NewReverseProxy(target *url.URL, transport http.RoundTripper, entry *ReverseProxyEntry) *ReverseProxy {
// check on init rather than on request
var setHeaders = func(r *http.Request) {}
var hideHeaders = func(r *http.Request) {}

View file

@ -2,8 +2,8 @@ package route
import (
"crypto/tls"
"fmt"
"net"
"sync"
"time"
"net/http"
@ -11,6 +11,7 @@ import (
"strings"
"github.com/sirupsen/logrus"
"github.com/yusing/go-proxy/docker/idlewatcher"
E "github.com/yusing/go-proxy/error"
P "github.com/yusing/go-proxy/proxy"
PT "github.com/yusing/go-proxy/proxy/fields"
@ -23,57 +24,65 @@ type (
TargetURL *URL `json:"target_url"`
PathPatterns PT.PathPatterns `json:"path_patterns"`
entry *P.ReverseProxyEntry
mux *http.ServeMux
handler *P.ReverseProxy
regIdleWatcher func() E.NestedError
unregIdleWatcher func()
}
URL url.URL
PathKey = PT.PathPattern
SubdomainKey = PT.Alias
)
var httpRoutes = F.NewMap[SubdomainKey, *HTTPRoute]()
func NewHTTPRoute(entry *P.ReverseProxyEntry) (*HTTPRoute, E.NestedError) {
var trans http.RoundTripper
var regIdleWatcher func() E.NestedError
var unregIdleWatcher func()
func NewHTTPRoute(entry *P.Entry) (*HTTPRoute, E.NestedError) {
var tr *http.Transport
if entry.NoTLSVerify {
tr = transportNoTLS
trans = transportNoTLS
} else {
tr = transport
trans = transport
}
rp := P.NewReverseProxy(entry.URL, tr, entry)
rp := P.NewReverseProxy(entry.URL, trans, entry)
httpRoutes.Lock()
defer httpRoutes.Unlock()
var r *HTTPRoute
r, ok := httpRoutes.UnsafeGet(entry.Alias)
if !ok {
r = &HTTPRoute{
Alias: entry.Alias,
TargetURL: (*URL)(entry.URL),
PathPatterns: entry.PathPatterns,
handler: rp,
if entry.UseIdleWatcher() {
regIdleWatcher = func() E.NestedError {
watcher, err := idlewatcher.Register(entry)
if err.HasError() {
return err
}
// patch round-tripper
rp.Transport = watcher.PatchRoundTripper(trans)
return nil
}
httpRoutes.UnsafeSet(entry.Alias, r)
}
rewrite := rp.Rewrite
if logrus.GetLevel() == logrus.DebugLevel {
l := logrus.WithField("alias", entry.Alias)
rp.Rewrite = func(pr *P.ProxyRequest) {
l.Debug("request URL: ", pr.In.Host, pr.In.URL.Path)
l.Debug("request headers: ", pr.In.Header)
rewrite(pr)
unregIdleWatcher = func() {
idlewatcher.Unregister(entry.ContainerName)
rp.Transport = trans
}
} else {
rp.Rewrite = rewrite
}
return r, E.Nil()
httpRoutesMu.Lock()
defer httpRoutesMu.Unlock()
_, exists := httpRoutes.Load(entry.Alias)
if exists {
return nil, E.AlreadyExist("HTTPRoute alias", entry.Alias)
}
r := &HTTPRoute{
Alias: entry.Alias,
TargetURL: (*URL)(entry.URL),
PathPatterns: entry.PathPatterns,
entry: entry,
handler: rp,
regIdleWatcher: regIdleWatcher,
unregIdleWatcher: unregIdleWatcher,
}
return r, nil
}
func (r *HTTPRoute) String() string {
@ -81,18 +90,35 @@ func (r *HTTPRoute) String() string {
}
func (r *HTTPRoute) Start() E.NestedError {
httpRoutesMu.Lock()
defer httpRoutesMu.Unlock()
if r.regIdleWatcher != nil {
if err := r.regIdleWatcher(); err.HasError() {
return err
}
}
r.mux = http.NewServeMux()
for _, p := range r.PathPatterns {
r.mux.HandleFunc(string(p), r.handler.ServeHTTP)
}
httpRoutes.Set(r.Alias, r)
return E.Nil()
httpRoutes.Store(r.Alias, r)
return nil
}
func (r *HTTPRoute) Stop() E.NestedError {
httpRoutesMu.Lock()
defer httpRoutesMu.Unlock()
if r.unregIdleWatcher != nil {
r.unregIdleWatcher()
}
r.mux = nil
httpRoutes.Delete(r.Alias)
return E.Nil()
return nil
}
func (u *URL) String() string {
@ -104,27 +130,26 @@ func (u *URL) MarshalText() (text []byte, err error) {
}
func ProxyHandler(w http.ResponseWriter, r *http.Request) {
mux, err := findMux(r.Host, PathKey(r.URL.Path))
mux, err := findMux(r.Host)
if err != nil {
err = E.Failure("request").
Subjectf("%s %s%s", r.Method, r.Host, r.URL.Path).
With(err)
http.Error(w, err.Error(), http.StatusNotFound)
http.Error(w, err.String(), http.StatusNotFound)
logrus.Error(err)
return
}
mux.ServeHTTP(w, r)
}
func findMux(host string, path PathKey) (*http.ServeMux, error) {
func findMux(host string) (*http.ServeMux, E.NestedError) {
sd := strings.Split(host, ".")[0]
if r, ok := httpRoutes.UnsafeGet(PT.Alias(sd)); ok {
if r, ok := httpRoutes.Load(PT.Alias(sd)); ok {
return r.mux, nil
}
return nil, E.NotExists("route", fmt.Sprintf("subdomain: %s, path: %s", sd, path))
return nil, E.NotExist("route", sd)
}
// TODO: default + per proxy
var (
transport = &http.Transport{
Proxy: http.ProxyFromEnvironment,
@ -135,10 +160,12 @@ var (
MaxIdleConns: 1000,
MaxIdleConnsPerHost: 1000,
}
transportNoTLS = func() *http.Transport {
var clone = transport.Clone()
clone.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
return clone
}()
httpRoutes = F.NewMapOf[SubdomainKey, *HTTPRoute]()
httpRoutesMu sync.Mutex
)

View file

@ -1,6 +1,9 @@
package route
import (
"fmt"
"net/url"
E "github.com/yusing/go-proxy/error"
M "github.com/yusing/go-proxy/models"
P "github.com/yusing/go-proxy/proxy"
@ -9,27 +12,81 @@ import (
type (
Route interface {
RouteImpl
Entry() *M.ProxyEntry
Type() RouteType
URL() *url.URL
}
Routes = F.Map[string, Route]
RouteType string
RouteImpl interface {
Start() E.NestedError
Stop() E.NestedError
String() string
}
Routes = F.Map[string, Route]
route struct {
RouteImpl
type_ RouteType
entry *M.ProxyEntry
}
)
const (
RouteTypeStream RouteType = "stream"
RouteTypeReverseProxy RouteType = "reverse_proxy"
)
// function alias
var NewRoutes = F.NewMap[string, Route]
var NewRoutes = F.NewMapOf[string, Route]
func NewRoute(en *M.ProxyEntry) (Route, E.NestedError) {
entry, err := P.NewEntry(en)
rt, err := P.ValidateEntry(en)
if err.HasError() {
return nil, err
}
switch e := entry.(type) {
var t RouteType
switch e := rt.(type) {
case *P.StreamEntry:
return NewStreamRoute(e)
case *P.Entry:
return NewHTTPRoute(e)
rt, err = NewStreamRoute(e)
t = RouteTypeStream
case *P.ReverseProxyEntry:
rt, err = NewHTTPRoute(e)
t = RouteTypeReverseProxy
default:
panic("bug: should not reach here")
}
return &route{RouteImpl: rt.(RouteImpl), entry: en, type_: t}, err
}
func (rt *route) Entry() *M.ProxyEntry {
return rt.entry
}
func (rt *route) Type() RouteType {
return rt.type_
}
func (rt *route) URL() *url.URL {
url, _ := url.Parse(fmt.Sprintf("%s://%s", rt.entry.Scheme, rt.entry.Host))
return url
}
func FromEntries(entries M.ProxyEntries) (Routes, E.NestedError) {
b := E.NewBuilder("errors in routes")
routes := NewRoutes()
entries.RangeAll(func(alias string, entry *M.ProxyEntry) {
entry.Alias = alias
r, err := NewRoute(entry)
if err.HasError() {
b.Add(err.Subject(alias))
} else {
routes.Store(alias, r)
}
})
return routes, b.Build()
}

View file

@ -12,7 +12,7 @@ import (
)
type StreamRoute struct {
*P.StreamEntry
P.StreamEntry
StreamImpl `json:"-"`
wg sync.WaitGroup
@ -35,7 +35,7 @@ func NewStreamRoute(entry *P.StreamEntry) (*StreamRoute, E.NestedError) {
return nil, E.Unsupported("scheme", fmt.Sprintf("%v -> %v", entry.Scheme.ListeningScheme, entry.Scheme.ProxyScheme))
}
base := &StreamRoute{
StreamEntry: entry,
StreamEntry: *entry,
wg: sync.WaitGroup{},
connCh: make(chan any),
}
@ -45,11 +45,11 @@ func NewStreamRoute(entry *P.StreamEntry) (*StreamRoute, E.NestedError) {
base.StreamImpl = NewUDPRoute(base)
}
base.l = logrus.WithField("route", base.StreamImpl)
return base, E.Nil()
return base, nil
}
func (r *StreamRoute) String() string {
return fmt.Sprintf("%s-stream: %s", r.Scheme, r.Alias)
return fmt.Sprintf("%s stream: %s", r.Scheme, r.Alias)
}
func (r *StreamRoute) Start() E.NestedError {
@ -59,13 +59,13 @@ func (r *StreamRoute) Start() E.NestedError {
r.stopCh = make(chan struct{}, 1)
r.wg.Wait()
if err := r.Setup(); err != nil {
return E.Failure("setup").With(err)
return E.FailWith("setup", err)
}
r.started.Store(true)
r.wg.Add(2)
go r.grAcceptConnections()
go r.grHandleConnections()
return E.Nil()
return nil
}
func (r *StreamRoute) Stop() E.NestedError {
@ -88,7 +88,7 @@ func (r *StreamRoute) Stop() E.NestedError {
case <-time.After(streamStopListenTimeout):
l.Error("timed out waiting for connections")
}
return E.Nil()
return nil
}
func (r *StreamRoute) grAcceptConnections() {

View file

@ -65,9 +65,10 @@ func (route *TCPRoute) Handle(c any) error {
}()
route.mu.Lock()
defer route.mu.Unlock()
pipe := U.NewBidirectionalPipe(pipeCtx, clientConn, serverConn)
route.pipe = append(route.pipe, pipe)
route.mu.Unlock()
return pipe.Start()
}
@ -78,7 +79,7 @@ func (route *TCPRoute) CloseListeners() {
route.listener.Close()
route.listener = nil
for _, pipe := range route.pipe {
if err := pipe.Stop(); err.HasError() {
if err := pipe.Stop(); err != nil {
route.l.Error(err)
}
}

View file

@ -1,229 +1,116 @@
package functional
import (
"context"
"sync"
"github.com/puzpuzpuz/xsync/v3"
"gopkg.in/yaml.v3"
E "github.com/yusing/go-proxy/error"
)
type Map[KT comparable, VT any] struct {
m map[KT]VT
defVals map[KT]VT
sync.RWMutex
*xsync.MapOf[KT, VT]
}
// NewMap creates a new Map with the given map as its initial values.
//
// Parameters:
// - dv: optional default values for the Map
//
// Return:
// - *Map[KT, VT]: a pointer to the newly created Map.
func NewMap[KT comparable, VT any](dv ...map[KT]VT) *Map[KT, VT] {
return NewMapFrom(make(map[KT]VT), dv...)
func NewMapOf[KT comparable, VT any](options ...func(*xsync.MapConfig)) Map[KT, VT] {
return Map[KT, VT]{xsync.NewMapOf[KT, VT](options...)}
}
// NewMapOf creates a new Map with the given map as its initial values.
//
// Type parameters:
// - M: type for the new map.
//
// Parameters:
// - dv: optional default values for the Map
//
// Return:
// - *Map[KT, VT]: a pointer to the newly created Map.
func NewMapOf[M Map[KT, VT], KT comparable, VT any](dv ...map[KT]VT) *Map[KT, VT] {
return NewMapFrom(make(map[KT]VT), dv...)
}
// NewMapFrom creates a new Map with the given map as its initial values.
//
// Parameters:
// - from: a map of type KT to VT, which will be the initial values of the Map.
// - dv: optional default values for the Map
//
// Return:
// - *Map[KT, VT]: a pointer to the newly created Map.
func NewMapFrom[KT comparable, VT any](from map[KT]VT, dv ...map[KT]VT) *Map[KT, VT] {
if len(dv) > 0 {
return &Map[KT, VT]{m: from, defVals: dv[0]}
func NewMapFrom[KT comparable, VT any](m map[KT]VT) (res Map[KT, VT]) {
res = NewMapOf[KT, VT](xsync.WithPresize(len(m)))
for k, v := range m {
res.Store(k, v)
}
return &Map[KT, VT]{m: from}
return
}
func (m *Map[KT, VT]) Set(key KT, value VT) {
m.Lock()
m.m[key] = value
m.Unlock()
}
func MapFind[KT comparable, VT, CT any](m Map[KT, VT], criteria func(VT) (CT, bool)) (_ CT) {
result := make(chan CT, 1)
func (m *Map[KT, VT]) Get(key KT) VT {
m.RLock()
defer m.RUnlock()
value, ok := m.m[key]
if !ok && m.defVals != nil {
return m.defVals[key]
}
return value
}
// Find searches for the first element in the map that satisfies the given criteria.
//
// Parameters:
// - criteria: a function that takes a value of type VT and returns a tuple of any type and a boolean.
//
// Return:
// - any: the first value that satisfies the criteria, or nil if no match is found.
func (m *Map[KT, VT]) Find(criteria func(VT) (any, bool)) any {
m.RLock()
defer m.RUnlock()
result := make(chan any)
wg := sync.WaitGroup{}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
for _, v := range m.m {
wg.Add(1)
go func(val VT) {
defer wg.Done()
if value, ok := criteria(val); ok {
select {
case result <- value:
cancel() // Cancel other goroutines if a result is found
case <-ctx.Done(): // If already cancelled
return
}
m.Range(func(key KT, value VT) bool {
select {
case <-result: // already have a result
return false // stop iteration
default:
if got, ok := criteria(value); ok {
result <- got
return false
}
}(v)
}
go func() {
wg.Wait()
close(result)
}()
// The first valid match, if any
select {
case res, ok := <-result:
if ok {
return res
return true
}
case <-ctx.Done():
})
select {
case v := <-result:
return v
default:
return
}
return nil // Return nil if no matches found
}
func (m *Map[KT, VT]) UnsafeGet(key KT) (VT, bool) {
value, ok := m.m[key]
return value, ok
}
func (m *Map[KT, VT]) UnsafeSet(key KT, value VT) {
m.m[key] = value
}
func (m *Map[KT, VT]) Delete(key KT) {
m.Lock()
delete(m.m, key)
m.Unlock()
}
func (m *Map[KT, VT]) UnsafeDelete(key KT) {
delete(m.m, key)
}
// MergeWith merges the contents of another Map[KT, VT]
// into the current Map[KT, VT] and
// returns a map that were duplicated.
// MergeFrom add contents from another `Map`, ignore duplicated keys
//
// Parameters:
// - other: a pointer to another Map[KT, VT] to be merged into the current Map[KT, VT].
// - other: `Map` of values to add from
//
// Return:
// - Map[KT, VT]: a map of key-value pairs that were duplicated during the merge.
func (m *Map[KT, VT]) MergeWith(other *Map[KT, VT]) Map[KT, VT] {
dups := make(map[KT]VT)
// - Map: a `Map` of duplicated keys-value pairs
func (m Map[KT, VT]) MergeFrom(other Map[KT, VT]) Map[KT, VT] {
dups := NewMapOf[KT, VT]()
m.Lock()
for k, v := range other.m {
if _, isDup := m.m[k]; !isDup {
m.m[k] = v
other.Range(func(k KT, v VT) bool {
if _, ok := m.Load(k); ok {
dups.Store(k, v)
} else {
dups[k] = v
m.Store(k, v)
}
}
m.Unlock()
return Map[KT, VT]{m: dups}
return true
})
return dups
}
func (m *Map[KT, VT]) Clear() {
m.Lock()
m.m = make(map[KT]VT)
m.Unlock()
func (m Map[KT, VT]) RangeAll(do func(k KT, v VT)) {
m.Range(func(k KT, v VT) bool {
do(k, v)
return true
})
}
func (m *Map[KT, VT]) Size() int {
m.RLock()
defer m.RUnlock()
return len(m.m)
func (m Map[KT, VT]) RemoveAll(criteria func(VT) bool) {
m.Range(func(k KT, v VT) bool {
if criteria(v) {
m.Delete(k)
}
return true
})
}
func (m *Map[KT, VT]) Contains(key KT) bool {
m.RLock()
_, ok := m.m[key]
m.RUnlock()
func (m Map[KT, VT]) Has(k KT) bool {
_, ok := m.Load(k)
return ok
}
func (m *Map[KT, VT]) Clone() *Map[KT, VT] {
m.RLock()
defer m.RUnlock()
clone := make(map[KT]VT, len(m.m))
for k, v := range m.m {
clone[k] = v
func (m Map[KT, VT]) UnmarshalFromYAML(data []byte) E.NestedError {
if m.Size() != 0 {
return E.FailedWhy("unmarshal from yaml", "map is not empty")
}
return &Map[KT, VT]{m: clone, defVals: m.defVals}
}
func (m *Map[KT, VT]) EachKV(fn func(k KT, v VT)) {
m.Lock()
for k, v := range m.m {
fn(k, v)
tmp := make(map[KT]VT)
if err := E.From(yaml.Unmarshal(data, tmp)); err.HasError() {
return err
}
m.Unlock()
}
func (m *Map[KT, VT]) Each(fn func(v VT)) {
m.Lock()
for _, v := range m.m {
fn(v)
for k, v := range tmp {
m.Store(k, v)
}
m.Unlock()
return nil
}
func (m *Map[KT, VT]) EachParallel(fn func(v VT)) {
m.Lock()
ParallelForEachValue(m.m, fn)
m.Unlock()
}
func (m *Map[KT, VT]) EachKVParallel(fn func(k KT, v VT)) {
m.Lock()
ParallelForEachKV(m.m, fn)
m.Unlock()
}
func (m *Map[KT, VT]) UnmarshalFromYAML(data []byte) E.NestedError {
return E.From(yaml.Unmarshal(data, m.m))
}
func (m *Map[KT, VT]) Iterator() map[KT]VT {
return m.m
func (m Map[KT, VT]) String() string {
tmp := make(map[KT]VT, m.Size())
m.RangeAll(func(k KT, v VT) {
tmp[k] = v
})
data, err := yaml.Marshal(tmp)
if err != nil {
return err.Error()
}
return string(data)
}

View file

@ -0,0 +1,75 @@
package functional_test
import (
"testing"
. "github.com/yusing/go-proxy/utils/functional"
. "github.com/yusing/go-proxy/utils/testing"
)
func TestNewMapFrom(t *testing.T) {
m := NewMapFrom(map[string]int{
"a": 1,
"b": 2,
"c": 3,
})
ExpectEqual(t, m.Size(), 3)
ExpectTrue(t, m.Has("a"))
ExpectTrue(t, m.Has("b"))
ExpectTrue(t, m.Has("c"))
}
func TestMapFind(t *testing.T) {
m := NewMapFrom(map[string]map[string]int{
"a": {
"a": 1,
},
"b": {
"a": 1,
"b": 2,
},
"c": {
"b": 2,
"c": 3,
},
})
res := MapFind(m, func(inner map[string]int) (int, bool) {
if _, ok := inner["c"]; ok && inner["c"] == 3 {
return inner["c"], true
}
return 0, false
})
ExpectEqual(t, res, 3)
}
func TestMergeFrom(t *testing.T) {
m1 := NewMapFrom(map[string]int{
"a": 1,
"b": 2,
"c": 3,
"d": 4,
})
m2 := NewMapFrom(map[string]int{
"a": 1,
"c": 123,
"e": 456,
"f": 6,
})
dup := m1.MergeFrom(m2)
ExpectEqual(t, m1.Size(), 6)
ExpectTrue(t, m1.Has("e"))
ExpectTrue(t, m1.Has("f"))
c, _ := m1.Load("c")
d, _ := m1.Load("d")
e, _ := m1.Load("e")
f, _ := m1.Load("f")
ExpectEqual(t, c, 3)
ExpectEqual(t, d, 4)
ExpectEqual(t, e, 456)
ExpectEqual(t, f, 6)
ExpectEqual(t, dup.Size(), 2)
ExpectTrue(t, dup.Has("a"))
ExpectTrue(t, dup.Has("c"))
}

View file

@ -10,15 +10,8 @@ import (
E "github.com/yusing/go-proxy/error"
)
// TODO: move to "utils/io"
type (
Reader interface {
Read() ([]byte, E.NestedError)
}
StdReader struct {
r Reader
}
FileReader struct {
Path string
}
@ -29,13 +22,6 @@ type (
closed atomic.Bool
}
StdReadCloser struct {
r *ReadCloser
}
ByteReader []byte
NewByteReader = ByteReader
Pipe struct {
r ReadCloser
w io.WriteCloser
@ -44,49 +30,25 @@ type (
}
BidirectionalPipe struct {
pSrcDst Pipe
pDstSrc Pipe
pSrcDst *Pipe
pDstSrc *Pipe
}
)
func NewFileReader(path string) *FileReader {
return &FileReader{Path: path}
}
func (r StdReader) Read() ([]byte, error) {
return r.r.Read()
}
func (r *FileReader) Read() ([]byte, E.NestedError) {
return E.Check(os.ReadFile(r.Path))
}
func (r ByteReader) Read() ([]byte, E.NestedError) {
return r, E.Nil()
}
func (r *ReadCloser) Read(p []byte) (int, E.NestedError) {
func (r *ReadCloser) Read(p []byte) (int, error) {
select {
case <-r.ctx.Done():
return 0, E.From(r.ctx.Err())
return 0, r.ctx.Err()
default:
return E.Check(r.r.Read(p))
return r.r.Read(p)
}
}
func (r *ReadCloser) Close() E.NestedError {
func (r *ReadCloser) Close() error {
if r.closed.Load() {
return E.Nil()
return nil
}
r.closed.Store(true)
return E.From(r.r.Close())
}
func (r StdReadCloser) Read(p []byte) (int, error) {
return r.r.Read(p)
}
func (r StdReadCloser) Close() error {
return r.r.Close()
}
@ -100,35 +62,35 @@ func NewPipe(ctx context.Context, r io.ReadCloser, w io.WriteCloser) *Pipe {
}
}
func (p *Pipe) Start() E.NestedError {
return Copy(p.ctx, p.w, &StdReadCloser{&p.r})
func (p *Pipe) Start() error {
return Copy(p.ctx, p.w, &p.r)
}
func (p *Pipe) Stop() E.NestedError {
func (p *Pipe) Stop() error {
p.cancel()
return E.Join("error stopping pipe", p.r.Close(), p.w.Close())
return E.JoinE("error stopping pipe", p.r.Close(), p.w.Close()).Error()
}
func (p *Pipe) Write(b []byte) (int, E.NestedError) {
return E.Check(p.w.Write(b))
func (p *Pipe) Write(b []byte) (int, error) {
return p.w.Write(b)
}
func NewBidirectionalPipe(ctx context.Context, rw1 io.ReadWriteCloser, rw2 io.ReadWriteCloser) *BidirectionalPipe {
return &BidirectionalPipe{
pSrcDst: *NewPipe(ctx, rw1, rw2),
pDstSrc: *NewPipe(ctx, rw2, rw1),
pSrcDst: NewPipe(ctx, rw1, rw2),
pDstSrc: NewPipe(ctx, rw2, rw1),
}
}
func NewBidirectionalPipeIntermediate(ctx context.Context, listener io.ReadCloser, client io.ReadWriteCloser, target io.ReadWriteCloser) *BidirectionalPipe {
return &BidirectionalPipe{
pSrcDst: *NewPipe(ctx, listener, client),
pDstSrc: *NewPipe(ctx, client, target),
pSrcDst: NewPipe(ctx, listener, client),
pDstSrc: NewPipe(ctx, client, target),
}
}
func (p *BidirectionalPipe) Start() E.NestedError {
errCh := make(chan E.NestedError, 2)
func (p *BidirectionalPipe) Start() error {
errCh := make(chan error, 2)
go func() {
errCh <- p.pSrcDst.Start()
}()
@ -136,34 +98,34 @@ func (p *BidirectionalPipe) Start() E.NestedError {
errCh <- p.pDstSrc.Start()
}()
for err := range errCh {
if err.HasError() {
if err != nil {
return err
}
}
return E.Nil()
return nil
}
func (p *BidirectionalPipe) Stop() E.NestedError {
return E.Join("error stopping pipe", p.pSrcDst.Stop(), p.pDstSrc.Stop())
func (p *BidirectionalPipe) Stop() error {
return E.JoinE("error stopping pipe", p.pSrcDst.Stop(), p.pDstSrc.Stop()).Error()
}
func Copy(ctx context.Context, dst io.WriteCloser, src io.ReadCloser) E.NestedError {
_, err := io.Copy(dst, StdReadCloser{&ReadCloser{ctx: ctx, r: src}})
return E.From(err)
func Copy(ctx context.Context, dst io.WriteCloser, src io.ReadCloser) error {
_, err := io.Copy(dst, &ReadCloser{ctx: ctx, r: src})
return err
}
func LoadJson[T any](path string, pointer *T) E.NestedError {
data, err := os.ReadFile(path)
if err != nil {
return E.From(err)
data, err := E.Check(os.ReadFile(path))
if err.HasError() {
return err
}
return E.From(json.Unmarshal(data, pointer))
}
func SaveJson[T any](path string, pointer *T, perm os.FileMode) E.NestedError {
data, err := json.Marshal(pointer)
if err != nil {
return E.From(err)
data, err := E.Check(json.Marshal(pointer))
if err.HasError() {
return err
}
return E.From(os.WriteFile(path, data, perm))
}

View file

@ -20,5 +20,5 @@ func SetFieldFromSnake[T, VT any](obj *T, field string, value VT) E.NestedError
return E.Invalid("field", field)
}
prop.Set(reflect.ValueOf(value))
return E.Nil()
return nil
}

View file

@ -17,45 +17,26 @@ func ValidateYaml(schema *jsonschema.Schema, data []byte) E.NestedError {
err := yaml.Unmarshal(data, &i)
if err != nil {
return E.Failure("unmarshal yaml").With(err)
return E.FailWith("unmarshal yaml", err)
}
m, err := json.Marshal(i)
if err != nil {
return E.Failure("marshal json").With(err)
return E.FailWith("marshal json", err)
}
err = schema.Validate(bytes.NewReader(m))
if err == nil {
return E.Nil()
return nil
}
errors := E.NewBuilder("yaml validation error")
for _, e := range err.(*jsonschema.ValidationError).Causes {
errors.Add(e)
errors.AddE(e)
}
return errors.Build()
}
// TryJsonStringify converts the given object to a JSON string.
//
// It takes an object of any type and attempts to marshal it into a JSON string.
// If the marshaling is successful, the JSON string is returned.
// If the marshaling fails, the object is converted to a string using fmt.Sprint and returned.
//
// Parameters:
// - o: The object to be converted to a JSON string.
//
// Return type:
// - string: The JSON string representation of the object.
func TryJsonStringify(o any) string {
b, err := json.Marshal(o)
if err != nil {
return fmt.Sprint(o)
}
return string(b)
}
// Serialize converts the given data into a map[string]any representation.
//
// It uses reflection to inspect the data type and handle different kinds of data.
@ -123,7 +104,7 @@ func Serialize(data any) (SerializedObject, E.NestedError) {
return nil, E.Unsupported("type", value.Kind())
}
return result, E.Nil()
return result, nil
}
func Deserialize(src map[string]any, target any) E.NestedError {
@ -166,7 +147,7 @@ func Deserialize(src map[string]any, target any) E.NestedError {
propNew := reflect.New(propType.Elem())
err := Deserialize(vSerialized, propNew.Interface())
if err.HasError() {
return E.Failure("set field").With(k).With(err)
return E.Failure("set field").With(err).Subject(k)
}
prop.Set(propNew)
default:
@ -180,7 +161,15 @@ func Deserialize(src map[string]any, target any) E.NestedError {
}
}
return E.Nil()
return nil
}
func DeserializeJson(j map[string]string, target any) E.NestedError {
data, err := E.Check(json.Marshal(j))
if err.HasError() {
return err
}
return E.From(json.Unmarshal(data, target))
}
func toLowerNoSnake(s string) string {

11
src/utils/string.go Normal file
View file

@ -0,0 +1,11 @@
package utils
import "strings"
func CommaSeperatedList(s string) []string {
res := strings.Split(s, ",")
for i, part := range res {
res[i] = strings.TrimSpace(part)
}
return res
}

View file

@ -3,25 +3,23 @@ package utils
import (
"reflect"
"testing"
E "github.com/yusing/go-proxy/error"
)
func ExpectNoError(t *testing.T, err error) {
t.Helper()
var noError bool
switch t := err.(type) {
case E.NestedError:
noError = t.NoError()
default:
noError = err == nil
}
if !noError {
if err != nil && !reflect.ValueOf(err).IsNil() {
t.Errorf("expected err=nil, got %s", err.Error())
}
}
func ExpectEqual(t *testing.T, got, want any) {
func ExpectEqual[T comparable](t *testing.T, got T, want T) {
t.Helper()
if got != want {
t.Errorf("expected:\n%v, got\n%v", want, got)
}
}
func ExpectDeepEqual[T any](t *testing.T, got T, want T) {
t.Helper()
if !reflect.DeepEqual(got, want) {
t.Errorf("expected:\n%v, got\n%v", want, got)
@ -47,7 +45,7 @@ func ExpectType[T any](t *testing.T, got any) T {
tExpect := reflect.TypeFor[T]()
_, ok := got.(T)
if !ok {
t.Errorf("expected type %T, got %T", tExpect, got)
t.Errorf("expected type %s, got %T", tExpect, got)
}
return got.(T)
}

View file

@ -2,13 +2,13 @@ package watcher
import (
"context"
"fmt"
"time"
"github.com/docker/docker/api/types/events"
"github.com/docker/docker/api/types/filters"
D "github.com/yusing/go-proxy/docker"
E "github.com/yusing/go-proxy/error"
. "github.com/yusing/go-proxy/watcher/event"
)
type DockerWatcher struct {
@ -34,13 +34,14 @@ func (w *DockerWatcher) Events(ctx context.Context) (<-chan Event, <-chan E.Nest
if err.NoError() {
break
}
errCh <- E.From(err)
errCh <- err
time.Sleep(1 * time.Second)
}
if err.HasError() {
errCh <- E.Failure("connecting to docker")
return
}
defer cl.Close()
cEventCh, cErrCh := cl.Events(ctx, dwOptions)
started <- struct{}{}
@ -58,13 +59,16 @@ func (w *DockerWatcher) Events(ctx context.Context) (<-chan Event, <-chan E.Nest
case events.ActionStart:
Action = ActionCreated
case events.ActionDie:
Action = ActionDeleted
Action = ActionStopped
default: // NOTE: should not happen
Action = ActionModified
}
eventCh <- Event{
ActorName: fmt.Sprintf("container %q", msg.Actor.Attributes["name"]),
Action: Action,
Type: EventTypeDocker,
ActorID: msg.Actor.ID,
ActorAttributes: msg.Actor.Attributes, // labels
ActorName: msg.Actor.Attributes["name"],
Action: Action,
}
case err := <-cErrCh:
if err == nil {

View file

@ -1,26 +0,0 @@
package watcher
import "fmt"
type (
Event struct {
ActorName string
Action Action
}
Action string
)
const (
ActionModified Action = "MODIFIED"
ActionCreated Action = "CREATED"
ActionStarted Action = "STARTED"
ActionDeleted Action = "DELETED"
)
func (e Event) String() string {
return fmt.Sprintf("%s %s", e.ActorName, e.Action)
}
func (a Action) IsDelete() bool {
return a == ActionDeleted
}

View file

@ -0,0 +1,34 @@
package event
import "fmt"
type (
Event struct {
Type EventType
ActorName string
ActorID string
ActorAttributes map[string]string
Action Action
}
Action string
EventType string
)
const (
ActionModified Action = "modified"
ActionCreated Action = "created"
ActionStarted Action = "started"
ActionDeleted Action = "deleted"
ActionStopped Action = "stopped"
EventTypeDocker EventType = "docker"
EventTypeFile EventType = "file"
)
func (e Event) String() string {
return fmt.Sprintf("%s %s", e.ActorName, e.Action)
}
func (a Action) IsDelete() bool {
return a == ActionDeleted
}

View file

@ -6,6 +6,7 @@ import (
"github.com/yusing/go-proxy/common"
E "github.com/yusing/go-proxy/error"
. "github.com/yusing/go-proxy/watcher/event"
)
type fileWatcher struct {

View file

@ -9,6 +9,7 @@ import (
"github.com/fsnotify/fsnotify"
"github.com/sirupsen/logrus"
E "github.com/yusing/go-proxy/error"
. "github.com/yusing/go-proxy/watcher/event"
)
type fileWatcherHelper struct {
@ -93,7 +94,10 @@ func (h *fileWatcherHelper) start() {
continue
}
msg := Event{ActorName: w.filename}
msg := Event{
Type: EventTypeFile,
ActorName: w.filename,
}
switch {
case event.Has(fsnotify.Create):
msg.Action = ActionCreated

View file

@ -4,6 +4,7 @@ import (
"context"
E "github.com/yusing/go-proxy/error"
. "github.com/yusing/go-proxy/watcher/event"
)
type Watcher interface {