fixed tcp/udp I/O, deadlock, nil dereference; improved docker watcher, idlewatcher, loading page

This commit is contained in:
yusing 2024-09-23 00:49:46 +08:00
parent 96bce79e4b
commit 090b73d287
31 changed files with 687 additions and 468 deletions

View file

@ -250,7 +250,6 @@ services:
ports: ports:
- 25565 - 25565
labels: labels:
- proxy.mc.scheme=tcp
- proxy.mc.port=20001:25565 - proxy.mc.port=20001:25565
environment: environment:
- EULA=TRUE - EULA=TRUE

View file

@ -4,45 +4,26 @@ For docker client on other machine, set this up, then add `name: tcp://<machine_
```yml ```yml
# compose.yml on remote machine (e.g. server1) # compose.yml on remote machine (e.g. server1)
services: docker-proxy:
docker-proxy:
container_name: docker-proxy container_name: docker-proxy
image: ghcr.io/linuxserver/socket-proxy image: tecnativa/docker-socket-proxy
privileged: true
environment: environment:
- ALLOW_START=1 #optional - ALLOW_START=1
- ALLOW_STOP=1 #optional - ALLOW_STOP=1
- ALLOW_RESTARTS=0 #optional - ALLOW_RESTARTS=1
- AUTH=0 #optional - CONTAINERS=1
- BUILD=0 #optional - EVENTS=1
- COMMIT=0 #optional - PING=1
- CONFIGS=0 #optional - POST=1
- CONTAINERS=1 #optional - VERSION=1
- DISABLE_IPV6=1 #optional
- DISTRIBUTION=0 #optional
- EVENTS=1 #optional
- EXEC=0 #optional
- IMAGES=0 #optional
- INFO=0 #optional
- NETWORKS=0 #optional
- NODES=0 #optional
- PING=1 #optional
- POST=1 #optional
- PLUGINS=0 #optional
- SECRETS=0 #optional
- SERVICES=0 #optional
- SESSION=0 #optional
- SWARM=0 #optional
- SYSTEM=0 #optional
- TASKS=0 #optional
- VERSION=1 #optional
- VOLUMES=0 #optional
volumes: volumes:
- /var/run/docker.sock:/var/run/docker.sock - /var/run/docker.sock:/var/run/docker.sock
restart: always restart: always
tmpfs:
- /run
ports: ports:
- 2375:2375 - 2375:2375
# or more secure
- <machine_ip>:2375:2375
``` ```
```yml ```yml

View file

@ -59,7 +59,7 @@ func (p *Provider) ObtainCert() (res E.NestedError) {
defer b.To(&res) defer b.To(&res)
if p.cfg.Provider == ProviderLocal { if p.cfg.Provider == ProviderLocal {
b.Addf("provider is set to %q", ProviderLocal) b.Addf("provider is set to %q", ProviderLocal).WithSeverity(E.SeverityWarning)
return return
} }

29
src/autocert/setup.go Normal file
View file

@ -0,0 +1,29 @@
package autocert
import (
"context"
"os"
E "github.com/yusing/go-proxy/error"
)
func (p *Provider) Setup(ctx context.Context) (err E.NestedError) {
if err = p.LoadCert(); err != nil {
if !err.Is(os.ErrNotExist) { // ignore if cert doesn't exist
return err
}
logger.Debug("obtaining cert due to error loading cert")
if err = p.ObtainCert(); err != nil {
return err.Warn()
}
}
go p.ScheduleRenewal(ctx)
for _, expiry := range p.GetExpiries() {
logger.Infof("certificate expire on %s", expiry)
break
}
return nil
}

View file

@ -18,6 +18,7 @@ const (
CommandListRoutes = "ls-routes" CommandListRoutes = "ls-routes"
CommandReload = "reload" CommandReload = "reload"
CommandDebugListEntries = "debug-ls-entries" CommandDebugListEntries = "debug-ls-entries"
CommandDebugListProviders = "debug-ls-providers"
) )
var ValidCommands = []string{ var ValidCommands = []string{
@ -27,6 +28,7 @@ var ValidCommands = []string{
CommandListRoutes, CommandListRoutes,
CommandReload, CommandReload,
CommandDebugListEntries, CommandDebugListEntries,
CommandDebugListProviders,
} }
func GetArgs() Args { func GetArgs() Args {

View file

@ -32,7 +32,7 @@ const (
const ( const (
SchemaBasePath = "schema/" SchemaBasePath = "schema/"
ConfigSchemaPath = SchemaBasePath + "config.schema.json" ConfigSchemaPath = SchemaBasePath + "config.schema.json"
ProvidersSchemaPath = SchemaBasePath + "providers.schema.json" FileProviderSchemaPath = SchemaBasePath + "providers.schema.json"
) )
const DockerHostFromEnv = "$DOCKER_HOST" const DockerHostFromEnv = "$DOCKER_HOST"

View file

@ -14,6 +14,7 @@ import (
U "github.com/yusing/go-proxy/utils" U "github.com/yusing/go-proxy/utils"
F "github.com/yusing/go-proxy/utils/functional" F "github.com/yusing/go-proxy/utils/functional"
W "github.com/yusing/go-proxy/watcher" W "github.com/yusing/go-proxy/watcher"
"github.com/yusing/go-proxy/watcher/events"
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
) )
@ -94,7 +95,7 @@ func (cfg *Config) WatchChanges() {
case <-cfg.watcherCtx.Done(): case <-cfg.watcherCtx.Done():
return return
case event := <-eventCh: case event := <-eventCh:
if event.Action.IsDelete() { if event.Action == events.ActionFileDeleted {
cfg.stopProviders() cfg.stopProviders()
} else { } else {
cfg.reloadReq <- struct{}{} cfg.reloadReq <- struct{}{}
@ -107,71 +108,6 @@ func (cfg *Config) WatchChanges() {
}() }()
} }
func (cfg *Config) FindRoute(alias string) R.Route {
return F.MapFind(cfg.proxyProviders,
func(p *PR.Provider) (R.Route, bool) {
if route, ok := p.GetRoute(alias); ok {
return route, true
}
return nil, false
},
)
}
func (cfg *Config) RoutesByAlias() map[string]U.SerializedObject {
routes := make(map[string]U.SerializedObject)
cfg.forEachRoute(func(alias string, r R.Route, p *PR.Provider) {
obj, err := U.Serialize(r)
if err.HasError() {
cfg.l.Error(err)
return
}
obj["provider"] = p.GetName()
obj["type"] = string(r.Type())
routes[alias] = obj
})
return routes
}
func (cfg *Config) Statistics() map[string]any {
nTotalStreams := 0
nTotalRPs := 0
providerStats := make(map[string]any)
cfg.forEachRoute(func(alias string, r R.Route, p *PR.Provider) {
s, ok := providerStats[p.GetName()]
if !ok {
s = make(map[string]int)
}
stats := s.(map[string]int)
switch r.Type() {
case R.RouteTypeStream:
stats["num_streams"]++
nTotalStreams++
case R.RouteTypeReverseProxy:
stats["num_reverse_proxies"]++
nTotalRPs++
default:
panic("bug: should not reach here")
}
})
return map[string]any{
"num_total_streams": nTotalStreams,
"num_total_reverse_proxies": nTotalRPs,
"providers": providerStats,
}
}
func (cfg *Config) DumpEntries() map[string]*M.RawEntry {
entries := make(map[string]*M.RawEntry)
cfg.forEachRoute(func(alias string, r R.Route, p *PR.Provider) {
entries[alias] = r.Entry()
})
return entries
}
func (cfg *Config) forEachRoute(do func(alias string, r R.Route, p *PR.Provider)) { func (cfg *Config) forEachRoute(do func(alias string, r R.Route, p *PR.Provider)) {
cfg.proxyProviders.RangeAll(func(_ string, p *PR.Provider) { cfg.proxyProviders.RangeAll(func(_ string, p *PR.Provider) {
p.RangeRoutes(func(a string, r R.Route) { p.RangeRoutes(func(a string, r R.Route) {
@ -259,7 +195,7 @@ func (cfg *Config) loadProviders(providers *M.ProxyProviders) (res E.NestedError
} }
func (cfg *Config) controlProviders(action string, do func(*PR.Provider) E.NestedError) { func (cfg *Config) controlProviders(action string, do func(*PR.Provider) E.NestedError) {
errors := E.NewBuilder("cannot %s these providers", action) errors := E.NewBuilder("errors in %s these providers", action)
cfg.proxyProviders.RangeAll(func(name string, p *PR.Provider) { cfg.proxyProviders.RangeAll(func(name string, p *PR.Provider) {
if err := do(p); err.HasError() { if err := do(p); err.HasError() {

82
src/config/query.go Normal file
View file

@ -0,0 +1,82 @@
package config
import (
M "github.com/yusing/go-proxy/models"
PR "github.com/yusing/go-proxy/proxy/provider"
R "github.com/yusing/go-proxy/route"
U "github.com/yusing/go-proxy/utils"
F "github.com/yusing/go-proxy/utils/functional"
)
func (cfg *Config) DumpEntries() map[string]*M.RawEntry {
entries := make(map[string]*M.RawEntry)
cfg.forEachRoute(func(alias string, r R.Route, p *PR.Provider) {
entries[alias] = r.Entry()
})
return entries
}
func (cfg *Config) DumpProviders() map[string]*PR.Provider {
entries := make(map[string]*PR.Provider)
cfg.proxyProviders.RangeAll(func(name string, p *PR.Provider) {
entries[name] = p
})
return entries
}
func (cfg *Config) RoutesByAlias() map[string]U.SerializedObject {
routes := make(map[string]U.SerializedObject)
cfg.forEachRoute(func(alias string, r R.Route, p *PR.Provider) {
obj, err := U.Serialize(r)
if err.HasError() {
cfg.l.Error(err)
return
}
obj["provider"] = p.GetName()
obj["type"] = string(r.Type())
routes[alias] = obj
})
return routes
}
func (cfg *Config) Statistics() map[string]any {
nTotalStreams := 0
nTotalRPs := 0
providerStats := make(map[string]any)
cfg.forEachRoute(func(alias string, r R.Route, p *PR.Provider) {
s, ok := providerStats[p.GetName()]
if !ok {
s = make(map[string]int)
}
stats := s.(map[string]int)
switch r.Type() {
case R.RouteTypeStream:
stats["num_streams"]++
nTotalStreams++
case R.RouteTypeReverseProxy:
stats["num_reverse_proxies"]++
nTotalRPs++
default:
panic("bug: should not reach here")
}
})
return map[string]any{
"num_total_streams": nTotalStreams,
"num_total_reverse_proxies": nTotalRPs,
"providers": providerStats,
}
}
func (cfg *Config) FindRoute(alias string) R.Route {
return F.MapFind(cfg.proxyProviders,
func(p *PR.Provider) (R.Route, bool) {
if route, ok := p.GetRoute(alias); ok {
return route, true
}
return nil, false
},
)
}

View file

@ -10,6 +10,7 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/yusing/go-proxy/common" "github.com/yusing/go-proxy/common"
E "github.com/yusing/go-proxy/error" E "github.com/yusing/go-proxy/error"
F "github.com/yusing/go-proxy/utils/functional"
) )
type Client struct { type Client struct {
@ -48,9 +49,7 @@ func (c *Client) Close() error {
return nil return nil
} }
clientMapMu.Lock() clientMap.Delete(c.key)
defer clientMapMu.Unlock()
delete(clientMap, c.key)
client := c.Client client := c.Client
c.Client = nil c.Client = nil
@ -78,7 +77,7 @@ func ConnectClient(host string) (Client, E.NestedError) {
defer clientMapMu.Unlock() defer clientMapMu.Unlock()
// check if client exists // check if client exists
if client, ok := clientMap[host]; ok { if client, ok := clientMap.Load(host); ok {
client.refCount.Add(1) client.refCount.Add(1)
return client, nil return client, nil
} }
@ -129,23 +128,22 @@ func ConnectClient(host string) (Client, E.NestedError) {
c.refCount.Add(1) c.refCount.Add(1)
c.l.Debugf("client connected") c.l.Debugf("client connected")
clientMap[host] = c clientMap.Store(host, c)
return clientMap[host], nil return c, nil
} }
func CloseAllClients() { func CloseAllClients() {
clientMapMu.Lock() clientMap.RangeAll(func(_ string, c Client) {
defer clientMapMu.Unlock() c.Client.Close()
for _, client := range clientMap { })
client.Close() clientMap.Clear()
}
clientMap = make(map[string]Client)
logger.Debug("closed all clients") logger.Debug("closed all clients")
} }
var ( var (
clientMap map[string]Client = make(map[string]Client) clientMap F.Map[string, Client] = F.NewMapOf[string, Client]()
clientMapMu sync.Mutex clientMapMu sync.Mutex
clientOptEnvHost = []client.Opt{ clientOptEnvHost = []client.Opt{
client.WithHostFromEnv(), client.WithHostFromEnv(),
client.WithAPIVersionNegotiation(), client.WithAPIVersionNegotiation(),

View file

@ -0,0 +1,87 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>{{.Title}}</title>
<style>
/* Global Styles */
* {
box-sizing: border-box;
margin: 0;
padding: 0;
}
body {
font-family: Inter, Arial, sans-serif;
font-size: 16px;
line-height: 1.5;
color: #fff;
background-color: #212121;
display: flex;
justify-content: center;
align-items: center;
height: 100vh;
margin: 0;
}
/* Spinner Styles */
.spinner {
width: 120px;
height: 120px;
border: 16px solid #333;
border-radius: 50%;
border-top: 16px solid #66d9ef;
animation: spin 2s linear infinite;
}
@keyframes spin {
0% {
transform: rotate(0deg);
}
100% {
transform: rotate(360deg);
}
}
/* Error Styles */
.error {
display: inline-block;
text-align: center;
justify-content: center;
}
.error::before {
content: "\26A0"; /* Unicode for warning symbol */
font-size: 40px;
color: #ff9900;
}
/* Message Styles */
.message {
font-size: 24px;
font-weight: bold;
padding-left: 32px;
text-align: center;
}
</style>
</head>
<body>
<script>
window.onload = async function () {
let result = await fetch(window.location.href, {
headers: {
{{ range $key, $value := .RequestHeaders }}
'{{ $key }}' : {{ $value }}
{{ end }}
},
}).then((resp) => resp.text())
.catch((err) => {
document.getElementById("message").innerText = err;
});
if (result) {
document.documentElement.innerHTML = result
}
};
</script>
<div class="{{.SpinnerClass}}"></div>
<div class="message">{{.Message}}</div>
</body>
</html>

View file

@ -0,0 +1,93 @@
package idlewatcher
import (
"bytes"
_ "embed"
"fmt"
"io"
"net/http"
"strings"
"text/template"
)
type templateData struct {
Title string
Message string
RequestHeaders http.Header
SpinnerClass string
}
//go:embed html/loading_page.html
var loadingPage []byte
var loadingPageTmpl = func() *template.Template {
tmpl, err := template.New("loading").Parse(string(loadingPage))
if err != nil {
panic(err)
}
return tmpl
}()
const (
htmlContentType = "text/html; charset=utf-8"
errPrefix = "\u1000"
headerGoProxyTargetURL = "X-GoProxy-Target"
headerContentType = "Content-Type"
spinnerClassSpinner = "spinner"
spinnerClassErrorSign = "error"
)
func (w *watcher) makeSuccResp(redirectURL string, resp *http.Response) (*http.Response, error) {
h := make(http.Header)
h.Set("Location", redirectURL)
h.Set("Content-Length", "0")
h.Set(headerContentType, htmlContentType)
return &http.Response{
StatusCode: http.StatusTemporaryRedirect,
Header: h,
Body: http.NoBody,
TLS: resp.TLS,
}, nil
}
func (w *watcher) makeErrResp(errFmt string, args ...any) (*http.Response, error) {
return w.makeResp(errPrefix+errFmt, args...)
}
func (w *watcher) makeResp(format string, args ...any) (*http.Response, error) {
msg := fmt.Sprintf(format, args...)
data := new(templateData)
data.Title = w.ContainerName
data.Message = strings.ReplaceAll(msg, "\n", "<br>")
data.Message = strings.ReplaceAll(data.Message, " ", "&ensp;")
data.RequestHeaders = make(http.Header)
data.RequestHeaders.Add(headerGoProxyTargetURL, "window.location.href")
if strings.HasPrefix(data.Message, errPrefix) {
data.Message = strings.TrimLeft(data.Message, errPrefix)
data.SpinnerClass = spinnerClassErrorSign
} else {
data.SpinnerClass = spinnerClassSpinner
}
buf := bytes.NewBuffer(make([]byte, 128)) // more than enough
err := loadingPageTmpl.Execute(buf, data)
if err != nil { // should never happen
panic(err)
}
return &http.Response{
StatusCode: http.StatusAccepted,
Header: http.Header{
headerContentType: {htmlContentType},
"Cache-Control": {
"no-cache",
"no-store",
"must-revalidate",
},
},
Body: io.NopCloser(buf),
ContentLength: int64(buf.Len()),
}, nil
}

View file

@ -1,6 +1,10 @@
package idlewatcher package idlewatcher
import "net/http" import (
"context"
"net/http"
"time"
)
type ( type (
roundTripper struct { roundTripper struct {
@ -12,3 +16,63 @@ type (
func (rt roundTripper) RoundTrip(req *http.Request) (*http.Response, error) { func (rt roundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
return rt.patched(req) return rt.patched(req)
} }
func (w *watcher) roundTrip(origRoundTrip roundTripFunc, req *http.Request) (*http.Response, error) {
// target site is ready, passthrough
if w.ready.Load() {
return origRoundTrip(req)
}
// wake the container
w.wakeCh <- struct{}{}
// initial request
targetUrl := req.Header.Get(headerGoProxyTargetURL)
if targetUrl == "" {
return w.makeResp(
"%s is starting... Please wait",
w.ContainerName,
)
}
w.l.Debug("serving event")
// stream request
rtDone := make(chan *http.Response, 1)
ctx, cancel := context.WithTimeout(req.Context(), w.WakeTimeout)
defer cancel()
// loop original round trip until success in a goroutine
go func() {
for {
select {
case <-ctx.Done():
return
case <-w.ctx.Done():
return
default:
resp, err := origRoundTrip(req)
if err == nil {
w.ready.Store(true)
rtDone <- resp
return
}
time.Sleep(time.Millisecond * 200)
}
}
}()
for {
select {
case resp := <-rtDone:
return w.makeSuccResp(targetUrl, resp)
case <-ctx.Done():
if ctx.Err() == context.DeadlineExceeded {
return w.makeErrResp("Timed out waiting for %s to fully wake", w.ContainerName)
}
return w.makeErrResp("idlewatcher has stopped\n%s", w.ctx.Err().Error())
case <-w.ctx.Done():
return w.makeErrResp("idlewatcher has stopped\n%s", w.ctx.Err().Error())
}
}
}

View file

@ -1,9 +1,7 @@
package idlewatcher package idlewatcher
import ( import (
"bytes"
"context" "context"
"io"
"net/http" "net/http"
"sync" "sync"
"sync/atomic" "sync/atomic"
@ -16,33 +14,45 @@ import (
P "github.com/yusing/go-proxy/proxy" P "github.com/yusing/go-proxy/proxy"
PT "github.com/yusing/go-proxy/proxy/fields" PT "github.com/yusing/go-proxy/proxy/fields"
W "github.com/yusing/go-proxy/watcher" W "github.com/yusing/go-proxy/watcher"
event "github.com/yusing/go-proxy/watcher/events"
) )
type watcher struct { type (
watcher struct {
*P.ReverseProxyEntry *P.ReverseProxyEntry
client D.Client client D.Client
refCount atomic.Int32 ready atomic.Bool // whether the site is ready to accept connection
stopByMethod StopCallback // send a docker command w.r.t. `stop_method`
stopByMethod StopCallback
wakeCh chan struct{} wakeCh chan struct{}
wakeDone chan E.NestedError wakeDone chan E.NestedError
running atomic.Bool
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
refCount *sync.WaitGroup
l logrus.FieldLogger l logrus.FieldLogger
} }
type (
WakeDone <-chan error WakeDone <-chan error
WakeFunc func() WakeDone WakeFunc func() WakeDone
StopCallback func() E.NestedError StopCallback func() E.NestedError
) )
var (
mainLoopCtx context.Context
mainLoopCancel context.CancelFunc
mainLoopWg sync.WaitGroup
watcherMap = make(map[string]*watcher)
watcherMapMu sync.Mutex
newWatcherCh = make(chan *watcher)
logger = logrus.WithField("module", "idle_watcher")
)
func Register(entry *P.ReverseProxyEntry) (*watcher, E.NestedError) { func Register(entry *P.ReverseProxyEntry) (*watcher, E.NestedError) {
failure := E.Failure("idle_watcher register") failure := E.Failure("idle_watcher register")
@ -67,12 +77,12 @@ func Register(entry *P.ReverseProxyEntry) (*watcher, E.NestedError) {
w := &watcher{ w := &watcher{
ReverseProxyEntry: entry, ReverseProxyEntry: entry,
client: client, client: client,
refCount: &sync.WaitGroup{},
wakeCh: make(chan struct{}, 1), wakeCh: make(chan struct{}, 1),
wakeDone: make(chan E.NestedError, 1), wakeDone: make(chan E.NestedError, 1),
l: logger.WithField("container", entry.ContainerName), l: logger.WithField("container", entry.ContainerName),
} }
w.refCount.Add(1) w.refCount.Add(1)
w.running.Store(entry.ContainerRunning)
w.stopByMethod = w.getStopCallback() w.stopByMethod = w.getStopCallback()
watcherMap[w.ContainerName] = w watcherMap[w.ContainerName] = w
@ -84,20 +94,9 @@ func Register(entry *P.ReverseProxyEntry) (*watcher, E.NestedError) {
return w, nil return w, nil
} }
// If the container is not registered, this is no-op
func Unregister(containerName string) { func Unregister(containerName string) {
watcherMapMu.Lock()
defer watcherMapMu.Unlock()
if w, ok := watcherMap[containerName]; ok { if w, ok := watcherMap[containerName]; ok {
if w.refCount.Add(-1) > 0 { w.refCount.Add(-1)
return
}
if w.cancel != nil {
w.cancel()
}
w.client.Close()
delete(watcherMap, containerName)
} }
} }
@ -107,8 +106,6 @@ func Start() {
mainLoopCtx, mainLoopCancel = context.WithCancel(context.Background()) mainLoopCtx, mainLoopCancel = context.WithCancel(context.Background())
defer mainLoopWg.Wait()
for { for {
select { select {
case <-mainLoopCtx.Done(): case <-mainLoopCtx.Done():
@ -117,8 +114,11 @@ func Start() {
w.l.Debug("registered") w.l.Debug("registered")
mainLoopWg.Add(1) mainLoopWg.Add(1)
go func() { go func() {
w.watch() w.watchUntilCancel()
Unregister(w.ContainerName) w.refCount.Wait() // wait for 0 ref count
w.client.Close()
delete(watcherMap, w.ContainerName)
w.l.Debug("unregistered") w.l.Debug("unregistered")
mainLoopWg.Done() mainLoopWg.Done()
}() }()
@ -137,31 +137,6 @@ func (w *watcher) PatchRoundTripper(rtp http.RoundTripper) roundTripper {
}} }}
} }
func (w *watcher) roundTrip(origRoundTrip roundTripFunc, req *http.Request) (*http.Response, error) {
w.wakeCh <- struct{}{}
if w.running.Load() {
return origRoundTrip(req)
}
timeout := time.After(w.WakeTimeout)
for {
if w.running.Load() {
return origRoundTrip(req)
}
select {
case <-req.Context().Done():
return nil, req.Context().Err()
case err := <-w.wakeDone:
if err != nil {
return nil, err.Error()
}
case <-timeout:
return getLoadingResponse(), nil
}
}
}
func (w *watcher) containerStop() error { func (w *watcher) containerStop() error {
return w.client.ContainerStop(w.ctx, w.ContainerName, container.StopOptions{ return w.client.ContainerStop(w.ctx, w.ContainerName, container.StopOptions{
Signal: string(w.StopSignal), Signal: string(w.StopSignal),
@ -205,7 +180,6 @@ func (w *watcher) wakeIfStopped() E.NestedError {
case "paused": case "paused":
return E.From(w.containerUnpause()) return E.From(w.containerUnpause())
case "running": case "running":
w.running.Store(true)
return nil return nil
default: default:
return E.Unexpected("container state", status) return E.Unexpected("container state", status)
@ -236,15 +210,12 @@ func (w *watcher) getStopCallback() StopCallback {
} }
} }
func (w *watcher) watch() { func (w *watcher) watchUntilCancel() {
watcherCtx, watcherCancel := context.WithCancel(context.Background())
w.ctx = watcherCtx
w.cancel = watcherCancel
dockerWatcher := W.NewDockerWatcherWithClient(w.client)
defer close(w.wakeCh) defer close(w.wakeCh)
w.ctx, w.cancel = context.WithCancel(context.Background())
dockerWatcher := W.NewDockerWatcherWithClient(w.client)
dockerEventCh, dockerEventErrCh := dockerWatcher.EventsWithOptions(w.ctx, W.DockerListOptions{ dockerEventCh, dockerEventErrCh := dockerWatcher.EventsWithOptions(w.ctx, W.DockerListOptions{
Filters: W.NewDockerFilter( Filters: W.NewDockerFilter(
W.DockerFilterContainer, W.DockerFilterContainer,
@ -265,7 +236,7 @@ func (w *watcher) watch() {
select { select {
case <-mainLoopCtx.Done(): case <-mainLoopCtx.Done():
w.cancel() w.cancel()
case <-watcherCtx.Done(): case <-w.ctx.Done():
w.l.Debug("stopped") w.l.Debug("stopped")
return return
case err := <-dockerEventErrCh: case err := <-dockerEventErrCh:
@ -273,16 +244,18 @@ func (w *watcher) watch() {
w.l.Error(E.FailWith("docker watcher", err)) w.l.Error(E.FailWith("docker watcher", err))
} }
case e := <-dockerEventCh: case e := <-dockerEventCh:
switch e.Action { switch {
case event.ActionDockerStartUnpause: // create / start / unpause
w.running.Store(true) case e.Action.IsContainerWake():
w.l.Infof("%s %s", e.ActorName, e.Action) ticker.Reset(w.IdleTimeout)
case event.ActionDockerStopPause: w.l.Info(e)
w.running.Store(false) default: // stop / pause / kill
w.l.Infof("%s %s", e.ActorName, e.Action) ticker.Stop()
w.ready.Store(false)
w.l.Info(e)
} }
case <-ticker.C: case <-ticker.C:
w.l.Debug("timeout") w.l.Debug("idle timeout")
ticker.Stop() ticker.Stop()
if err := w.stopByMethod(); err != nil && err.IsNot(context.Canceled) { if err := w.stopByMethod(); err != nil && err.IsNot(context.Canceled) {
w.l.Error(E.FailWith("stop", err).Extraf("stop method: %s", w.StopMethod)) w.l.Error(E.FailWith("stop", err).Extraf("stop method: %s", w.StopMethod))
@ -301,57 +274,3 @@ func (w *watcher) watch() {
} }
} }
} }
func getLoadingResponse() *http.Response {
return &http.Response{
StatusCode: http.StatusAccepted,
Header: http.Header{
"Content-Type": {"text/html"},
"Cache-Control": {
"no-cache",
"no-store",
"must-revalidate",
},
},
Body: io.NopCloser(bytes.NewReader((loadingPage))),
ContentLength: int64(len(loadingPage)),
}
}
var (
mainLoopCtx context.Context
mainLoopCancel context.CancelFunc
mainLoopWg sync.WaitGroup
watcherMap = make(map[string]*watcher)
watcherMapMu sync.Mutex
newWatcherCh = make(chan *watcher)
logger = logrus.WithField("module", "idle_watcher")
loadingPage = []byte(`
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Loading...</title>
</head>
<body>
<script>
window.onload = function() {
setTimeout(function() {
window.location.reload()
}, 1000)
// fetch(window.location.href)
// .then(resp => resp.text())
// .then(data => { document.body.innerHTML = data; })
// .catch(err => { document.body.innerHTML = 'Error: ' + err; });
};
</script>
<h1>Container is starting... Please wait</h1>
</body>
</html>
`[1:])
)

View file

@ -25,6 +25,7 @@ func NewBuilder(format string, args ...any) Builder {
func (b Builder) Add(err NestedError) Builder { func (b Builder) Add(err NestedError) Builder {
if err != nil { if err != nil {
b.Lock() b.Lock()
// TODO: if err severity is higher than b.severity, update b.severity
b.errors = append(b.errors, err) b.errors = append(b.errors, err)
b.Unlock() b.Unlock()
} }

View file

@ -18,8 +18,8 @@ type (
) )
const ( const (
SeverityFatal Severity = iota SeverityWarning Severity = iota
SeverityWarning SeverityFatal
) )
func From(err error) NestedError { func From(err error) NestedError {

View file

@ -20,7 +20,7 @@ func Failure(what string) NestedError {
} }
func FailedWhy(what string, why string) NestedError { func FailedWhy(what string, why string) NestedError {
return errorf("%s %w because %s", what, ErrFailure, why) return Failure(what).With(why)
} }
func FailWith(what string, err any) NestedError { func FailWith(what string, err any) NestedError {

View file

@ -8,6 +8,9 @@ import (
"net/http" "net/http"
"os" "os"
"os/signal" "os/signal"
"reflect"
"runtime"
"strings"
"sync" "sync"
"syscall" "syscall"
"time" "time"
@ -28,6 +31,7 @@ import (
func main() { func main() {
args := common.GetArgs() args := common.GetArgs()
l := logrus.WithField("module", "main") l := logrus.WithField("module", "main")
onShutdown := F.NewSlice[func()]()
if common.IsDebug { if common.IsDebug {
logrus.SetLevel(logrus.DebugLevel) logrus.SetLevel(logrus.DebugLevel)
@ -40,20 +44,18 @@ func main() {
DisableSorting: true, DisableSorting: true,
DisableLevelTruncation: true, DisableLevelTruncation: true,
FullTimestamp: true, FullTimestamp: true,
ForceColors: true,
TimestampFormat: "01-02 15:04:05", TimestampFormat: "01-02 15:04:05",
}) })
} }
if args.Command == common.CommandReload { if args.Command == common.CommandReload {
if err := apiUtils.ReloadServer(); err.HasError() { if err := apiUtils.ReloadServer(); err.HasError() {
l.Fatal(err) log.Fatal(err)
} }
log.Print("ok")
return return
} }
onShutdown := F.NewSlice[func()]()
// exit if only validate config // exit if only validate config
if args.Command == common.CommandValidate { if args.Command == common.CommandValidate {
data, err := os.ReadFile(common.ConfigPath) data, err := os.ReadFile(common.ConfigPath)
@ -72,19 +74,19 @@ func main() {
log.Fatal(err) log.Fatal(err)
} }
if args.Command == common.CommandListConfigs { switch args.Command {
case common.CommandListConfigs:
printJSON(cfg.Value()) printJSON(cfg.Value())
return return
} case common.CommandListRoutes:
if args.Command == common.CommandListRoutes {
printJSON(cfg.RoutesByAlias()) printJSON(cfg.RoutesByAlias())
return return
} case common.CommandDebugListEntries:
if args.Command == common.CommandDebugListEntries {
printJSON(cfg.DumpEntries()) printJSON(cfg.DumpEntries())
return return
case common.CommandDebugListProviders:
printJSON(cfg.DumpProviders())
return
} }
cfg.StartProxyProviders() cfg.StartProxyProviders()
@ -106,25 +108,14 @@ func main() {
autocert := cfg.GetAutoCertProvider() autocert := cfg.GetAutoCertProvider()
if autocert != nil { if autocert != nil {
if err = autocert.LoadCert(); err.HasError() { ctx, cancel := context.WithCancel(context.Background())
if !err.Is(os.ErrNotExist) { // ignore if cert doesn't exist if err = autocert.Setup(ctx); err != nil && err.IsWarning() {
l.Error(err) cancel()
}
l.Debug("obtaining cert due to error loading cert")
if err = autocert.ObtainCert(); err.HasError() {
l.Warn(err) l.Warn(err)
} } else if err.IsFatal() {
} l.Fatal(err)
} else {
if err.NoError() { onShutdown.Add(cancel)
ctx, certRenewalCancel := context.WithCancel(context.Background())
go autocert.ScheduleRenewal(ctx)
onShutdown.Add(certRenewalCancel)
}
for _, expiry := range autocert.GetExpiries() {
l.Infof("certificate expire on %s", expiry)
break
} }
} else { } else {
l.Info("autocert not configured") l.Info("autocert not configured")
@ -165,7 +156,9 @@ func main() {
wg.Add(onShutdown.Size()) wg.Add(onShutdown.Size())
onShutdown.ForEach(func(f func()) { onShutdown.ForEach(func(f func()) {
go func() { go func() {
l.Debugf("waiting for %s to complete...", funcName(f))
f() f()
l.Debugf("%s done", funcName(f))
wg.Done() wg.Done()
}() }()
}) })
@ -180,9 +173,17 @@ func main() {
logrus.Info("shutdown complete") logrus.Info("shutdown complete")
case <-timeout: case <-timeout:
logrus.Info("timeout waiting for shutdown") logrus.Info("timeout waiting for shutdown")
onShutdown.ForEach(func(f func()) {
l.Warnf("%s() is still running", funcName(f))
})
} }
} }
func funcName(f func()) string {
parts := strings.Split(runtime.FuncForPC(reflect.ValueOf(f).Pointer()).Name(), "/go-proxy/")
return parts[len(parts)-1]
}
func printJSON(obj any) { func printJSON(obj any) {
j, err := E.Check(json.Marshal(obj)) j, err := E.Check(json.Marshal(obj))
if err.HasError() { if err.HasError() {

View file

@ -32,7 +32,7 @@ func ValidateStreamScheme(s string) (ss *StreamScheme, err E.NestedError) {
} }
func (s StreamScheme) String() string { func (s StreamScheme) String() string {
return fmt.Sprintf("%s -> %s", s.ListeningScheme, s.ProxyScheme) return fmt.Sprintf("%s:%s", s.ListeningScheme, s.ProxyScheme)
} }
// IsCoherent checks if the ListeningScheme and ProxyScheme of the StreamScheme are equal. // IsCoherent checks if the ListeningScheme and ProxyScheme of the StreamScheme are equal.

View file

@ -1,6 +1,7 @@
package provider package provider
import ( import (
"fmt"
"regexp" "regexp"
"strconv" "strconv"
"strings" "strings"
@ -26,6 +27,10 @@ func DockerProviderImpl(dockerHost string) (ProviderImpl, E.NestedError) {
return &DockerProvider{dockerHost: dockerHost, hostname: hostname}, nil return &DockerProvider{dockerHost: dockerHost, hostname: hostname}, nil
} }
func (p *DockerProvider) String() string {
return fmt.Sprintf("docker:%s", p.dockerHost)
}
func (p *DockerProvider) NewWatcher() W.Watcher { func (p *DockerProvider) NewWatcher() W.Watcher {
return W.NewDockerWatcher(p.dockerHost) return W.NewDockerWatcher(p.dockerHost)
} }
@ -145,7 +150,7 @@ func (p *DockerProvider) entriesFromContainerLabels(container D.Container) (M.Ra
entryPortSplit := strings.Split(entry.Port, ":") entryPortSplit := strings.Split(entry.Port, ":")
if len(entryPortSplit) == 2 && entryPortSplit[1] == containerPort { if len(entryPortSplit) == 2 && entryPortSplit[1] == containerPort {
entryPortSplit[1] = publicPort entryPortSplit[1] = publicPort
} else if entryPortSplit[0] == containerPort { } else if len(entryPortSplit) == 1 && entryPortSplit[0] == containerPort {
entryPortSplit[0] = publicPort entryPortSplit[0] = publicPort
} }
entry.Port = strings.Join(entryPortSplit, ":") entry.Port = strings.Join(entryPortSplit, ":")

View file

@ -35,7 +35,11 @@ func FileProviderImpl(filename string) (ProviderImpl, E.NestedError) {
} }
func Validate(data []byte) E.NestedError { func Validate(data []byte) E.NestedError {
return U.ValidateYaml(U.GetSchema(common.ProvidersSchemaPath), data) return U.ValidateYaml(U.GetSchema(common.FileProviderSchemaPath), data)
}
func (p FileProvider) String() string {
return p.fileName
} }
func (p FileProvider) OnEvent(event W.Event, routes R.Routes) (res EventResult) { func (p FileProvider) OnEvent(event W.Event, routes R.Routes) (res EventResult) {

View file

@ -2,7 +2,6 @@ package provider
import ( import (
"context" "context"
"fmt"
"path" "path"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -13,7 +12,7 @@ import (
type ( type (
Provider struct { Provider struct {
ProviderImpl ProviderImpl `json:"-"`
name string name string
t ProviderType t ProviderType
@ -30,6 +29,7 @@ type (
// even returns error, routes must be non-nil // even returns error, routes must be non-nil
LoadRoutesImpl() (R.Routes, E.NestedError) LoadRoutesImpl() (R.Routes, E.NestedError)
OnEvent(event W.Event, routes R.Routes) EventResult OnEvent(event W.Event, routes R.Routes) EventResult
String() string
} }
ProviderType string ProviderType string
EventResult struct { EventResult struct {
@ -83,8 +83,9 @@ func (p *Provider) GetType() ProviderType {
return p.t return p.t
} }
func (p *Provider) String() string { // to work with json marshaller
return fmt.Sprintf("%s-%s", p.t, p.name) func (p *Provider) MarshalText() ([]byte, error) {
return []byte(p.String()), nil
} }
func (p *Provider) StartAllRoutes() (res E.NestedError) { func (p *Provider) StartAllRoutes() (res E.NestedError) {
@ -92,7 +93,6 @@ func (p *Provider) StartAllRoutes() (res E.NestedError) {
defer errors.To(&res) defer errors.To(&res)
// start watcher no matter load success or not // start watcher no matter load success or not
p.watcherCtx, p.watcherCancel = context.WithCancel(context.Background())
go p.watchEvents() go p.watchEvents()
nStarted := 0 nStarted := 0
@ -153,6 +153,7 @@ func (p *Provider) LoadRoutes() E.NestedError {
} }
func (p *Provider) watchEvents() { func (p *Provider) watchEvents() {
p.watcherCtx, p.watcherCancel = context.WithCancel(context.Background())
events, errs := p.watcher.Events(p.watcherCtx) events, errs := p.watcher.Events(p.watcherCtx)
l := p.l.WithField("module", "watcher") l := p.l.WithField("module", "watcher")
@ -160,21 +161,15 @@ func (p *Provider) watchEvents() {
select { select {
case <-p.watcherCtx.Done(): case <-p.watcherCtx.Done():
return return
case event, ok := <-events: case event := <-events:
if !ok { // channel closed
return
}
res := p.OnEvent(event, p.routes) res := p.OnEvent(event, p.routes)
l.Infof("%s event %q", event.Type, event) l.Infof("%s event %q", event.Type, event)
l.Infof("%d route added, %d routes removed", res.nAdded, res.nRemoved) l.Infof("%d route added, %d routes removed", res.nAdded, res.nRemoved)
if res.err.HasError() { if res.err.HasError() {
l.Error(res.err) l.Error(res.err)
} }
case err, ok := <-errs: case err := <-errs:
if !ok { if err == nil || err.Is(context.Canceled) {
return
}
if err.Is(context.Canceled) {
continue continue
} }
l.Errorf("watcher error: %s", err) l.Errorf("watcher error: %s", err)

View file

@ -232,7 +232,7 @@ func NewReverseProxy(target *url.URL, transport http.RoundTripper, entry *Revers
} }
return &ReverseProxy{Rewrite: func(pr *ProxyRequest) { return &ReverseProxy{Rewrite: func(pr *ProxyRequest) {
rewriteRequestURL(pr.Out, target) rewriteRequestURL(pr.Out, target)
pr.SetXForwarded() // pr.SetXForwarded()
setHeaders(pr.Out) setHeaders(pr.Out)
hideHeaders(pr.Out) hideHeaders(pr.Out)
}, Transport: transport} }, Transport: transport}
@ -348,9 +348,9 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
} }
outreq.Header.Del("Forwarded") outreq.Header.Del("Forwarded")
// outreq.Header.Del("X-Forwarded-For") outreq.Header.Del("X-Forwarded-For")
// outreq.Header.Del("X-Forwarded-Host") outreq.Header.Del("X-Forwarded-Host")
// outreq.Header.Del("X-Forwarded-Proto") outreq.Header.Del("X-Forwarded-Proto")
pr := &ProxyRequest{ pr := &ProxyRequest{
In: req, In: req,

View file

@ -4,5 +4,5 @@ import (
"time" "time"
) )
const udpBufferSize = 1500 const udpBufferSize = 8192
const streamStopListenTimeout = 1 * time.Second const streamStopListenTimeout = 1 * time.Second

View file

@ -37,19 +37,23 @@ type (
) )
func NewHTTPRoute(entry *P.ReverseProxyEntry) (*HTTPRoute, E.NestedError) { func NewHTTPRoute(entry *P.ReverseProxyEntry) (*HTTPRoute, E.NestedError) {
var trans http.RoundTripper var trans *http.Transport
var regIdleWatcher func() E.NestedError var regIdleWatcher func() E.NestedError
var unregIdleWatcher func() var unregIdleWatcher func()
if entry.NoTLSVerify { if entry.NoTLSVerify {
trans = transportNoTLS trans = transportNoTLS.Clone()
} else { } else {
trans = transport trans = transport.Clone()
} }
rp := P.NewReverseProxy(entry.URL, trans, entry) rp := P.NewReverseProxy(entry.URL, trans, entry)
if entry.UseIdleWatcher() { if entry.UseIdleWatcher() {
// allow time for response header up to `WakeTimeout`
if entry.WakeTimeout > trans.ResponseHeaderTimeout {
trans.ResponseHeaderTimeout = entry.WakeTimeout
}
regIdleWatcher = func() E.NestedError { regIdleWatcher = func() E.NestedError {
watcher, err := idlewatcher.Register(entry) watcher, err := idlewatcher.Register(entry)
if err.HasError() { if err.HasError() {
@ -114,6 +118,7 @@ func (r *HTTPRoute) Stop() E.NestedError {
if r.unregIdleWatcher != nil { if r.unregIdleWatcher != nil {
r.unregIdleWatcher() r.unregIdleWatcher()
r.unregIdleWatcher = nil
} }
r.mux = nil r.mux = nil
@ -151,13 +156,13 @@ func findMux(host string) (*http.ServeMux, E.NestedError) {
} }
var ( var (
transport = &http.Transport{ defaultDialer = net.Dialer{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 60 * time.Second, Timeout: 60 * time.Second,
KeepAlive: 60 * time.Second, KeepAlive: 60 * time.Second,
}).DialContext, }
MaxIdleConns: 1000, transport = &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: defaultDialer.DialContext,
MaxIdleConnsPerHost: 1000, MaxIdleConnsPerHost: 1000,
} }
transportNoTLS = func() *http.Transport { transportNoTLS = func() *http.Transport {

View file

@ -2,6 +2,7 @@ package route
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"sync" "sync"
"sync/atomic" "sync/atomic"
@ -129,7 +130,7 @@ func (r *StreamRoute) grHandleConnections() {
case conn := <-r.connCh: case conn := <-r.connCh:
go func() { go func() {
err := r.Handle(conn) err := r.Handle(conn)
if err != nil { if err != nil && !errors.Is(err, context.Canceled) {
r.l.Error(err) r.l.Error(err)
} }
}() }()

View file

@ -5,7 +5,6 @@ import (
"fmt" "fmt"
"net" "net"
"sync" "sync"
"syscall"
"time" "time"
U "github.com/yusing/go-proxy/utils" U "github.com/yusing/go-proxy/utils"
@ -13,14 +12,16 @@ import (
const tcpDialTimeout = 5 * time.Second const tcpDialTimeout = 5 * time.Second
type Pipes []*U.BidirectionalPipe type (
Pipes []U.BidirectionalPipe
type TCPRoute struct { TCPRoute struct {
*StreamRoute *StreamRoute
listener net.Listener listener net.Listener
pipe Pipes pipe Pipes
mu sync.Mutex mu sync.Mutex
} }
)
func NewTCPRoute(base *StreamRoute) StreamImpl { func NewTCPRoute(base *StreamRoute) StreamImpl {
return &TCPRoute{ return &TCPRoute{
@ -59,10 +60,11 @@ func (route *TCPRoute) Handle(c any) error {
} }
route.mu.Lock() route.mu.Lock()
defer route.mu.Unlock()
pipe := U.NewBidirectionalPipe(route.ctx, clientConn, serverConn) pipe := U.NewBidirectionalPipe(route.ctx, clientConn, serverConn)
route.pipe = append(route.pipe, pipe) route.pipe = append(route.pipe, pipe)
route.mu.Unlock()
return pipe.Start() return pipe.Start()
} }
@ -72,16 +74,4 @@ func (route *TCPRoute) CloseListeners() {
} }
route.listener.Close() route.listener.Close()
route.listener = nil route.listener = nil
for _, pipe := range route.pipe {
if err := pipe.Stop(); err != nil {
switch err {
// target closing connection
// TODO: handle this by fixing utils/io.go
case net.ErrClosed, syscall.EPIPE:
return
default:
route.l.Error(err)
}
}
}
} }

View file

@ -4,33 +4,34 @@ import (
"fmt" "fmt"
"io" "io"
"net" "net"
"sync"
"github.com/yusing/go-proxy/utils" U "github.com/yusing/go-proxy/utils"
F "github.com/yusing/go-proxy/utils/functional"
) )
type UDPRoute struct { type (
UDPRoute struct {
*StreamRoute *StreamRoute
connMap UDPConnMap connMap UDPConnMap
connMapMutex sync.Mutex
listeningConn *net.UDPConn listeningConn *net.UDPConn
targetAddr *net.UDPAddr targetAddr *net.UDPAddr
} }
UDPConn struct {
type UDPConn struct {
src *net.UDPConn src *net.UDPConn
dst *net.UDPConn dst *net.UDPConn
*utils.BidirectionalPipe U.BidirectionalPipe
} }
UDPConnMap = F.Map[string, *UDPConn]
)
type UDPConnMap map[string]*UDPConn var NewUDPConnMap = F.NewMapOf[string, *UDPConn]
func NewUDPRoute(base *StreamRoute) StreamImpl { func NewUDPRoute(base *StreamRoute) StreamImpl {
return &UDPRoute{ return &UDPRoute{
StreamRoute: base, StreamRoute: base,
connMap: make(UDPConnMap), connMap: NewUDPConnMap(),
} }
} }
@ -69,11 +70,9 @@ func (route *UDPRoute) Accept() (any, error) {
} }
key := srcAddr.String() key := srcAddr.String()
conn, ok := route.connMap[key] conn, ok := route.connMap.Load(key)
if !ok { if !ok {
route.connMapMutex.Lock()
if conn, ok = route.connMap[key]; !ok {
srcConn, err := net.DialUDP("udp", nil, srcAddr) srcConn, err := net.DialUDP("udp", nil, srcAddr)
if err != nil { if err != nil {
return nil, err return nil, err
@ -86,11 +85,9 @@ func (route *UDPRoute) Accept() (any, error) {
conn = &UDPConn{ conn = &UDPConn{
srcConn, srcConn,
dstConn, dstConn,
utils.NewBidirectionalPipe(route.ctx, sourceRWCloser{in, dstConn}, sourceRWCloser{in, srcConn}), U.NewBidirectionalPipe(route.ctx, sourceRWCloser{in, dstConn}, sourceRWCloser{in, srcConn}),
} }
route.connMap[key] = conn route.connMap.Store(key, conn)
}
route.connMapMutex.Unlock()
} }
_, err = conn.dst.Write(buffer[:nRead]) _, err = conn.dst.Write(buffer[:nRead])
@ -106,15 +103,15 @@ func (route *UDPRoute) CloseListeners() {
route.listeningConn.Close() route.listeningConn.Close()
route.listeningConn = nil route.listeningConn = nil
} }
for _, conn := range route.connMap { route.connMap.RangeAll(func(_ string, conn *UDPConn) {
if err := conn.src.Close(); err != nil { if err := conn.src.Close(); err != nil {
route.l.Errorf("error closing src conn: %s", err) route.l.Errorf("error closing src conn: %s", err)
} }
if err := conn.dst.Close(); err != nil { if err := conn.dst.Close(); err != nil {
route.l.Error("error closing dst conn: %s", err) route.l.Error("error closing dst conn: %s", err)
} }
} })
route.connMap = make(UDPConnMap) route.connMap.Clear()
} }
type sourceRWCloser struct { type sourceRWCloser struct {

View file

@ -3,9 +3,10 @@ package utils
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"errors"
"io" "io"
"os" "os"
"sync/atomic" "syscall"
E "github.com/yusing/go-proxy/error" E "github.com/yusing/go-proxy/error"
) )
@ -16,15 +17,19 @@ type (
Path string Path string
} }
ReadCloser struct { ContextReader struct {
ctx context.Context ctx context.Context
r io.ReadCloser io.Reader
closed atomic.Bool }
ContextWriter struct {
ctx context.Context
io.Writer
} }
Pipe struct { Pipe struct {
r ReadCloser r ContextReader
w io.WriteCloser w ContextWriter
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
} }
@ -35,48 +40,48 @@ type (
} }
) )
func (r *ReadCloser) Read(p []byte) (int, error) { func (r *ContextReader) Read(p []byte) (int, error) {
select { select {
case <-r.ctx.Done(): case <-r.ctx.Done():
return 0, r.ctx.Err() return 0, r.ctx.Err()
default: default:
return r.r.Read(p) return r.Reader.Read(p)
} }
} }
func (r *ReadCloser) Close() error { func (w *ContextWriter) Write(p []byte) (int, error) {
if r.closed.Load() { select {
return nil case <-w.ctx.Done():
return 0, w.ctx.Err()
default:
return w.Writer.Write(p)
} }
r.closed.Store(true)
return r.r.Close()
} }
func NewPipe(ctx context.Context, r io.ReadCloser, w io.WriteCloser) *Pipe { func NewPipe(ctx context.Context, r io.ReadCloser, w io.WriteCloser) *Pipe {
ctx, cancel := context.WithCancel(ctx) _, cancel := context.WithCancel(ctx)
return &Pipe{ return &Pipe{
r: ReadCloser{ctx: ctx, r: r}, r: ContextReader{ctx: ctx, Reader: r},
w: w, w: ContextWriter{ctx: ctx, Writer: w},
ctx: ctx, ctx: ctx,
cancel: cancel, cancel: cancel,
} }
} }
func (p *Pipe) Start() error { func (p *Pipe) Start() (err error) {
return Copy(p.ctx, p.w, &p.r) err = Copy(&p.w, &p.r)
switch {
case
// NOTE: ignoring broken pipe and connection reset by peer
errors.Is(err, syscall.EPIPE),
errors.Is(err, syscall.ECONNRESET):
return nil
}
return err
} }
func (p *Pipe) Stop() error { func NewBidirectionalPipe(ctx context.Context, rw1 io.ReadWriteCloser, rw2 io.ReadWriteCloser) BidirectionalPipe {
p.cancel() return BidirectionalPipe{
return E.JoinE("error stopping pipe", p.r.Close(), p.w.Close()).Error()
}
func (p *Pipe) Write(b []byte) (int, error) {
return p.w.Write(b)
}
func NewBidirectionalPipe(ctx context.Context, rw1 io.ReadWriteCloser, rw2 io.ReadWriteCloser) *BidirectionalPipe {
return &BidirectionalPipe{
pSrcDst: NewPipe(ctx, rw1, rw2), pSrcDst: NewPipe(ctx, rw1, rw2),
pDstSrc: NewPipe(ctx, rw2, rw1), pDstSrc: NewPipe(ctx, rw2, rw1),
} }
@ -89,7 +94,7 @@ func NewBidirectionalPipeIntermediate(ctx context.Context, listener io.ReadClose
} }
} }
func (p *BidirectionalPipe) Start() error { func (p BidirectionalPipe) Start() error {
errCh := make(chan error, 2) errCh := make(chan error, 2)
go func() { go func() {
errCh <- p.pSrcDst.Start() errCh <- p.pSrcDst.Start()
@ -97,20 +102,11 @@ func (p *BidirectionalPipe) Start() error {
go func() { go func() {
errCh <- p.pDstSrc.Start() errCh <- p.pDstSrc.Start()
}() }()
for err := range errCh { return E.JoinE("bidirectional pipe error", <-errCh, <-errCh).Error()
if err != nil {
return err
}
}
return nil
} }
func (p *BidirectionalPipe) Stop() error { func Copy(dst *ContextWriter, src *ContextReader) error {
return E.JoinE("error stopping pipe", p.pSrcDst.Stop(), p.pDstSrc.Stop()).Error() _, err := io.Copy(dst, src)
}
func Copy(ctx context.Context, dst io.WriteCloser, src io.ReadCloser) error {
_, err := io.Copy(dst, &ReadCloser{ctx: ctx, r: src})
return err return err
} }

View file

@ -4,13 +4,10 @@ import (
"github.com/santhosh-tekuri/jsonschema" "github.com/santhosh-tekuri/jsonschema"
) )
var schemaCompiler = func() *jsonschema.Compiler { var (
c := jsonschema.NewCompiler() schemaCompiler = jsonschema.NewCompiler()
c.Draft = jsonschema.Draft7 schemaStorage = make(map[string]*jsonschema.Schema)
return c )
}()
var schemaStorage = make(map[string]*jsonschema.Schema)
func GetSchema(path string) *jsonschema.Schema { func GetSchema(path string) *jsonschema.Schema {
if schema, ok := schemaStorage[path]; ok { if schema, ok := schemaStorage[path]; ok {

View file

@ -42,11 +42,19 @@ func DockerrFilterContainerName(name string) filters.KeyValuePair {
} }
func NewDockerWatcher(host string) DockerWatcher { func NewDockerWatcher(host string) DockerWatcher {
return DockerWatcher{host: host, FieldLogger: logrus.WithField("module", "docker_watcher")} return DockerWatcher{
host: host,
FieldLogger: (logrus.
WithField("module", "docker_watcher").
WithField("host", host))}
} }
func NewDockerWatcherWithClient(client D.Client) DockerWatcher { func NewDockerWatcherWithClient(client D.Client) DockerWatcher {
return DockerWatcher{client: client, FieldLogger: logrus.WithField("module", "docker_watcher")} return DockerWatcher{
client: client,
FieldLogger: (logrus.
WithField("module", "docker_watcher").
WithField("host", client.DaemonHost()))}
} }
func (w DockerWatcher) Events(ctx context.Context) (<-chan Event, <-chan E.NestedError) { func (w DockerWatcher) Events(ctx context.Context) (<-chan Event, <-chan E.NestedError) {
@ -56,7 +64,6 @@ func (w DockerWatcher) Events(ctx context.Context) (<-chan Event, <-chan E.Neste
func (w DockerWatcher) EventsWithOptions(ctx context.Context, options DockerListOptions) (<-chan Event, <-chan E.NestedError) { func (w DockerWatcher) EventsWithOptions(ctx context.Context, options DockerListOptions) (<-chan Event, <-chan E.NestedError) {
eventCh := make(chan Event) eventCh := make(chan Event)
errCh := make(chan E.NestedError) errCh := make(chan E.NestedError)
started := make(chan struct{})
eventsCtx, eventsCancel := context.WithCancel(ctx) eventsCtx, eventsCancel := context.WithCancel(ctx)
@ -75,7 +82,7 @@ func (w DockerWatcher) EventsWithOptions(ctx context.Context, options DockerList
attempts := 0 attempts := 0
for { for {
w.client, err = D.ConnectClient(w.host) w.client, err = D.ConnectClient(w.host)
if err != nil { if err == nil {
break break
} }
attempts++ attempts++
@ -89,8 +96,11 @@ func (w DockerWatcher) EventsWithOptions(ctx context.Context, options DockerList
} }
} }
w.Debugf("client connected")
cEventCh, cErrCh := w.client.Events(eventsCtx, options) cEventCh, cErrCh := w.client.Events(eventsCtx, options)
started <- struct{}{}
w.Debugf("watcher started")
for { for {
select { select {
@ -130,7 +140,6 @@ func (w DockerWatcher) EventsWithOptions(ctx context.Context, options DockerList
} }
} }
}() }()
<-started
return eventCh, errCh return eventCh, errCh
} }

View file

@ -14,36 +14,64 @@ type (
ActorAttributes map[string]string ActorAttributes map[string]string
Action Action Action Action
} }
Action string Action uint16
EventType string EventType string
) )
const ( const (
ActionFileModified Action = "modified" ActionFileModified Action = (1 << iota)
ActionFileCreated Action = "created" ActionFileCreated
ActionFileDeleted Action = "deleted" ActionFileDeleted
ActionDockerStartUnpause Action = "start" ActionContainerCreate
ActionDockerStopPause Action = "stop" ActionContainerStart
ActionContainerUnpause
ActionContainerKill
ActionContainerStop
ActionContainerPause
ActionContainerDie
actionContainerWakeMask = ActionContainerCreate | ActionContainerStart | ActionContainerUnpause
actionContainerSleepMask = ActionContainerKill | ActionContainerStop | ActionContainerPause | ActionContainerDie
)
const (
EventTypeDocker EventType = "docker" EventTypeDocker EventType = "docker"
EventTypeFile EventType = "file" EventTypeFile EventType = "file"
) )
var DockerEventMap = map[dockerEvents.Action]Action{ var DockerEventMap = map[dockerEvents.Action]Action{
dockerEvents.ActionCreate: ActionDockerStartUnpause, dockerEvents.ActionCreate: ActionContainerCreate,
dockerEvents.ActionStart: ActionDockerStartUnpause, dockerEvents.ActionStart: ActionContainerStart,
dockerEvents.ActionPause: ActionDockerStartUnpause, dockerEvents.ActionUnPause: ActionContainerUnpause,
dockerEvents.ActionDie: ActionDockerStopPause,
dockerEvents.ActionStop: ActionDockerStopPause, dockerEvents.ActionKill: ActionContainerKill,
dockerEvents.ActionUnPause: ActionDockerStopPause, dockerEvents.ActionStop: ActionContainerStop,
dockerEvents.ActionKill: ActionDockerStopPause, dockerEvents.ActionPause: ActionContainerPause,
dockerEvents.ActionDie: ActionContainerDie,
} }
var dockerActionNameMap = func() (m map[Action]string) {
m = make(map[Action]string, len(DockerEventMap))
for k, v := range DockerEventMap {
m[v] = string(k)
}
return
}()
func (e Event) String() string { func (e Event) String() string {
return fmt.Sprintf("%s %s", e.ActorName, e.Action) return fmt.Sprintf("%s %s", e.ActorName, e.Action)
} }
func (a Action) IsDelete() bool { func (a Action) String() string {
return a == ActionFileDeleted return dockerActionNameMap[a]
}
func (a Action) IsContainerWake() bool {
return a&actionContainerWakeMask != 0
}
func (a Action) IsContainerSleep() bool {
return a&actionContainerSleepMask != 0
} }