v0.5.0-rc5: check release

This commit is contained in:
yusing 2024-09-19 20:40:03 +08:00
parent be7a766cb2
commit 4a2d42bfa9
68 changed files with 1971 additions and 1107 deletions

View file

@ -15,6 +15,7 @@ A [lightweight](docs/benchmark_result.md), easy-to-use, and efficient reverse pr
- [go-proxy](#go-proxy) - [go-proxy](#go-proxy)
- [Key Points](#key-points) - [Key Points](#key-points)
- [Getting Started](#getting-started) - [Getting Started](#getting-started)
- [Setup](#setup)
- [Commands line arguments](#commands-line-arguments) - [Commands line arguments](#commands-line-arguments)
- [Environment variables](#environment-variables) - [Environment variables](#environment-variables)
- [Use JSON Schema in VSCode](#use-json-schema-in-vscode) - [Use JSON Schema in VSCode](#use-json-schema-in-vscode)
@ -27,10 +28,11 @@ A [lightweight](docs/benchmark_result.md), easy-to-use, and efficient reverse pr
- Easy to use - Easy to use
- Effortless configuration - Effortless configuration
- Error messages is clear and detailed - Error messages is clear and detailed, easy troubleshooting
- Auto certificate obtaining and renewal (See [Supported DNS Challenge Providers](docs/dns_providers.md)) - Auto certificate obtaining and renewal (See [Supported DNS Challenge Providers](docs/dns_providers.md))
- Auto configuration for docker containers - Auto configuration for docker containers
- Auto hot-reload on container state / config file changes - Auto hot-reload on container state / config file changes
- Stop containers on idle, wake it up on traffic _(optional)_
- Support HTTP(s), TCP and UDP - Support HTTP(s), TCP and UDP
- Web UI for configuration and monitoring (See [screenshots](https://github.com/yusing/go-proxy-frontend?tab=readme-ov-file#screenshots)) - Web UI for configuration and monitoring (See [screenshots](https://github.com/yusing/go-proxy-frontend?tab=readme-ov-file#screenshots))
- Written in **[Go](https://go.dev)** - Written in **[Go](https://go.dev)**
@ -39,6 +41,8 @@ A [lightweight](docs/benchmark_result.md), easy-to-use, and efficient reverse pr
## Getting Started ## Getting Started
### Setup
1. Setup DNS Records, e.g. 1. Setup DNS Records, e.g.
- A Record: `*.y.z` -> `10.0.10.1` - A Record: `*.y.z` -> `10.0.10.1`
@ -60,6 +64,7 @@ A [lightweight](docs/benchmark_result.md), easy-to-use, and efficient reverse pr
| `validate` | validate config and exit | | | `validate` | validate config and exit | |
| `reload` | trigger a force reload of config | | | `reload` | trigger a force reload of config | |
| `ls-config` | list config and exit | `go-proxy ls-config \| jq` | | `ls-config` | list config and exit | `go-proxy ls-config \| jq` |
| `ls-route` | list proxy entries and exit | `go-proxy ls-route \| jq` |
**run with `docker exec <container_name> /app/go-proxy <command>`** **run with `docker exec <container_name> /app/go-proxy <command>`**
@ -104,7 +109,7 @@ providers:
### Provider File ### Provider File
Fields are same as [docker labels](docs/docker.md#labels) starting from `scheme` See [Fields](docs/docker.md#fields)
See [providers.example.yml](providers.example.yml) for examples See [providers.example.yml](providers.example.yml) for examples

View file

@ -6,7 +6,7 @@ services:
network_mode: host network_mode: host
labels: labels:
- proxy.aliases=gp - proxy.aliases=gp
- proxy.gp.port=8888 - proxy.gp.port=3000
depends_on: depends_on:
- app - app
app: app:

View file

@ -85,12 +85,17 @@
### Syntax ### Syntax
| Label | Description | Default | | Label | Description | Default | Accepted values |
| ----------------------- | -------------------------------------------------------- | ---------------- | | ----------------------- | --------------------------------------------------------------------- | -------------------- | ------------------------------------------------------------------------- |
| `proxy.aliases` | comma separated aliases for subdomain and label matching | `container_name` | | `proxy.aliases` | comma separated aliases for subdomain and label matching | `container_name` | any |
| `proxy.exclude` | to be excluded from `go-proxy` | false | | `proxy.exclude` | to be excluded from `go-proxy` | false | boolean |
| `proxy.<alias>.<field>` | set field for specific alias | N/A | | `proxy.idle_timeout` | time for idle (no traffic) before put it into sleep **(http/s only)** | empty **(disabled)** | `number[unit]...`, e.g. `1m30s` |
| `proxy.*.<field>` | set field for all aliases | N/A | | `proxy.wake_timeout` | time to wait for container to start before responding a loading page | empty | `number[unit]...` |
| `proxy.stop_method` | method to stop after `idle_timeout` | `stop` | `stop`, `pause`, `kill` |
| `proxy.stop_timeout` | time to wait for stop command | `10s` | `number[unit]...` |
| `proxy.stop_signal` | signal sent to container for `stop` and `kill` methods | docker's default | `SIGINT`, `SIGTERM`, `SIGHUP`, `SIGQUIT` and those without **SIG** prefix |
| `proxy.<alias>.<field>` | set field for specific alias | N/A | N/A |
| `proxy.*.<field>` | set field for all aliases | N/A | N/A |
### Fields ### Fields
@ -228,12 +233,18 @@ services:
volumes: volumes:
- adg-work:/opt/adguardhome/work - adg-work:/opt/adguardhome/work
- adg-conf:/opt/adguardhome/conf - adg-conf:/opt/adguardhome/conf
ports:
- 80
- 3000
- 53
mc: mc:
image: itzg/minecraft-server image: itzg/minecraft-server
tty: true tty: true
stdin_open: true stdin_open: true
container_name: mc container_name: mc
restart: unless-stopped restart: unless-stopped
ports:
- 25565
labels: labels:
- proxy.mc.scheme=tcp - proxy.mc.scheme=tcp
- proxy.mc.port=20001:25565 - proxy.mc.port=20001:25565
@ -246,6 +257,9 @@ services:
restart: unless-stopped restart: unless-stopped
container_name: pal container_name: pal
stop_grace_period: 30s stop_grace_period: 30s
ports:
- 8211
- 27015
labels: labels:
- proxy.aliases=pal1,pal2 - proxy.aliases=pal1,pal2
- proxy.*.scheme=udp - proxy.*.scheme=udp
@ -261,6 +275,8 @@ services:
- nginx:/usr/share/nginx/html - nginx:/usr/share/nginx/html
ports: ports:
- 80 - 80
labels:
proxy.idle_timeout: 1m
go-proxy: go-proxy:
image: ghcr.io/yusing/go-proxy:latest image: ghcr.io/yusing/go-proxy:latest
container_name: go-proxy container_name: go-proxy

View file

@ -3,6 +3,7 @@ package v1
import ( import (
"fmt" "fmt"
"net/http" "net/http"
"strings"
U "github.com/yusing/go-proxy/api/v1/utils" U "github.com/yusing/go-proxy/api/v1/utils"
"github.com/yusing/go-proxy/config" "github.com/yusing/go-proxy/config"
@ -17,17 +18,19 @@ func CheckHealth(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
} }
var ok bool var ok bool
route := cfg.FindRoute(target)
switch route := cfg.FindRoute(target).(type) { switch {
case nil: case route == nil:
U.HandleErr(w, r, U.ErrNotFound("target", target), http.StatusNotFound) U.HandleErr(w, r, U.ErrNotFound("target", target), http.StatusNotFound)
return return
case *R.HTTPRoute: case route.Type() == R.RouteTypeReverseProxy:
ok = U.IsSiteHealthy(route.TargetURL.String()) ok = U.IsSiteHealthy(route.URL().String())
case *R.StreamRoute: case route.Type() == R.RouteTypeStream:
entry := route.Entry()
ok = U.IsStreamHealthy( ok = U.IsStreamHealthy(
string(route.Scheme.ProxyScheme), strings.Split(entry.Scheme, ":")[1], // target scheme
fmt.Sprintf("%s:%v", route.Host, route.Port.ProxyPort), fmt.Sprintf("%s:%v", entry.Host, strings.Split(entry.Port, ":")[1]),
) )
} }

View file

@ -9,7 +9,6 @@ import (
U "github.com/yusing/go-proxy/api/v1/utils" U "github.com/yusing/go-proxy/api/v1/utils"
"github.com/yusing/go-proxy/common" "github.com/yusing/go-proxy/common"
"github.com/yusing/go-proxy/config" "github.com/yusing/go-proxy/config"
E "github.com/yusing/go-proxy/error"
"github.com/yusing/go-proxy/proxy/provider" "github.com/yusing/go-proxy/proxy/provider"
) )
@ -32,25 +31,25 @@ func SetFileContent(w http.ResponseWriter, r *http.Request) {
U.HandleErr(w, r, U.ErrMissingKey("filename"), http.StatusBadRequest) U.HandleErr(w, r, U.ErrMissingKey("filename"), http.StatusBadRequest)
return return
} }
content, err := E.Check(io.ReadAll(r.Body)) content, err := io.ReadAll(r.Body)
if err.HasError() { if err != nil {
U.HandleErr(w, r, err) U.HandleErr(w, r, err)
return return
} }
if filename == common.ConfigFileName { if filename == common.ConfigFileName {
err = config.Validate(content) err = config.Validate(content).Error()
} else { } else {
err = provider.Validate(content) err = provider.Validate(content).Error()
} }
if err.HasError() { if err != nil {
U.HandleErr(w, r, err, http.StatusBadRequest) U.HandleErr(w, r, err, http.StatusBadRequest)
return return
} }
err = E.From(os.WriteFile(path.Join(common.ConfigBasePath, filename), content, 0644)) err = os.WriteFile(path.Join(common.ConfigBasePath, filename), content, 0644)
if err.HasError() { if err != nil {
U.HandleErr(w, r, err) U.HandleErr(w, r, err)
return return
} }

View file

@ -8,7 +8,7 @@ import (
) )
func Reload(cfg *config.Config, w http.ResponseWriter, r *http.Request) { func Reload(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
if err := cfg.Reload(); err.HasError() { if err := cfg.Reload().Error(); err != nil {
U.HandleErr(w, r, err) U.HandleErr(w, r, err)
return return
} }

View file

@ -9,14 +9,14 @@ import (
E "github.com/yusing/go-proxy/error" E "github.com/yusing/go-proxy/error"
) )
func HandleErr(w http.ResponseWriter, r *http.Request, err error, code ...int) { func HandleErr(w http.ResponseWriter, r *http.Request, origErr error, code ...int) {
err = E.From(err).Subjectf("%s %s", r.Method, r.URL) err := E.From(origErr).Subjectf("%s %s", r.Method, r.URL)
logrus.WithField("module", "api").Error(err) logrus.WithField("module", "api").Error(err)
if len(code) > 0 { if len(code) > 0 {
http.Error(w, err.Error(), code[0]) http.Error(w, err.String(), code[0])
return return
} }
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.String(), http.StatusInternalServerError)
} }
func ErrMissingKey(k string) error { func ErrMissingKey(k string) error {

View file

@ -44,7 +44,7 @@ func ReloadServer() E.NestedError {
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
return E.Failure("server reload").Subjectf("status code: %v", resp.StatusCode) return E.Failure("server reload").Subjectf("status code: %v", resp.StatusCode)
} }
return E.Nil() return nil
} }
var HttpClient = &http.Client{ var HttpClient = &http.Client{

View file

@ -26,33 +26,35 @@ func NewConfig(cfg *M.AutoCertConfig) *Config {
return (*Config)(cfg) return (*Config)(cfg)
} }
func (cfg *Config) GetProvider() (*Provider, E.NestedError) { func (cfg *Config) GetProvider() (provider *Provider, res E.NestedError) {
errors := E.NewBuilder("cannot create autocert provider") b := E.NewBuilder("unable to initialize autocert")
defer b.To(&res)
if cfg.Provider != ProviderLocal { if cfg.Provider != ProviderLocal {
if len(cfg.Domains) == 0 { if len(cfg.Domains) == 0 {
errors.Addf("no domains specified") b.Addf("no domains specified")
} }
if cfg.Provider == "" { if cfg.Provider == "" {
errors.Addf("no provider specified") b.Addf("no provider specified")
} }
if cfg.Email == "" { if cfg.Email == "" {
errors.Addf("no email specified") b.Addf("no email specified")
} }
// check if provider is implemented // check if provider is implemented
_, ok := providersGenMap[cfg.Provider] _, ok := providersGenMap[cfg.Provider]
if !ok { if !ok {
errors.Addf("unknown provider: %q", cfg.Provider) b.Addf("unknown provider: %q", cfg.Provider)
} }
} }
if err := errors.Build(); err.HasError() { if b.HasError() {
return nil, err return
} }
privKey, err := E.Check(ecdsa.GenerateKey(elliptic.P256(), rand.Reader)) privKey, err := E.Check(ecdsa.GenerateKey(elliptic.P256(), rand.Reader))
if err.HasError() { if err.HasError() {
return nil, E.Failure("generate private key").With(err) b.Add(E.FailWith("generate private key", err))
return
} }
user := &User{ user := &User{
@ -63,11 +65,11 @@ func (cfg *Config) GetProvider() (*Provider, E.NestedError) {
legoCfg := lego.NewConfig(user) legoCfg := lego.NewConfig(user)
legoCfg.Certificate.KeyType = certcrypto.RSA2048 legoCfg.Certificate.KeyType = certcrypto.RSA2048
base := &Provider{ provider = &Provider{
cfg: cfg, cfg: cfg,
user: user, user: user,
legoCfg: legoCfg, legoCfg: legoCfg,
} }
return base, E.Nil() return
} }

View file

@ -1,6 +1,8 @@
package autocert package autocert
import ( import (
"errors"
"github.com/go-acme/lego/v4/providers/dns/clouddns" "github.com/go-acme/lego/v4/providers/dns/clouddns"
"github.com/go-acme/lego/v4/providers/dns/cloudflare" "github.com/go-acme/lego/v4/providers/dns/cloudflare"
"github.com/go-acme/lego/v4/providers/dns/duckdns" "github.com/go-acme/lego/v4/providers/dns/duckdns"
@ -31,4 +33,8 @@ var providersGenMap = map[string]ProviderGenerator{
ProviderOVH: providerGenerator(ovh.NewDefaultConfig, ovh.NewDNSProviderConfig), ProviderOVH: providerGenerator(ovh.NewDefaultConfig, ovh.NewDNSProviderConfig),
} }
var (
ErrGetCertFailure = errors.New("get certificate failed")
)
var logger = logrus.WithField("module", "autocert") var logger = logrus.WithField("module", "autocert")

View file

@ -33,7 +33,7 @@ type CertExpiries map[string]time.Time
func (p *Provider) GetCert(_ *tls.ClientHelloInfo) (*tls.Certificate, error) { func (p *Provider) GetCert(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
if p.tlsCert == nil { if p.tlsCert == nil {
return nil, E.Failure("get certificate") return nil, ErrGetCertFailure
} }
return p.tlsCert, nil return p.tlsCert, nil
} }
@ -54,52 +54,60 @@ func (p *Provider) GetExpiries() CertExpiries {
return p.certExpiries return p.certExpiries
} }
func (p *Provider) ObtainCert() E.NestedError { func (p *Provider) ObtainCert() (res E.NestedError) {
b := E.NewBuilder("failed to obtain certificate")
defer b.To(&res)
if p.cfg.Provider == ProviderLocal { if p.cfg.Provider == ProviderLocal {
return E.FailureWhy("obtain cert", "provider is set to \"local\"") b.Addf("provider is set to %q", ProviderLocal)
return
} }
if p.client == nil { if p.client == nil {
if err := p.initClient(); err.HasError() { if err := p.initClient(); err.HasError() {
return E.Failure("obtain cert").With(err) b.Add(E.FailWith("init autocert client", err))
return
} }
} }
ne := E.Failure("obtain certificate")
client := p.client
if p.user.Registration == nil { if p.user.Registration == nil {
if err := p.loadRegistration(); err.HasError() { if err := p.loadRegistration(); err.HasError() {
ne = ne.With(err)
if err := p.registerACME(); err.HasError() { if err := p.registerACME(); err.HasError() {
return ne.With(err) b.Add(E.FailWith("register ACME", err))
return
} }
} }
} }
client := p.client
req := certificate.ObtainRequest{ req := certificate.ObtainRequest{
Domains: p.cfg.Domains, Domains: p.cfg.Domains,
Bundle: true, Bundle: true,
} }
cert, err := E.Check(client.Certificate.Obtain(req)) cert, err := E.Check(client.Certificate.Obtain(req))
if err.HasError() { if err.HasError() {
return ne.With(err) b.Add(err)
return
} }
err = p.saveCert(cert) err = p.saveCert(cert)
if err.HasError() { if err.HasError() {
return ne.With(E.Failure("save certificate").With(err)) b.Add(E.FailWith("save certificate", err))
return
} }
tlsCert, err := E.Check(tls.X509KeyPair(cert.Certificate, cert.PrivateKey)) tlsCert, err := E.Check(tls.X509KeyPair(cert.Certificate, cert.PrivateKey))
if err.HasError() { if err.HasError() {
return ne.With(E.Failure("parse obtained certificate").With(err)) b.Add(E.FailWith("parse obtained certificate", err))
return
} }
expiries, err := getCertExpiries(&tlsCert) expiries, err := getCertExpiries(&tlsCert)
if err.HasError() { if err.HasError() {
return ne.With(E.Failure("get certificate expiry").With(err)) b.Add(E.FailWith("get certificate expiry", err))
return
} }
p.tlsCert = &tlsCert p.tlsCert = &tlsCert
p.certExpiries = expiries p.certExpiries = expiries
return E.Nil() return nil
} }
func (p *Provider) LoadCert() E.NestedError { func (p *Provider) LoadCert() E.NestedError {
@ -152,50 +160,50 @@ func (p *Provider) ScheduleRenewal(ctx context.Context) {
func (p *Provider) initClient() E.NestedError { func (p *Provider) initClient() E.NestedError {
legoClient, err := E.Check(lego.NewClient(p.legoCfg)) legoClient, err := E.Check(lego.NewClient(p.legoCfg))
if err.HasError() { if err.HasError() {
return E.Failure("create lego client").With(err) return E.FailWith("create lego client", err)
} }
legoProvider, err := providersGenMap[p.cfg.Provider](p.cfg.Options) legoProvider, err := providersGenMap[p.cfg.Provider](p.cfg.Options)
if err.HasError() { if err.HasError() {
return E.Failure("create lego provider").With(err) return E.FailWith("create lego provider", err)
} }
err = E.From(legoClient.Challenge.SetDNS01Provider(legoProvider)) err = E.From(legoClient.Challenge.SetDNS01Provider(legoProvider))
if err.HasError() { if err.HasError() {
return E.Failure("set challenge provider").With(err) return E.FailWith("set challenge provider", err)
} }
p.client = legoClient p.client = legoClient
return E.Nil() return nil
} }
func (p *Provider) registerACME() E.NestedError { func (p *Provider) registerACME() E.NestedError {
if p.user.Registration != nil { if p.user.Registration != nil {
return E.Nil() return nil
} }
reg, err := E.Check(p.client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true})) reg, err := E.Check(p.client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true}))
if err.HasError() { if err.HasError() {
return E.Failure("register ACME").With(err) return err
} }
p.user.Registration = reg p.user.Registration = reg
if err := p.saveRegistration(); err.HasError() { if err := p.saveRegistration(); err.HasError() {
logger.Warn(err) logger.Warn(err)
} }
return E.Nil() return nil
} }
func (p *Provider) loadRegistration() E.NestedError { func (p *Provider) loadRegistration() E.NestedError {
if p.user.Registration != nil { if p.user.Registration != nil {
return E.Nil() return nil
} }
reg := &registration.Resource{} reg := &registration.Resource{}
err := U.LoadJson(RegistrationFile, reg) err := U.LoadJson(RegistrationFile, reg)
if err.HasError() { if err.HasError() {
return E.Failure("parse registration file").With(err) return E.FailWith("parse registration file", err)
} }
p.user.Registration = reg p.user.Registration = reg
return E.Nil() return nil
} }
func (p *Provider) saveRegistration() E.NestedError { func (p *Provider) saveRegistration() E.NestedError {
@ -205,13 +213,13 @@ func (p *Provider) saveRegistration() E.NestedError {
func (p *Provider) saveCert(cert *certificate.Resource) E.NestedError { func (p *Provider) saveCert(cert *certificate.Resource) E.NestedError {
err := os.WriteFile(p.cfg.KeyPath, cert.PrivateKey, 0o600) // -rw------- err := os.WriteFile(p.cfg.KeyPath, cert.PrivateKey, 0o600) // -rw-------
if err != nil { if err != nil {
return E.Failure("write key file").With(err) return E.FailWith("write key file", err)
} }
err = os.WriteFile(p.cfg.CertPath, cert.Certificate, 0o644) // -rw-r--r-- err = os.WriteFile(p.cfg.CertPath, cert.Certificate, 0o644) // -rw-r--r--
if err != nil { if err != nil {
return E.Failure("write cert file").With(err) return E.FailWith("write cert file", err)
} }
return E.Nil() return nil
} }
func (p *Provider) certState() CertState { func (p *Provider) certState() CertState {
@ -245,13 +253,13 @@ func (p *Provider) renewIfNeeded() E.NestedError {
case CertStateMismatch: case CertStateMismatch:
logger.Info("cert domains mismatch with config, renewing") logger.Info("cert domains mismatch with config, renewing")
default: default:
return E.Nil() return nil
} }
if err := p.ObtainCert(); err.HasError() { if err := p.ObtainCert(); err.HasError() {
return E.Failure("renew certificate").With(err) return E.FailWith("renew certificate", err)
} }
return E.Nil() return nil
} }
func getCertExpiries(cert *tls.Certificate) (CertExpiries, E.NestedError) { func getCertExpiries(cert *tls.Certificate) (CertExpiries, E.NestedError) {
@ -259,7 +267,7 @@ func getCertExpiries(cert *tls.Certificate) (CertExpiries, E.NestedError) {
for _, cert := range cert.Certificate { for _, cert := range cert.Certificate {
x509Cert, err := E.Check(x509.ParseCertificate(cert)) x509Cert, err := E.Check(x509.ParseCertificate(cert))
if err.HasError() { if err.HasError() {
return nil, E.Failure("parse certificate").With(err) return nil, E.FailWith("parse certificate", err)
} }
if x509Cert.IsCA { if x509Cert.IsCA {
continue continue
@ -269,7 +277,7 @@ func getCertExpiries(cert *tls.Certificate) (CertExpiries, E.NestedError) {
r[x509Cert.DNSNames[i]] = x509Cert.NotAfter r[x509Cert.DNSNames[i]] = x509Cert.NotAfter
} }
} }
return r, E.Nil() return r, nil
} }
func providerGenerator[CT any, PT challenge.Provider]( func providerGenerator[CT any, PT challenge.Provider](
@ -286,6 +294,6 @@ func providerGenerator[CT any, PT challenge.Provider](
if err.HasError() { if err.HasError() {
return nil, err return nil, err
} }
return p, E.Nil() return p, nil
} }
} }

View file

@ -4,7 +4,8 @@ import (
"testing" "testing"
"github.com/go-acme/lego/v4/providers/dns/ovh" "github.com/go-acme/lego/v4/providers/dns/ovh"
. "github.com/yusing/go-proxy/utils" U "github.com/yusing/go-proxy/utils"
. "github.com/yusing/go-proxy/utils/testing"
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
) )
@ -44,6 +45,6 @@ oauth2_config:
testYaml = testYaml[1:] // remove first \n testYaml = testYaml[1:] // remove first \n
opt := make(map[string]any) opt := make(map[string]any)
ExpectNoError(t, yaml.Unmarshal([]byte(testYaml), opt)) ExpectNoError(t, yaml.Unmarshal([]byte(testYaml), opt))
ExpectNoError(t, Deserialize(opt, cfg)) ExpectTrue(t, U.Deserialize(opt, cfg).NoError())
ExpectEqual(t, cfg, cfgExpected) ExpectDeepEqual(t, cfg, cfgExpected)
} }

View file

@ -15,25 +15,32 @@ const (
CommandStart = "" CommandStart = ""
CommandValidate = "validate" CommandValidate = "validate"
CommandListConfigs = "ls-config" CommandListConfigs = "ls-config"
CommandListRoutes = "ls-routes"
CommandReload = "reload" CommandReload = "reload"
) )
var ValidCommands = []string{CommandStart, CommandValidate, CommandListConfigs, CommandReload} var ValidCommands = []string{
CommandStart,
CommandValidate,
CommandListConfigs,
CommandListRoutes,
CommandReload,
}
func GetArgs() Args { func GetArgs() Args {
var args Args var args Args
flag.Parse() flag.Parse()
args.Command = flag.Arg(0) args.Command = flag.Arg(0)
if err := validateArgs(args.Command, ValidCommands); err.HasError() { if err := validateArg(args.Command); err.HasError() {
logrus.Fatal(err) logrus.Fatal(err)
} }
return args return args
} }
func validateArgs[T comparable](arg T, validArgs []T) E.NestedError { func validateArg(arg string) E.NestedError {
for _, v := range validArgs { for _, v := range ValidCommands {
if arg == v { if arg == v {
return E.Nil() return nil
} }
} }
return E.Invalid("argument", arg) return E.Invalid("argument", arg)

View file

@ -41,7 +41,6 @@ const (
ProxyHTTPPort = ":80" ProxyHTTPPort = ":80"
ProxyHTTPSPort = ":443" ProxyHTTPSPort = ":443"
APIHTTPPort = ":8888" APIHTTPPort = ":8888"
PanelHTTPPort = ":8080"
) )
var WellKnownHTTPPorts = map[uint16]bool{ var WellKnownHTTPPorts = map[uint16]bool{
@ -53,7 +52,7 @@ var WellKnownHTTPPorts = map[uint16]bool{
} }
var ( var (
ImageNamePortMapTCP = map[string]int{ ServiceNamePortMapTCP = map[string]int{
"postgres": 5432, "postgres": 5432,
"mysql": 3306, "mysql": 3306,
"mariadb": 3306, "mariadb": 3306,
@ -62,8 +61,7 @@ var (
"memcached": 11211, "memcached": 11211,
"rabbitmq": 5672, "rabbitmq": 5672,
"mongo": 27017, "mongo": 27017,
}
ExtraNamePortMapTCP = map[string]int{
"dns": 53, "dns": 53,
"ssh": 22, "ssh": 22,
"ftp": 21, "ftp": 21,
@ -71,20 +69,9 @@ var (
"pop3": 110, "pop3": 110,
"imap": 143, "imap": 143,
} }
NamePortMapTCP = func() map[string]int {
m := make(map[string]int)
for k, v := range ImageNamePortMapTCP {
m[k] = v
}
for k, v := range ExtraNamePortMapTCP {
m[k] = v
}
return m
}()
) )
// docker library uses uint16, so followed here var ImageNamePortMapHTTP = map[string]int{
var ImageNamePortMapHTTP = map[string]uint16{
"nginx": 80, "nginx": 80,
"httpd": 80, "httpd": 80,
"adguardhome": 3000, "adguardhome": 3000,
@ -101,3 +88,10 @@ var ImageNamePortMapHTTP = map[string]uint16{
"dockge": 5001, "dockge": 5001,
"nginx-proxy-manager": 81, "nginx-proxy-manager": 81,
} }
const (
IdleTimeoutDefault = "0"
WakeTimeoutDefault = "10s"
StopTimeoutDefault = "10s"
StopMethodDefault = "stop"
)

View file

@ -2,6 +2,7 @@ package config
import ( import (
"context" "context"
"os"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/yusing/go-proxy/autocert" "github.com/yusing/go-proxy/autocert"
@ -17,32 +18,26 @@ import (
) )
type Config struct { type Config struct {
value *M.Config value *M.Config
proxyProviders F.Map[string, *PR.Provider]
l logrus.FieldLogger
reader U.Reader
proxyProviders *F.Map[string, *PR.Provider]
autocertProvider *autocert.Provider autocertProvider *autocert.Provider
l logrus.FieldLogger
watcher W.Watcher watcher W.Watcher
watcherCtx context.Context watcherCtx context.Context
watcherCancel context.CancelFunc watcherCancel context.CancelFunc
reloadReq chan struct{} reloadReq chan struct{}
} }
func New() (*Config, E.NestedError) { func Load() (*Config, E.NestedError) {
cfg := &Config{ cfg := &Config{
l: logrus.WithField("module", "config"), proxyProviders: F.NewMapOf[string, *PR.Provider](),
reader: U.NewFileReader(common.ConfigPath), l: logrus.WithField("module", "config"),
watcher: W.NewFileWatcher(common.ConfigFileName), watcher: W.NewFileWatcher(common.ConfigFileName),
reloadReq: make(chan struct{}, 1), reloadReq: make(chan struct{}, 1),
} }
if err := cfg.load(); err.HasError() { return cfg, cfg.load()
return nil, err
}
cfg.startProviders()
cfg.watchChanges()
return cfg, E.Nil()
} }
func Validate(data []byte) E.NestedError { func Validate(data []byte) E.NestedError {
@ -57,11 +52,17 @@ func (cfg *Config) GetAutoCertProvider() *autocert.Provider {
return cfg.autocertProvider return cfg.autocertProvider
} }
func (cfg *Config) StartProxyProviders() {
cfg.startProviders()
cfg.watchChanges()
}
func (cfg *Config) Dispose() { func (cfg *Config) Dispose() {
cfg.watcherCancel() if cfg.watcherCancel != nil {
cfg.l.Debug("stopped watcher") cfg.watcherCancel()
cfg.l.Debug("stopped watcher")
}
cfg.stopProviders() cfg.stopProviders()
cfg.l.Debug("stopped providers")
} }
func (cfg *Config) Reload() E.NestedError { func (cfg *Config) Reload() E.NestedError {
@ -70,46 +71,31 @@ func (cfg *Config) Reload() E.NestedError {
return err return err
} }
cfg.startProviders() cfg.startProviders()
return E.Nil() return nil
} }
func (cfg *Config) FindRoute(alias string) R.Route { func (cfg *Config) FindRoute(alias string) R.Route {
r := cfg.proxyProviders.Find( return F.MapFind(cfg.proxyProviders,
func(p *PR.Provider) (any, bool) { func(p *PR.Provider) (R.Route, bool) {
rs := p.GetCurrentRoutes() if route, ok := p.GetRoute(alias); ok {
if rs.Contains(alias) { return route, true
return rs.Get(alias), true
} }
return nil, false return nil, false
}, },
) )
if r == nil {
return nil
}
return r.(R.Route)
} }
func (cfg *Config) RoutesByAlias() map[string]U.SerializedObject { func (cfg *Config) RoutesByAlias() map[string]U.SerializedObject {
routes := make(map[string]U.SerializedObject) routes := make(map[string]U.SerializedObject)
cfg.proxyProviders.Each(func(p *PR.Provider) { cfg.forEachRoute(func(alias string, r R.Route, p *PR.Provider) {
prName := p.GetName() obj, err := U.Serialize(r)
p.GetCurrentRoutes().EachKV(func(a string, r R.Route) { if err.HasError() {
obj, err := U.Serialize(r) cfg.l.Error(err)
if err.HasError() { return
cfg.l.Error(err) }
return obj["provider"] = p.GetName()
} obj["type"] = string(r.Type())
obj["provider"] = prName routes[alias] = obj
switch r.(type) {
case *R.StreamRoute:
obj["type"] = "stream"
case *R.HTTPRoute:
obj["type"] = "reverse_proxy"
default:
panic("bug: should not reach here")
}
routes[a] = obj
})
}) })
return routes return routes
} }
@ -119,26 +105,23 @@ func (cfg *Config) Statistics() map[string]any {
nTotalRPs := 0 nTotalRPs := 0
providerStats := make(map[string]any) providerStats := make(map[string]any)
cfg.proxyProviders.Each(func(p *PR.Provider) { cfg.forEachRoute(func(alias string, r R.Route, p *PR.Provider) {
stats := make(map[string]any) s, ok := providerStats[p.GetName()]
nStreams := 0 if !ok {
nRPs := 0 s = make(map[string]int)
p.GetCurrentRoutes().EachKV(func(a string, r R.Route) { }
switch r.(type) {
case *R.StreamRoute: stats := s.(map[string]int)
nStreams++ switch r.Type() {
nTotalStreams++ case R.RouteTypeStream:
case *R.HTTPRoute: stats["num_streams"]++
nRPs++ nTotalStreams++
nTotalRPs++ case R.RouteTypeReverseProxy:
default: stats["num_reverse_proxies"]++
panic("bug: should not reach here") nTotalRPs++
} default:
}) panic("bug: should not reach here")
stats["type"] = p.GetType() }
stats["num_streams"] = nStreams
stats["num_reverse_proxies"] = nRPs
providerStats[p.GetName()] = stats
}) })
return map[string]any{ return map[string]any{
@ -148,6 +131,14 @@ func (cfg *Config) Statistics() map[string]any {
} }
} }
func (cfg *Config) forEachRoute(do func(alias string, r R.Route, p *PR.Provider)) {
cfg.proxyProviders.RangeAll(func(_ string, p *PR.Provider) {
p.RangeRoutes(func(a string, r R.Route) {
do(a, r, p)
})
})
}
func (cfg *Config) watchChanges() { func (cfg *Config) watchChanges() {
cfg.watcherCtx, cfg.watcherCancel = context.WithCancel(context.Background()) cfg.watcherCtx, cfg.watcherCancel = context.WithCancel(context.Background())
go func() { go func() {
@ -182,64 +173,82 @@ func (cfg *Config) watchChanges() {
}() }()
} }
func (cfg *Config) load() E.NestedError { func (cfg *Config) load() (res E.NestedError) {
b := E.NewBuilder("errors loading config")
defer b.To(&res)
cfg.l.Debug("loading config") cfg.l.Debug("loading config")
defer cfg.l.Debug("loaded config")
data, err := cfg.reader.Read() data, err := E.Check(os.ReadFile(common.ConfigPath))
if err.HasError() { if err.HasError() {
return E.Failure("read config").With(err) b.Add(E.FailWith("read config", err))
} return
model := M.DefaultConfig()
if err := E.From(yaml.Unmarshal(data, model)); err.HasError() {
return E.Failure("parse config").With(err)
} }
if !common.NoSchemaValidation { if !common.NoSchemaValidation {
if err = Validate(data); err.HasError() { if err = Validate(data); err.HasError() {
return err b.Add(E.FailWith("schema validation", err))
return
} }
} }
warnings := E.NewBuilder("errors loading config") model := M.DefaultConfig()
if err := E.From(yaml.Unmarshal(data, model)); err.HasError() {
b.Add(E.FailWith("parse config", err))
return
}
cfg.l.Debug("initializing autocert") // errors are non fatal below
ap, err := autocert.NewConfig(&model.AutoCert).GetProvider() b.WithSeverity(E.SeverityWarning)
if err.HasError() { b.Add(cfg.initAutoCert(&model.AutoCert))
warnings.Add(E.Failure("autocert provider").With(err)) b.Add(cfg.loadProviders(&model.Providers))
} else {
cfg.l.Debug("initialized autocert")
}
cfg.autocertProvider = ap
cfg.l.Debug("loading providers")
cfg.proxyProviders = F.NewMap[string, *PR.Provider]()
for _, filename := range model.Providers.Files {
p := PR.NewFileProvider(filename)
cfg.proxyProviders.Set(p.GetName(), p)
}
for name, dockerHost := range model.Providers.Docker {
p := PR.NewDockerProvider(name, dockerHost)
cfg.proxyProviders.Set(p.GetName(), p)
}
cfg.l.Debug("loaded providers")
cfg.value = model cfg.value = model
return
}
if err := warnings.Build(); err.HasError() { func (cfg *Config) initAutoCert(autocertCfg *M.AutoCertConfig) (err E.NestedError) {
cfg.l.Warn(err) if cfg.autocertProvider != nil {
return
} }
cfg.l.Debug("loaded config") cfg.l.Debug("initializing autocert")
return E.Nil() defer cfg.l.Debug("initialized autocert")
cfg.autocertProvider, err = autocert.NewConfig(autocertCfg).GetProvider()
if err.HasError() {
err = E.FailWith("autocert provider", err)
}
return
}
func (cfg *Config) loadProviders(providers *M.ProxyProviders) (res E.NestedError) {
cfg.l.Debug("loading providers")
defer cfg.l.Debug("loaded providers")
b := E.NewBuilder("errors loading providers")
defer b.To(&res)
for _, filename := range providers.Files {
p := PR.NewFileProvider(filename)
cfg.proxyProviders.Store(p.GetName(), p)
b.Add(p.LoadRoutes())
}
for name, dockerHost := range providers.Docker {
p := PR.NewDockerProvider(name, dockerHost)
cfg.proxyProviders.Store(p.GetName(), p)
b.Add(p.LoadRoutes())
}
return
} }
func (cfg *Config) controlProviders(action string, do func(*PR.Provider) E.NestedError) { func (cfg *Config) controlProviders(action string, do func(*PR.Provider) E.NestedError) {
errors := E.NewBuilder("cannot %s these providers", action) errors := E.NewBuilder("cannot %s these providers", action)
cfg.proxyProviders.EachKVParallel(func(name string, p *PR.Provider) { cfg.proxyProviders.RangeAll(func(name string, p *PR.Provider) {
if err := do(p); err.HasError() { if err := do(p); err.HasError() {
errors.Add(E.From(err).Subject(p)) errors.Add(err.Subject(p))
} }
}) })

View file

@ -3,6 +3,7 @@ package docker
import ( import (
"net/http" "net/http"
"sync" "sync"
"sync/atomic"
"github.com/docker/cli/cli/connhelper" "github.com/docker/cli/cli/connhelper"
"github.com/docker/docker/client" "github.com/docker/docker/client"
@ -11,14 +12,37 @@ import (
E "github.com/yusing/go-proxy/error" E "github.com/yusing/go-proxy/error"
) )
type Client = *client.Client type Client struct {
key string
refCount *atomic.Int32
*client.Client
}
func (c Client) DaemonHostname() string {
url, _ := client.ParseHostURL(c.DaemonHost())
return url.Hostname()
}
// if the client is still referenced, this is no-op
func (c Client) Close() error {
if c.refCount.Load() > 0 {
c.refCount.Add(-1)
return nil
}
clientMapMu.Lock()
defer clientMapMu.Unlock()
delete(clientMap, c.key)
return c.Client.Close()
}
// ConnectClient creates a new Docker client connection to the specified host. // ConnectClient creates a new Docker client connection to the specified host.
// //
// Returns existing client if available. // Returns existing client if available.
// //
// Parameters: // Parameters:
// - host: the host to connect to (either a URL or "FROM_ENV"). // - host: the host to connect to (either a URL or common.DockerHostFromEnv).
// //
// Returns: // Returns:
// - Client: the Docker client connection. // - Client: the Docker client connection.
@ -29,7 +53,8 @@ func ConnectClient(host string) (Client, E.NestedError) {
// check if client exists // check if client exists
if client, ok := clientMap[host]; ok { if client, ok := clientMap[host]; ok {
return client, E.Nil() client.refCount.Add(1)
return client, nil
} }
// create client // create client
@ -41,7 +66,7 @@ func ConnectClient(host string) (Client, E.NestedError) {
default: default:
helper, err := E.Check(connhelper.GetConnectionHelper(host)) helper, err := E.Check(connhelper.GetConnectionHelper(host))
if err.HasError() { if err.HasError() {
logger.Fatalf("unexpected error: %s", err) return Client{}, E.UnexpectedError(err.Error())
} }
if helper != nil { if helper != nil {
httpClient := &http.Client{ httpClient := &http.Client{
@ -66,11 +91,16 @@ func ConnectClient(host string) (Client, E.NestedError) {
client, err := E.Check(client.NewClientWithOpts(opt...)) client, err := E.Check(client.NewClientWithOpts(opt...))
if err.HasError() { if err.HasError() {
return nil, err return Client{}, err
} }
clientMap[host] = client clientMap[host] = Client{
return client, E.Nil() Client: client,
key: host,
refCount: &atomic.Int32{},
}
clientMap[host].refCount.Add(1)
return clientMap[host], nil
} }
func CloseAllClients() { func CloseAllClients() {
@ -83,12 +113,13 @@ func CloseAllClients() {
logger.Debug("closed all clients") logger.Debug("closed all clients")
} }
var clientMap map[string]Client = make(map[string]Client) var (
var clientMapMu sync.Mutex clientMap map[string]Client = make(map[string]Client)
clientMapMu sync.Mutex
clientOptEnvHost = []client.Opt{
client.WithHostFromEnv(),
client.WithAPIVersionNegotiation(),
}
var clientOptEnvHost = []client.Opt{ logger = logrus.WithField("module", "docker")
client.WithHostFromEnv(), )
client.WithAPIVersionNegotiation(),
}
var logger = logrus.WithField("module", "docker")

View file

@ -12,35 +12,41 @@ import (
) )
type ClientInfo struct { type ClientInfo struct {
Host string Client Client
Containers []types.Container Containers []types.Container
} }
func GetClientInfo(clientHost string) (*ClientInfo, E.NestedError) { var listOptions = container.ListOptions{
// Filters: filters.NewArgs(
// filters.Arg("health", "healthy"),
// filters.Arg("health", "none"),
// filters.Arg("health", "starting"),
// ),
All: true,
}
func GetClientInfo(clientHost string, getContainer bool) (*ClientInfo, E.NestedError) {
dockerClient, err := ConnectClient(clientHost) dockerClient, err := ConnectClient(clientHost)
if err.HasError() { if err.HasError() {
return nil, E.Failure("create docker client").With(err) return nil, E.FailWith("connect to docker", err)
} }
defer dockerClient.Close()
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel() defer cancel()
containers, err := E.Check(dockerClient.ContainerList(ctx, container.ListOptions{})) var containers []types.Container
if err.HasError() { if getContainer {
return nil, E.Failure("list containers").With(err) containers, err = E.Check(dockerClient.ContainerList(ctx, listOptions))
if err.HasError() {
return nil, E.FailWith("list containers", err)
}
} }
// extract host from docker client url return &ClientInfo{
// since the services being proxied to Client: dockerClient,
// should have the same IP as the docker client Containers: containers,
url, err := E.Check(client.ParseHostURL(dockerClient.DaemonHost())) }, nil
if err.HasError() {
return nil, E.Invalid("host url", dockerClient.DaemonHost()).With(err)
}
if url.Scheme == "unix" {
return &ClientInfo{Host: "localhost", Containers: containers}, E.Nil()
}
return &ClientInfo{Host: url.Hostname(), Containers: containers}, E.Nil()
} }
func IsErrConnectionFailed(err error) bool { func IsErrConnectionFailed(err error) bool {

109
src/docker/container.go Normal file
View file

@ -0,0 +1,109 @@
package docker
import (
"fmt"
"strconv"
"strings"
"github.com/docker/docker/api/types"
U "github.com/yusing/go-proxy/utils"
)
type ProxyProperties struct {
DockerHost string `yaml:"docker_host" json:"docker_host"`
ContainerName string `yaml:"container_name" json:"container_name"`
ImageName string `yaml:"image_name" json:"image_name"`
Aliases []string `yaml:"aliases" json:"aliases"`
IsExcluded bool `yaml:"is_excluded" json:"is_excluded"`
FirstPort string `yaml:"first_port" json:"first_port"`
IdleTimeout string `yaml:"idle_timeout" json:"idle_timeout"`
WakeTimeout string `yaml:"wake_timeout" json:"wake_timeout"`
StopMethod string `yaml:"stop_method" json:"stop_method"`
StopTimeout string `yaml:"stop_timeout" json:"stop_timeout"` // stop_method = "stop" only
StopSignal string `yaml:"stop_signal" json:"stop_signal"` // stop_method = "stop" | "kill" only
}
type Container struct {
*types.Container
*ProxyProperties
}
func FromDocker(c *types.Container, dockerHost string) (res Container) {
res.Container = c
res.ProxyProperties = &ProxyProperties{
DockerHost: dockerHost,
ContainerName: res.getName(),
ImageName: res.getImageName(),
Aliases: res.getAliases(),
IsExcluded: U.ParseBool(res.getDeleteLabel(LableExclude)),
FirstPort: res.firstPortOrEmpty(),
IdleTimeout: res.getDeleteLabel(LabelIdleTimeout),
WakeTimeout: res.getDeleteLabel(LabelWakeTimeout),
StopMethod: res.getDeleteLabel(LabelStopMethod),
StopTimeout: res.getDeleteLabel(LabelStopTimeout),
StopSignal: res.getDeleteLabel(LabelStopSignal),
}
return
}
func FromJson(json types.ContainerJSON, dockerHost string) Container {
ports := make([]types.Port, 0)
for k, bindings := range json.NetworkSettings.Ports {
for _, v := range bindings {
pubPort, _ := strconv.Atoi(v.HostPort)
privPort, _ := strconv.Atoi(k.Port())
ports = append(ports, types.Port{
IP: v.HostIP,
PublicPort: uint16(pubPort),
PrivatePort: uint16(privPort),
})
}
}
return FromDocker(&types.Container{
ID: json.ID,
Names: []string{json.Name},
Image: json.Image,
Ports: ports,
Labels: json.Config.Labels,
State: json.State.Status,
Status: json.State.Status,
}, dockerHost)
}
func (c Container) getDeleteLabel(label string) string {
if l, ok := c.Labels[label]; ok {
delete(c.Labels, label)
return l
}
return ""
}
func (c Container) getAliases() []string {
if l := c.getDeleteLabel(LableAliases); l != "" {
return U.CommaSeperatedList(l)
} else {
return []string{c.getName()}
}
}
func (c Container) getName() string {
return strings.TrimPrefix(c.Names[0], "/")
}
func (c Container) getImageName() string {
colonSep := strings.Split(c.Image, ":")
slashSep := strings.Split(colonSep[len(colonSep)-1], "/")
return slashSep[len(slashSep)-1]
}
func (c Container) firstPortOrEmpty() string {
if len(c.Ports) == 0 {
return ""
}
for _, p := range c.Ports {
if p.PublicPort != 0 {
return fmt.Sprint(p.PublicPort)
}
}
return ""
}

View file

@ -0,0 +1,14 @@
package idlewatcher
import "net/http"
type (
roundTripper struct {
patched roundTripFunc
}
roundTripFunc func(*http.Request) (*http.Response, error)
)
func (rt roundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
return rt.patched(req)
}

View file

@ -0,0 +1,329 @@
package idlewatcher
import (
"bytes"
"context"
"io"
"net/http"
"sync"
"sync/atomic"
"time"
"github.com/docker/docker/api/types/container"
"github.com/sirupsen/logrus"
D "github.com/yusing/go-proxy/docker"
E "github.com/yusing/go-proxy/error"
P "github.com/yusing/go-proxy/proxy"
PT "github.com/yusing/go-proxy/proxy/fields"
)
type watcher struct {
*P.ReverseProxyEntry
client D.Client
refCount atomic.Int32
stopByMethod StopCallback
wakeCh chan struct{}
wakeDone chan E.NestedError
ctx context.Context
cancel context.CancelFunc
l logrus.FieldLogger
}
type (
WakeDone <-chan error
WakeFunc func() WakeDone
StopCallback func() (bool, E.NestedError)
)
func Register(entry *P.ReverseProxyEntry) (*watcher, E.NestedError) {
failure := E.Failure("idle_watcher register")
if entry.IdleTimeout == 0 {
return nil, failure.With(E.Invalid("idle_timeout", 0))
}
watcherMapMu.Lock()
defer watcherMapMu.Unlock()
if w, ok := watcherMap[entry.ContainerName]; ok {
w.refCount.Add(1)
return w, nil
}
client, err := D.ConnectClient(entry.DockerHost)
if err.HasError() {
return nil, failure.With(err)
}
w := &watcher{
ReverseProxyEntry: entry,
client: client,
wakeCh: make(chan struct{}, 1),
wakeDone: make(chan E.NestedError, 1),
l: logger.WithField("container", entry.ContainerName),
}
w.refCount.Add(1)
w.stopByMethod = w.getStopCallback()
watcherMap[w.ContainerName] = w
go func() {
newWatcherCh <- w
}()
return w, nil
}
// If the container is not registered, this is no-op
func Unregister(containerName string) {
watcherMapMu.Lock()
defer watcherMapMu.Unlock()
if w, ok := watcherMap[containerName]; ok {
if w.refCount.Load() == 0 {
w.cancel()
close(w.wakeCh)
delete(watcherMap, containerName)
} else {
w.refCount.Add(-1)
}
}
}
func Start() {
logger.Debug("started")
defer logger.Debug("stopped")
mainLoopCtx, mainLoopCancel = context.WithCancel(context.Background())
defer mainLoopWg.Wait()
for {
select {
case <-mainLoopCtx.Done():
return
case w := <-newWatcherCh:
w.l.Debug("registered")
mainLoopWg.Add(1)
go func() {
w.watch()
Unregister(w.ContainerName)
w.l.Debug("unregistered")
mainLoopWg.Done()
}()
}
}
}
func Stop() {
mainLoopCancel()
mainLoopWg.Wait()
}
func (w *watcher) PatchRoundTripper(rtp http.RoundTripper) roundTripper {
return roundTripper{patched: func(r *http.Request) (*http.Response, error) {
return w.roundTrip(rtp.RoundTrip, r)
}}
}
func (w *watcher) roundTrip(origRoundTrip roundTripFunc, req *http.Request) (*http.Response, error) {
timeout := time.After(w.WakeTimeout)
w.wakeCh <- struct{}{}
for {
select {
case err := <-w.wakeDone:
if err != nil {
return nil, err.Error()
}
return origRoundTrip(req)
case <-timeout:
resp := loadingResponse
resp.TLS = req.TLS
return &resp, nil
}
}
}
func (w *watcher) containerStop() error {
return w.client.ContainerStop(w.ctx, w.ContainerName, container.StopOptions{
Signal: string(w.StopSignal),
Timeout: &w.StopTimeout})
}
func (w *watcher) containerPause() error {
return w.client.ContainerPause(w.ctx, w.ContainerName)
}
func (w *watcher) containerKill() error {
return w.client.ContainerKill(w.ctx, w.ContainerName, string(w.StopSignal))
}
func (w *watcher) containerUnpause() error {
return w.client.ContainerUnpause(w.ctx, w.ContainerName)
}
func (w *watcher) containerStart() error {
return w.client.ContainerStart(w.ctx, w.ContainerName, container.StartOptions{})
}
func (w *watcher) containerStatus() (string, E.NestedError) {
json, err := w.client.ContainerInspect(w.ctx, w.ContainerName)
if err != nil {
return "", E.FailWith("inspect container", err)
}
return json.State.Status, nil
}
func (w *watcher) wakeIfStopped() (bool, E.NestedError) {
failure := E.Failure("wake")
status, err := w.containerStatus()
if err.HasError() {
return false, failure.With(err)
}
// "created", "running", "paused", "restarting", "removing", "exited", or "dead"
switch status {
case "exited", "dead":
err = E.From(w.containerStart())
case "paused":
err = E.From(w.containerUnpause())
case "running":
return false, nil
default:
return false, failure.With(E.Unexpected("container state", status))
}
if err.HasError() {
return false, failure.With(err)
}
status, err = w.containerStatus()
if err.HasError() {
return false, failure.With(err)
} else if status != "running" {
return false, failure.With(E.Unexpected("container state", status))
} else {
return true, nil
}
}
func (w *watcher) getStopCallback() StopCallback {
var cb func() error
switch w.StopMethod {
case PT.StopMethodPause:
cb = w.containerPause
case PT.StopMethodStop:
cb = w.containerStop
case PT.StopMethodKill:
cb = w.containerKill
default:
panic("should not reach here")
}
return func() (bool, E.NestedError) {
status, err := w.containerStatus()
if err.HasError() {
return false, E.FailWith("stop", err)
}
if status != "running" {
return false, nil
}
err = E.From(cb())
if err.HasError() {
return false, E.FailWith("stop", err)
}
return true, nil
}
}
func (w *watcher) watch() {
watcherCtx, watcherCancel := context.WithCancel(context.Background())
w.ctx = watcherCtx
w.cancel = watcherCancel
ticker := time.NewTicker(w.IdleTimeout)
defer ticker.Stop()
for {
select {
case <-mainLoopCtx.Done():
watcherCancel()
case <-watcherCtx.Done():
w.l.Debug("stopped")
return
case <-ticker.C:
w.l.Debug("timeout")
stopped, err := w.stopByMethod()
if err.HasError() {
w.l.Error(err.Extraf("stop method: %s", w.StopMethod))
} else if stopped {
w.l.Infof("%s: ok", w.StopMethod)
} else {
ticker.Stop()
}
case <-w.wakeCh:
w.l.Debug("wake received")
go func() {
started, err := w.wakeIfStopped()
if err != nil {
w.l.Error(err)
} else if started {
w.l.Infof("awaken")
ticker.Reset(w.IdleTimeout)
}
w.wakeDone <- err // this is passed to roundtrip
}()
}
}
}
var (
mainLoopCtx context.Context
mainLoopCancel context.CancelFunc
mainLoopWg sync.WaitGroup
watcherMap = make(map[string]*watcher)
watcherMapMu sync.Mutex
newWatcherCh = make(chan *watcher)
logger = logrus.WithField("module", "idle_watcher")
loadingResponse = http.Response{
StatusCode: http.StatusAccepted,
Header: http.Header{
"Content-Type": {"text/html"},
"Cache-Control": {
"no-cache",
"no-store",
"must-revalidate",
},
},
Body: io.NopCloser(bytes.NewReader((loadingPage))),
ContentLength: int64(len(loadingPage)),
}
loadingPage = []byte(`
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Loading...</title>
</head>
<body>
<script>
window.onload = function() {
setTimeout(function() {
location.reload();
}, 1000); // 1000 milliseconds = 1 second
};
</script>
<p>Container is starting... Please wait</p>
</body>
</html>
`[1:])
)

19
src/docker/inspect.go Normal file
View file

@ -0,0 +1,19 @@
package docker
import (
"context"
"time"
E "github.com/yusing/go-proxy/error"
)
func (c Client) Inspect(containerID string) (Container, E.NestedError) {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
json, err := c.ContainerInspect(ctx, containerID)
if err != nil {
return Container{}, E.From(err)
}
return FromJson(json, c.key), nil
}

View file

@ -36,7 +36,7 @@ func ParseLabel(label string, value string) (*Label, E.NestedError) {
return &Label{ return &Label{
Namespace: label, Namespace: label,
Value: value, Value: value,
}, E.Nil() }, nil
} }
l := &Label{ l := &Label{
@ -54,12 +54,12 @@ func ParseLabel(label string, value string) (*Label, E.NestedError) {
// find if namespace has value parser // find if namespace has value parser
pm, ok := labelValueParserMap[l.Namespace] pm, ok := labelValueParserMap[l.Namespace]
if !ok { if !ok {
return l, E.Nil() return l, nil
} }
// find if attribute has value parser // find if attribute has value parser
p, ok := pm[l.Attribute] p, ok := pm[l.Attribute]
if !ok { if !ok {
return l, E.Nil() return l, nil
} }
// try to parse value // try to parse value
v, err := p(value) v, err := p(value)
@ -67,7 +67,7 @@ func ParseLabel(label string, value string) (*Label, E.NestedError) {
return nil, err return nil, err
} }
l.Value = v l.Value = v
return l, E.Nil() return l, nil
} }
func RegisterNamespace(namespace string, pm ValueParserMap) { func RegisterNamespace(namespace string, pm ValueParserMap) {

View file

@ -10,7 +10,7 @@ import (
func yamlListParser(value string) (any, E.NestedError) { func yamlListParser(value string) (any, E.NestedError) {
value = strings.TrimSpace(value) value = strings.TrimSpace(value)
if value == "" { if value == "" {
return []string{}, E.Nil() return []string{}, nil
} }
var data []string var data []string
err := E.From(yaml.Unmarshal([]byte(value), &data)) err := E.From(yaml.Unmarshal([]byte(value), &data))
@ -34,23 +34,15 @@ func yamlStringMappingParser(value string) (any, E.NestedError) {
h[key] = val h[key] = val
} }
} }
return h, E.Nil() return h, nil
}
func commaSepParser(value string) (any, E.NestedError) {
v := strings.Split(value, ",")
for i := range v {
v[i] = strings.TrimSpace(v[i])
}
return v, E.Nil()
} }
func boolParser(value string) (any, E.NestedError) { func boolParser(value string) (any, E.NestedError) {
switch strings.ToLower(value) { switch strings.ToLower(value) {
case "true", "yes", "1": case "true", "yes", "1":
return true, E.Nil() return true, nil
case "false", "no", "0": case "false", "no", "0":
return false, E.Nil() return false, nil
default: default:
return nil, E.Invalid("boolean value", value) return nil, E.Invalid("boolean value", value)
} }

View file

@ -7,7 +7,7 @@ import (
"testing" "testing"
E "github.com/yusing/go-proxy/error" E "github.com/yusing/go-proxy/error"
. "github.com/yusing/go-proxy/utils" . "github.com/yusing/go-proxy/utils/testing"
) )
func makeLabel(namespace string, alias string, field string) string { func makeLabel(namespace string, alias string, field string) string {
@ -19,7 +19,7 @@ func TestHomePageLabel(t *testing.T) {
field := "ip" field := "ip"
v := "bar" v := "bar"
pl, err := ParseLabel(makeLabel(NSHomePage, alias, field), v) pl, err := ParseLabel(makeLabel(NSHomePage, alias, field), v)
ExpectNoError(t, err) ExpectNoError(t, err.Error())
if pl.Target != alias { if pl.Target != alias {
t.Errorf("Expected alias=%s, got %s", alias, pl.Target) t.Errorf("Expected alias=%s, got %s", alias, pl.Target)
} }
@ -34,8 +34,8 @@ func TestHomePageLabel(t *testing.T) {
func TestStringProxyLabel(t *testing.T) { func TestStringProxyLabel(t *testing.T) {
v := "bar" v := "bar"
pl, err := ParseLabel(makeLabel(NSProxy, "foo", "ip"), v) pl, err := ParseLabel(makeLabel(NSProxy, "foo", "ip"), v)
ExpectNoError(t, err) ExpectNoError(t, err.Error())
ExpectEqual(t, pl.Value, v) ExpectEqual(t, pl.Value.(string), v)
} }
func TestBoolProxyLabelValid(t *testing.T) { func TestBoolProxyLabelValid(t *testing.T) {
@ -52,8 +52,8 @@ func TestBoolProxyLabelValid(t *testing.T) {
for k, v := range tests { for k, v := range tests {
pl, err := ParseLabel(makeLabel(NSProxy, "foo", "no_tls_verify"), k) pl, err := ParseLabel(makeLabel(NSProxy, "foo", "no_tls_verify"), k)
ExpectNoError(t, err) ExpectNoError(t, err.Error())
ExpectEqual(t, pl.Value, v) ExpectEqual(t, pl.Value.(bool), v)
} }
} }
@ -78,7 +78,7 @@ X-Custom-Header2: boo`
} }
pl, err := ParseLabel(makeLabel(NSProxy, "foo", "set_headers"), v) pl, err := ParseLabel(makeLabel(NSProxy, "foo", "set_headers"), v)
ExpectNoError(t, err) ExpectNoError(t, err.Error())
hGot := ExpectType[map[string]string](t, pl.Value) hGot := ExpectType[map[string]string](t, pl.Value)
if hGot != nil && !reflect.DeepEqual(h, hGot) { if hGot != nil && !reflect.DeepEqual(h, hGot) {
t.Errorf("Expected %v, got %v", h, hGot) t.Errorf("Expected %v, got %v", h, hGot)
@ -109,33 +109,32 @@ func TestHideHeadersProxyLabel(t *testing.T) {
` `
v = strings.TrimPrefix(v, "\n") v = strings.TrimPrefix(v, "\n")
pl, err := ParseLabel(makeLabel(NSProxy, "foo", "hide_headers"), v) pl, err := ParseLabel(makeLabel(NSProxy, "foo", "hide_headers"), v)
ExpectNoError(t, err) ExpectNoError(t, err.Error())
sGot := ExpectType[[]string](t, pl.Value) sGot := ExpectType[[]string](t, pl.Value)
sWant := []string{"X-Custom-Header1", "X-Custom-Header2", "X-Custom-Header3"} sWant := []string{"X-Custom-Header1", "X-Custom-Header2", "X-Custom-Header3"}
if sGot != nil { if sGot != nil {
ExpectEqual(t, sGot, sWant) ExpectDeepEqual(t, sGot, sWant)
} }
} }
func TestCommaSepProxyLabelSingle(t *testing.T) { // func TestCommaSepProxyLabelSingle(t *testing.T) {
v := "a" // v := "a"
pl, err := ParseLabel("proxy.aliases", v) // pl, err := ParseLabel("proxy.aliases", v)
ExpectNoError(t, err) // ExpectNoError(t, err)
sGot := ExpectType[[]string](t, pl.Value) // sGot := ExpectType[[]string](t, pl.Value)
sWant := []string{"a"} // sWant := []string{"a"}
if sGot != nil { // if sGot != nil {
ExpectEqual(t, sGot, sWant) // ExpectEqual(t, sGot, sWant)
} // }
// }
} // func TestCommaSepProxyLabelMulti(t *testing.T) {
// v := "X-Custom-Header1, X-Custom-Header2,X-Custom-Header3"
func TestCommaSepProxyLabelMulti(t *testing.T) { // pl, err := ParseLabel("proxy.aliases", v)
v := "X-Custom-Header1, X-Custom-Header2,X-Custom-Header3" // ExpectNoError(t, err)
pl, err := ParseLabel("proxy.aliases", v) // sGot := ExpectType[[]string](t, pl.Value)
ExpectNoError(t, err) // sWant := []string{"X-Custom-Header1", "X-Custom-Header2", "X-Custom-Header3"}
sGot := ExpectType[[]string](t, pl.Value) // if sGot != nil {
sWant := []string{"X-Custom-Header1", "X-Custom-Header2", "X-Custom-Header3"} // ExpectEqual(t, sGot, sWant)
if sGot != nil { // }
ExpectEqual(t, sGot, sWant) // }
}
}

13
src/docker/labels.go Normal file
View file

@ -0,0 +1,13 @@
package docker
const (
WildcardAlias = "*"
LableAliases = NSProxy + ".aliases"
LableExclude = NSProxy + ".exclude"
LabelIdleTimeout = NSProxy + ".idle_timeout"
LabelWakeTimeout = NSProxy + ".wake_timeout"
LabelStopMethod = NSProxy + ".stop_method"
LabelStopTimeout = NSProxy + ".stop_timeout"
LabelStopSignal = NSProxy + ".stop_signal"
)

View file

@ -6,16 +6,23 @@ import (
) )
type Builder struct { type Builder struct {
message string *builder
errors []error }
type builder struct {
message string
errors []NestedError
severity Severity
sync.Mutex sync.Mutex
} }
func NewBuilder(format string, args ...any) *Builder { func NewBuilder(format string, args ...any) Builder {
return &Builder{message: fmt.Sprintf(format, args...)} return Builder{&builder{message: fmt.Sprintf(format, args...)}}
} }
func (b *Builder) Add(err error) *Builder { // adding nil / nil is no-op,
// you may safely pass expressions returning error to it
func (b Builder) Add(err NestedError) Builder {
if err != nil { if err != nil {
b.Lock() b.Lock()
b.errors = append(b.errors, err) b.errors = append(b.errors, err)
@ -24,8 +31,17 @@ func (b *Builder) Add(err error) *Builder {
return b return b
} }
func (b *Builder) Addf(format string, args ...any) *Builder { func (b Builder) AddE(err error) Builder {
return b.Add(fmt.Errorf(format, args...)) return b.Add(From(err))
}
func (b Builder) Addf(format string, args ...any) Builder {
return b.Add(errorf(format, args...))
}
func (b Builder) WithSeverity(s Severity) Builder {
b.severity = s
return b
} }
// Build builds a NestedError based on the errors collected in the Builder. // Build builds a NestedError based on the errors collected in the Builder.
@ -35,9 +51,21 @@ func (b *Builder) Addf(format string, args ...any) *Builder {
// //
// Returns: // Returns:
// - NestedError: the built NestedError. // - NestedError: the built NestedError.
func (b *Builder) Build() NestedError { func (b Builder) Build() NestedError {
if len(b.errors) == 0 { if len(b.errors) == 0 {
return Nil() return nil
} }
return Join(b.message, b.errors...) return Join(b.message, b.errors...).Severity(b.severity)
}
func (b Builder) To(ptr *NestedError) {
if *ptr == nil {
*ptr = b.Build()
} else {
**ptr = *b.Build()
}
}
func (b Builder) HasError() bool {
return len(b.errors) > 0
} }

View file

@ -1,13 +1,38 @@
package error package error
import "testing" import (
"testing"
func TestBuilder(t *testing.T) { . "github.com/yusing/go-proxy/utils/testing"
)
func TestBuilderEmpty(t *testing.T) {
eb := NewBuilder("qwer")
ExpectTrue(t, eb.Build() == nil)
ExpectTrue(t, eb.Build().NoError())
ExpectFalse(t, eb.HasError())
}
func TestBuilderAddNil(t *testing.T) {
eb := NewBuilder("asdf")
var err NestedError
for range 3 {
eb.Add(nil)
}
for range 3 {
eb.Add(err)
}
ExpectTrue(t, eb.Build() == nil)
ExpectTrue(t, eb.Build().NoError())
ExpectFalse(t, eb.HasError())
}
func TestBuilderNested(t *testing.T) {
eb := NewBuilder("error occurred") eb := NewBuilder("error occurred")
eb.Add(Failure("Action 1").With(Invalid("Inner", "1")).With(Invalid("Inner", "2"))) eb.Add(Failure("Action 1").With(Invalid("Inner", "1")).With(Invalid("Inner", "2")))
eb.Add(Failure("Action 2").With(Invalid("Inner", "3"))) eb.Add(Failure("Action 2").With(Invalid("Inner", "3")))
got := eb.Build().Error() got := eb.Build().String()
expected1 := expected1 :=
(`error occurred: (`error occurred:
- Action 1 failed: - Action 1 failed:

View file

@ -7,35 +7,37 @@ import (
) )
type ( type (
// NestedError is an error with an inner error NestedError = *nestedError
// and a list of extra nested errors. nestedError struct {
// subject string
// It is designed to be non nil. err error // can be nil
// extras []nestedError
// You can use it to join multiple errors, severity Severity
// or to set a inner reason for a nested error.
//
// When a method returns both valid values and errors,
// You should return (Slice/Map, NestedError).
// Caller then should handle the nested error,
// and continue with the valid values.
NestedError struct {
subject string
err error // can be nil
extras []NestedError
} }
errorInterface struct {
*nestedError
}
Severity uint8
) )
func Nil() NestedError { return NestedError{} } const (
SeverityFatal Severity = iota
SeverityWarning
)
func (e errorInterface) Error() string {
return e.String()
}
func From(err error) NestedError { func From(err error) NestedError {
if IsNil(err) {
return nil
}
switch err := err.(type) { switch err := err.(type) {
case nil: case errorInterface:
return Nil() return err.nestedError
case NestedError:
return err
default: default:
return NestedError{err: err} return &nestedError{err: err}
} }
} }
@ -45,40 +47,84 @@ func Check[T any](obj T, err error) (T, NestedError) {
return obj, From(err) return obj, From(err)
} }
func Join(message string, err ...error) NestedError { func Join(message string, err ...NestedError) NestedError {
extras := make([]NestedError, 0, len(err)) extras := make([]nestedError, len(err))
nErr := 0 nErr := 0
for _, e := range err { for i, e := range err {
if err == nil { if e == nil {
continue continue
} }
extras = append(extras, From(e)) extras[i] = *e
nErr += 1 nErr += 1
} }
if nErr == 0 { if nErr == 0 {
return Nil() return nil
} }
return NestedError{ return &nestedError{
err: errors.New(message), err: errors.New(message),
extras: extras, extras: extras,
} }
} }
func (ne NestedError) Error() string { func JoinE(message string, err ...error) NestedError {
b := NewBuilder(message)
for _, e := range err {
b.AddE(e)
}
return b.Build()
}
func IsNil(err error) bool {
return err == nil
}
func IsNotNil(err error) bool {
return err != nil
}
func (ne NestedError) String() string {
var buf strings.Builder var buf strings.Builder
ne.writeToSB(&buf, 0, "") ne.writeToSB(&buf, 0, "")
return buf.String() return buf.String()
} }
func (ne NestedError) Is(err error) bool { func (ne NestedError) Is(err error) bool {
return errors.Is(ne.err, err) if ne == nil {
return err == nil
}
// return errors.Is(ne.err, err)
if errors.Is(ne.err, err) {
return true
}
for _, e := range ne.extras {
if e.Is(err) {
return true
}
}
return false
}
func (ne NestedError) IsNot(err error) bool {
return !ne.Is(err)
}
func (ne NestedError) Error() error {
if ne == nil {
return nil
}
return errorInterface{ne}
} }
func (ne NestedError) With(s any) NestedError { func (ne NestedError) With(s any) NestedError {
if ne == nil {
return ne
}
var msg string var msg string
switch ss := s.(type) { switch ss := s.(type) {
case nil: case nil:
return ne return ne
case *nestedError:
return ne.withError(ss.Error())
case error: case error:
return ne.withError(ss) return ne.withError(ss)
case string: case string:
@ -92,10 +138,13 @@ func (ne NestedError) With(s any) NestedError {
} }
func (ne NestedError) Extraf(format string, args ...any) NestedError { func (ne NestedError) Extraf(format string, args ...any) NestedError {
return ne.With(fmt.Errorf(format, args...)) return ne.With(errorf(format, args...))
} }
func (ne NestedError) Subject(s any) NestedError { func (ne NestedError) Subject(s any) NestedError {
if ne == nil {
return ne
}
switch ss := s.(type) { switch ss := s.(type) {
case string: case string:
ne.subject = ss ne.subject = ss
@ -108,6 +157,9 @@ func (ne NestedError) Subject(s any) NestedError {
} }
func (ne NestedError) Subjectf(format string, args ...any) NestedError { func (ne NestedError) Subjectf(format string, args ...any) NestedError {
if ne == nil {
return ne
}
if strings.Contains(format, "%q") { if strings.Contains(format, "%q") {
panic("Subjectf format should not contain %q") panic("Subjectf format should not contain %q")
} }
@ -118,12 +170,36 @@ func (ne NestedError) Subjectf(format string, args ...any) NestedError {
return ne return ne
} }
func (ne NestedError) Severity(s Severity) NestedError {
if ne == nil {
return ne
}
ne.severity = s
return ne
}
func (ne NestedError) Warn() NestedError {
if ne == nil {
return ne
}
ne.severity = SeverityWarning
return ne
}
func (ne NestedError) NoError() bool { func (ne NestedError) NoError() bool {
return ne.err == nil return ne == nil
} }
func (ne NestedError) HasError() bool { func (ne NestedError) HasError() bool {
return ne.err != nil return ne != nil
}
func (ne NestedError) IsFatal() bool {
return ne != nil && ne.severity == SeverityFatal
}
func (ne NestedError) IsWarning() bool {
return ne != nil && ne.severity == SeverityWarning
} }
func errorf(format string, args ...any) NestedError { func errorf(format string, args ...any) NestedError {
@ -131,11 +207,13 @@ func errorf(format string, args ...any) NestedError {
} }
func (ne NestedError) withError(err error) NestedError { func (ne NestedError) withError(err error) NestedError {
ne.extras = append(ne.extras, From(err)) if ne != nil && IsNotNil(err) {
ne.extras = append(ne.extras, *From(err))
}
return ne return ne
} }
func (ne *NestedError) writeToSB(sb *strings.Builder, level int, prefix string) { func (ne NestedError) writeToSB(sb *strings.Builder, level int, prefix string) {
ne.writeIndents(sb, level) ne.writeIndents(sb, level)
sb.WriteString(prefix) sb.WriteString(prefix)
@ -146,7 +224,7 @@ func (ne *NestedError) writeToSB(sb *strings.Builder, level int, prefix string)
sb.WriteString(ne.err.Error()) sb.WriteString(ne.err.Error())
if ne.subject != "" { if ne.subject != "" {
if ne.err != nil { if IsNotNil(ne.err) {
sb.WriteString(fmt.Sprintf(" for %q", ne.subject)) sb.WriteString(fmt.Sprintf(" for %q", ne.subject))
} else { } else {
sb.WriteString(fmt.Sprint(ne.subject)) sb.WriteString(fmt.Sprint(ne.subject))
@ -161,7 +239,7 @@ func (ne *NestedError) writeToSB(sb *strings.Builder, level int, prefix string)
} }
} }
func (ne *NestedError) writeIndents(sb *strings.Builder, level int) { func (ne NestedError) writeIndents(sb *strings.Builder, level int) {
for i := 0; i < level; i++ { for i := 0; i < level; i++ {
sb.WriteString(" ") sb.WriteString(" ")
} }

View file

@ -4,7 +4,7 @@ import (
"testing" "testing"
. "github.com/yusing/go-proxy/error" . "github.com/yusing/go-proxy/error"
. "github.com/yusing/go-proxy/utils" . "github.com/yusing/go-proxy/utils/testing"
) )
func TestErrorIs(t *testing.T) { func TestErrorIs(t *testing.T) {
@ -16,27 +16,53 @@ func TestErrorIs(t *testing.T) {
ExpectTrue(t, Invalid("foo", "bar").Is(ErrInvalid)) ExpectTrue(t, Invalid("foo", "bar").Is(ErrInvalid))
ExpectFalse(t, Invalid("foo", "bar").Is(ErrFailure)) ExpectFalse(t, Invalid("foo", "bar").Is(ErrFailure))
ExpectTrue(t, Nil().Is(nil))
ExpectFalse(t, Nil().Is(ErrInvalid))
ExpectFalse(t, Invalid("foo", "bar").Is(nil)) ExpectFalse(t, Invalid("foo", "bar").Is(nil))
} }
func TestNil(t *testing.T) { func TestErrorNestedIs(t *testing.T) {
ExpectTrue(t, Nil().NoError()) var err NestedError
ExpectFalse(t, Nil().HasError()) ExpectTrue(t, err.Is(nil))
ExpectEqual(t, Nil().Error(), "nil")
err = Failure("some reason")
ExpectTrue(t, err.Is(ErrFailure))
ExpectFalse(t, err.Is(ErrAlreadyExist))
err.With(AlreadyExist("something", ""))
ExpectTrue(t, err.Is(ErrFailure))
ExpectTrue(t, err.Is(ErrAlreadyExist))
ExpectFalse(t, err.Is(ErrInvalid))
}
func TestIsNil(t *testing.T) {
var err NestedError
ExpectTrue(t, err.Is(nil))
ExpectFalse(t, err.HasError())
ExpectTrue(t, err == nil)
ExpectTrue(t, err.NoError())
eb := NewBuilder("")
returnNil := func() error {
return eb.Build().Error()
}
ExpectTrue(t, IsNil(returnNil()))
ExpectTrue(t, returnNil() == nil)
ExpectTrue(t, (err.
Subject("any").
With("something").
Extraf("foo %s", "bar")) == nil)
} }
func TestErrorSimple(t *testing.T) { func TestErrorSimple(t *testing.T) {
ne := Failure("foo bar") ne := Failure("foo bar")
ExpectEqual(t, ne.Error(), "foo bar failed") ExpectEqual(t, ne.String(), "foo bar failed")
ne = ne.Subject("baz") ne = ne.Subject("baz")
ExpectEqual(t, ne.Error(), "foo bar failed for \"baz\"") ExpectEqual(t, ne.String(), "foo bar failed for \"baz\"")
} }
func TestErrorWith(t *testing.T) { func TestErrorWith(t *testing.T) {
ne := Failure("foo").With("bar").With("baz") ne := Failure("foo").With("bar").With("baz")
ExpectEqual(t, ne.Error(), "foo failed:\n - bar\n - baz") ExpectEqual(t, ne.String(), "foo failed:\n - bar\n - baz")
} }
func TestErrorNested(t *testing.T) { func TestErrorNested(t *testing.T) {
@ -72,5 +98,5 @@ func TestErrorNested(t *testing.T) {
- inner3 failed for "action 3": - inner3 failed for "action 3":
- 3 - 3
- 3` - 3`
ExpectEqual(t, ne.Error(), want) ExpectEqual(t, ne.String(), want)
} }

View file

@ -5,33 +5,48 @@ import (
) )
var ( var (
ErrFailure = stderrors.New("failed") ErrFailure = stderrors.New("failed")
ErrInvalid = stderrors.New("invalid") ErrInvalid = stderrors.New("invalid")
ErrUnsupported = stderrors.New("unsupported") ErrUnsupported = stderrors.New("unsupported")
ErrNotExists = stderrors.New("does not exist") ErrUnexpected = stderrors.New("unexpected")
ErrDuplicated = stderrors.New("duplicated") ErrNotExists = stderrors.New("does not exist")
ErrAlreadyExist = stderrors.New("already exist")
) )
const fmtSubjectWhat = "%w %v: %v"
func Failure(what string) NestedError { func Failure(what string) NestedError {
return errorf("%s %w", what, ErrFailure) return errorf("%s %w", what, ErrFailure)
} }
func FailureWhy(what string, why string) NestedError { func FailedWhy(what string, why string) NestedError {
return errorf("%s %w because %s", what, ErrFailure, why) return errorf("%s %w because %s", what, ErrFailure, why)
} }
func FailWith(what string, err any) NestedError {
return Failure(what).With(err)
}
func Invalid(subject, what any) NestedError { func Invalid(subject, what any) NestedError {
return errorf("%w %v - %v", ErrInvalid, subject, what) return errorf(fmtSubjectWhat, ErrInvalid, subject, what)
} }
func Unsupported(subject, what any) NestedError { func Unsupported(subject, what any) NestedError {
return errorf("%w %v - %v", ErrUnsupported, subject, what) return errorf(fmtSubjectWhat, ErrUnsupported, subject, what)
} }
func NotExists(subject, what any) NestedError { func Unexpected(subject, what any) NestedError {
return errorf("%s %v - %v", subject, ErrNotExists, what) return errorf(fmtSubjectWhat, ErrUnexpected, subject, what)
} }
func Duplicated(subject, what any) NestedError { func UnexpectedError(err error) NestedError {
return errorf("%w %v: %v", ErrDuplicated, subject, what) return errorf("%w error: %w", ErrUnexpected, err)
}
func NotExist(subject, what any) NestedError {
return errorf("%v %w: %v", subject, ErrNotExists, what)
}
func AlreadyExist(subject, what any) NestedError {
return errorf("%v %w: %v", subject, ErrAlreadyExist, what)
} }

View file

@ -7,6 +7,7 @@ require (
github.com/docker/docker v27.2.1+incompatible github.com/docker/docker v27.2.1+incompatible
github.com/fsnotify/fsnotify v1.7.0 github.com/fsnotify/fsnotify v1.7.0
github.com/go-acme/lego/v4 v4.18.0 github.com/go-acme/lego/v4 v4.18.0
github.com/puzpuzpuz/xsync/v3 v3.4.0
github.com/santhosh-tekuri/jsonschema v1.2.4 github.com/santhosh-tekuri/jsonschema v1.2.4
github.com/sirupsen/logrus v1.9.3 github.com/sirupsen/logrus v1.9.3
golang.org/x/net v0.29.0 golang.org/x/net v0.29.0

View file

@ -73,6 +73,8 @@ github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/puzpuzpuz/xsync/v3 v3.4.0 h1:DuVBAdXuGFHv8adVXjWWZ63pJq+NRXOWVXlKDBZ+mJ4=
github.com/puzpuzpuz/xsync/v3 v3.4.0/go.mod h1:VjzYrABPabuM4KyBh1Ftq6u8nhwY5tBPKP9jpmh0nnA=
github.com/rogpeppe/go-internal v1.8.1 h1:geMPLpDpQOgVyCg5z5GoRwLHepNdb71NXb67XFkP+Eg= github.com/rogpeppe/go-internal v1.8.1 h1:geMPLpDpQOgVyCg5z5GoRwLHepNdb71NXb67XFkP+Eg=
github.com/rogpeppe/go-internal v1.8.1/go.mod h1:JeRgkft04UBgHMgCIwADu4Pn6Mtm5d4nPKWu0nJ5d+o= github.com/rogpeppe/go-internal v1.8.1/go.mod h1:JeRgkft04UBgHMgCIwADu4Pn6Mtm5d4nPKWu0nJ5d+o=
github.com/santhosh-tekuri/jsonschema v1.2.4 h1:hNhW8e7t+H1vgY+1QeEQpveR6D4+OwKPXCfD2aieJis= github.com/santhosh-tekuri/jsonschema v1.2.4 h1:hNhW8e7t+H1vgY+1QeEQpveR6D4+OwKPXCfD2aieJis=

View file

@ -18,6 +18,7 @@ import (
"github.com/yusing/go-proxy/common" "github.com/yusing/go-proxy/common"
"github.com/yusing/go-proxy/config" "github.com/yusing/go-proxy/config"
"github.com/yusing/go-proxy/docker" "github.com/yusing/go-proxy/docker"
"github.com/yusing/go-proxy/docker/idlewatcher"
E "github.com/yusing/go-proxy/error" E "github.com/yusing/go-proxy/error"
R "github.com/yusing/go-proxy/route" R "github.com/yusing/go-proxy/route"
"github.com/yusing/go-proxy/server" "github.com/yusing/go-proxy/server"
@ -53,37 +54,40 @@ func main() {
// exit if only validate config // exit if only validate config
if args.Command == common.CommandValidate { if args.Command == common.CommandValidate {
var err E.NestedError data, err := os.ReadFile(common.ConfigPath)
data, err := E.Check(os.ReadFile(common.ConfigPath)) if err == nil {
if err.HasError() { err = config.Validate(data).Error()
l.WithError(err).Fatalf("config error")
} }
if err = config.Validate(data); err.HasError() { if err != nil {
l.WithError(err).Fatalf("config error") l.Fatal("config error: ", err)
} }
l.Printf("config OK") l.Printf("config OK")
return return
} }
cfg, err := config.New() cfg, err := config.Load()
if err.HasError() { if err.IsFatal() {
l.Fatalf("config error: %s", err) l.Fatal(err)
} }
if args.Command == common.CommandListConfigs { if args.Command == common.CommandListConfigs {
yml, err := E.Check(json.Marshal(cfg.Value())) printJSON(cfg.Value())
if err.HasError() {
panic(err)
}
rawLogger := log.New(os.Stdout, "", 0)
rawLogger.Printf("%s", yml) // raw output for convenience using "jq"
return return
} }
onShutdown.Add(func() { cfg.StartProxyProviders()
docker.CloseAllClients()
cfg.Dispose() if args.Command == common.CommandListRoutes {
}) printJSON(cfg.RoutesByAlias())
return
}
if err.HasError() {
l.Warn(err)
}
onShutdown.Add(docker.CloseAllClients)
onShutdown.Add(cfg.Dispose)
sig := make(chan os.Signal, 1) sig := make(chan os.Signal, 1)
signal.Notify(sig, syscall.SIGINT) signal.Notify(sig, syscall.SIGINT)
@ -109,8 +113,9 @@ func main() {
onShutdown.Add(certRenewalCancel) onShutdown.Add(certRenewalCancel)
} }
for name, expiry := range autocert.GetExpiries() { for _, expiry := range autocert.GetExpiries() {
l.Infof("certificate %q: expire on %s", name, expiry) l.Infof("certificate expire on %s", expiry)
break
} }
} else { } else {
l.Info("autocert not configured") l.Info("autocert not configured")
@ -137,6 +142,9 @@ func main() {
onShutdown.Add(proxyServer.Stop) onShutdown.Add(proxyServer.Stop)
onShutdown.Add(apiServer.Stop) onShutdown.Add(apiServer.Stop)
go idlewatcher.Start()
onShutdown.Add(idlewatcher.Stop)
// wait for signal // wait for signal
<-sig <-sig
@ -164,3 +172,12 @@ func main() {
logrus.Info("timeout waiting for shutdown") logrus.Info("timeout waiting for shutdown")
} }
} }
func printJSON(obj any) {
j, err := E.Check(json.Marshal(obj))
if err.HasError() {
logrus.Fatal(err)
}
rawLogger := log.New(os.Stdout, "", 0)
rawLogger.Printf("%s", j) // raw output for convenience using "jq"
}

View file

@ -1,13 +1,16 @@
package model package model
import ( import (
"strconv"
"strings" "strings"
. "github.com/yusing/go-proxy/common"
D "github.com/yusing/go-proxy/docker"
F "github.com/yusing/go-proxy/utils/functional" F "github.com/yusing/go-proxy/utils/functional"
) )
type ( type (
ProxyEntry struct { ProxyEntry struct { // raw entry object before validation
Alias string `yaml:"-" json:"-"` Alias string `yaml:"-" json:"-"`
Scheme string `yaml:"scheme" json:"scheme"` Scheme string `yaml:"scheme" json:"scheme"`
Host string `yaml:"host" json:"host"` Host string `yaml:"host" json:"host"`
@ -16,35 +19,66 @@ type (
PathPatterns []string `yaml:"path_patterns" json:"path_patterns"` // http(s) proxy only PathPatterns []string `yaml:"path_patterns" json:"path_patterns"` // http(s) proxy only
SetHeaders map[string]string `yaml:"set_headers" json:"set_headers"` // http(s) proxy only SetHeaders map[string]string `yaml:"set_headers" json:"set_headers"` // http(s) proxy only
HideHeaders []string `yaml:"hide_headers" json:"hide_headers"` // http(s) proxy only HideHeaders []string `yaml:"hide_headers" json:"hide_headers"` // http(s) proxy only
/* Docker only */
*D.ProxyProperties `yaml:"-" json:"-"`
} }
ProxyEntries = *F.Map[string, *ProxyEntry] ProxyEntries = F.Map[string, *ProxyEntry]
) )
var NewProxyEntries = F.NewMap[string, *ProxyEntry] var NewProxyEntries = F.NewMapOf[string, *ProxyEntry]
func (e *ProxyEntry) SetDefaults() { func (e *ProxyEntry) SetDefaults() {
if e.Scheme == "" { if e.Scheme == "" {
if strings.ContainsRune(e.Port, ':') { switch {
case strings.ContainsRune(e.Port, ':'):
e.Scheme = "tcp" e.Scheme = "tcp"
} else { case e.ProxyProperties != nil:
switch e.Port { if _, ok := ServiceNamePortMapTCP[e.ImageName]; ok {
case "443", "8443": e.Scheme = "tcp"
e.Scheme = "https"
default:
e.Scheme = "http"
} }
} }
} }
if e.Scheme == "" {
switch e.Port {
case "443", "8443":
e.Scheme = "https"
default:
e.Scheme = "http"
}
}
if e.Host == "" { if e.Host == "" {
e.Host = "localhost" e.Host = "localhost"
} }
if e.Port == "" { if e.Port == "" {
switch e.Scheme { e.Port = e.FirstPort
case "http": }
e.Port = "80" if e.Port == "" {
case "https": if port, ok := ServiceNamePortMapTCP[e.Port]; ok {
e.Port = "443" e.Port = strconv.Itoa(port)
} else if port, ok := ImageNamePortMapHTTP[e.Port]; ok {
e.Port = strconv.Itoa(port)
} else {
switch e.Scheme {
case "http":
e.Port = "80"
case "https":
e.Port = "443"
}
} }
} }
if e.IdleTimeout == "" {
e.IdleTimeout = IdleTimeoutDefault
}
if e.WakeTimeout == "" {
e.WakeTimeout = WakeTimeoutDefault
}
if e.StopTimeout == "" {
e.StopTimeout = StopTimeoutDefault
}
if e.StopMethod == "" {
e.StopMethod = StopMethodDefault
}
} }

View file

@ -4,6 +4,7 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"net/url" "net/url"
"time"
E "github.com/yusing/go-proxy/error" E "github.com/yusing/go-proxy/error"
M "github.com/yusing/go-proxy/models" M "github.com/yusing/go-proxy/models"
@ -11,16 +12,23 @@ import (
) )
type ( type (
Entry struct { // real model after validation ReverseProxyEntry struct { // real model after validation
Alias T.Alias Alias T.Alias
Scheme T.Scheme Scheme T.Scheme
Host T.Host
Port T.Port
URL *url.URL URL *url.URL
NoTLSVerify bool NoTLSVerify bool
PathPatterns T.PathPatterns PathPatterns T.PathPatterns
SetHeaders http.Header SetHeaders http.Header
HideHeaders []string HideHeaders []string
/* Docker only */
IdleTimeout time.Duration
WakeTimeout time.Duration
StopMethod T.StopMethod
StopTimeout int
StopSignal T.Signal
DockerHost string
ContainerName string
} }
StreamEntry struct { StreamEntry struct {
Alias T.Alias `json:"alias"` Alias T.Alias `json:"alias"`
@ -30,69 +38,105 @@ type (
} }
) )
func NewEntry(m *M.ProxyEntry) (any, E.NestedError) { func (rp *ReverseProxyEntry) UseIdleWatcher() bool {
return rp.IdleTimeout > 0 && rp.DockerHost != ""
}
func ValidateEntry(m *M.ProxyEntry) (any, E.NestedError) {
m.SetDefaults() m.SetDefaults()
scheme, err := T.NewScheme(m.Scheme) scheme, err := T.NewScheme(m.Scheme)
if err.HasError() { if err.HasError() {
return nil, err return nil, err
} }
var entry any
e := E.NewBuilder("error validating proxy entry")
if scheme.IsStream() { if scheme.IsStream() {
return validateStreamEntry(m) entry = validateStreamEntry(m, e)
} else {
entry = validateRPEntry(m, scheme, e)
} }
return validateEntry(m, scheme) if err := e.Build(); err.HasError() {
return nil, err
}
return entry, nil
} }
func validateEntry(m *M.ProxyEntry, s T.Scheme) (*Entry, E.NestedError) { func validateRPEntry(m *M.ProxyEntry, s T.Scheme, b E.Builder) *ReverseProxyEntry {
host, err := T.NewHost(m.Host) var stopTimeOut time.Duration
if err.HasError() {
return nil, err host, err := T.ValidateHost(m.Host)
} b.Add(err)
port, err := T.NewPort(m.Port)
if err.HasError() { port, err := T.ValidatePort(m.Port)
return nil, err b.Add(err)
}
pathPatterns, err := T.NewPathPatterns(m.PathPatterns) pathPatterns, err := T.ValidatePathPatterns(m.PathPatterns)
if err.HasError() { b.Add(err)
return nil, err
} setHeaders, err := T.ValidateHTTPHeaders(m.SetHeaders)
setHeaders, err := T.NewHTTPHeaders(m.SetHeaders) b.Add(err)
if err.HasError() {
return nil, err
}
url, err := E.Check(url.Parse(fmt.Sprintf("%s://%s:%d", s, host, port))) url, err := E.Check(url.Parse(fmt.Sprintf("%s://%s:%d", s, host, port)))
if err.HasError() { b.Add(err)
return nil, err
idleTimeout, err := T.ValidateDurationPostitive(m.IdleTimeout)
b.Add(err)
wakeTimeout, err := T.ValidateDurationPostitive(m.WakeTimeout)
b.Add(err)
stopMethod, err := T.ValidateStopMethod(m.StopMethod)
b.Add(err)
if stopMethod == T.StopMethodStop {
stopTimeOut, err = T.ValidateDurationPostitive(m.StopTimeout)
b.Add(err)
}
stopSignal, err := T.ValidateSignal(m.StopSignal)
b.Add(err)
if err.HasError() {
return nil
}
return &ReverseProxyEntry{
Alias: T.NewAlias(m.Alias),
Scheme: s,
URL: url,
NoTLSVerify: m.NoTLSVerify,
PathPatterns: pathPatterns,
SetHeaders: setHeaders,
HideHeaders: m.HideHeaders,
IdleTimeout: idleTimeout,
WakeTimeout: wakeTimeout,
StopMethod: stopMethod,
StopTimeout: int(stopTimeOut.Seconds()), // docker api takes integer seconds for timeout argument
StopSignal: stopSignal,
DockerHost: m.DockerHost,
ContainerName: m.ContainerName,
} }
return &Entry{
Alias: T.NewAlias(m.Alias),
Scheme: s,
Host: host,
Port: port,
URL: url,
NoTLSVerify: m.NoTLSVerify,
PathPatterns: pathPatterns,
SetHeaders: setHeaders,
HideHeaders: m.HideHeaders,
}, E.Nil()
} }
func validateStreamEntry(m *M.ProxyEntry) (*StreamEntry, E.NestedError) { func validateStreamEntry(m *M.ProxyEntry, b E.Builder) *StreamEntry {
host, err := T.NewHost(m.Host) host, err := T.ValidateHost(m.Host)
if err.HasError() { b.Add(err)
return nil, err
} port, err := T.ValidateStreamPort(m.Port)
port, err := T.NewStreamPort(m.Port) b.Add(err)
if err.HasError() {
return nil, err scheme, err := T.ValidateStreamScheme(m.Scheme)
} b.Add(err)
scheme, err := T.NewStreamScheme(m.Scheme)
if err.HasError() { if b.HasError() {
return nil, err return nil
} }
return &StreamEntry{ return &StreamEntry{
Alias: T.NewAlias(m.Alias), Alias: T.NewAlias(m.Alias),
Scheme: *scheme, Scheme: *scheme,
Host: host, Host: host,
Port: port, Port: port,
}, E.Nil() }
} }

View file

@ -1,23 +1,6 @@
package fields package fields
import ( type (
"strings" Alias string
NewAlias = Alias
F "github.com/yusing/go-proxy/utils/functional"
) )
type Alias string
type Aliases struct{ *F.Slice[Alias] }
func NewAlias(s string) Alias {
return Alias(s)
}
func NewAliases(s string) Aliases {
split := strings.Split(s, ",")
a := Aliases{F.NewSliceN[Alias](len(split))}
for i, v := range split {
a.Set(i, NewAlias(v))
}
return a
}

View file

@ -7,7 +7,7 @@ import (
E "github.com/yusing/go-proxy/error" E "github.com/yusing/go-proxy/error"
) )
func NewHTTPHeaders(headers map[string]string) (http.Header, E.NestedError) { func ValidateHTTPHeaders(headers map[string]string) (http.Header, E.NestedError) {
h := make(http.Header) h := make(http.Header)
for k, v := range headers { for k, v := range headers {
vSplit := strings.Split(v, ",") vSplit := strings.Split(v, ",")
@ -15,5 +15,5 @@ func NewHTTPHeaders(headers map[string]string) (http.Header, E.NestedError) {
h.Add(k, strings.TrimSpace(header)) h.Add(k, strings.TrimSpace(header))
} }
} }
return h, E.Nil() return h, nil
} }

View file

@ -7,6 +7,6 @@ import (
type Host string type Host string
type Subdomain = Alias type Subdomain = Alias
func NewHost(s string) (Host, E.NestedError) { func ValidateHost(s string) (Host, E.NestedError) {
return Host(s), E.Nil() return Host(s), nil
} }

View file

@ -9,7 +9,7 @@ type PathMode string
func NewPathMode(pm string) (PathMode, E.NestedError) { func NewPathMode(pm string) (PathMode, E.NestedError) {
switch pm { switch pm {
case "", "forward": case "", "forward":
return PathMode(pm), E.Nil() return PathMode(pm), nil
default: default:
return "", E.Invalid("path mode", pm) return "", E.Invalid("path mode", pm)
} }

View file

@ -16,12 +16,12 @@ func NewPathPattern(s string) (PathPattern, E.NestedError) {
if !pathPattern.MatchString(string(s)) { if !pathPattern.MatchString(string(s)) {
return "", E.Invalid("path pattern", s) return "", E.Invalid("path pattern", s)
} }
return PathPattern(s), E.Nil() return PathPattern(s), nil
} }
func NewPathPatterns(s []string) (PathPatterns, E.NestedError) { func ValidatePathPatterns(s []string) (PathPatterns, E.NestedError) {
if len(s) == 0 { if len(s) == 0 {
return []PathPattern{"/"}, E.Nil() return []PathPattern{"/"}, nil
} }
pp := make(PathPatterns, len(s)) pp := make(PathPatterns, len(s))
for i, v := range s { for i, v := range s {
@ -31,7 +31,7 @@ func NewPathPatterns(s []string) (PathPatterns, E.NestedError) {
pp[i] = pattern pp[i] = pattern
} }
} }
return pp, E.Nil() return pp, nil
} }
var pathPattern = regexp.MustCompile("^((GET|POST|DELETE|PUT|PATCH|HEAD|OPTIONS|CONNECT)\\s)?(/\\w*)+/?$") var pathPattern = regexp.MustCompile("^((GET|POST|DELETE|PUT|PATCH|HEAD|OPTIONS|CONNECT)\\s)?(/\\w*)+/?$")

View file

@ -8,7 +8,7 @@ import (
type Port int type Port int
func NewPort(v string) (Port, E.NestedError) { func ValidatePort(v string) (Port, E.NestedError) {
p, err := strconv.Atoi(v) p, err := strconv.Atoi(v)
if err != nil { if err != nil {
return ErrPort, E.Invalid("port number", v).With(err) return ErrPort, E.Invalid("port number", v).With(err)
@ -21,14 +21,14 @@ func NewPortInt[Int int | uint16](v Int) (Port, E.NestedError) {
if err := pp.boundCheck(); err.HasError() { if err := pp.boundCheck(); err.HasError() {
return ErrPort, err return ErrPort, err
} }
return pp, E.Nil() return pp, nil
} }
func (p Port) boundCheck() E.NestedError { func (p Port) boundCheck() E.NestedError {
if p < MinPort || p > MaxPort { if p < MinPort || p > MaxPort {
return E.Invalid("port", p) return E.Invalid("port", p)
} }
return E.Nil() return nil
} }
const ( const (

View file

@ -1,8 +1,6 @@
package fields package fields
import ( import (
"strings"
E "github.com/yusing/go-proxy/error" E "github.com/yusing/go-proxy/error"
) )
@ -11,24 +9,11 @@ type Scheme string
func NewScheme(s string) (Scheme, E.NestedError) { func NewScheme(s string) (Scheme, E.NestedError) {
switch s { switch s {
case "http", "https", "tcp", "udp": case "http", "https", "tcp", "udp":
return Scheme(s), E.Nil() return Scheme(s), nil
} }
return "", E.Invalid("scheme", s) return "", E.Invalid("scheme", s)
} }
func NewSchemeFromPort(p string) (Scheme, E.NestedError) {
var s string
switch {
case strings.ContainsRune(p, ':'):
s = "tcp"
case strings.HasSuffix(p, "443"):
s = "https"
default:
s = "http"
}
return Scheme(s), E.Nil()
}
func (s Scheme) IsHTTP() bool { return s == "http" } func (s Scheme) IsHTTP() bool { return s == "http" }
func (s Scheme) IsHTTPS() bool { return s == "https" } func (s Scheme) IsHTTPS() bool { return s == "https" }
func (s Scheme) IsTCP() bool { return s == "tcp" } func (s Scheme) IsTCP() bool { return s == "tcp" }

View file

@ -0,0 +1,17 @@
package fields
import (
E "github.com/yusing/go-proxy/error"
)
type Signal string
func ValidateSignal(s string) (Signal, E.NestedError) {
switch s {
case "", "SIGINT", "SIGTERM", "SIGHUP", "SIGQUIT",
"INT", "TERM", "HUP", "QUIT":
return Signal(s), nil
}
return "", E.Invalid("signal", s)
}

View file

@ -0,0 +1,23 @@
package fields
import (
E "github.com/yusing/go-proxy/error"
)
type StopMethod string
const (
StopMethodPause StopMethod = "pause"
StopMethodStop StopMethod = "stop"
StopMethodKill StopMethod = "kill"
)
func ValidateStopMethod(s string) (StopMethod, E.NestedError) {
sm := StopMethod(s)
switch sm {
case StopMethodPause, StopMethodStop, StopMethodKill:
return sm, nil
default:
return "", E.Invalid("stop_method", sm)
}
}

View file

@ -12,13 +12,13 @@ type StreamPort struct {
ProxyPort Port `json:"proxy"` ProxyPort Port `json:"proxy"`
} }
func NewStreamPort(p string) (StreamPort, E.NestedError) { func ValidateStreamPort(p string) (StreamPort, E.NestedError) {
split := strings.Split(p, ":") split := strings.Split(p, ":")
if len(split) != 2 { if len(split) != 2 {
return StreamPort{}, E.Invalid("stream port", p).With("should be in 'x:y' format") return StreamPort{}, E.Invalid("stream port", p).With("should be in 'x:y' format")
} }
listeningPort, err := NewPort(split[0]) listeningPort, err := ValidatePort(split[0])
if err.HasError() { if err.HasError() {
return StreamPort{}, err return StreamPort{}, err
} }
@ -26,7 +26,7 @@ func NewStreamPort(p string) (StreamPort, E.NestedError) {
return StreamPort{}, err return StreamPort{}, err
} }
proxyPort, err := NewPort(split[1]) proxyPort, err := ValidatePort(split[1])
if err.HasError() { if err.HasError() {
proxyPort, err = parseNameToPort(split[1]) proxyPort, err = parseNameToPort(split[1])
if err.HasError() { if err.HasError() {
@ -37,13 +37,13 @@ func NewStreamPort(p string) (StreamPort, E.NestedError) {
return StreamPort{}, err return StreamPort{}, err
} }
return StreamPort{ListeningPort: listeningPort, ProxyPort: proxyPort}, E.Nil() return StreamPort{ListeningPort: listeningPort, ProxyPort: proxyPort}, nil
} }
func parseNameToPort(name string) (Port, E.NestedError) { func parseNameToPort(name string) (Port, E.NestedError) {
port, ok := common.NamePortMapTCP[name] port, ok := common.ServiceNamePortMapTCP[name]
if !ok { if !ok {
return -1, E.Unsupported("service", name) return -1, E.Unsupported("service", name)
} }
return Port(port), E.Nil() return Port(port), nil
} }

View file

@ -12,7 +12,7 @@ type StreamScheme struct {
ProxyScheme Scheme `json:"proxy"` ProxyScheme Scheme `json:"proxy"`
} }
func NewStreamScheme(s string) (ss *StreamScheme, err E.NestedError) { func ValidateStreamScheme(s string) (ss *StreamScheme, err E.NestedError) {
ss = &StreamScheme{} ss = &StreamScheme{}
parts := strings.Split(s, ":") parts := strings.Split(s, ":")
if len(parts) == 1 { if len(parts) == 1 {
@ -28,7 +28,7 @@ func NewStreamScheme(s string) (ss *StreamScheme, err E.NestedError) {
if err.HasError() { if err.HasError() {
return nil, err return nil, err
} }
return ss, E.Nil() return ss, nil
} }
func (s StreamScheme) String() string { func (s StreamScheme) String() string {

View file

@ -0,0 +1,18 @@
package fields
import (
"time"
E "github.com/yusing/go-proxy/error"
)
func ValidateDurationPostitive(value string) (time.Duration, E.NestedError) {
d, err := time.ParseDuration(value)
if err != nil {
return 0, E.Invalid("duration", value)
}
if d < 0 {
return 0, E.Invalid("duration", "negative value")
}
return d, nil
}

View file

@ -1,168 +1,160 @@
package provider package provider
import ( import (
"fmt"
"strings"
"github.com/docker/docker/api/types"
"github.com/sirupsen/logrus"
D "github.com/yusing/go-proxy/docker" D "github.com/yusing/go-proxy/docker"
E "github.com/yusing/go-proxy/error" E "github.com/yusing/go-proxy/error"
M "github.com/yusing/go-proxy/models" M "github.com/yusing/go-proxy/models"
PT "github.com/yusing/go-proxy/proxy/fields" R "github.com/yusing/go-proxy/route"
U "github.com/yusing/go-proxy/utils"
W "github.com/yusing/go-proxy/watcher" W "github.com/yusing/go-proxy/watcher"
. "github.com/yusing/go-proxy/watcher/event"
) )
type DockerProvider struct { type DockerProvider struct {
dockerHost string dockerHost, hostname string
} }
func DockerProviderImpl(dockerHost string) ProviderImpl { func DockerProviderImpl(dockerHost string) ProviderImpl {
return &DockerProvider{dockerHost: dockerHost} return &DockerProvider{dockerHost: dockerHost}
} }
// GetProxyEntries returns proxy entries from a docker client. func (p *DockerProvider) NewWatcher() W.Watcher {
// return W.NewDockerWatcher(p.dockerHost)
// It retrieves the docker client information using the dockerhelper.GetClientInfo method. }
// Then, it iterates over the containers in the docker client information and calls
// the getEntriesFromLabels method to get the proxy entries for each container. func (p *DockerProvider) LoadRoutesImpl() (routes R.Routes, err E.NestedError) {
// Any errors encountered during the process are added to the ne error object.
// Finally, it returns the collected proxy entries and the ne error object.
//
// Parameters:
// - p: A pointer to the DockerProvider struct.
//
// Returns:
// - P.EntryModelSlice: (non-nil) A slice of EntryModel structs representing the proxy entries.
// - error: An error object if there was an error retrieving the docker client information or parsing the labels.
func (p DockerProvider) GetProxyEntries() (M.ProxyEntries, E.NestedError) {
entries := M.NewProxyEntries() entries := M.NewProxyEntries()
info, err := D.GetClientInfo(p.dockerHost) info, err := D.GetClientInfo(p.dockerHost, true)
if err.HasError() { if err.HasError() {
return entries, err return routes, E.FailWith("connect to docker", err)
} }
errors := E.NewBuilder("errors when parse docker labels") errors := E.NewBuilder("errors when parse docker labels")
for _, container := range info.Containers { for _, c := range info.Containers {
en, err := p.getEntriesFromLabels(&container, info.Host) container := D.FromDocker(&c, p.dockerHost)
if container.IsExcluded {
continue
}
newEntries, err := p.entriesFromContainerLabels(container)
if err.HasError() { if err.HasError() {
errors.Add(err) errors.Add(err)
} }
// although err is not nil // although err is not nil
// there may be some valid entries in `en` // there may be some valid entries in `en`
dups := entries.MergeWith(en) dups := entries.MergeFrom(newEntries)
// add the duplicate proxy entries to the error // add the duplicate proxy entries to the error
dups.EachKV(func(k string, v *M.ProxyEntry) { dups.RangeAll(func(k string, v *M.ProxyEntry) {
errors.Addf("duplicate alias %s", k) errors.Addf("duplicate alias %s", k)
}) })
} }
return entries, errors.Build() entries.RangeAll(func(_ string, e *M.ProxyEntry) {
e.DockerHost = p.dockerHost
})
routes, err = R.FromEntries(entries)
errors.Add(err)
return routes, errors.Build()
} }
func (p *DockerProvider) NewWatcher() W.Watcher { func (p *DockerProvider) OnEvent(event Event, routes R.Routes) (res EventResult) {
return W.NewDockerWatcher(p.dockerHost) b := E.NewBuilder("event %s error", event)
defer b.To(&res.err)
routes.RangeAll(func(k string, v R.Route) {
if v.Entry().ContainerName == event.ActorName {
b.Add(v.Stop())
routes.Delete(k)
res.nRemoved++
}
})
switch event.Action {
case ActionStarted, ActionCreated, ActionModified:
client, err := D.ConnectClient(p.dockerHost)
if err.HasError() {
b.Add(E.FailWith("connect to docker", err))
return
}
defer client.Close()
cont, err := client.Inspect(event.ActorID)
if err.HasError() {
b.Add(E.FailWith("inspect container", err))
return
}
entries, err := p.entriesFromContainerLabels(cont)
b.Add(err)
entries.RangeAll(func(alias string, entry *M.ProxyEntry) {
if routes.Has(alias) {
b.Add(E.AlreadyExist("alias", alias))
} else {
if route, err := R.NewRoute(entry); err.HasError() {
b.Add(err)
} else {
routes.Store(alias, route)
b.Add(route.Start())
res.nAdded++
}
}
})
}
return
} }
// Returns a list of proxy entries for a container. // Returns a list of proxy entries for a container.
// Always non-nil // Always non-nil
func (p *DockerProvider) getEntriesFromLabels(container *types.Container, clientHost string) (M.ProxyEntries, E.NestedError) { func (p *DockerProvider) entriesFromContainerLabels(container D.Container) (M.ProxyEntries, E.NestedError) {
var mainAlias string
var aliases PT.Aliases
if exclude, ok := container.Labels[D.NSProxy+".exclude"]; ok {
if U.ParseBool(exclude) {
return M.NewProxyEntries(), E.Nil()
}
}
// set mainAlias to docker compose service name if available
if serviceName, ok := container.Labels["com.docker.compose.service"]; ok {
mainAlias = serviceName
}
// if mainAlias is not set,
// or container name is different from service name
// use container name
if containerName := strings.TrimPrefix(container.Names[0], "/"); containerName != mainAlias {
mainAlias = containerName
}
if l, ok := container.Labels[D.NSProxy+".aliases"]; ok {
aliases = PT.NewAliases(l)
delete(container.Labels, D.NSProxy+"proxy.aliases")
} else {
aliases = PT.NewAliases(mainAlias)
}
entries := M.NewProxyEntries() entries := M.NewProxyEntries()
// find first port, return if no port exposed
defaultPort, err := findFirstPort(container)
if err.HasError() {
logrus.Debug(mainAlias, " ", err.Error())
}
// init entries map for all aliases // init entries map for all aliases
aliases.ForEach(func(a PT.Alias) { for _, a := range container.Aliases {
entries.Set(string(a), &M.ProxyEntry{ entries.Store(a, &M.ProxyEntry{
Alias: string(a), Alias: a,
Host: clientHost, Host: p.hostname,
Port: defaultPort, ProxyProperties: container.ProxyProperties,
}) })
}) }
errors := E.NewBuilder("failed to apply label for %q", mainAlias) errors := E.NewBuilder("failed to apply label")
for key, val := range container.Labels { for key, val := range container.Labels {
lbl, err := D.ParseLabel(key, val) errors.Add(p.applyLabel(entries, key, val))
if err.HasError() {
errors.Add(E.From(err).Subject(key))
continue
}
if lbl.Namespace != D.NSProxy {
continue
}
if lbl.Target == wildcardAlias {
// apply label for all aliases
entries.EachKV(func(a string, e *M.ProxyEntry) {
if err = D.ApplyLabel(e, lbl); err.HasError() {
errors.Add(E.From(err).Subject(lbl.Target))
}
})
} else {
config, ok := entries.UnsafeGet(lbl.Target)
if !ok {
errors.Add(E.NotExists("alias", lbl.Target))
continue
}
if err = D.ApplyLabel(config, lbl); err.HasError() {
errors.Add(err.Subject(lbl.Target))
}
}
} }
entries.EachKV(func(a string, e *M.ProxyEntry) { return entries, errors.Build().Subject(container.ContainerName)
if e.Port == "" {
entries.UnsafeDelete(a)
}
})
return entries, errors.Build()
} }
func findFirstPort(c *types.Container) (string, E.NestedError) { func (p *DockerProvider) applyLabel(entries M.ProxyEntries, key, val string) (res E.NestedError) {
if len(c.Ports) == 0 { b := E.NewBuilder("errors in label %s", key)
return "", E.FailureWhy("findFirstPort", "no port exposed") defer b.To(&res)
lbl, err := D.ParseLabel(key, val)
if err.HasError() {
b.Add(err.Subject(key))
} }
for _, p := range c.Ports { if lbl.Namespace != D.NSProxy {
if p.PublicPort != 0 { return
return fmt.Sprint(p.PublicPort), E.Nil() }
if lbl.Target == D.WildcardAlias {
// apply label for all aliases
entries.RangeAll(func(a string, e *M.ProxyEntry) {
if err = D.ApplyLabel(e, lbl); err.HasError() {
b.Add(err.Subject(lbl.Target))
}
})
} else {
config, ok := entries.Load(lbl.Target)
if !ok {
b.Add(E.NotExist("alias", lbl.Target))
return
}
if err = D.ApplyLabel(config, lbl); err.HasError() {
b.Add(err.Subject(lbl.Target))
} }
} }
return "", E.Failure("findFirstPort") return
} }
const wildcardAlias = "*"

View file

@ -7,8 +7,10 @@ import (
"github.com/yusing/go-proxy/common" "github.com/yusing/go-proxy/common"
E "github.com/yusing/go-proxy/error" E "github.com/yusing/go-proxy/error"
M "github.com/yusing/go-proxy/models" M "github.com/yusing/go-proxy/models"
R "github.com/yusing/go-proxy/route"
U "github.com/yusing/go-proxy/utils" U "github.com/yusing/go-proxy/utils"
W "github.com/yusing/go-proxy/watcher" W "github.com/yusing/go-proxy/watcher"
. "github.com/yusing/go-proxy/watcher/event"
) )
type FileProvider struct { type FileProvider struct {
@ -27,26 +29,53 @@ func Validate(data []byte) E.NestedError {
return U.ValidateYaml(U.GetSchema(common.ProvidersSchemaPath), data) return U.ValidateYaml(U.GetSchema(common.ProvidersSchemaPath), data)
} }
func (p *FileProvider) String() string { func (p FileProvider) OnEvent(event Event, routes R.Routes) (res EventResult) {
return p.fileName b := E.NewBuilder("event %s error", event)
defer b.To(&res.err)
newRoutes, err := p.LoadRoutesImpl()
if err.HasError() {
b.Add(err)
return
}
routes.RangeAll(func(_ string, v R.Route) {
b.Add(v.Stop())
})
routes.Clear()
newRoutes.RangeAll(func(_ string, v R.Route) {
b.Add(v.Start())
})
routes.MergeFrom(newRoutes)
return
} }
func (p *FileProvider) GetProxyEntries() (M.ProxyEntries, E.NestedError) { func (p *FileProvider) LoadRoutesImpl() (routes R.Routes, res E.NestedError) {
b := E.NewBuilder("file %q validation failure", p.fileName)
defer b.To(&res)
entries := M.NewProxyEntries() entries := M.NewProxyEntries()
data, err := E.Check(os.ReadFile(p.path)) data, err := E.Check(os.ReadFile(p.path))
if err.HasError() { if err.HasError() {
return entries, E.Failure("read file").Subject(p).With(err) b.Add(E.FailWith("read file", err))
return
} }
ne := E.Failure("validation").Subject(p)
if !common.NoSchemaValidation { if !common.NoSchemaValidation {
if err = Validate(data); err.HasError() { if err = Validate(data); err.HasError() {
return entries, ne.With(err) b.Add(err)
return
} }
} }
if err = entries.UnmarshalFromYAML(data); err.HasError() { if err = entries.UnmarshalFromYAML(data); err.HasError() {
return entries, ne.With(err) b.Add(err)
return
} }
return entries, E.Nil()
return R.FromEntries(entries)
} }
func (p *FileProvider) NewWatcher() W.Watcher { func (p *FileProvider) NewWatcher() W.Watcher {

View file

@ -4,38 +4,40 @@ import (
"context" "context"
"fmt" "fmt"
"path" "path"
"time"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
E "github.com/yusing/go-proxy/error" E "github.com/yusing/go-proxy/error"
M "github.com/yusing/go-proxy/models"
R "github.com/yusing/go-proxy/route" R "github.com/yusing/go-proxy/route"
W "github.com/yusing/go-proxy/watcher" W "github.com/yusing/go-proxy/watcher"
. "github.com/yusing/go-proxy/watcher/event"
) )
type ProviderImpl interface { type (
GetProxyEntries() (M.ProxyEntries, E.NestedError) Provider struct {
NewWatcher() W.Watcher ProviderImpl
}
type Provider struct { name string
ProviderImpl t ProviderType
routes R.Routes
name string watcher W.Watcher
t ProviderType watcherCtx context.Context
routes *R.Routes watcherCancel context.CancelFunc
reloadReqCh chan struct{}
watcher W.Watcher l *logrus.Entry
watcherCtx context.Context }
watcherCancel context.CancelFunc ProviderImpl interface {
NewWatcher() W.Watcher
l *logrus.Entry LoadRoutesImpl() (R.Routes, E.NestedError)
OnEvent(event Event, routes R.Routes) EventResult
cooldownCh chan struct{} }
} ProviderType string
EventResult struct {
type ProviderType string nRemoved int
nAdded int
err E.NestedError
}
)
const ( const (
ProviderTypeDocker ProviderType = "docker" ProviderTypeDocker ProviderType = "docker"
@ -44,16 +46,14 @@ const (
func newProvider(name string, t ProviderType) *Provider { func newProvider(name string, t ProviderType) *Provider {
p := &Provider{ p := &Provider{
name: name, name: name,
t: t, t: t,
routes: R.NewRoutes(), routes: R.NewRoutes(),
reloadReqCh: make(chan struct{}, 1),
cooldownCh: make(chan struct{}, 1),
} }
p.l = logrus.WithField("provider", p) p.l = logrus.WithField("provider", p)
go p.processReloadRequests()
return p return p
} }
func NewFileProvider(filename string) *Provider { func NewFileProvider(filename string) *Provider {
name := path.Base(filename) name := path.Base(filename)
p := newProvider(name, ProviderTypeFile) p := newProvider(name, ProviderTypeFile)
@ -78,25 +78,21 @@ func (p *Provider) GetType() ProviderType {
} }
func (p *Provider) String() string { func (p *Provider) String() string {
return fmt.Sprintf("%s: %s", p.t, p.name) return fmt.Sprintf("%s-%s", p.t, p.name)
} }
func (p *Provider) StartAllRoutes() E.NestedError { func (p *Provider) StartAllRoutes() (res E.NestedError) {
err := p.loadRoutes() errors := E.NewBuilder("errors in routes")
defer errors.To(&res)
// start watcher no matter load success or not // start watcher no matter load success or not
p.watcherCtx, p.watcherCancel = context.WithCancel(context.Background()) p.watcherCtx, p.watcherCancel = context.WithCancel(context.Background())
go p.watchEvents() go p.watchEvents()
errors := E.NewBuilder("errors in routes")
nStarted := 0 nStarted := 0
nFailed := 0 nFailed := 0
if err.HasError() { p.routes.RangeAll(func(alias string, r R.Route) {
errors.Add(err)
}
p.routes.EachKVParallel(func(alias string, r R.Route) {
if err := r.Start(); err.HasError() { if err := r.Start(); err.HasError() {
errors.Add(err.Subject(r)) errors.Add(err.Subject(r))
nFailed++ nFailed++
@ -106,18 +102,21 @@ func (p *Provider) StartAllRoutes() E.NestedError {
}) })
p.l.Debugf("%d routes started, %d failed", nStarted, nFailed) p.l.Debugf("%d routes started, %d failed", nStarted, nFailed)
return errors.Build() return
} }
func (p *Provider) StopAllRoutes() E.NestedError { func (p *Provider) StopAllRoutes() (res E.NestedError) {
if p.watcherCancel != nil { if p.watcherCancel != nil {
p.watcherCancel() p.watcherCancel()
p.watcherCancel = nil p.watcherCancel = nil
} }
errors := E.NewBuilder("errors stopping routes for provider %q", p.name) errors := E.NewBuilder("errors stopping routes for provider %q", p.name)
defer errors.To(&res)
nStopped := 0 nStopped := 0
nFailed := 0 nFailed := 0
p.routes.EachKVParallel(func(alias string, r R.Route) { p.routes.RangeAll(func(alias string, r R.Route) {
if err := r.Stop(); err.HasError() { if err := r.Stop(); err.HasError() {
errors.Add(err.Subject(r)) errors.Add(err.Subject(r))
nFailed++ nFailed++
@ -126,20 +125,24 @@ func (p *Provider) StopAllRoutes() E.NestedError {
} }
}) })
p.l.Debugf("%d routes stopped, %d failed", nStopped, nFailed) p.l.Debugf("%d routes stopped, %d failed", nStopped, nFailed)
return errors.Build() return
} }
func (p *Provider) ReloadRoutes() { func (p *Provider) RangeRoutes(do func(string, R.Route)) {
select { p.routes.RangeAll(do)
case p.reloadReqCh <- struct{}{}: }
// Successfully sent reload request
default: func (p *Provider) GetRoute(alias string) (R.Route, bool) {
// Reload request already in progress, ignore this request return p.routes.Load(alias)
}
func (p *Provider) LoadRoutes() E.NestedError {
routes, err := p.LoadRoutesImpl()
if err != nil {
return err
} }
} p.routes = routes
return nil
func (p *Provider) GetCurrentRoutes() *R.Routes {
return p.routes
} }
func (p *Provider) watchEvents() { func (p *Provider) watchEvents() {
@ -151,11 +154,15 @@ func (p *Provider) watchEvents() {
case <-p.watcherCtx.Done(): case <-p.watcherCtx.Done():
return return
case event, ok := <-events: case event, ok := <-events:
if !ok { if !ok { // channel closed
return return
} }
l.Info(event) res := p.OnEvent(event, p.routes)
p.ReloadRoutes() l.Infof("%s event %q", event.Type, event)
l.Infof("%d route added, %d routes removed", res.nAdded, res.nRemoved)
if res.err.HasError() {
l.Error(res.err)
}
case err, ok := <-errs: case err, ok := <-errs:
if !ok { if !ok {
return return
@ -167,50 +174,3 @@ func (p *Provider) watchEvents() {
} }
} }
} }
func (p *Provider) processReloadRequests() {
for range p.reloadReqCh {
// prevent busy loop caused by a container
// repeating crashing and restarting
select {
case p.cooldownCh <- struct{}{}:
p.l.Info("Starting to reload routes")
nRoutes := p.routes.Size()
p.StopAllRoutes()
p.loadRoutes()
p.StartAllRoutes()
p.l.Infof("Routes reloaded (%d -> %d)", nRoutes, p.routes.Size())
go func() {
time.Sleep(reloadCooldown)
<-p.cooldownCh
}()
default:
}
}
}
func (p *Provider) loadRoutes() E.NestedError {
entries, err := p.GetProxyEntries()
if err.HasError() {
p.l.Warn(err.Subject(p))
}
p.routes = R.NewRoutes()
errors := E.NewBuilder("errors loading routes from %s", p)
entries.EachKV(func(a string, e *M.ProxyEntry) {
e.Alias = a
r, err := R.NewRoute(e)
if err.HasError() {
errors.Add(err.Subject(a))
} else {
p.routes.Set(a, r)
}
})
return errors.Build()
}
const reloadCooldown = 50 * time.Millisecond

View file

@ -207,7 +207,7 @@ func joinURLPath(a, b *url.URL) (path, rawpath string) {
// } // }
// //
// TODO: headers in ModifyResponse // TODO: headers in ModifyResponse
func NewReverseProxy(target *url.URL, transport *http.Transport, entry *Entry) *ReverseProxy { func NewReverseProxy(target *url.URL, transport http.RoundTripper, entry *ReverseProxyEntry) *ReverseProxy {
// check on init rather than on request // check on init rather than on request
var setHeaders = func(r *http.Request) {} var setHeaders = func(r *http.Request) {}
var hideHeaders = func(r *http.Request) {} var hideHeaders = func(r *http.Request) {}

View file

@ -2,8 +2,8 @@ package route
import ( import (
"crypto/tls" "crypto/tls"
"fmt"
"net" "net"
"sync"
"time" "time"
"net/http" "net/http"
@ -11,6 +11,7 @@ import (
"strings" "strings"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/yusing/go-proxy/docker/idlewatcher"
E "github.com/yusing/go-proxy/error" E "github.com/yusing/go-proxy/error"
P "github.com/yusing/go-proxy/proxy" P "github.com/yusing/go-proxy/proxy"
PT "github.com/yusing/go-proxy/proxy/fields" PT "github.com/yusing/go-proxy/proxy/fields"
@ -23,57 +24,65 @@ type (
TargetURL *URL `json:"target_url"` TargetURL *URL `json:"target_url"`
PathPatterns PT.PathPatterns `json:"path_patterns"` PathPatterns PT.PathPatterns `json:"path_patterns"`
entry *P.ReverseProxyEntry
mux *http.ServeMux mux *http.ServeMux
handler *P.ReverseProxy handler *P.ReverseProxy
regIdleWatcher func() E.NestedError
unregIdleWatcher func()
} }
URL url.URL URL url.URL
PathKey = PT.PathPattern
SubdomainKey = PT.Alias SubdomainKey = PT.Alias
) )
var httpRoutes = F.NewMap[SubdomainKey, *HTTPRoute]() func NewHTTPRoute(entry *P.ReverseProxyEntry) (*HTTPRoute, E.NestedError) {
var trans http.RoundTripper
var regIdleWatcher func() E.NestedError
var unregIdleWatcher func()
func NewHTTPRoute(entry *P.Entry) (*HTTPRoute, E.NestedError) {
var tr *http.Transport
if entry.NoTLSVerify { if entry.NoTLSVerify {
tr = transportNoTLS trans = transportNoTLS
} else { } else {
tr = transport trans = transport
} }
rp := P.NewReverseProxy(entry.URL, tr, entry) rp := P.NewReverseProxy(entry.URL, trans, entry)
httpRoutes.Lock() if entry.UseIdleWatcher() {
defer httpRoutes.Unlock() regIdleWatcher = func() E.NestedError {
watcher, err := idlewatcher.Register(entry)
var r *HTTPRoute if err.HasError() {
r, ok := httpRoutes.UnsafeGet(entry.Alias) return err
if !ok { }
r = &HTTPRoute{ // patch round-tripper
Alias: entry.Alias, rp.Transport = watcher.PatchRoundTripper(trans)
TargetURL: (*URL)(entry.URL), return nil
PathPatterns: entry.PathPatterns,
handler: rp,
} }
httpRoutes.UnsafeSet(entry.Alias, r) unregIdleWatcher = func() {
} idlewatcher.Unregister(entry.ContainerName)
rp.Transport = trans
rewrite := rp.Rewrite
if logrus.GetLevel() == logrus.DebugLevel {
l := logrus.WithField("alias", entry.Alias)
rp.Rewrite = func(pr *P.ProxyRequest) {
l.Debug("request URL: ", pr.In.Host, pr.In.URL.Path)
l.Debug("request headers: ", pr.In.Header)
rewrite(pr)
} }
} else {
rp.Rewrite = rewrite
} }
return r, E.Nil() httpRoutesMu.Lock()
defer httpRoutesMu.Unlock()
_, exists := httpRoutes.Load(entry.Alias)
if exists {
return nil, E.AlreadyExist("HTTPRoute alias", entry.Alias)
}
r := &HTTPRoute{
Alias: entry.Alias,
TargetURL: (*URL)(entry.URL),
PathPatterns: entry.PathPatterns,
entry: entry,
handler: rp,
regIdleWatcher: regIdleWatcher,
unregIdleWatcher: unregIdleWatcher,
}
return r, nil
} }
func (r *HTTPRoute) String() string { func (r *HTTPRoute) String() string {
@ -81,18 +90,35 @@ func (r *HTTPRoute) String() string {
} }
func (r *HTTPRoute) Start() E.NestedError { func (r *HTTPRoute) Start() E.NestedError {
httpRoutesMu.Lock()
defer httpRoutesMu.Unlock()
if r.regIdleWatcher != nil {
if err := r.regIdleWatcher(); err.HasError() {
return err
}
}
r.mux = http.NewServeMux() r.mux = http.NewServeMux()
for _, p := range r.PathPatterns { for _, p := range r.PathPatterns {
r.mux.HandleFunc(string(p), r.handler.ServeHTTP) r.mux.HandleFunc(string(p), r.handler.ServeHTTP)
} }
httpRoutes.Set(r.Alias, r)
return E.Nil() httpRoutes.Store(r.Alias, r)
return nil
} }
func (r *HTTPRoute) Stop() E.NestedError { func (r *HTTPRoute) Stop() E.NestedError {
httpRoutesMu.Lock()
defer httpRoutesMu.Unlock()
if r.unregIdleWatcher != nil {
r.unregIdleWatcher()
}
r.mux = nil r.mux = nil
httpRoutes.Delete(r.Alias) httpRoutes.Delete(r.Alias)
return E.Nil() return nil
} }
func (u *URL) String() string { func (u *URL) String() string {
@ -104,27 +130,26 @@ func (u *URL) MarshalText() (text []byte, err error) {
} }
func ProxyHandler(w http.ResponseWriter, r *http.Request) { func ProxyHandler(w http.ResponseWriter, r *http.Request) {
mux, err := findMux(r.Host, PathKey(r.URL.Path)) mux, err := findMux(r.Host)
if err != nil { if err != nil {
err = E.Failure("request"). err = E.Failure("request").
Subjectf("%s %s%s", r.Method, r.Host, r.URL.Path). Subjectf("%s %s%s", r.Method, r.Host, r.URL.Path).
With(err) With(err)
http.Error(w, err.Error(), http.StatusNotFound) http.Error(w, err.String(), http.StatusNotFound)
logrus.Error(err) logrus.Error(err)
return return
} }
mux.ServeHTTP(w, r) mux.ServeHTTP(w, r)
} }
func findMux(host string, path PathKey) (*http.ServeMux, error) { func findMux(host string) (*http.ServeMux, E.NestedError) {
sd := strings.Split(host, ".")[0] sd := strings.Split(host, ".")[0]
if r, ok := httpRoutes.UnsafeGet(PT.Alias(sd)); ok { if r, ok := httpRoutes.Load(PT.Alias(sd)); ok {
return r.mux, nil return r.mux, nil
} }
return nil, E.NotExists("route", fmt.Sprintf("subdomain: %s, path: %s", sd, path)) return nil, E.NotExist("route", sd)
} }
// TODO: default + per proxy
var ( var (
transport = &http.Transport{ transport = &http.Transport{
Proxy: http.ProxyFromEnvironment, Proxy: http.ProxyFromEnvironment,
@ -135,10 +160,12 @@ var (
MaxIdleConns: 1000, MaxIdleConns: 1000,
MaxIdleConnsPerHost: 1000, MaxIdleConnsPerHost: 1000,
} }
transportNoTLS = func() *http.Transport { transportNoTLS = func() *http.Transport {
var clone = transport.Clone() var clone = transport.Clone()
clone.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} clone.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
return clone return clone
}() }()
httpRoutes = F.NewMapOf[SubdomainKey, *HTTPRoute]()
httpRoutesMu sync.Mutex
) )

View file

@ -1,6 +1,9 @@
package route package route
import ( import (
"fmt"
"net/url"
E "github.com/yusing/go-proxy/error" E "github.com/yusing/go-proxy/error"
M "github.com/yusing/go-proxy/models" M "github.com/yusing/go-proxy/models"
P "github.com/yusing/go-proxy/proxy" P "github.com/yusing/go-proxy/proxy"
@ -9,27 +12,81 @@ import (
type ( type (
Route interface { Route interface {
RouteImpl
Entry() *M.ProxyEntry
Type() RouteType
URL() *url.URL
}
Routes = F.Map[string, Route]
RouteType string
RouteImpl interface {
Start() E.NestedError Start() E.NestedError
Stop() E.NestedError Stop() E.NestedError
String() string String() string
} }
Routes = F.Map[string, Route] route struct {
RouteImpl
type_ RouteType
entry *M.ProxyEntry
}
)
const (
RouteTypeStream RouteType = "stream"
RouteTypeReverseProxy RouteType = "reverse_proxy"
) )
// function alias // function alias
var NewRoutes = F.NewMap[string, Route] var NewRoutes = F.NewMapOf[string, Route]
func NewRoute(en *M.ProxyEntry) (Route, E.NestedError) { func NewRoute(en *M.ProxyEntry) (Route, E.NestedError) {
entry, err := P.NewEntry(en) rt, err := P.ValidateEntry(en)
if err.HasError() { if err.HasError() {
return nil, err return nil, err
} }
switch e := entry.(type) {
var t RouteType
switch e := rt.(type) {
case *P.StreamEntry: case *P.StreamEntry:
return NewStreamRoute(e) rt, err = NewStreamRoute(e)
case *P.Entry: t = RouteTypeStream
return NewHTTPRoute(e) case *P.ReverseProxyEntry:
rt, err = NewHTTPRoute(e)
t = RouteTypeReverseProxy
default: default:
panic("bug: should not reach here") panic("bug: should not reach here")
} }
return &route{RouteImpl: rt.(RouteImpl), entry: en, type_: t}, err
}
func (rt *route) Entry() *M.ProxyEntry {
return rt.entry
}
func (rt *route) Type() RouteType {
return rt.type_
}
func (rt *route) URL() *url.URL {
url, _ := url.Parse(fmt.Sprintf("%s://%s", rt.entry.Scheme, rt.entry.Host))
return url
}
func FromEntries(entries M.ProxyEntries) (Routes, E.NestedError) {
b := E.NewBuilder("errors in routes")
routes := NewRoutes()
entries.RangeAll(func(alias string, entry *M.ProxyEntry) {
entry.Alias = alias
r, err := NewRoute(entry)
if err.HasError() {
b.Add(err.Subject(alias))
} else {
routes.Store(alias, r)
}
})
return routes, b.Build()
} }

View file

@ -12,7 +12,7 @@ import (
) )
type StreamRoute struct { type StreamRoute struct {
*P.StreamEntry P.StreamEntry
StreamImpl `json:"-"` StreamImpl `json:"-"`
wg sync.WaitGroup wg sync.WaitGroup
@ -35,7 +35,7 @@ func NewStreamRoute(entry *P.StreamEntry) (*StreamRoute, E.NestedError) {
return nil, E.Unsupported("scheme", fmt.Sprintf("%v -> %v", entry.Scheme.ListeningScheme, entry.Scheme.ProxyScheme)) return nil, E.Unsupported("scheme", fmt.Sprintf("%v -> %v", entry.Scheme.ListeningScheme, entry.Scheme.ProxyScheme))
} }
base := &StreamRoute{ base := &StreamRoute{
StreamEntry: entry, StreamEntry: *entry,
wg: sync.WaitGroup{}, wg: sync.WaitGroup{},
connCh: make(chan any), connCh: make(chan any),
} }
@ -45,11 +45,11 @@ func NewStreamRoute(entry *P.StreamEntry) (*StreamRoute, E.NestedError) {
base.StreamImpl = NewUDPRoute(base) base.StreamImpl = NewUDPRoute(base)
} }
base.l = logrus.WithField("route", base.StreamImpl) base.l = logrus.WithField("route", base.StreamImpl)
return base, E.Nil() return base, nil
} }
func (r *StreamRoute) String() string { func (r *StreamRoute) String() string {
return fmt.Sprintf("%s-stream: %s", r.Scheme, r.Alias) return fmt.Sprintf("%s stream: %s", r.Scheme, r.Alias)
} }
func (r *StreamRoute) Start() E.NestedError { func (r *StreamRoute) Start() E.NestedError {
@ -59,13 +59,13 @@ func (r *StreamRoute) Start() E.NestedError {
r.stopCh = make(chan struct{}, 1) r.stopCh = make(chan struct{}, 1)
r.wg.Wait() r.wg.Wait()
if err := r.Setup(); err != nil { if err := r.Setup(); err != nil {
return E.Failure("setup").With(err) return E.FailWith("setup", err)
} }
r.started.Store(true) r.started.Store(true)
r.wg.Add(2) r.wg.Add(2)
go r.grAcceptConnections() go r.grAcceptConnections()
go r.grHandleConnections() go r.grHandleConnections()
return E.Nil() return nil
} }
func (r *StreamRoute) Stop() E.NestedError { func (r *StreamRoute) Stop() E.NestedError {
@ -88,7 +88,7 @@ func (r *StreamRoute) Stop() E.NestedError {
case <-time.After(streamStopListenTimeout): case <-time.After(streamStopListenTimeout):
l.Error("timed out waiting for connections") l.Error("timed out waiting for connections")
} }
return E.Nil() return nil
} }
func (r *StreamRoute) grAcceptConnections() { func (r *StreamRoute) grAcceptConnections() {

View file

@ -65,9 +65,10 @@ func (route *TCPRoute) Handle(c any) error {
}() }()
route.mu.Lock() route.mu.Lock()
defer route.mu.Unlock()
pipe := U.NewBidirectionalPipe(pipeCtx, clientConn, serverConn) pipe := U.NewBidirectionalPipe(pipeCtx, clientConn, serverConn)
route.pipe = append(route.pipe, pipe) route.pipe = append(route.pipe, pipe)
route.mu.Unlock()
return pipe.Start() return pipe.Start()
} }
@ -78,7 +79,7 @@ func (route *TCPRoute) CloseListeners() {
route.listener.Close() route.listener.Close()
route.listener = nil route.listener = nil
for _, pipe := range route.pipe { for _, pipe := range route.pipe {
if err := pipe.Stop(); err.HasError() { if err := pipe.Stop(); err != nil {
route.l.Error(err) route.l.Error(err)
} }
} }

View file

@ -1,229 +1,116 @@
package functional package functional
import ( import (
"context" "github.com/puzpuzpuz/xsync/v3"
"sync"
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
E "github.com/yusing/go-proxy/error" E "github.com/yusing/go-proxy/error"
) )
type Map[KT comparable, VT any] struct { type Map[KT comparable, VT any] struct {
m map[KT]VT *xsync.MapOf[KT, VT]
defVals map[KT]VT
sync.RWMutex
} }
// NewMap creates a new Map with the given map as its initial values. func NewMapOf[KT comparable, VT any](options ...func(*xsync.MapConfig)) Map[KT, VT] {
// return Map[KT, VT]{xsync.NewMapOf[KT, VT](options...)}
// Parameters:
// - dv: optional default values for the Map
//
// Return:
// - *Map[KT, VT]: a pointer to the newly created Map.
func NewMap[KT comparable, VT any](dv ...map[KT]VT) *Map[KT, VT] {
return NewMapFrom(make(map[KT]VT), dv...)
} }
// NewMapOf creates a new Map with the given map as its initial values. func NewMapFrom[KT comparable, VT any](m map[KT]VT) (res Map[KT, VT]) {
// res = NewMapOf[KT, VT](xsync.WithPresize(len(m)))
// Type parameters: for k, v := range m {
// - M: type for the new map. res.Store(k, v)
//
// Parameters:
// - dv: optional default values for the Map
//
// Return:
// - *Map[KT, VT]: a pointer to the newly created Map.
func NewMapOf[M Map[KT, VT], KT comparable, VT any](dv ...map[KT]VT) *Map[KT, VT] {
return NewMapFrom(make(map[KT]VT), dv...)
}
// NewMapFrom creates a new Map with the given map as its initial values.
//
// Parameters:
// - from: a map of type KT to VT, which will be the initial values of the Map.
// - dv: optional default values for the Map
//
// Return:
// - *Map[KT, VT]: a pointer to the newly created Map.
func NewMapFrom[KT comparable, VT any](from map[KT]VT, dv ...map[KT]VT) *Map[KT, VT] {
if len(dv) > 0 {
return &Map[KT, VT]{m: from, defVals: dv[0]}
} }
return &Map[KT, VT]{m: from} return
} }
func (m *Map[KT, VT]) Set(key KT, value VT) { func MapFind[KT comparable, VT, CT any](m Map[KT, VT], criteria func(VT) (CT, bool)) (_ CT) {
m.Lock() result := make(chan CT, 1)
m.m[key] = value
m.Unlock()
}
func (m *Map[KT, VT]) Get(key KT) VT { m.Range(func(key KT, value VT) bool {
m.RLock() select {
defer m.RUnlock() case <-result: // already have a result
value, ok := m.m[key] return false // stop iteration
if !ok && m.defVals != nil { default:
return m.defVals[key] if got, ok := criteria(value); ok {
} result <- got
return value return false
}
// Find searches for the first element in the map that satisfies the given criteria.
//
// Parameters:
// - criteria: a function that takes a value of type VT and returns a tuple of any type and a boolean.
//
// Return:
// - any: the first value that satisfies the criteria, or nil if no match is found.
func (m *Map[KT, VT]) Find(criteria func(VT) (any, bool)) any {
m.RLock()
defer m.RUnlock()
result := make(chan any)
wg := sync.WaitGroup{}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
for _, v := range m.m {
wg.Add(1)
go func(val VT) {
defer wg.Done()
if value, ok := criteria(val); ok {
select {
case result <- value:
cancel() // Cancel other goroutines if a result is found
case <-ctx.Done(): // If already cancelled
return
}
} }
}(v) return true
}
go func() {
wg.Wait()
close(result)
}()
// The first valid match, if any
select {
case res, ok := <-result:
if ok {
return res
} }
case <-ctx.Done(): })
select {
case v := <-result:
return v
default:
return
} }
return nil // Return nil if no matches found
} }
func (m *Map[KT, VT]) UnsafeGet(key KT) (VT, bool) { // MergeFrom add contents from another `Map`, ignore duplicated keys
value, ok := m.m[key]
return value, ok
}
func (m *Map[KT, VT]) UnsafeSet(key KT, value VT) {
m.m[key] = value
}
func (m *Map[KT, VT]) Delete(key KT) {
m.Lock()
delete(m.m, key)
m.Unlock()
}
func (m *Map[KT, VT]) UnsafeDelete(key KT) {
delete(m.m, key)
}
// MergeWith merges the contents of another Map[KT, VT]
// into the current Map[KT, VT] and
// returns a map that were duplicated.
// //
// Parameters: // Parameters:
// - other: a pointer to another Map[KT, VT] to be merged into the current Map[KT, VT]. // - other: `Map` of values to add from
// //
// Return: // Return:
// - Map[KT, VT]: a map of key-value pairs that were duplicated during the merge. // - Map: a `Map` of duplicated keys-value pairs
func (m *Map[KT, VT]) MergeWith(other *Map[KT, VT]) Map[KT, VT] { func (m Map[KT, VT]) MergeFrom(other Map[KT, VT]) Map[KT, VT] {
dups := make(map[KT]VT) dups := NewMapOf[KT, VT]()
m.Lock() other.Range(func(k KT, v VT) bool {
for k, v := range other.m { if _, ok := m.Load(k); ok {
if _, isDup := m.m[k]; !isDup { dups.Store(k, v)
m.m[k] = v
} else { } else {
dups[k] = v m.Store(k, v)
} }
} return true
m.Unlock() })
return Map[KT, VT]{m: dups} return dups
} }
func (m *Map[KT, VT]) Clear() { func (m Map[KT, VT]) RangeAll(do func(k KT, v VT)) {
m.Lock() m.Range(func(k KT, v VT) bool {
m.m = make(map[KT]VT) do(k, v)
m.Unlock() return true
})
} }
func (m *Map[KT, VT]) Size() int { func (m Map[KT, VT]) RemoveAll(criteria func(VT) bool) {
m.RLock() m.Range(func(k KT, v VT) bool {
defer m.RUnlock() if criteria(v) {
return len(m.m) m.Delete(k)
}
return true
})
} }
func (m *Map[KT, VT]) Contains(key KT) bool { func (m Map[KT, VT]) Has(k KT) bool {
m.RLock() _, ok := m.Load(k)
_, ok := m.m[key]
m.RUnlock()
return ok return ok
} }
func (m *Map[KT, VT]) Clone() *Map[KT, VT] { func (m Map[KT, VT]) UnmarshalFromYAML(data []byte) E.NestedError {
m.RLock() if m.Size() != 0 {
defer m.RUnlock() return E.FailedWhy("unmarshal from yaml", "map is not empty")
clone := make(map[KT]VT, len(m.m))
for k, v := range m.m {
clone[k] = v
} }
return &Map[KT, VT]{m: clone, defVals: m.defVals} tmp := make(map[KT]VT)
} if err := E.From(yaml.Unmarshal(data, tmp)); err.HasError() {
return err
func (m *Map[KT, VT]) EachKV(fn func(k KT, v VT)) {
m.Lock()
for k, v := range m.m {
fn(k, v)
} }
m.Unlock() for k, v := range tmp {
} m.Store(k, v)
func (m *Map[KT, VT]) Each(fn func(v VT)) {
m.Lock()
for _, v := range m.m {
fn(v)
} }
m.Unlock() return nil
} }
func (m *Map[KT, VT]) EachParallel(fn func(v VT)) { func (m Map[KT, VT]) String() string {
m.Lock() tmp := make(map[KT]VT, m.Size())
ParallelForEachValue(m.m, fn) m.RangeAll(func(k KT, v VT) {
m.Unlock() tmp[k] = v
} })
data, err := yaml.Marshal(tmp)
func (m *Map[KT, VT]) EachKVParallel(fn func(k KT, v VT)) { if err != nil {
m.Lock() return err.Error()
ParallelForEachKV(m.m, fn) }
m.Unlock() return string(data)
}
func (m *Map[KT, VT]) UnmarshalFromYAML(data []byte) E.NestedError {
return E.From(yaml.Unmarshal(data, m.m))
}
func (m *Map[KT, VT]) Iterator() map[KT]VT {
return m.m
} }

View file

@ -0,0 +1,75 @@
package functional_test
import (
"testing"
. "github.com/yusing/go-proxy/utils/functional"
. "github.com/yusing/go-proxy/utils/testing"
)
func TestNewMapFrom(t *testing.T) {
m := NewMapFrom(map[string]int{
"a": 1,
"b": 2,
"c": 3,
})
ExpectEqual(t, m.Size(), 3)
ExpectTrue(t, m.Has("a"))
ExpectTrue(t, m.Has("b"))
ExpectTrue(t, m.Has("c"))
}
func TestMapFind(t *testing.T) {
m := NewMapFrom(map[string]map[string]int{
"a": {
"a": 1,
},
"b": {
"a": 1,
"b": 2,
},
"c": {
"b": 2,
"c": 3,
},
})
res := MapFind(m, func(inner map[string]int) (int, bool) {
if _, ok := inner["c"]; ok && inner["c"] == 3 {
return inner["c"], true
}
return 0, false
})
ExpectEqual(t, res, 3)
}
func TestMergeFrom(t *testing.T) {
m1 := NewMapFrom(map[string]int{
"a": 1,
"b": 2,
"c": 3,
"d": 4,
})
m2 := NewMapFrom(map[string]int{
"a": 1,
"c": 123,
"e": 456,
"f": 6,
})
dup := m1.MergeFrom(m2)
ExpectEqual(t, m1.Size(), 6)
ExpectTrue(t, m1.Has("e"))
ExpectTrue(t, m1.Has("f"))
c, _ := m1.Load("c")
d, _ := m1.Load("d")
e, _ := m1.Load("e")
f, _ := m1.Load("f")
ExpectEqual(t, c, 3)
ExpectEqual(t, d, 4)
ExpectEqual(t, e, 456)
ExpectEqual(t, f, 6)
ExpectEqual(t, dup.Size(), 2)
ExpectTrue(t, dup.Has("a"))
ExpectTrue(t, dup.Has("c"))
}

View file

@ -10,15 +10,8 @@ import (
E "github.com/yusing/go-proxy/error" E "github.com/yusing/go-proxy/error"
) )
// TODO: move to "utils/io"
type ( type (
Reader interface {
Read() ([]byte, E.NestedError)
}
StdReader struct {
r Reader
}
FileReader struct { FileReader struct {
Path string Path string
} }
@ -29,13 +22,6 @@ type (
closed atomic.Bool closed atomic.Bool
} }
StdReadCloser struct {
r *ReadCloser
}
ByteReader []byte
NewByteReader = ByteReader
Pipe struct { Pipe struct {
r ReadCloser r ReadCloser
w io.WriteCloser w io.WriteCloser
@ -44,49 +30,25 @@ type (
} }
BidirectionalPipe struct { BidirectionalPipe struct {
pSrcDst Pipe pSrcDst *Pipe
pDstSrc Pipe pDstSrc *Pipe
} }
) )
func NewFileReader(path string) *FileReader { func (r *ReadCloser) Read(p []byte) (int, error) {
return &FileReader{Path: path}
}
func (r StdReader) Read() ([]byte, error) {
return r.r.Read()
}
func (r *FileReader) Read() ([]byte, E.NestedError) {
return E.Check(os.ReadFile(r.Path))
}
func (r ByteReader) Read() ([]byte, E.NestedError) {
return r, E.Nil()
}
func (r *ReadCloser) Read(p []byte) (int, E.NestedError) {
select { select {
case <-r.ctx.Done(): case <-r.ctx.Done():
return 0, E.From(r.ctx.Err()) return 0, r.ctx.Err()
default: default:
return E.Check(r.r.Read(p)) return r.r.Read(p)
} }
} }
func (r *ReadCloser) Close() E.NestedError { func (r *ReadCloser) Close() error {
if r.closed.Load() { if r.closed.Load() {
return E.Nil() return nil
} }
r.closed.Store(true) r.closed.Store(true)
return E.From(r.r.Close())
}
func (r StdReadCloser) Read(p []byte) (int, error) {
return r.r.Read(p)
}
func (r StdReadCloser) Close() error {
return r.r.Close() return r.r.Close()
} }
@ -100,35 +62,35 @@ func NewPipe(ctx context.Context, r io.ReadCloser, w io.WriteCloser) *Pipe {
} }
} }
func (p *Pipe) Start() E.NestedError { func (p *Pipe) Start() error {
return Copy(p.ctx, p.w, &StdReadCloser{&p.r}) return Copy(p.ctx, p.w, &p.r)
} }
func (p *Pipe) Stop() E.NestedError { func (p *Pipe) Stop() error {
p.cancel() p.cancel()
return E.Join("error stopping pipe", p.r.Close(), p.w.Close()) return E.JoinE("error stopping pipe", p.r.Close(), p.w.Close()).Error()
} }
func (p *Pipe) Write(b []byte) (int, E.NestedError) { func (p *Pipe) Write(b []byte) (int, error) {
return E.Check(p.w.Write(b)) return p.w.Write(b)
} }
func NewBidirectionalPipe(ctx context.Context, rw1 io.ReadWriteCloser, rw2 io.ReadWriteCloser) *BidirectionalPipe { func NewBidirectionalPipe(ctx context.Context, rw1 io.ReadWriteCloser, rw2 io.ReadWriteCloser) *BidirectionalPipe {
return &BidirectionalPipe{ return &BidirectionalPipe{
pSrcDst: *NewPipe(ctx, rw1, rw2), pSrcDst: NewPipe(ctx, rw1, rw2),
pDstSrc: *NewPipe(ctx, rw2, rw1), pDstSrc: NewPipe(ctx, rw2, rw1),
} }
} }
func NewBidirectionalPipeIntermediate(ctx context.Context, listener io.ReadCloser, client io.ReadWriteCloser, target io.ReadWriteCloser) *BidirectionalPipe { func NewBidirectionalPipeIntermediate(ctx context.Context, listener io.ReadCloser, client io.ReadWriteCloser, target io.ReadWriteCloser) *BidirectionalPipe {
return &BidirectionalPipe{ return &BidirectionalPipe{
pSrcDst: *NewPipe(ctx, listener, client), pSrcDst: NewPipe(ctx, listener, client),
pDstSrc: *NewPipe(ctx, client, target), pDstSrc: NewPipe(ctx, client, target),
} }
} }
func (p *BidirectionalPipe) Start() E.NestedError { func (p *BidirectionalPipe) Start() error {
errCh := make(chan E.NestedError, 2) errCh := make(chan error, 2)
go func() { go func() {
errCh <- p.pSrcDst.Start() errCh <- p.pSrcDst.Start()
}() }()
@ -136,34 +98,34 @@ func (p *BidirectionalPipe) Start() E.NestedError {
errCh <- p.pDstSrc.Start() errCh <- p.pDstSrc.Start()
}() }()
for err := range errCh { for err := range errCh {
if err.HasError() { if err != nil {
return err return err
} }
} }
return E.Nil() return nil
} }
func (p *BidirectionalPipe) Stop() E.NestedError { func (p *BidirectionalPipe) Stop() error {
return E.Join("error stopping pipe", p.pSrcDst.Stop(), p.pDstSrc.Stop()) return E.JoinE("error stopping pipe", p.pSrcDst.Stop(), p.pDstSrc.Stop()).Error()
} }
func Copy(ctx context.Context, dst io.WriteCloser, src io.ReadCloser) E.NestedError { func Copy(ctx context.Context, dst io.WriteCloser, src io.ReadCloser) error {
_, err := io.Copy(dst, StdReadCloser{&ReadCloser{ctx: ctx, r: src}}) _, err := io.Copy(dst, &ReadCloser{ctx: ctx, r: src})
return E.From(err) return err
} }
func LoadJson[T any](path string, pointer *T) E.NestedError { func LoadJson[T any](path string, pointer *T) E.NestedError {
data, err := os.ReadFile(path) data, err := E.Check(os.ReadFile(path))
if err != nil { if err.HasError() {
return E.From(err) return err
} }
return E.From(json.Unmarshal(data, pointer)) return E.From(json.Unmarshal(data, pointer))
} }
func SaveJson[T any](path string, pointer *T, perm os.FileMode) E.NestedError { func SaveJson[T any](path string, pointer *T, perm os.FileMode) E.NestedError {
data, err := json.Marshal(pointer) data, err := E.Check(json.Marshal(pointer))
if err != nil { if err.HasError() {
return E.From(err) return err
} }
return E.From(os.WriteFile(path, data, perm)) return E.From(os.WriteFile(path, data, perm))
} }

View file

@ -20,5 +20,5 @@ func SetFieldFromSnake[T, VT any](obj *T, field string, value VT) E.NestedError
return E.Invalid("field", field) return E.Invalid("field", field)
} }
prop.Set(reflect.ValueOf(value)) prop.Set(reflect.ValueOf(value))
return E.Nil() return nil
} }

View file

@ -17,45 +17,26 @@ func ValidateYaml(schema *jsonschema.Schema, data []byte) E.NestedError {
err := yaml.Unmarshal(data, &i) err := yaml.Unmarshal(data, &i)
if err != nil { if err != nil {
return E.Failure("unmarshal yaml").With(err) return E.FailWith("unmarshal yaml", err)
} }
m, err := json.Marshal(i) m, err := json.Marshal(i)
if err != nil { if err != nil {
return E.Failure("marshal json").With(err) return E.FailWith("marshal json", err)
} }
err = schema.Validate(bytes.NewReader(m)) err = schema.Validate(bytes.NewReader(m))
if err == nil { if err == nil {
return E.Nil() return nil
} }
errors := E.NewBuilder("yaml validation error") errors := E.NewBuilder("yaml validation error")
for _, e := range err.(*jsonschema.ValidationError).Causes { for _, e := range err.(*jsonschema.ValidationError).Causes {
errors.Add(e) errors.AddE(e)
} }
return errors.Build() return errors.Build()
} }
// TryJsonStringify converts the given object to a JSON string.
//
// It takes an object of any type and attempts to marshal it into a JSON string.
// If the marshaling is successful, the JSON string is returned.
// If the marshaling fails, the object is converted to a string using fmt.Sprint and returned.
//
// Parameters:
// - o: The object to be converted to a JSON string.
//
// Return type:
// - string: The JSON string representation of the object.
func TryJsonStringify(o any) string {
b, err := json.Marshal(o)
if err != nil {
return fmt.Sprint(o)
}
return string(b)
}
// Serialize converts the given data into a map[string]any representation. // Serialize converts the given data into a map[string]any representation.
// //
// It uses reflection to inspect the data type and handle different kinds of data. // It uses reflection to inspect the data type and handle different kinds of data.
@ -123,7 +104,7 @@ func Serialize(data any) (SerializedObject, E.NestedError) {
return nil, E.Unsupported("type", value.Kind()) return nil, E.Unsupported("type", value.Kind())
} }
return result, E.Nil() return result, nil
} }
func Deserialize(src map[string]any, target any) E.NestedError { func Deserialize(src map[string]any, target any) E.NestedError {
@ -166,7 +147,7 @@ func Deserialize(src map[string]any, target any) E.NestedError {
propNew := reflect.New(propType.Elem()) propNew := reflect.New(propType.Elem())
err := Deserialize(vSerialized, propNew.Interface()) err := Deserialize(vSerialized, propNew.Interface())
if err.HasError() { if err.HasError() {
return E.Failure("set field").With(k).With(err) return E.Failure("set field").With(err).Subject(k)
} }
prop.Set(propNew) prop.Set(propNew)
default: default:
@ -180,7 +161,15 @@ func Deserialize(src map[string]any, target any) E.NestedError {
} }
} }
return E.Nil() return nil
}
func DeserializeJson(j map[string]string, target any) E.NestedError {
data, err := E.Check(json.Marshal(j))
if err.HasError() {
return err
}
return E.From(json.Unmarshal(data, target))
} }
func toLowerNoSnake(s string) string { func toLowerNoSnake(s string) string {

11
src/utils/string.go Normal file
View file

@ -0,0 +1,11 @@
package utils
import "strings"
func CommaSeperatedList(s string) []string {
res := strings.Split(s, ",")
for i, part := range res {
res[i] = strings.TrimSpace(part)
}
return res
}

View file

@ -3,25 +3,23 @@ package utils
import ( import (
"reflect" "reflect"
"testing" "testing"
E "github.com/yusing/go-proxy/error"
) )
func ExpectNoError(t *testing.T, err error) { func ExpectNoError(t *testing.T, err error) {
t.Helper() t.Helper()
var noError bool if err != nil && !reflect.ValueOf(err).IsNil() {
switch t := err.(type) {
case E.NestedError:
noError = t.NoError()
default:
noError = err == nil
}
if !noError {
t.Errorf("expected err=nil, got %s", err.Error()) t.Errorf("expected err=nil, got %s", err.Error())
} }
} }
func ExpectEqual(t *testing.T, got, want any) { func ExpectEqual[T comparable](t *testing.T, got T, want T) {
t.Helper()
if got != want {
t.Errorf("expected:\n%v, got\n%v", want, got)
}
}
func ExpectDeepEqual[T any](t *testing.T, got T, want T) {
t.Helper() t.Helper()
if !reflect.DeepEqual(got, want) { if !reflect.DeepEqual(got, want) {
t.Errorf("expected:\n%v, got\n%v", want, got) t.Errorf("expected:\n%v, got\n%v", want, got)
@ -47,7 +45,7 @@ func ExpectType[T any](t *testing.T, got any) T {
tExpect := reflect.TypeFor[T]() tExpect := reflect.TypeFor[T]()
_, ok := got.(T) _, ok := got.(T)
if !ok { if !ok {
t.Errorf("expected type %T, got %T", tExpect, got) t.Errorf("expected type %s, got %T", tExpect, got)
} }
return got.(T) return got.(T)
} }

View file

@ -2,13 +2,13 @@ package watcher
import ( import (
"context" "context"
"fmt"
"time" "time"
"github.com/docker/docker/api/types/events" "github.com/docker/docker/api/types/events"
"github.com/docker/docker/api/types/filters" "github.com/docker/docker/api/types/filters"
D "github.com/yusing/go-proxy/docker" D "github.com/yusing/go-proxy/docker"
E "github.com/yusing/go-proxy/error" E "github.com/yusing/go-proxy/error"
. "github.com/yusing/go-proxy/watcher/event"
) )
type DockerWatcher struct { type DockerWatcher struct {
@ -34,13 +34,14 @@ func (w *DockerWatcher) Events(ctx context.Context) (<-chan Event, <-chan E.Nest
if err.NoError() { if err.NoError() {
break break
} }
errCh <- E.From(err) errCh <- err
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
} }
if err.HasError() { if err.HasError() {
errCh <- E.Failure("connecting to docker") errCh <- E.Failure("connecting to docker")
return return
} }
defer cl.Close()
cEventCh, cErrCh := cl.Events(ctx, dwOptions) cEventCh, cErrCh := cl.Events(ctx, dwOptions)
started <- struct{}{} started <- struct{}{}
@ -58,13 +59,16 @@ func (w *DockerWatcher) Events(ctx context.Context) (<-chan Event, <-chan E.Nest
case events.ActionStart: case events.ActionStart:
Action = ActionCreated Action = ActionCreated
case events.ActionDie: case events.ActionDie:
Action = ActionDeleted Action = ActionStopped
default: // NOTE: should not happen default: // NOTE: should not happen
Action = ActionModified Action = ActionModified
} }
eventCh <- Event{ eventCh <- Event{
ActorName: fmt.Sprintf("container %q", msg.Actor.Attributes["name"]), Type: EventTypeDocker,
Action: Action, ActorID: msg.Actor.ID,
ActorAttributes: msg.Actor.Attributes, // labels
ActorName: msg.Actor.Attributes["name"],
Action: Action,
} }
case err := <-cErrCh: case err := <-cErrCh:
if err == nil { if err == nil {

View file

@ -1,26 +0,0 @@
package watcher
import "fmt"
type (
Event struct {
ActorName string
Action Action
}
Action string
)
const (
ActionModified Action = "MODIFIED"
ActionCreated Action = "CREATED"
ActionStarted Action = "STARTED"
ActionDeleted Action = "DELETED"
)
func (e Event) String() string {
return fmt.Sprintf("%s %s", e.ActorName, e.Action)
}
func (a Action) IsDelete() bool {
return a == ActionDeleted
}

View file

@ -0,0 +1,34 @@
package event
import "fmt"
type (
Event struct {
Type EventType
ActorName string
ActorID string
ActorAttributes map[string]string
Action Action
}
Action string
EventType string
)
const (
ActionModified Action = "modified"
ActionCreated Action = "created"
ActionStarted Action = "started"
ActionDeleted Action = "deleted"
ActionStopped Action = "stopped"
EventTypeDocker EventType = "docker"
EventTypeFile EventType = "file"
)
func (e Event) String() string {
return fmt.Sprintf("%s %s", e.ActorName, e.Action)
}
func (a Action) IsDelete() bool {
return a == ActionDeleted
}

View file

@ -6,6 +6,7 @@ import (
"github.com/yusing/go-proxy/common" "github.com/yusing/go-proxy/common"
E "github.com/yusing/go-proxy/error" E "github.com/yusing/go-proxy/error"
. "github.com/yusing/go-proxy/watcher/event"
) )
type fileWatcher struct { type fileWatcher struct {

View file

@ -9,6 +9,7 @@ import (
"github.com/fsnotify/fsnotify" "github.com/fsnotify/fsnotify"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
E "github.com/yusing/go-proxy/error" E "github.com/yusing/go-proxy/error"
. "github.com/yusing/go-proxy/watcher/event"
) )
type fileWatcherHelper struct { type fileWatcherHelper struct {
@ -93,7 +94,10 @@ func (h *fileWatcherHelper) start() {
continue continue
} }
msg := Event{ActorName: w.filename} msg := Event{
Type: EventTypeFile,
ActorName: w.filename,
}
switch { switch {
case event.Has(fsnotify.Create): case event.Has(fsnotify.Create):
msg.Action = ActionCreated msg.Action = ActionCreated

View file

@ -4,6 +4,7 @@ import (
"context" "context"
E "github.com/yusing/go-proxy/error" E "github.com/yusing/go-proxy/error"
. "github.com/yusing/go-proxy/watcher/event"
) )
type Watcher interface { type Watcher interface {