bug fixes

This commit is contained in:
yusing 2024-03-21 04:21:28 +00:00
parent b37e201ea8
commit 48a9e312f5
15 changed files with 198 additions and 145 deletions

Binary file not shown.

5
go.mod
View file

@ -3,8 +3,8 @@ module github.com/yusing/go-proxy
go 1.21.7 go 1.21.7
require ( require (
github.com/docker/cli v25.0.4+incompatible github.com/docker/cli v26.0.0+incompatible
github.com/docker/docker v25.0.4+incompatible github.com/docker/docker v26.0.0+incompatible
github.com/fsnotify/fsnotify v1.7.0 github.com/fsnotify/fsnotify v1.7.0
github.com/sirupsen/logrus v1.9.3 github.com/sirupsen/logrus v1.9.3
golang.org/x/net v0.22.0 golang.org/x/net v0.22.0
@ -21,6 +21,7 @@ require (
github.com/go-logr/logr v1.4.1 // indirect github.com/go-logr/logr v1.4.1 // indirect
github.com/go-logr/stdr v1.2.2 // indirect github.com/go-logr/stdr v1.2.2 // indirect
github.com/gogo/protobuf v1.3.2 // indirect github.com/gogo/protobuf v1.3.2 // indirect
github.com/moby/docker-image-spec v1.3.1 // indirect
github.com/moby/term v0.5.0 // indirect github.com/moby/term v0.5.0 // indirect
github.com/morikuni/aec v1.0.0 // indirect github.com/morikuni/aec v1.0.0 // indirect
github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opencontainers/go-digest v1.0.0 // indirect

10
go.sum
View file

@ -11,10 +11,10 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/distribution/reference v0.5.0 h1:/FUIFXtfc/x2gpa5/VGfiGLuOIdYa1t65IKK2OFGvA0= github.com/distribution/reference v0.5.0 h1:/FUIFXtfc/x2gpa5/VGfiGLuOIdYa1t65IKK2OFGvA0=
github.com/distribution/reference v0.5.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E= github.com/distribution/reference v0.5.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E=
github.com/docker/cli v25.0.4+incompatible h1:DatRkJ+nrFoYL2HZUzjM5Z5sAmcA5XGp+AW0oEw2+cA= github.com/docker/cli v26.0.0+incompatible h1:90BKrx1a1HKYpSnnBFR6AgDq/FqkHxwlUyzJVPxD30I=
github.com/docker/cli v25.0.4+incompatible/go.mod h1:JLrzqnKDaYBop7H2jaqPtU4hHvMKP+vjCwu2uszcLI8= github.com/docker/cli v26.0.0+incompatible/go.mod h1:JLrzqnKDaYBop7H2jaqPtU4hHvMKP+vjCwu2uszcLI8=
github.com/docker/docker v25.0.4+incompatible h1:XITZTrq+52tZyZxUOtFIahUf3aH367FLxJzt9vZeAF8= github.com/docker/docker v26.0.0+incompatible h1:Ng2qi+gdKADUa/VM+6b6YaY2nlZhk/lVJiKR/2bMudU=
github.com/docker/docker v25.0.4+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk= github.com/docker/docker v26.0.0+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
github.com/docker/go-connections v0.5.0 h1:USnMq7hx7gwdVZq1L49hLXaFtUdTADjXGp+uj1Br63c= github.com/docker/go-connections v0.5.0 h1:USnMq7hx7gwdVZq1L49hLXaFtUdTADjXGp+uj1Br63c=
github.com/docker/go-connections v0.5.0/go.mod h1:ov60Kzw0kKElRwhNs9UlUHAE/F9Fe6GLaXnqyDdmEXc= github.com/docker/go-connections v0.5.0/go.mod h1:ov60Kzw0kKElRwhNs9UlUHAE/F9Fe6GLaXnqyDdmEXc=
github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4= github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4=
@ -38,6 +38,8 @@ github.com/grpc-ecosystem/grpc-gateway/v2 v2.19.0 h1:Wqo399gCIufwto+VfwCSvsnfGpF
github.com/grpc-ecosystem/grpc-gateway/v2 v2.19.0/go.mod h1:qmOFXW2epJhM0qSnUUYpldc7gVz2KMQwJ/QYCDIa7XU= github.com/grpc-ecosystem/grpc-gateway/v2 v2.19.0/go.mod h1:qmOFXW2epJhM0qSnUUYpldc7gVz2KMQwJ/QYCDIa7XU=
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0=
github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo=
github.com/moby/term v0.5.0 h1:xt8Q1nalod/v7BqbG21f8mQPqH+xAaC9C3N3wfWbVP0= github.com/moby/term v0.5.0 h1:xt8Q1nalod/v7BqbG21f8mQPqH+xAaC9C3N3wfWbVP0=
github.com/moby/term v0.5.0/go.mod h1:8FzsFHVUBGZdbDsJw/ot+X+d5HLUbvklYLJ9uGfcI3Y= github.com/moby/term v0.5.0/go.mod h1:8FzsFHVUBGZdbDsJw/ot+X+d5HLUbvklYLJ9uGfcI3Y=
github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A= github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=

View file

@ -35,24 +35,24 @@ func (cfg *config) Load() error {
defer cfg.mutex.Unlock() defer cfg.mutex.Unlock()
// unload if any // unload if any
if cfg.Providers != nil { cfg.StopProviders()
for _, p := range cfg.Providers {
p.StopAllRoutes()
}
}
cfg.Providers = make(map[string]*Provider)
data, err := os.ReadFile(configPath) data, err := os.ReadFile(configPath)
if err != nil { if err != nil {
return fmt.Errorf("unable to read config file: %v", err) return fmt.Errorf("unable to read config file: %v", err)
} }
cfg.Providers = make(map[string]*Provider)
if err = yaml.Unmarshal(data, &cfg); err != nil { if err = yaml.Unmarshal(data, &cfg); err != nil {
return fmt.Errorf("unable to parse config file: %v", err) return fmt.Errorf("unable to parse config file: %v", err)
} }
for name, p := range cfg.Providers { for name, p := range cfg.Providers {
p.name = name err := p.Init(name)
if err != nil {
cfgl.Errorf("failed to initialize provider %q %v", name, err)
cfg.Providers[name] = nil
}
} }
return nil return nil
@ -73,14 +73,19 @@ func (cfg *config) MustReload() {
} }
func (cfg *config) StartProviders() { func (cfg *config) StartProviders() {
if cfg.Providers == nil {
cfgl.Fatal("providers not loaded")
}
// Providers have their own mutex, no lock needed // Providers have their own mutex, no lock needed
ParallelForEachValue(cfg.Providers, (*Provider).StartAllRoutes) ParallelForEachValue(cfg.Providers, (*Provider).StartAllRoutes)
} }
func (cfg *config) StopProviders() { func (cfg *config) StopProviders() {
if cfg.Providers != nil {
// Providers have their own mutex, no lock needed // Providers have their own mutex, no lock needed
ParallelForEachValue(cfg.Providers, (*Provider).StopAllRoutes) ParallelForEachValue(cfg.Providers, (*Provider).StopAllRoutes)
} }
}
func (cfg *config) WatchChanges() { func (cfg *config) WatchChanges() {
cfg.watcher.Start() cfg.watcher.Start()

View file

@ -96,7 +96,7 @@ const (
templatePath = "templates/panel.html" templatePath = "templates/panel.html"
) )
const StreamStopListenTimeout = 1 * time.Second const StreamStopListenTimeout = 2 * time.Second
const udpBufferSize = 1500 const udpBufferSize = 1500

View file

@ -5,6 +5,7 @@ import (
"net/http" "net/http"
"reflect" "reflect"
"strings" "strings"
"time"
"github.com/docker/cli/cli/connhelper" "github.com/docker/cli/cli/connhelper"
"github.com/docker/docker/api/types" "github.com/docker/docker/api/types"
@ -61,7 +62,7 @@ func (p *Provider) getContainerProxyConfigs(container types.Container, clientIP
} }
if config.Port == "0" { if config.Port == "0" {
// no ports exposed or specified // no ports exposed or specified
l.Info("no ports exposed, ignored") l.Debugf("no ports exposed, ignored")
continue continue
} }
if config.Scheme == "" { if config.Scheme == "" {
@ -116,26 +117,17 @@ func (p *Provider) getContainerProxyConfigs(container types.Container, clientIP
return cfgs return cfgs
} }
func (p *Provider) getDockerProxyConfigs() ([]*ProxyConfig, error) { func (p *Provider) getDockerClient() (*client.Client, error) {
var clientIP string var dockerOpts []client.Opt
var opts []client.Opt
var err error
if p.Value == clientUrlFromEnv { if p.Value == clientUrlFromEnv {
clientIP = "" dockerOpts = []client.Opt{
opts = []client.Opt{
client.WithHostFromEnv(), client.WithHostFromEnv(),
client.WithAPIVersionNegotiation(), client.WithAPIVersionNegotiation(),
} }
} else { } else {
url, err := client.ParseHostURL(p.Value)
if err != nil {
return nil, fmt.Errorf("unable to parse docker host url: %v", err)
}
clientIP = strings.Split(url.Host, ":")[0]
helper, err := connhelper.GetConnectionHelper(p.Value) helper, err := connhelper.GetConnectionHelper(p.Value)
if err != nil { if err != nil {
return nil, fmt.Errorf("unexpected error: %v", err) p.l.Fatal("unexpected error: ", err)
} }
if helper != nil { if helper != nil {
httpClient := &http.Client{ httpClient := &http.Client{
@ -143,26 +135,44 @@ func (p *Provider) getDockerProxyConfigs() ([]*ProxyConfig, error) {
DialContext: helper.Dialer, DialContext: helper.Dialer,
}, },
} }
opts = []client.Opt{ dockerOpts = []client.Opt{
client.WithHTTPClient(httpClient), client.WithHTTPClient(httpClient),
client.WithHost(helper.Host), client.WithHost(helper.Host),
client.WithAPIVersionNegotiation(), client.WithAPIVersionNegotiation(),
client.WithDialContext(helper.Dialer), client.WithDialContext(helper.Dialer),
} }
} else { } else {
opts = []client.Opt{ dockerOpts = []client.Opt{
client.WithHost(p.Value), client.WithHost(p.Value),
client.WithAPIVersionNegotiation(), client.WithAPIVersionNegotiation(),
} }
} }
} }
return client.NewClientWithOpts(dockerOpts...)
}
func (p *Provider) getDockerProxyConfigs() ([]*ProxyConfig, error) {
var clientIP string
if p.Value == clientUrlFromEnv {
clientIP = ""
} else {
url, err := client.ParseHostURL(p.Value)
if err != nil {
return nil, fmt.Errorf("unable to parse docker host url: %v", err)
}
clientIP = strings.Split(url.Host, ":")[0]
}
dockerClient, err := p.getDockerClient()
p.dockerClient, err = client.NewClientWithOpts(opts...)
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to create docker client: %v", err) return nil, fmt.Errorf("unable to create docker client: %v", err)
} }
containerSlice, err := p.dockerClient.ContainerList(context.Background(), container.ListOptions{All: true}) ctx, _ := context.WithTimeout(context.Background(), 3*time.Second)
containerSlice, err := dockerClient.ContainerList(ctx, container.ListOptions{All: true})
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to list containers: %v", err) return nil, fmt.Errorf("unable to list containers: %v", err)
} }

View file

@ -103,12 +103,12 @@ func NewHTTPRoute(config *ProxyConfig) (*HTTPRoute, error) {
if logLevel == logrus.DebugLevel { if logLevel == logrus.DebugLevel {
route.Proxy.Rewrite = func(pr *ProxyRequest) { route.Proxy.Rewrite = func(pr *ProxyRequest) {
rewrite(pr) rewrite(pr)
route.l.Debug("Request URL: ", pr.In.Host, pr.In.URL.Path) route.l.Debug("request URL: ", pr.In.Host, pr.In.URL.Path)
route.l.Debug("Request headers: ", pr.In.Header) route.l.Debug("request headers: ", pr.In.Header)
} }
route.Proxy.ModifyResponse = func(r *http.Response) error { route.Proxy.ModifyResponse = func(r *http.Response) error {
route.l.Debug("Response URL: ", r.Request.URL.String()) route.l.Debug("response URL: ", r.Request.URL.String())
route.l.Debug("Response headers: ", r.Header) route.l.Debug("response headers: ", r.Header)
if modifyResponse != nil { if modifyResponse != nil {
return modifyResponse(r) return modifyResponse(r)
} }
@ -121,15 +121,11 @@ func NewHTTPRoute(config *ProxyConfig) (*HTTPRoute, error) {
return route, nil return route, nil
} }
func (r *HTTPRoute) RemoveFromRoutes() { func (r *HTTPRoute) Start() {}
func (r *HTTPRoute) Stop() {
httpRoutes.Delete(r.Alias) httpRoutes.Delete(r.Alias)
} }
// dummy implementation for Route interface
func (r *HTTPRoute) SetupListen() {}
func (r *HTTPRoute) Listen() {}
func (r *HTTPRoute) StopListening() {}
func isValidProxyPathMode(mode string) bool { func isValidProxyPathMode(mode string) bool {
switch mode { switch mode {
case ProxyPathMode_Forward, ProxyPathMode_Sub, ProxyPathMode_RemovedPath: case ProxyPathMode_Forward, ProxyPathMode_Sub, ProxyPathMode_RemovedPath:

View file

@ -39,14 +39,14 @@ func main() {
err = http.ListenAndServe(":80", http.HandlerFunc(httpProxyHandler)) err = http.ListenAndServe(":80", http.HandlerFunc(httpProxyHandler))
} }
if err != nil { if err != nil {
log.Fatal("HTTP server error: ", err) log.Fatal("http server error: ", err)
} }
}() }()
go func() { go func() {
log.Infof("starting http panel on port 8080") log.Infof("starting http panel on port 8080")
err = http.ListenAndServe(":8080", http.HandlerFunc(panelHandler)) err = http.ListenAndServe(":8080", http.HandlerFunc(panelHandler))
if err != nil { if err != nil {
log.Warning("HTTP panel error: ", err) log.Warning("http panel error: ", err)
} }
}() }()
@ -75,6 +75,6 @@ func main() {
<-sig <-sig
cfg.StopWatching() cfg.StopWatching()
cfg.StopProviders() cfg.StopProviders()
close(fsWatcherStop) StopFSWatcher()
close(dockerWatcherStop) StopDockerWatcher()
} }

View file

@ -38,7 +38,7 @@ func panelHandler(w http.ResponseWriter, r *http.Request) {
func panelIndex(w http.ResponseWriter, r *http.Request) { func panelIndex(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet { if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return return
} }
@ -67,7 +67,7 @@ func panelIndex(w http.ResponseWriter, r *http.Request) {
func panelCheckTargetHealth(w http.ResponseWriter, r *http.Request) { func panelCheckTargetHealth(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodHead { if r.Method != http.MethodHead {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return return
} }

View file

@ -4,7 +4,6 @@ import (
"fmt" "fmt"
"sync" "sync"
"github.com/docker/docker/client"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
@ -12,27 +11,57 @@ type Provider struct {
Kind string // docker, file Kind string // docker, file
Value string Value string
name string
watcher Watcher watcher Watcher
routes map[string]Route // id -> Route routes map[string]Route // id -> Route
dockerClient *client.Client
mutex sync.Mutex mutex sync.Mutex
l logrus.FieldLogger l logrus.FieldLogger
} }
func (p *Provider) Setup() error { // Init is called after LoadProxyConfig
func (p *Provider) Init(name string) error {
p.l = prlog.WithFields(logrus.Fields{"kind": p.Kind, "name": name})
if err := p.loadProxyConfig(); err != nil {
return err
}
p.initWatcher()
return nil
}
func (p *Provider) StartAllRoutes() {
ParallelForEachValue(p.routes, Route.Start)
p.watcher.Start()
}
func (p *Provider) StopAllRoutes() {
p.watcher.Stop()
ParallelForEachValue(p.routes, Route.Stop)
p.routes = make(map[string]Route)
}
func (p *Provider) ReloadRoutes() {
p.mutex.Lock()
defer p.mutex.Unlock()
p.StopAllRoutes()
err := p.loadProxyConfig()
if err != nil {
p.l.Error("failed to reload routes: ", err)
return
}
p.StartAllRoutes()
}
func (p *Provider) loadProxyConfig() error {
var cfgs []*ProxyConfig var cfgs []*ProxyConfig
var err error var err error
p.l = prlog.WithFields(logrus.Fields{"kind": p.Kind, "name": p.name})
switch p.Kind { switch p.Kind {
case ProviderKind_Docker: case ProviderKind_Docker:
cfgs, err = p.getDockerProxyConfigs() cfgs, err = p.getDockerProxyConfigs()
p.watcher = NewDockerWatcher(p.dockerClient, p.ReloadRoutes)
case ProviderKind_File: case ProviderKind_File:
cfgs, err = p.getFileProxyConfigs() cfgs, err = p.getFileProxyConfigs()
p.watcher = NewFileWatcher(p.Value, p.ReloadRoutes, p.StopAllRoutes)
default: default:
// this line should never be reached // this line should never be reached
return fmt.Errorf("unknown provider kind") return fmt.Errorf("unknown provider kind")
@ -43,45 +72,30 @@ func (p *Provider) Setup() error {
} }
p.l.Infof("loaded %d proxy configurations", len(cfgs)) p.l.Infof("loaded %d proxy configurations", len(cfgs))
p.routes = make(map[string]Route, len(cfgs))
for _, cfg := range cfgs { for _, cfg := range cfgs {
r, err := NewRoute(cfg) r, err := NewRoute(cfg)
if err != nil { if err != nil {
p.l.Errorf("error creating route %s: %v", cfg.Alias, err) p.l.Errorf("error creating route %s: %v", cfg.Alias, err)
continue continue
} }
r.SetupListen()
r.Listen()
p.routes[cfg.GetID()] = r p.routes[cfg.GetID()] = r
} }
return nil return nil
} }
func (p *Provider) StartAllRoutes() { func (p *Provider) initWatcher() error {
p.routes = make(map[string]Route) switch p.Kind {
err := p.Setup() case ProviderKind_Docker:
var err error
dockerClient, err := p.getDockerClient()
if err != nil { if err != nil {
p.l.Error(err) return fmt.Errorf("unable to create docker client: %v", err)
return
} }
p.watcher.Start() p.watcher = NewDockerWatcher(dockerClient, p.ReloadRoutes)
case ProviderKind_File:
p.watcher = NewFileWatcher(p.Value, p.ReloadRoutes, p.StopAllRoutes)
} }
return nil
func (p *Provider) StopAllRoutes() {
p.watcher.Stop()
p.dockerClient = nil
ParallelForEachValue(p.routes, func(r Route) {
r.StopListening()
r.RemoveFromRoutes()
})
p.routes = make(map[string]Route)
}
func (p *Provider) ReloadRoutes() {
p.mutex.Lock()
defer p.mutex.Unlock()
p.StopAllRoutes()
p.StartAllRoutes()
} }

View file

@ -5,10 +5,8 @@ import (
) )
type Route interface { type Route interface {
SetupListen() Start()
Listen() Stop()
StopListening()
RemoveFromRoutes()
} }
func NewRoute(cfg *ProxyConfig) (Route, error) { func NewRoute(cfg *ProxyConfig) (Route, error) {

View file

@ -47,7 +47,7 @@ func newStreamRouteBase(config *ProxyConfig) (*StreamRouteBase, error) {
port_split := strings.Split(config.Port, ":") port_split := strings.Split(config.Port, ":")
if len(port_split) != 2 { if len(port_split) != 2 {
cfgl.Warnf("Invalid port %s, assuming it is target port", config.Port) cfgl.Warnf("invalid port %s, assuming it is target port", config.Port)
srcPort = "0" srcPort = "0"
dstPort = config.Port dstPort = config.Port
} else { } else {
@ -96,7 +96,7 @@ func newStreamRouteBase(config *ProxyConfig) (*StreamRouteBase, error) {
id: config.GetID(), id: config.GetID(),
wg: sync.WaitGroup{}, wg: sync.WaitGroup{},
stopChann: make(chan struct{}), stopChann: make(chan struct{}, 1),
l: srlog.WithFields(logrus.Fields{ l: srlog.WithFields(logrus.Fields{
"alias": config.Alias, "alias": config.Alias,
"src": fmt.Sprintf("%s://:%d", srcScheme, srcPortInt), "src": fmt.Sprintf("%s://:%d", srcScheme, srcPortInt),
@ -128,7 +128,7 @@ func (route *StreamRouteBase) Logger() logrus.FieldLogger {
return route.l return route.l
} }
func (route *StreamRouteBase) SetupListen() { func (route *StreamRouteBase) setupListen() {
if route.ListeningPort == 0 { if route.ListeningPort == 0 {
freePort, err := utils.findUseFreePort(20000) freePort, err := utils.findUseFreePort(20000)
if err != nil { if err != nil {
@ -136,13 +136,10 @@ func (route *StreamRouteBase) SetupListen() {
return return
} }
route.ListeningPort = freePort route.ListeningPort = freePort
route.l.Info("Assigned free port ", route.ListeningPort) route.l.Info("listening on free port ", route.ListeningPort)
return
} }
route.l.Info("Listening on ", route.ListeningUrl()) route.l.Info("listening on ", route.ListeningUrl())
}
func (route *StreamRouteBase) RemoveFromRoutes() {
streamRoutes.Delete(route.id)
} }
func (route *StreamRouteBase) wait() { func (route *StreamRouteBase) wait() {
@ -159,9 +156,11 @@ func (route *StreamRouteBase) unmarkPort() {
func stopListening(route StreamRoute) { func stopListening(route StreamRoute) {
l := route.Logger() l := route.Logger()
l.Debug("Stopping listening") l.Debug("stopping listening")
// close channel -> wait -> close listeners
route.closeChannel() route.closeChannel()
route.closeListeners()
done := make(chan struct{}) done := make(chan struct{})
@ -173,10 +172,10 @@ func stopListening(route StreamRoute) {
select { select {
case <-done: case <-done:
l.Info("Stopped listening") l.Info("stopped listening")
return
case <-time.After(StreamStopListenTimeout): case <-time.After(StreamStopListenTimeout):
l.Error("timed out waiting for connections") l.Error("timed out waiting for connections")
return
} }
route.closeListeners()
} }

View file

@ -32,7 +32,8 @@ func NewTCPRoute(config *ProxyConfig) (StreamRoute, error) {
}, nil }, nil
} }
func (route *TCPRoute) Listen() { func (route *TCPRoute) Start() {
route.setupListen()
in, err := net.Listen("tcp", fmt.Sprintf(":%v", route.ListeningPort)) in, err := net.Listen("tcp", fmt.Sprintf(":%v", route.ListeningPort))
if err != nil { if err != nil {
route.l.Error(err) route.l.Error(err)
@ -44,8 +45,9 @@ func (route *TCPRoute) Listen() {
go route.grHandleConnections() go route.grHandleConnections()
} }
func (route *TCPRoute) StopListening() { func (route *TCPRoute) Stop() {
stopListening(route) stopListening(route)
streamRoutes.Delete(route.id)
} }
func (route *TCPRoute) closeListeners() { func (route *TCPRoute) closeListeners() {

View file

@ -45,7 +45,9 @@ func NewUDPRoute(config *ProxyConfig) (StreamRoute, error) {
}, nil }, nil
} }
func (route *UDPRoute) Listen() { func (route *UDPRoute) Start() {
route.setupListen()
source, err := net.ListenPacket(route.ListeningScheme, fmt.Sprintf(":%v", route.ListeningPort)) source, err := net.ListenPacket(route.ListeningScheme, fmt.Sprintf(":%v", route.ListeningPort))
if err != nil { if err != nil {
route.l.Error(err) route.l.Error(err)
@ -67,22 +69,24 @@ func (route *UDPRoute) Listen() {
go route.grHandleConnections() go route.grHandleConnections()
} }
func (route *UDPRoute) StopListening() { func (route *UDPRoute) Stop() {
stopListening(route) stopListening(route)
streamRoutes.Delete(route.id)
} }
func (route *UDPRoute) closeListeners() { func (route *UDPRoute) closeListeners() {
if route.listeningConn != nil { if route.listeningConn != nil {
route.listeningConn.Close() route.listeningConn.Close()
route.listeningConn = nil
} }
if route.targetConn != nil { if route.targetConn != nil {
route.targetConn.Close() route.targetConn.Close()
}
route.listeningConn = nil
route.targetConn = nil route.targetConn = nil
}
for _, conn := range route.connMap { for _, conn := range route.connMap {
conn.(*net.UDPConn).Close() // TODO: change on non udp target conn.(*net.UDPConn).Close() // TODO: change on non udp target
} }
route.connMap = make(map[net.Addr]net.Conn)
} }
func (route *UDPRoute) grAcceptConnections() { func (route *UDPRoute) grAcceptConnections() {

View file

@ -18,6 +18,7 @@ import (
type Watcher interface { type Watcher interface {
Start() Start()
Stop() Stop()
Dispose()
} }
type watcherBase struct { type watcherBase struct {
@ -36,7 +37,7 @@ type fileWatcher struct {
type dockerWatcher struct { type dockerWatcher struct {
*watcherBase *watcherBase
client *client.Client client *client.Client
stop chan struct{} stopCh chan struct{}
wg sync.WaitGroup wg sync.WaitGroup
} }
@ -60,7 +61,7 @@ func NewDockerWatcher(c *client.Client, onChange func()) Watcher {
return &dockerWatcher{ return &dockerWatcher{
watcherBase: newWatcher("Docker", c.DaemonHost(), onChange), watcherBase: newWatcher("Docker", c.DaemonHost(), onChange),
client: c, client: c,
stop: make(chan struct{}, 1), stopCh: make(chan struct{}, 1),
} }
} }
@ -71,11 +72,15 @@ func (w *fileWatcher) Start() {
err := fsWatcher.Add(w.path) err := fsWatcher.Add(w.path)
if err != nil { if err != nil {
w.l.Error("failed to start: ", err) w.l.Error("failed to start: ", err)
return
} }
fileWatchMap.Set(w.path, w) fileWatchMap.Set(w.path, w)
} }
func (w *fileWatcher) Stop() { func (w *fileWatcher) Stop() {
if fsWatcher == nil {
return
}
fileWatchMap.Delete(w.path) fileWatchMap.Delete(w.path)
err := fsWatcher.Remove(w.path) err := fsWatcher.Remove(w.path)
if err != nil { if err != nil {
@ -83,20 +88,29 @@ func (w *fileWatcher) Stop() {
} }
} }
func (w *fileWatcher) Dispose() {
w.Stop()
}
func (w *dockerWatcher) Start() { func (w *dockerWatcher) Start() {
dockerWatchMap.Set(w.name, w) dockerWatchMap.Set(w.name, w)
w.wg.Add(1) w.wg.Add(1)
go func() { go w.watch()
w.watch()
w.wg.Done()
}()
} }
func (w *dockerWatcher) Stop() { func (w *dockerWatcher) Stop() {
close(w.stop) if w.stopCh == nil {
w.stop = nil return
dockerWatchMap.Delete(w.name) }
close(w.stopCh)
w.wg.Wait() w.wg.Wait()
w.stopCh = nil
dockerWatchMap.Delete(w.name)
}
func (w *dockerWatcher) Dispose() {
w.Stop()
w.client.Close()
} }
func InitFSWatcher() { func InitFSWatcher() {
@ -106,33 +120,35 @@ func InitFSWatcher() {
return return
} }
fsWatcher = w fsWatcher = w
fsWatcherWg.Add(1)
go watchFiles() go watchFiles()
} }
func InitDockerWatcher() { func InitDockerWatcher() {
// stop all docker client on watcher stop // stop all docker client on watcher stop
go func() { go func() {
defer dockerWatcherWg.Done()
<-dockerWatcherStop <-dockerWatcherStop
stopAllDockerClients() ParallelForEachValue(
dockerWatchMap.Iterator(),
(*dockerWatcher).Dispose,
)
}() }()
} }
func stopAllDockerClients() { func StopFSWatcher() {
ParallelForEachValue( close(fsWatcherStop)
dockerWatchMap.Iterator(), fsWatcherWg.Wait()
func(w *dockerWatcher) {
w.Stop()
err := w.client.Close()
if err != nil {
w.l.WithField("action", "stop").Error(err)
} }
w.client = nil
}, func StopDockerWatcher() {
) close(dockerWatcherStop)
dockerWatcherWg.Wait()
} }
func watchFiles() { func watchFiles() {
defer fsWatcher.Close() defer fsWatcher.Close()
defer fsWatcherWg.Done()
for { for {
select { select {
case <-fsWatcherStop: case <-fsWatcherStop:
@ -148,11 +164,11 @@ func watchFiles() {
} }
switch { switch {
case event.Has(fsnotify.Write): case event.Has(fsnotify.Write):
w.l.Info("File change detected") w.l.Info("file changed")
w.onChange() go w.onChange()
case event.Has(fsnotify.Remove), event.Has(fsnotify.Rename): case event.Has(fsnotify.Remove), event.Has(fsnotify.Rename):
w.l.Info("File renamed / deleted") w.l.Info("file renamed / deleted")
w.onDelete() go w.onDelete()
} }
case err := <-fsWatcher.Errors: case err := <-fsWatcher.Errors:
wlog.Error(err) wlog.Error(err)
@ -161,6 +177,8 @@ func watchFiles() {
} }
func (w *dockerWatcher) watch() { func (w *dockerWatcher) watch() {
defer w.wg.Done()
filter := filters.NewArgs( filter := filters.NewArgs(
filters.Arg("type", "container"), filters.Arg("type", "container"),
filters.Arg("event", "start"), filters.Arg("event", "start"),
@ -173,11 +191,11 @@ func (w *dockerWatcher) watch() {
for { for {
select { select {
case <-w.stop: case <-w.stopCh:
return return
case msg := <-msgChan: case msg := <-msgChan:
w.l.Infof("container %s %s", msg.Actor.Attributes["name"], msg.Action) w.l.Infof("container %s %s", msg.Actor.Attributes["name"], msg.Action)
w.onChange() go w.onChange()
case err := <-errChan: case err := <-errChan:
w.l.Errorf("%s, retrying in 1s", err) w.l.Errorf("%s, retrying in 1s", err)
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
@ -195,3 +213,7 @@ var (
fsWatcherStop = make(chan struct{}, 1) fsWatcherStop = make(chan struct{}, 1)
dockerWatcherStop = make(chan struct{}, 1) dockerWatcherStop = make(chan struct{}, 1)
) )
var (
fsWatcherWg sync.WaitGroup
dockerWatcherWg sync.WaitGroup
)