diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..498a49f --- /dev/null +++ b/Makefile @@ -0,0 +1,29 @@ +.PHONY: build up restart logs get test-udp-container + +all: build up logs + +build: + mkdir -p bin + CGO_ENABLED=0 GOOS=linux go build -o bin/go-proxy src/go-proxy/*.go + +up: + docker compose up -d --build go-proxy + +restart: + docker compose down -t 0 + docker compose up -d + +logs: + docker compose logs -f + +get: + go get -d -u ./src/go-proxy + +udp-server: + docker run -it --rm \ + -p 9999:9999/udp \ + --label proxy.test-udp.scheme=udp \ + --label proxy.test-udp.port=20003:9999 \ + --network data_default \ + --name test-udp \ + $$(docker build -q -f udp-test-server.Dockerfile .) \ No newline at end of file diff --git a/README.md b/README.md index 7f10f4a..0d44b6f 100755 --- a/README.md +++ b/README.md @@ -26,7 +26,7 @@ In the examples domain `x.y.z` is used, replace them with your domain - HTTP proxy - TCP/UDP Proxy (experimental, unable to release port on hot-reload) - Auto hot-reload when container start / die / stop. -- Simple panel to see all reverse proxies and health (visit port :8443 of go-proxy `https://*.y.z:8443`) +- Simple panel to see all reverse proxies and health (visit port [panel port] of go-proxy `https://*.y.z:[panel port]`) ![panel screenshot](screenshots/panel.png) @@ -45,7 +45,7 @@ In the examples domain `x.y.z` is used, replace them with your domain 4. Modify the path to your SSL certs. See [Getting SSL Certs](#getting-ssl-certs) -5. Start `go-proxy` with `docker compose up -d`. +5. Start `go-proxy` with `docker compose up -d` or `make up`. 6. (Optional) If you are using ufw with vpn that drop all inbound traffic except vpn, run below to allow docker containers to connect to `go-proxy` @@ -60,6 +60,8 @@ In the examples domain `x.y.z` is used, replace them with your domain 7. start your docker app, and visit .y.z +8. check the logs with `docker compose logs` or `make logs` to see if there is any error, check panel at [panel port] for active proxies + ## Configuration With container name, no label needs to be added. @@ -196,14 +198,16 @@ It takes ~ 0.1-0.4MB for each HTTP Proxy, and <2MB for each TCP/UDP Proxy ## Build it yourself -1. [Install go](https://go.dev/doc/install) if not already +1. Install [go](https://go.dev/doc/install) and `make` if not already -2. get dependencies with `sh scripts/get.sh` +2. get dependencies with `make get` -3. build binary with `sh scripts/build.sh` +3. build binary with `make build` 4. start your container with `docker compose up -d` ## Getting SSL certs I personally use `nginx-proxy-manager` to get SSL certs with auto renewal by Cloudflare DNS challenge. You may symlink the certs from `nginx-proxy-manager` to somewhere else, and mount them to `go-proxy`'s `/certs` + +[panel port]: 8443 diff --git a/bin/go-proxy b/bin/go-proxy index a5afef0..23d49a6 100755 Binary files a/bin/go-proxy and b/bin/go-proxy differ diff --git a/scripts/build.sh b/scripts/build.sh deleted file mode 100755 index f445a54..0000000 --- a/scripts/build.sh +++ /dev/null @@ -1,8 +0,0 @@ -#!/bin/sh -mkdir -p bin -CGO_ENABLED=0 GOOS=linux go build -o bin/go-proxy src/go-proxy/*.go || exit 1 - -if [ "$1" = "up" ]; then - docker compose up -d --build app && \ - docker compose logs -f -fi diff --git a/scripts/get.sh b/scripts/get.sh deleted file mode 100644 index de57f0e..0000000 --- a/scripts/get.sh +++ /dev/null @@ -1,2 +0,0 @@ -#!/bin/sh -go get -d -u ./src/go-proxy \ No newline at end of file diff --git a/scripts/udp-test-container.sh b/scripts/udp-test-container.sh deleted file mode 100644 index 2241cbb..0000000 --- a/scripts/udp-test-container.sh +++ /dev/null @@ -1,10 +0,0 @@ -#!/bin/sh -docker run -it --tty --rm \ - -p 9999:9999/udp \ - --label proxy.test-udp.scheme=udp \ - --label proxy.test-udp.port=20003:9999 \ - --network data_default \ - --name test-udp \ - debian:stable-slim \ - /bin/bash -c \ - "apt update && apt install -y netcat-openbsd && echo 'nc -u -l 9999' >> ~/.bashrc && bash" diff --git a/src/go-proxy/docker.go b/src/go-proxy/docker.go index 824714b..f81f673 100644 --- a/src/go-proxy/docker.go +++ b/src/go-proxy/docker.go @@ -18,7 +18,7 @@ import ( ) type ProxyConfig struct { - id string + id string Alias string Scheme string Host string @@ -89,7 +89,6 @@ func buildContainerRoute(container types.Container) { imageName := imageSplit[0] _, isKnownImage := imageNamePortMap[imageName] if isKnownImage { - log.Printf("[Build] Known image '%s' detected for %s", imageName, container_name) config.Scheme = "tcp" } else { config.Scheme = "http" @@ -109,7 +108,7 @@ func buildContainerRoute(container types.Container) { } config.Alias = alias config.UpdateId() - + wg.Add(1) go func() { createRoute(&config) @@ -139,7 +138,7 @@ func buildRoutes() { func findHTTPRoute(host string, path string) (*HTTPRoute, error) { subdomain := strings.Split(host, ".")[0] - routeMap, ok := routes.HTTPRoutes[subdomain] + routeMap, ok := routes.HTTPRoutes.TryGet(subdomain) if !ok { return nil, fmt.Errorf("no matching route for subdomain %s", subdomain) } diff --git a/src/go-proxy/main.go b/src/go-proxy/main.go index 527d95e..b7a473d 100644 --- a/src/go-proxy/main.go +++ b/src/go-proxy/main.go @@ -19,6 +19,7 @@ func main() { if err != nil { log.Fatal(err) } + buildRoutes() log.Printf("[Build] built %v reverse proxies", countRoutes()) beginListenStreams() @@ -30,15 +31,21 @@ func main() { filters.Arg("event", "die"), // stop seems like triggering die // filters.Arg("event", "stop"), ) - msgs, _ := dockerClient.Events(context.Background(), types.EventsOptions{Filters: filter}) + msgChan, errChan := dockerClient.Events(context.Background(), types.EventsOptions{Filters: filter}) - for msg := range msgs { - // TODO: handle actor only - log.Printf("[Event] %s %s caused rebuild", msg.Action, msg.Actor.Attributes["name"]) - endListenStreams() - buildRoutes() - log.Printf("[Build] rebuilt %v reverse proxies", countRoutes()) - beginListenStreams() + for { + select { + case msg := <-msgChan: + // TODO: handle actor only + log.Printf("[Event] %s %s caused rebuild", msg.Action, msg.Actor.Attributes["name"]) + endListenStreams() + buildRoutes() + log.Printf("[Build] rebuilt %v reverse proxies", countRoutes()) + beginListenStreams() + case err := <-errChan: + log.Printf("[Event] %s", err) + msgChan, errChan = dockerClient.Events(context.Background(), types.EventsOptions{Filters: filter}) + } } }() diff --git a/src/go-proxy/map.go b/src/go-proxy/map.go new file mode 100644 index 0000000..640c7cc --- /dev/null +++ b/src/go-proxy/map.go @@ -0,0 +1,94 @@ +package main + +import "sync" + +type SafeMapInterface[KT comparable, VT interface{}] interface { + Set(key KT, value VT) + Ensure(key KT) + Get(key KT) VT + TryGet(key KT) (VT, bool) + Clear() + Size() int + Contains(key KT) bool + ForEach(fn func(key KT, value VT)) + Iterator() map[KT]VT +} + +type SafeMap[KT comparable, VT interface{}] struct { + SafeMapInterface[KT, VT] + m map[KT]VT + mutex sync.Mutex + defaultFactory func() VT +} + +func NewSafeMap[KT comparable, VT interface{}](df... func() VT) *SafeMap[KT, VT] { + if len(df) == 0 { + return &SafeMap[KT, VT]{ + m: make(map[KT]VT), + } + } + return &SafeMap[KT, VT]{ + m: make(map[KT]VT), + defaultFactory: df[0], + } +} + +func (m *SafeMap[KT, VT]) Set(key KT, value VT) { + m.mutex.Lock() + m.m[key] = value + m.mutex.Unlock() +} + +func (m *SafeMap[KT, VT]) Ensure(key KT) { + m.mutex.Lock() + if _, ok := m.m[key]; !ok { + m.m[key] = m.defaultFactory() + } + m.mutex.Unlock() +} + +func (m *SafeMap[KT, VT]) Get(key KT) VT { + m.mutex.Lock() + value := m.m[key] + m.mutex.Unlock() + return value +} + +func (m *SafeMap[KT, VT]) TryGet(key KT) (VT, bool) { + m.mutex.Lock() + value, ok := m.m[key] + m.mutex.Unlock() + return value, ok +} + +func (m *SafeMap[KT, VT]) Clear() { + m.mutex.Lock() + m.m = make(map[KT]VT) + m.mutex.Unlock() +} + +func (m *SafeMap[KT, VT]) Size() int { + m.mutex.Lock() + size := len(m.m) + m.mutex.Unlock() + return size +} + +func (m *SafeMap[KT, VT]) Contains(key KT) bool { + m.mutex.Lock() + _, ok := m.m[key] + m.mutex.Unlock() + return ok +} + +func (m *SafeMap[KT, VT]) ForEach(fn func(key KT, value VT)) { + m.mutex.Lock() + for k, v := range m.m { + fn(k, v) + } + m.mutex.Unlock() +} + +func (m *SafeMap[KT, VT]) Iterator() map[KT]VT { + return m.m +} \ No newline at end of file diff --git a/src/go-proxy/route.go b/src/go-proxy/route.go index bebebe0..c7ef6e3 100644 --- a/src/go-proxy/route.go +++ b/src/go-proxy/route.go @@ -8,16 +8,12 @@ import ( ) type Routes struct { - HTTPRoutes map[string][]HTTPRoute // id -> path - StreamRoutes map[string]*StreamRoute // id -> target + HTTPRoutes *SafeMap[string, []HTTPRoute] // id -> path + StreamRoutes *SafeMap[string, StreamRoute] // id -> target Mutex sync.Mutex } -var routes = Routes{ - HTTPRoutes: make(map[string][]HTTPRoute), - StreamRoutes: make(map[string]*StreamRoute), - Mutex: sync.Mutex{}, -} +var routes = Routes{} var streamSchemes = []string{"tcp", "udp"} // TODO: support "tcp:udp", "udp:tcp" var httpSchemes = []string{"http", "https"} @@ -43,23 +39,23 @@ func isStreamScheme(scheme string) bool { } func initRoutes() { - routes.Mutex.Lock() - defer routes.Mutex.Unlock() - utils.resetPortsInUse() - routes.StreamRoutes = make(map[string]*StreamRoute) - routes.HTTPRoutes = make(map[string][]HTTPRoute) + routes.HTTPRoutes = NewSafeMap[string, []HTTPRoute]( + func() []HTTPRoute { + return make([]HTTPRoute, 0) + }, + ) + routes.StreamRoutes = NewSafeMap[string, StreamRoute]() } func countRoutes() int { - return len(routes.HTTPRoutes) + len(routes.StreamRoutes) + return routes.HTTPRoutes.Size() + routes.StreamRoutes.Size() } func createRoute(config *ProxyConfig) { if isStreamScheme(config.Scheme) { - _, inMap := routes.StreamRoutes[config.id] - if inMap { - log.Printf("[Build] Duplicated stream %s, ignoring", config.id) + if routes.StreamRoutes.Contains(config.id) { + log.Printf("[Build] Duplicated %s stream %s, ignoring", config.Scheme, config.id) return } route, err := NewStreamRoute(config) @@ -67,20 +63,15 @@ func createRoute(config *ProxyConfig) { log.Println(err) return } - routes.Mutex.Lock() - routes.StreamRoutes[config.id] = route - routes.Mutex.Unlock() + routes.StreamRoutes.Set(config.id, route) } else { - routes.Mutex.Lock() - _, inMap := routes.HTTPRoutes[config.Alias] - if !inMap { - routes.HTTPRoutes[config.Alias] = make([]HTTPRoute, 0) - } + routes.HTTPRoutes.Ensure(config.Alias) url, err := url.Parse(fmt.Sprintf("%s://%s:%s", config.Scheme, config.Host, config.Port)) if err != nil { - log.Fatal(err) + log.Println(err) + return } - routes.HTTPRoutes[config.Alias] = append(routes.HTTPRoutes[config.Alias], NewHTTPRoute(url, config.Path)) - routes.Mutex.Unlock() + route := NewHTTPRoute(url, config.Path) + routes.HTTPRoutes.Set(config.Alias, append(routes.HTTPRoutes.Get(config.Alias), route)) } } diff --git a/src/go-proxy/stream.go b/src/go-proxy/stream.go index e9b17d1..c665bbe 100644 --- a/src/go-proxy/stream.go +++ b/src/go-proxy/stream.go @@ -1,18 +1,30 @@ package main import ( - "context" + "errors" "fmt" "log" - "net" "strconv" "strings" "sync" - "sync/atomic" - "unsafe" + "time" ) -type StreamRoute struct { +type StreamRoute interface { + SetupListen() + Listen() + StopListening() + Logf(string, ...interface{}) + PrintError(error) + ListeningUrl() string + TargetUrl() string + + closeListeners() + closeChannel() + wait() +} + +type StreamRouteBase struct { Alias string // to show in panel Type string ListeningScheme string @@ -21,8 +33,180 @@ type StreamRoute struct { TargetHost string TargetPort string - Context context.Context - Cancel context.CancelFunc + wg sync.WaitGroup + stopChann chan struct{} +} + +func newStreamRouteBase(config *ProxyConfig) (*StreamRouteBase, error) { + var streamType string = TCPStreamType + var srcPort string + var dstPort string + var srcScheme string + var dstScheme string + + port_split := strings.Split(config.Port, ":") + if len(port_split) != 2 { + log.Printf(`[Build] %s: Invalid stream port %s, `+ + `assuming it's targetPort`, config.Alias, config.Port) + srcPort = "0" + dstPort = config.Port + } else { + srcPort = port_split[0] + dstPort = port_split[1] + } + + port, hasName := namePortMap[dstPort] + if hasName { + dstPort = port + } + + srcPortInt, err := strconv.Atoi(srcPort) + if err != nil { + return nil, fmt.Errorf( + "[Build] %s: Unrecognized stream source port %s, ignoring", + config.Alias, srcPort, + ) + } + + utils.markPortInUse(srcPortInt) + + _, err = strconv.Atoi(dstPort) + if err != nil { + return nil, fmt.Errorf( + "[Build] %s: Unrecognized stream target port %s, ignoring", + config.Alias, dstPort, + ) + } + + scheme_split := strings.Split(config.Scheme, ":") + + if len(scheme_split) == 2 { + srcScheme = scheme_split[0] + dstScheme = scheme_split[1] + } else { + srcScheme = config.Scheme + dstScheme = config.Scheme + } + + return &StreamRouteBase{ + Alias: config.Alias, + Type: streamType, + ListeningScheme: srcScheme, + ListeningPort: srcPort, + TargetScheme: dstScheme, + TargetHost: config.Host, + TargetPort: dstPort, + + wg: sync.WaitGroup{}, + stopChann: make(chan struct{}), + }, nil +} + +func NewStreamRoute(config *ProxyConfig) (StreamRoute, error) { + switch config.Scheme { + case TCPStreamType: + return NewTCPRoute(config) + case UDPStreamType: + return NewUDPRoute(config) + default: + return nil, errors.New("unknown stream type") + } +} + +func (route *StreamRouteBase) PrintError(err error) { + if err == nil { + return + } + route.Logf("Error: %s", err.Error()) +} + +func (route *StreamRouteBase) Logf(format string, v ...interface{}) { + log.Printf("[%s -> %s] %s: "+format, + append([]interface{}{ + route.ListeningScheme, + route.TargetScheme, + route.Alias}, + v..., + )..., + ) +} + +func (route *StreamRouteBase) ListeningUrl() string { + return fmt.Sprintf("%s:%s", route.ListeningScheme, route.ListeningPort) +} + +func (route *StreamRouteBase) TargetUrl() string { + return fmt.Sprintf("%s://%s:%s", route.TargetScheme, route.TargetHost, route.TargetPort) +} + +func (route *StreamRouteBase) SetupListen() { + if route.ListeningPort == "0" { + freePort, err := utils.findUseFreePort(20000) + if err != nil { + route.PrintError(err) + return + } + route.ListeningPort = fmt.Sprintf("%d", freePort) + route.Logf("Assigned free port %s", route.ListeningPort) + } + route.Logf("Listening on %s", route.ListeningUrl()) +} + +func (route *StreamRouteBase) wait() { + route.wg.Wait() +} + +func (route *StreamRouteBase) closeChannel() { + close(route.stopChann) +} + +func stopListening(route StreamRoute) { + route.Logf("Stopping listening") + route.closeChannel() + route.closeListeners() + + done := make(chan struct{}) + + go func() { + route.wait() + close(done) + }() + + select { + case <-done: + route.Logf("Stopped listening") + return + case <-time.After(streamStopListenTimeout): + route.Logf("timed out waiting for connections") + return + } +} + +func allStreamsDo(msg string, fn ...func(StreamRoute)) { + log.Printf("[Stream] %s", msg) + + var wg sync.WaitGroup + + for _, route := range routes.StreamRoutes.Iterator() { + wg.Add(1) + go func(r StreamRoute) { + for _, f := range fn { + f(r) + } + wg.Done() + }(route) + } + + wg.Wait() + log.Printf("[Stream] Finished %s", msg) +} + +func beginListenStreams() { + allStreamsDo("Start", StreamRoute.SetupListen, StreamRoute.Listen) +} + +func endListenStreams() { + allStreamsDo("Stop", StreamRoute.StopListening) } var imageNamePortMap = map[string]string{ @@ -57,150 +241,5 @@ var namePortMap = func() map[string]string { const UDPStreamType = "udp" const TCPStreamType = "tcp" -func NewStreamRoute(config *ProxyConfig) (*StreamRoute, error) { - var streamType string = TCPStreamType - var srcPort string - var dstPort string - var srcScheme string - var dstScheme string - var srcUDPAddr *net.UDPAddr = nil - var dstUDPAddr *net.UDPAddr = nil - - port_split := strings.Split(config.Port, ":") - if len(port_split) != 2 { - log.Printf(`[Build] Invalid stream port %s, `+ - `should be :, `+ - `assuming it is targetPort`, config.Port) - srcPort = "0" - dstPort = config.Port - } else { - srcPort = port_split[0] - dstPort = port_split[1] - } - - port, hasName := namePortMap[dstPort] - if hasName { - dstPort = port - } - _, err := strconv.Atoi(dstPort) - if err != nil { - return nil, fmt.Errorf( - "[Build] %s: Unrecognized stream target port %s, ignoring", - config.Alias, dstPort, - ) - } - - scheme_split := strings.Split(config.Scheme, ":") - - if len(scheme_split) == 2 { - srcScheme = scheme_split[0] - dstScheme = scheme_split[1] - } else { - srcScheme = config.Scheme - dstScheme = config.Scheme - } - - if srcScheme == "udp" { - streamType = UDPStreamType - srcUDPAddr, err = net.ResolveUDPAddr("udp", fmt.Sprintf("0.0.0.0:%s", srcPort)) - if err != nil { - return nil, err - } - } - - if dstScheme == "udp" { - streamType = UDPStreamType - dstUDPAddr, err = net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%s", config.Host, dstPort)) - if err != nil { - return nil, err - } - } - - lsPort, err := strconv.Atoi(srcPort) - if err != nil { - return nil, err - } - utils.markPortInUse(lsPort) - - ctx, cancel := context.WithCancel(context.Background()) - - route := StreamRoute{ - Alias: config.Alias, - Type: streamType, - ListeningScheme: srcScheme, - TargetScheme: dstScheme, - TargetHost: config.Host, - ListeningPort: srcPort, - TargetPort: dstPort, - - Context: ctx, - Cancel: cancel, - } - - if streamType == TCPStreamType { - return &route, nil - } - - return (*StreamRoute)(unsafe.Pointer(&UDPRoute{ - StreamRoute: route, - ConnMap: make(map[net.Addr]*net.UDPConn), - ConnMapMutex: sync.Mutex{}, - QueueSize: atomic.Int32{}, - SourceUDPAddr: srcUDPAddr, - TargetUDPAddr: dstUDPAddr, - })), nil -} - -func (route *StreamRoute) PrintError(err error) { - if err == nil { - return - } - log.Printf("[Stream] %s (%s => %s) error: %v", route.Alias, route.ListeningUrl(), route.TargetUrl(), err) -} - -func (route *StreamRoute) ListeningUrl() string { - return fmt.Sprintf("%s:%s", route.ListeningScheme, route.ListeningPort) -} - -func (route *StreamRoute) TargetUrl() string { - return fmt.Sprintf("%s://%s:%s", route.TargetScheme, route.TargetHost, route.TargetPort) -} - -func (route *StreamRoute) listenStream() { - if route.ListeningPort == "0" { - freePort, err := utils.findFreePort(20000) - if err != nil { - route.PrintError(err) - return - } - route.ListeningPort = fmt.Sprintf("%d", freePort) - utils.markPortInUse(freePort) - } - if route.Type == UDPStreamType { - listenUDP((*UDPRoute)(unsafe.Pointer(route))) - } else { - listenTCP(route) - } -} - -func beginListenStreams() { - for _, route := range routes.StreamRoutes { - go route.listenStream() - } -} - -func endListenStreams() { - var wg sync.WaitGroup - wg.Add(len(routes.StreamRoutes)) - defer wg.Wait() - - routes.Mutex.Lock() - defer routes.Mutex.Unlock() - - for _, route := range routes.StreamRoutes { - go func(r *StreamRoute) { - r.Cancel() - wg.Done() - }(route) - } -} +// const maxQueueSizePerStream = 100 +const streamStopListenTimeout = 1 * time.Second diff --git a/src/go-proxy/tcp.go b/src/go-proxy/tcp.go index 3419b26..7d3df92 100644 --- a/src/go-proxy/tcp.go +++ b/src/go-proxy/tcp.go @@ -12,34 +12,87 @@ import ( const tcpDialTimeout = 5 * time.Second -func listenTCP(route *StreamRoute) { - in, err := net.Listen( - route.ListeningScheme, - fmt.Sprintf(":%s", route.ListeningPort), - ) +type TCPRoute struct { + *StreamRouteBase + listener net.Listener + connChan chan net.Conn +} + +func NewTCPRoute(config *ProxyConfig) (StreamRoute, error) { + base, err := newStreamRouteBase(config) if err != nil { - log.Printf("[Stream Listen] %v", err) + return nil, err + } + if base.TargetScheme != TCPStreamType { + return nil, fmt.Errorf("tcp to %s not yet supported", base.TargetScheme) + } + return &TCPRoute{ + StreamRouteBase: base, + listener: nil, + connChan: make(chan net.Conn), + }, nil +} + +func (route *TCPRoute) Listen() { + in, err := net.Listen("tcp", ":"+route.ListeningPort) + if err != nil { + route.PrintError(err) return } + route.listener = in + route.wg.Add(2) + go route.grAcceptConnections() + go route.grHandleConnections() +} - defer in.Close() +func (route *TCPRoute) StopListening() { + stopListening(route) +} + +func (route *TCPRoute) closeListeners() { + if route.listener == nil { + return + } + route.listener.Close() + route.listener = nil +} + +func (route *TCPRoute) grAcceptConnections() { + defer route.wg.Done() for { select { - case <-route.Context.Done(): + case <-route.stopChann: return default: - clientConn, err := in.Accept() + conn, err := route.listener.Accept() if err != nil { - log.Printf("[Stream Accept] %v", err) - return + route.PrintError(err) + continue } - go connectTCPPipe(route, clientConn) + route.connChan <- conn } } } -func connectTCPPipe(route *StreamRoute, clientConn net.Conn) { +func (route *TCPRoute) grHandleConnections() { + defer route.wg.Done() + + for { + select { + case <-route.stopChann: + return + case conn := <-route.connChan: + route.wg.Add(1) + go route.grHandleConnection(conn) + } + } +} + +func (route *TCPRoute) grHandleConnection(clientConn net.Conn) { + defer clientConn.Close() + defer route.wg.Done() + ctx, cancel := context.WithTimeout(context.Background(), tcpDialTimeout) defer cancel() @@ -50,25 +103,29 @@ func connectTCPPipe(route *StreamRoute, clientConn net.Conn) { log.Printf("[Stream Dial] %v", err) return } - tcpPipe(route, clientConn, serverConn) + route.tcpPipe(clientConn, serverConn) } -func tcpPipe(route *StreamRoute, src net.Conn, dest net.Conn) { +func (route *TCPRoute) tcpPipe(src net.Conn, dest net.Conn) { + close := func() { + src.Close() + dest.Close() + } + var wg sync.WaitGroup wg.Add(2) // Number of goroutines - defer src.Close() - defer dest.Close() go func() { _, err := io.Copy(src, dest) - go route.PrintError(err) + route.PrintError(err) + close() wg.Done() }() go func() { _, err := io.Copy(dest, src) - go route.PrintError(err) + route.PrintError(err) + close() wg.Done() }() - wg.Wait() } diff --git a/src/go-proxy/udp.go b/src/go-proxy/udp.go index a07d737..44a8c66 100644 --- a/src/go-proxy/udp.go +++ b/src/go-proxy/udp.go @@ -1,125 +1,231 @@ package main import ( + "fmt" "io" - "log" "net" "sync" - "sync/atomic" - "time" ) const udpBufferSize = 1500 -const udpMaxQueueSizePerStream = 100 -const udpListenTimeout = 100 * time.Second -const udpConnectionTimeout = 30 * time.Second + +// const udpListenTimeout = 100 * time.Second +// const udpConnectionTimeout = 30 * time.Second type UDPRoute struct { - StreamRoute + *StreamRouteBase - ConnMap map[net.Addr]*net.UDPConn - ConnMapMutex sync.Mutex - QueueSize atomic.Int32 - SourceUDPAddr *net.UDPAddr - TargetUDPAddr *net.UDPAddr + connMap map[net.Addr]net.Conn + connMapMutex sync.Mutex + + listeningConn *net.UDPConn + targetConn *net.UDPConn + + connChan chan *UDPConn } -func listenUDP(route *UDPRoute) { - source, err := net.ListenUDP(route.ListeningScheme, route.SourceUDPAddr) +type UDPConn struct { + remoteAddr net.Addr + buffer []byte + bytesReceived []byte + nReceived int +} + +func NewUDPRoute(config *ProxyConfig) (StreamRoute, error) { + base, err := newStreamRouteBase(config) + if err != nil { + return nil, err + } + + if base.TargetScheme != UDPStreamType { + return nil, fmt.Errorf("udp to %s not yet supported", base.TargetScheme) + } + + return &UDPRoute{ + StreamRouteBase: base, + connMap: make(map[net.Addr]net.Conn), + connChan: make(chan *UDPConn), + }, nil +} + +func (route *UDPRoute) Listen() { + source, err := net.ListenPacket(route.ListeningScheme, fmt.Sprintf(":%s", route.ListeningPort)) if err != nil { route.PrintError(err) return } - target, err := net.DialUDP(route.TargetScheme, nil, route.TargetUDPAddr) + target, err := net.Dial(route.TargetScheme, fmt.Sprintf("%s:%s", route.TargetHost, route.TargetPort)) if err != nil { route.PrintError(err) + source.Close() return } - var wg sync.WaitGroup - defer wg.Wait() - defer source.Close() - defer target.Close() + route.listeningConn = source.(*net.UDPConn) + route.targetConn = target.(*net.UDPConn) - var udpBuffers = [udpMaxQueueSizePerStream][udpBufferSize]byte{} + route.wg.Add(2) + go route.grAcceptConnections() + go route.grHandleConnections() +} + +func (route *UDPRoute) StopListening() { + stopListening(route) +} + +func (route *UDPRoute) closeListeners() { + if route.listeningConn != nil { + route.listeningConn.Close() + } + if route.targetConn != nil { + route.targetConn.Close() + } + route.listeningConn = nil + route.targetConn = nil + for _, conn := range route.connMap { + conn.(*net.UDPConn).Close() // TODO: change on non udp target + } +} + +func (route *UDPRoute) grAcceptConnections() { + defer route.wg.Done() for { select { - case <-route.Context.Done(): + case <-route.stopChann: return default: - if route.QueueSize.Load() >= udpMaxQueueSizePerStream { - wg.Wait() + conn, err := route.accept() + if err != nil { + route.PrintError(err) + continue } - go udpLoop( - route, - source, - target, - udpBuffers[route.QueueSize.Load()][:], - &wg, - ) + route.connChan <- conn } } } -func udpLoop(route *UDPRoute, in *net.UDPConn, out *net.UDPConn, buffer []byte, wg *sync.WaitGroup) { - wg.Add(1) - route.QueueSize.Add(1) - defer route.QueueSize.Add(-1) - defer wg.Done() +func (route *UDPRoute) grHandleConnections() { + defer route.wg.Done() - var nRead int - var nWritten int - - in.SetReadDeadline(time.Now().Add(udpListenTimeout)) - nRead, srcAddr, err := in.ReadFromUDP(buffer) - - if err != nil { - return - } - - log.Printf("[Stream] received %d bytes from %s, forwarding to %s", nRead, srcAddr.String(), out.RemoteAddr().String()) - out.SetWriteDeadline(time.Now().Add(udpConnectionTimeout)) - nWritten, err = out.Write(buffer[:nRead]) - if nWritten != nRead { - err = io.ErrShortWrite - } - if err != nil { - go route.PrintError(err) - return - } - - err = udpPipe(route, out, srcAddr, buffer) - if err != nil { - go route.PrintError(err) + for { + select { + case <-route.stopChann: + return + case conn := <-route.connChan: + go func() { + err := route.handleConnection(conn) + if err != nil { + route.PrintError(err) + } + }() + } } } -func udpPipe(route *UDPRoute, src *net.UDPConn, destAddr *net.UDPAddr, buffer []byte) error { - src.SetReadDeadline(time.Now().Add(udpConnectionTimeout)) - nRead, err := src.Read(buffer) - if err != nil || nRead == 0 { - return err - } - log.Printf("[Stream] received %d bytes from %s, forwarding to %s", nRead, src.RemoteAddr().String(), destAddr.String()) - dest, ok := route.ConnMap[destAddr] +func (route *UDPRoute) handleConnection(conn *UDPConn) error { + var err error + + srcConn, ok := route.connMap[conn.remoteAddr] if !ok { - dest, err = net.DialUDP(src.LocalAddr().Network(), nil, destAddr) + route.connMapMutex.Lock() + srcConn, err = net.DialUDP("udp", nil, conn.remoteAddr.(*net.UDPAddr)) if err != nil { return err } - route.ConnMapMutex.Lock() - route.ConnMap[destAddr] = dest - route.ConnMapMutex.Unlock() + route.connMap[conn.remoteAddr] = srcConn + route.connMapMutex.Unlock() } - dest.SetWriteDeadline(time.Now().Add(udpConnectionTimeout)) - nWritten, err := dest.Write(buffer[:nRead]) + + // initiate connection to target + err = route.forwardReceived(conn, route.targetConn) if err != nil { return err } - if nWritten != nRead { - return io.ErrShortWrite + + for { + select { + case <-route.stopChann: + return nil + default: + // receive from target + conn, err = route.readFrom(route.targetConn, conn.buffer) + if err != nil { + return err + } + // forward to source + err = route.forwardReceived(conn, srcConn) + if err != nil { + return err + } + // read from source + conn, err = route.readFrom(srcConn, conn.buffer) + if err != nil { + continue + } + // forward to target + err = route.forwardReceived(conn, route.targetConn) + if err != nil { + return err + } + } } - return nil +} + +func (route *UDPRoute) accept() (*UDPConn, error) { + in := route.listeningConn + + buffer := make([]byte, udpBufferSize) + nRead, srcAddr, err := in.ReadFromUDP(buffer) + + if err != nil { + return nil, err + } + + if nRead == 0 { + return nil, io.ErrShortBuffer + } + + return &UDPConn{ + remoteAddr: srcAddr, + buffer: buffer, + bytesReceived: buffer[:nRead], + nReceived: nRead}, + nil +} + +func (route *UDPRoute) readFrom(src net.Conn, buffer []byte) (*UDPConn, error) { + nRead, err := src.Read(buffer) + + if err != nil { + return nil, err + } + + if nRead == 0 { + return nil, io.ErrShortBuffer + } + + return &UDPConn{ + remoteAddr: src.RemoteAddr(), + buffer: buffer, + bytesReceived: buffer[:nRead], + nReceived: nRead, + }, nil +} + +func (route *UDPRoute) forwardReceived(receivedConn *UDPConn, dest net.Conn) error { + route.Logf( + "forwarding %d bytes %s -> %s", + receivedConn.nReceived, + receivedConn.remoteAddr.String(), + dest.RemoteAddr().String(), + ) + nWritten, err := dest.Write(receivedConn.bytesReceived) + + if nWritten != receivedConn.nReceived { + err = io.ErrShortWrite + } + + return err } diff --git a/src/go-proxy/utils.go b/src/go-proxy/utils.go index fef734d..4584539 100644 --- a/src/go-proxy/utils.go +++ b/src/go-proxy/utils.go @@ -9,15 +9,18 @@ import ( ) type Utils struct { - PortsInUse map[int]bool + PortsInUse map[int]bool portsInUseMutex sync.Mutex } + var utils = &Utils{ - PortsInUse: make(map[int]bool), + PortsInUse: make(map[int]bool), portsInUseMutex: sync.Mutex{}, } -func (u *Utils) findFreePort(startingPort int) (int, error) { +func (u *Utils) findUseFreePort(startingPort int) (int, error) { + u.portsInUseMutex.Lock() + defer u.portsInUseMutex.Unlock() for port := startingPort; port <= startingPort+100 && port <= 65535; port++ { if u.PortsInUse[port] { continue @@ -25,31 +28,34 @@ func (u *Utils) findFreePort(startingPort int) (int, error) { addr := fmt.Sprintf(":%d", port) l, err := net.Listen("tcp", addr) if err == nil { + u.PortsInUse[port] = true l.Close() return port, nil } } l, err := net.Listen("tcp", ":0") if err == nil { - l.Close() // NOTE: may not be after 20000 - return l.Addr().(*net.TCPAddr).Port, nil + port := l.Addr().(*net.TCPAddr).Port + u.PortsInUse[port] = true + l.Close() + return port, nil } return -1, fmt.Errorf("unable to find free port: %v", err) } func (u *Utils) resetPortsInUse() { u.portsInUseMutex.Lock() - defer u.portsInUseMutex.Unlock() for port := range u.PortsInUse { u.PortsInUse[port] = false } + u.portsInUseMutex.Unlock() } -func (u* Utils) markPortInUse(port int) { +func (u *Utils) markPortInUse(port int) { u.portsInUseMutex.Lock() - defer u.portsInUseMutex.Unlock() u.PortsInUse[port] = true + u.portsInUseMutex.Unlock() } func (*Utils) healthCheckHttp(targetUrl string) error { @@ -57,13 +63,13 @@ func (*Utils) healthCheckHttp(targetUrl string) error { // if HEAD is not allowed, try GET resp, err := healthCheckHttpClient.Head(targetUrl) if resp != nil { - defer resp.Body.Close() + resp.Body.Close() } if err != nil && resp != nil && resp.StatusCode == http.StatusMethodNotAllowed { _, err = healthCheckHttpClient.Get(targetUrl) } if resp != nil { - defer resp.Body.Close() + resp.Body.Close() } return err } @@ -73,6 +79,6 @@ func (*Utils) healthCheckStream(scheme string, host string) error { if err != nil { return err } - defer conn.Close() + conn.Close() return nil } diff --git a/templates/panel.html b/templates/panel.html index 0bf366c..461b82f 100644 --- a/templates/panel.html +++ b/templates/panel.html @@ -105,7 +105,7 @@ - {{range $alias, $httpRoutes := .HTTPRoutes}} + {{range $alias, $httpRoutes := .HTTPRoutes.Iterator}} {{range $route := $httpRoutes}} {{$alias}} @@ -132,7 +132,7 @@ - {{range $_, $route := .StreamRoutes}} + {{range $_, $route := .StreamRoutes.Iterator}} {{$route.Alias}} {{$route.ListeningUrl}} diff --git a/udp-test-server.Dockerfile b/udp-test-server.Dockerfile new file mode 100644 index 0000000..3ba9a09 --- /dev/null +++ b/udp-test-server.Dockerfile @@ -0,0 +1,10 @@ +FROM debian:stable-slim + +RUN apt update && \ + apt install -y netcat-openbsd && \ + rm -rf /var/lib/apt/lists/* + +RUN printf '#!/bin/bash\nclear; echo "Netcat UDP server started"; nc -u -l 9999; exit' >> /entrypoint.sh +RUN chmod +x /entrypoint.sh + +ENTRYPOINT ["/entrypoint.sh"]