scripts moved to makefile, tcp/udp connections can now close gracefully, but udp is still failing testing with palworld server

This commit is contained in:
yusing 2024-03-04 19:09:36 +08:00
parent c94a13d273
commit a5c53a4f4f
16 changed files with 649 additions and 327 deletions

29
Makefile Normal file
View file

@ -0,0 +1,29 @@
.PHONY: build up restart logs get test-udp-container
all: build up logs
build:
mkdir -p bin
CGO_ENABLED=0 GOOS=linux go build -o bin/go-proxy src/go-proxy/*.go
up:
docker compose up -d --build go-proxy
restart:
docker compose down -t 0
docker compose up -d
logs:
docker compose logs -f
get:
go get -d -u ./src/go-proxy
udp-server:
docker run -it --rm \
-p 9999:9999/udp \
--label proxy.test-udp.scheme=udp \
--label proxy.test-udp.port=20003:9999 \
--network data_default \
--name test-udp \
$$(docker build -q -f udp-test-server.Dockerfile .)

View file

@ -26,7 +26,7 @@ In the examples domain `x.y.z` is used, replace them with your domain
- HTTP proxy
- TCP/UDP Proxy (experimental, unable to release port on hot-reload)
- Auto hot-reload when container start / die / stop.
- Simple panel to see all reverse proxies and health (visit port :8443 of go-proxy `https://*.y.z:8443`)
- Simple panel to see all reverse proxies and health (visit port [panel port] of go-proxy `https://*.y.z:[panel port]`)
![panel screenshot](screenshots/panel.png)
@ -45,7 +45,7 @@ In the examples domain `x.y.z` is used, replace them with your domain
4. Modify the path to your SSL certs. See [Getting SSL Certs](#getting-ssl-certs)
5. Start `go-proxy` with `docker compose up -d`.
5. Start `go-proxy` with `docker compose up -d` or `make up`.
6. (Optional) If you are using ufw with vpn that drop all inbound traffic except vpn, run below to allow docker containers to connect to `go-proxy`
@ -60,6 +60,8 @@ In the examples domain `x.y.z` is used, replace them with your domain
7. start your docker app, and visit <container_name>.y.z
8. check the logs with `docker compose logs` or `make logs` to see if there is any error, check panel at [panel port] for active proxies
## Configuration
With container name, no label needs to be added.
@ -196,14 +198,16 @@ It takes ~ 0.1-0.4MB for each HTTP Proxy, and <2MB for each TCP/UDP Proxy
## Build it yourself
1. [Install go](https://go.dev/doc/install) if not already
1. Install [go](https://go.dev/doc/install) and `make` if not already
2. get dependencies with `sh scripts/get.sh`
2. get dependencies with `make get`
3. build binary with `sh scripts/build.sh`
3. build binary with `make build`
4. start your container with `docker compose up -d`
## Getting SSL certs
I personally use `nginx-proxy-manager` to get SSL certs with auto renewal by Cloudflare DNS challenge. You may symlink the certs from `nginx-proxy-manager` to somewhere else, and mount them to `go-proxy`'s `/certs`
[panel port]: 8443

Binary file not shown.

View file

@ -1,8 +0,0 @@
#!/bin/sh
mkdir -p bin
CGO_ENABLED=0 GOOS=linux go build -o bin/go-proxy src/go-proxy/*.go || exit 1
if [ "$1" = "up" ]; then
docker compose up -d --build app && \
docker compose logs -f
fi

View file

@ -1,2 +0,0 @@
#!/bin/sh
go get -d -u ./src/go-proxy

View file

@ -1,10 +0,0 @@
#!/bin/sh
docker run -it --tty --rm \
-p 9999:9999/udp \
--label proxy.test-udp.scheme=udp \
--label proxy.test-udp.port=20003:9999 \
--network data_default \
--name test-udp \
debian:stable-slim \
/bin/bash -c \
"apt update && apt install -y netcat-openbsd && echo 'nc -u -l 9999' >> ~/.bashrc && bash"

View file

@ -18,7 +18,7 @@ import (
)
type ProxyConfig struct {
id string
id string
Alias string
Scheme string
Host string
@ -89,7 +89,6 @@ func buildContainerRoute(container types.Container) {
imageName := imageSplit[0]
_, isKnownImage := imageNamePortMap[imageName]
if isKnownImage {
log.Printf("[Build] Known image '%s' detected for %s", imageName, container_name)
config.Scheme = "tcp"
} else {
config.Scheme = "http"
@ -109,7 +108,7 @@ func buildContainerRoute(container types.Container) {
}
config.Alias = alias
config.UpdateId()
wg.Add(1)
go func() {
createRoute(&config)
@ -139,7 +138,7 @@ func buildRoutes() {
func findHTTPRoute(host string, path string) (*HTTPRoute, error) {
subdomain := strings.Split(host, ".")[0]
routeMap, ok := routes.HTTPRoutes[subdomain]
routeMap, ok := routes.HTTPRoutes.TryGet(subdomain)
if !ok {
return nil, fmt.Errorf("no matching route for subdomain %s", subdomain)
}

View file

@ -19,6 +19,7 @@ func main() {
if err != nil {
log.Fatal(err)
}
buildRoutes()
log.Printf("[Build] built %v reverse proxies", countRoutes())
beginListenStreams()
@ -30,15 +31,21 @@ func main() {
filters.Arg("event", "die"), // stop seems like triggering die
// filters.Arg("event", "stop"),
)
msgs, _ := dockerClient.Events(context.Background(), types.EventsOptions{Filters: filter})
msgChan, errChan := dockerClient.Events(context.Background(), types.EventsOptions{Filters: filter})
for msg := range msgs {
// TODO: handle actor only
log.Printf("[Event] %s %s caused rebuild", msg.Action, msg.Actor.Attributes["name"])
endListenStreams()
buildRoutes()
log.Printf("[Build] rebuilt %v reverse proxies", countRoutes())
beginListenStreams()
for {
select {
case msg := <-msgChan:
// TODO: handle actor only
log.Printf("[Event] %s %s caused rebuild", msg.Action, msg.Actor.Attributes["name"])
endListenStreams()
buildRoutes()
log.Printf("[Build] rebuilt %v reverse proxies", countRoutes())
beginListenStreams()
case err := <-errChan:
log.Printf("[Event] %s", err)
msgChan, errChan = dockerClient.Events(context.Background(), types.EventsOptions{Filters: filter})
}
}
}()

94
src/go-proxy/map.go Normal file
View file

@ -0,0 +1,94 @@
package main
import "sync"
type SafeMapInterface[KT comparable, VT interface{}] interface {
Set(key KT, value VT)
Ensure(key KT)
Get(key KT) VT
TryGet(key KT) (VT, bool)
Clear()
Size() int
Contains(key KT) bool
ForEach(fn func(key KT, value VT))
Iterator() map[KT]VT
}
type SafeMap[KT comparable, VT interface{}] struct {
SafeMapInterface[KT, VT]
m map[KT]VT
mutex sync.Mutex
defaultFactory func() VT
}
func NewSafeMap[KT comparable, VT interface{}](df... func() VT) *SafeMap[KT, VT] {
if len(df) == 0 {
return &SafeMap[KT, VT]{
m: make(map[KT]VT),
}
}
return &SafeMap[KT, VT]{
m: make(map[KT]VT),
defaultFactory: df[0],
}
}
func (m *SafeMap[KT, VT]) Set(key KT, value VT) {
m.mutex.Lock()
m.m[key] = value
m.mutex.Unlock()
}
func (m *SafeMap[KT, VT]) Ensure(key KT) {
m.mutex.Lock()
if _, ok := m.m[key]; !ok {
m.m[key] = m.defaultFactory()
}
m.mutex.Unlock()
}
func (m *SafeMap[KT, VT]) Get(key KT) VT {
m.mutex.Lock()
value := m.m[key]
m.mutex.Unlock()
return value
}
func (m *SafeMap[KT, VT]) TryGet(key KT) (VT, bool) {
m.mutex.Lock()
value, ok := m.m[key]
m.mutex.Unlock()
return value, ok
}
func (m *SafeMap[KT, VT]) Clear() {
m.mutex.Lock()
m.m = make(map[KT]VT)
m.mutex.Unlock()
}
func (m *SafeMap[KT, VT]) Size() int {
m.mutex.Lock()
size := len(m.m)
m.mutex.Unlock()
return size
}
func (m *SafeMap[KT, VT]) Contains(key KT) bool {
m.mutex.Lock()
_, ok := m.m[key]
m.mutex.Unlock()
return ok
}
func (m *SafeMap[KT, VT]) ForEach(fn func(key KT, value VT)) {
m.mutex.Lock()
for k, v := range m.m {
fn(k, v)
}
m.mutex.Unlock()
}
func (m *SafeMap[KT, VT]) Iterator() map[KT]VT {
return m.m
}

View file

@ -8,16 +8,12 @@ import (
)
type Routes struct {
HTTPRoutes map[string][]HTTPRoute // id -> path
StreamRoutes map[string]*StreamRoute // id -> target
HTTPRoutes *SafeMap[string, []HTTPRoute] // id -> path
StreamRoutes *SafeMap[string, StreamRoute] // id -> target
Mutex sync.Mutex
}
var routes = Routes{
HTTPRoutes: make(map[string][]HTTPRoute),
StreamRoutes: make(map[string]*StreamRoute),
Mutex: sync.Mutex{},
}
var routes = Routes{}
var streamSchemes = []string{"tcp", "udp"} // TODO: support "tcp:udp", "udp:tcp"
var httpSchemes = []string{"http", "https"}
@ -43,23 +39,23 @@ func isStreamScheme(scheme string) bool {
}
func initRoutes() {
routes.Mutex.Lock()
defer routes.Mutex.Unlock()
utils.resetPortsInUse()
routes.StreamRoutes = make(map[string]*StreamRoute)
routes.HTTPRoutes = make(map[string][]HTTPRoute)
routes.HTTPRoutes = NewSafeMap[string, []HTTPRoute](
func() []HTTPRoute {
return make([]HTTPRoute, 0)
},
)
routes.StreamRoutes = NewSafeMap[string, StreamRoute]()
}
func countRoutes() int {
return len(routes.HTTPRoutes) + len(routes.StreamRoutes)
return routes.HTTPRoutes.Size() + routes.StreamRoutes.Size()
}
func createRoute(config *ProxyConfig) {
if isStreamScheme(config.Scheme) {
_, inMap := routes.StreamRoutes[config.id]
if inMap {
log.Printf("[Build] Duplicated stream %s, ignoring", config.id)
if routes.StreamRoutes.Contains(config.id) {
log.Printf("[Build] Duplicated %s stream %s, ignoring", config.Scheme, config.id)
return
}
route, err := NewStreamRoute(config)
@ -67,20 +63,15 @@ func createRoute(config *ProxyConfig) {
log.Println(err)
return
}
routes.Mutex.Lock()
routes.StreamRoutes[config.id] = route
routes.Mutex.Unlock()
routes.StreamRoutes.Set(config.id, route)
} else {
routes.Mutex.Lock()
_, inMap := routes.HTTPRoutes[config.Alias]
if !inMap {
routes.HTTPRoutes[config.Alias] = make([]HTTPRoute, 0)
}
routes.HTTPRoutes.Ensure(config.Alias)
url, err := url.Parse(fmt.Sprintf("%s://%s:%s", config.Scheme, config.Host, config.Port))
if err != nil {
log.Fatal(err)
log.Println(err)
return
}
routes.HTTPRoutes[config.Alias] = append(routes.HTTPRoutes[config.Alias], NewHTTPRoute(url, config.Path))
routes.Mutex.Unlock()
route := NewHTTPRoute(url, config.Path)
routes.HTTPRoutes.Set(config.Alias, append(routes.HTTPRoutes.Get(config.Alias), route))
}
}

View file

@ -1,18 +1,30 @@
package main
import (
"context"
"errors"
"fmt"
"log"
"net"
"strconv"
"strings"
"sync"
"sync/atomic"
"unsafe"
"time"
)
type StreamRoute struct {
type StreamRoute interface {
SetupListen()
Listen()
StopListening()
Logf(string, ...interface{})
PrintError(error)
ListeningUrl() string
TargetUrl() string
closeListeners()
closeChannel()
wait()
}
type StreamRouteBase struct {
Alias string // to show in panel
Type string
ListeningScheme string
@ -21,8 +33,180 @@ type StreamRoute struct {
TargetHost string
TargetPort string
Context context.Context
Cancel context.CancelFunc
wg sync.WaitGroup
stopChann chan struct{}
}
func newStreamRouteBase(config *ProxyConfig) (*StreamRouteBase, error) {
var streamType string = TCPStreamType
var srcPort string
var dstPort string
var srcScheme string
var dstScheme string
port_split := strings.Split(config.Port, ":")
if len(port_split) != 2 {
log.Printf(`[Build] %s: Invalid stream port %s, `+
`assuming it's targetPort`, config.Alias, config.Port)
srcPort = "0"
dstPort = config.Port
} else {
srcPort = port_split[0]
dstPort = port_split[1]
}
port, hasName := namePortMap[dstPort]
if hasName {
dstPort = port
}
srcPortInt, err := strconv.Atoi(srcPort)
if err != nil {
return nil, fmt.Errorf(
"[Build] %s: Unrecognized stream source port %s, ignoring",
config.Alias, srcPort,
)
}
utils.markPortInUse(srcPortInt)
_, err = strconv.Atoi(dstPort)
if err != nil {
return nil, fmt.Errorf(
"[Build] %s: Unrecognized stream target port %s, ignoring",
config.Alias, dstPort,
)
}
scheme_split := strings.Split(config.Scheme, ":")
if len(scheme_split) == 2 {
srcScheme = scheme_split[0]
dstScheme = scheme_split[1]
} else {
srcScheme = config.Scheme
dstScheme = config.Scheme
}
return &StreamRouteBase{
Alias: config.Alias,
Type: streamType,
ListeningScheme: srcScheme,
ListeningPort: srcPort,
TargetScheme: dstScheme,
TargetHost: config.Host,
TargetPort: dstPort,
wg: sync.WaitGroup{},
stopChann: make(chan struct{}),
}, nil
}
func NewStreamRoute(config *ProxyConfig) (StreamRoute, error) {
switch config.Scheme {
case TCPStreamType:
return NewTCPRoute(config)
case UDPStreamType:
return NewUDPRoute(config)
default:
return nil, errors.New("unknown stream type")
}
}
func (route *StreamRouteBase) PrintError(err error) {
if err == nil {
return
}
route.Logf("Error: %s", err.Error())
}
func (route *StreamRouteBase) Logf(format string, v ...interface{}) {
log.Printf("[%s -> %s] %s: "+format,
append([]interface{}{
route.ListeningScheme,
route.TargetScheme,
route.Alias},
v...,
)...,
)
}
func (route *StreamRouteBase) ListeningUrl() string {
return fmt.Sprintf("%s:%s", route.ListeningScheme, route.ListeningPort)
}
func (route *StreamRouteBase) TargetUrl() string {
return fmt.Sprintf("%s://%s:%s", route.TargetScheme, route.TargetHost, route.TargetPort)
}
func (route *StreamRouteBase) SetupListen() {
if route.ListeningPort == "0" {
freePort, err := utils.findUseFreePort(20000)
if err != nil {
route.PrintError(err)
return
}
route.ListeningPort = fmt.Sprintf("%d", freePort)
route.Logf("Assigned free port %s", route.ListeningPort)
}
route.Logf("Listening on %s", route.ListeningUrl())
}
func (route *StreamRouteBase) wait() {
route.wg.Wait()
}
func (route *StreamRouteBase) closeChannel() {
close(route.stopChann)
}
func stopListening(route StreamRoute) {
route.Logf("Stopping listening")
route.closeChannel()
route.closeListeners()
done := make(chan struct{})
go func() {
route.wait()
close(done)
}()
select {
case <-done:
route.Logf("Stopped listening")
return
case <-time.After(streamStopListenTimeout):
route.Logf("timed out waiting for connections")
return
}
}
func allStreamsDo(msg string, fn ...func(StreamRoute)) {
log.Printf("[Stream] %s", msg)
var wg sync.WaitGroup
for _, route := range routes.StreamRoutes.Iterator() {
wg.Add(1)
go func(r StreamRoute) {
for _, f := range fn {
f(r)
}
wg.Done()
}(route)
}
wg.Wait()
log.Printf("[Stream] Finished %s", msg)
}
func beginListenStreams() {
allStreamsDo("Start", StreamRoute.SetupListen, StreamRoute.Listen)
}
func endListenStreams() {
allStreamsDo("Stop", StreamRoute.StopListening)
}
var imageNamePortMap = map[string]string{
@ -57,150 +241,5 @@ var namePortMap = func() map[string]string {
const UDPStreamType = "udp"
const TCPStreamType = "tcp"
func NewStreamRoute(config *ProxyConfig) (*StreamRoute, error) {
var streamType string = TCPStreamType
var srcPort string
var dstPort string
var srcScheme string
var dstScheme string
var srcUDPAddr *net.UDPAddr = nil
var dstUDPAddr *net.UDPAddr = nil
port_split := strings.Split(config.Port, ":")
if len(port_split) != 2 {
log.Printf(`[Build] Invalid stream port %s, `+
`should be <listeningPort>:<targetPort>, `+
`assuming it is targetPort`, config.Port)
srcPort = "0"
dstPort = config.Port
} else {
srcPort = port_split[0]
dstPort = port_split[1]
}
port, hasName := namePortMap[dstPort]
if hasName {
dstPort = port
}
_, err := strconv.Atoi(dstPort)
if err != nil {
return nil, fmt.Errorf(
"[Build] %s: Unrecognized stream target port %s, ignoring",
config.Alias, dstPort,
)
}
scheme_split := strings.Split(config.Scheme, ":")
if len(scheme_split) == 2 {
srcScheme = scheme_split[0]
dstScheme = scheme_split[1]
} else {
srcScheme = config.Scheme
dstScheme = config.Scheme
}
if srcScheme == "udp" {
streamType = UDPStreamType
srcUDPAddr, err = net.ResolveUDPAddr("udp", fmt.Sprintf("0.0.0.0:%s", srcPort))
if err != nil {
return nil, err
}
}
if dstScheme == "udp" {
streamType = UDPStreamType
dstUDPAddr, err = net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%s", config.Host, dstPort))
if err != nil {
return nil, err
}
}
lsPort, err := strconv.Atoi(srcPort)
if err != nil {
return nil, err
}
utils.markPortInUse(lsPort)
ctx, cancel := context.WithCancel(context.Background())
route := StreamRoute{
Alias: config.Alias,
Type: streamType,
ListeningScheme: srcScheme,
TargetScheme: dstScheme,
TargetHost: config.Host,
ListeningPort: srcPort,
TargetPort: dstPort,
Context: ctx,
Cancel: cancel,
}
if streamType == TCPStreamType {
return &route, nil
}
return (*StreamRoute)(unsafe.Pointer(&UDPRoute{
StreamRoute: route,
ConnMap: make(map[net.Addr]*net.UDPConn),
ConnMapMutex: sync.Mutex{},
QueueSize: atomic.Int32{},
SourceUDPAddr: srcUDPAddr,
TargetUDPAddr: dstUDPAddr,
})), nil
}
func (route *StreamRoute) PrintError(err error) {
if err == nil {
return
}
log.Printf("[Stream] %s (%s => %s) error: %v", route.Alias, route.ListeningUrl(), route.TargetUrl(), err)
}
func (route *StreamRoute) ListeningUrl() string {
return fmt.Sprintf("%s:%s", route.ListeningScheme, route.ListeningPort)
}
func (route *StreamRoute) TargetUrl() string {
return fmt.Sprintf("%s://%s:%s", route.TargetScheme, route.TargetHost, route.TargetPort)
}
func (route *StreamRoute) listenStream() {
if route.ListeningPort == "0" {
freePort, err := utils.findFreePort(20000)
if err != nil {
route.PrintError(err)
return
}
route.ListeningPort = fmt.Sprintf("%d", freePort)
utils.markPortInUse(freePort)
}
if route.Type == UDPStreamType {
listenUDP((*UDPRoute)(unsafe.Pointer(route)))
} else {
listenTCP(route)
}
}
func beginListenStreams() {
for _, route := range routes.StreamRoutes {
go route.listenStream()
}
}
func endListenStreams() {
var wg sync.WaitGroup
wg.Add(len(routes.StreamRoutes))
defer wg.Wait()
routes.Mutex.Lock()
defer routes.Mutex.Unlock()
for _, route := range routes.StreamRoutes {
go func(r *StreamRoute) {
r.Cancel()
wg.Done()
}(route)
}
}
// const maxQueueSizePerStream = 100
const streamStopListenTimeout = 1 * time.Second

View file

@ -12,34 +12,87 @@ import (
const tcpDialTimeout = 5 * time.Second
func listenTCP(route *StreamRoute) {
in, err := net.Listen(
route.ListeningScheme,
fmt.Sprintf(":%s", route.ListeningPort),
)
type TCPRoute struct {
*StreamRouteBase
listener net.Listener
connChan chan net.Conn
}
func NewTCPRoute(config *ProxyConfig) (StreamRoute, error) {
base, err := newStreamRouteBase(config)
if err != nil {
log.Printf("[Stream Listen] %v", err)
return nil, err
}
if base.TargetScheme != TCPStreamType {
return nil, fmt.Errorf("tcp to %s not yet supported", base.TargetScheme)
}
return &TCPRoute{
StreamRouteBase: base,
listener: nil,
connChan: make(chan net.Conn),
}, nil
}
func (route *TCPRoute) Listen() {
in, err := net.Listen("tcp", ":"+route.ListeningPort)
if err != nil {
route.PrintError(err)
return
}
route.listener = in
route.wg.Add(2)
go route.grAcceptConnections()
go route.grHandleConnections()
}
defer in.Close()
func (route *TCPRoute) StopListening() {
stopListening(route)
}
func (route *TCPRoute) closeListeners() {
if route.listener == nil {
return
}
route.listener.Close()
route.listener = nil
}
func (route *TCPRoute) grAcceptConnections() {
defer route.wg.Done()
for {
select {
case <-route.Context.Done():
case <-route.stopChann:
return
default:
clientConn, err := in.Accept()
conn, err := route.listener.Accept()
if err != nil {
log.Printf("[Stream Accept] %v", err)
return
route.PrintError(err)
continue
}
go connectTCPPipe(route, clientConn)
route.connChan <- conn
}
}
}
func connectTCPPipe(route *StreamRoute, clientConn net.Conn) {
func (route *TCPRoute) grHandleConnections() {
defer route.wg.Done()
for {
select {
case <-route.stopChann:
return
case conn := <-route.connChan:
route.wg.Add(1)
go route.grHandleConnection(conn)
}
}
}
func (route *TCPRoute) grHandleConnection(clientConn net.Conn) {
defer clientConn.Close()
defer route.wg.Done()
ctx, cancel := context.WithTimeout(context.Background(), tcpDialTimeout)
defer cancel()
@ -50,25 +103,29 @@ func connectTCPPipe(route *StreamRoute, clientConn net.Conn) {
log.Printf("[Stream Dial] %v", err)
return
}
tcpPipe(route, clientConn, serverConn)
route.tcpPipe(clientConn, serverConn)
}
func tcpPipe(route *StreamRoute, src net.Conn, dest net.Conn) {
func (route *TCPRoute) tcpPipe(src net.Conn, dest net.Conn) {
close := func() {
src.Close()
dest.Close()
}
var wg sync.WaitGroup
wg.Add(2) // Number of goroutines
defer src.Close()
defer dest.Close()
go func() {
_, err := io.Copy(src, dest)
go route.PrintError(err)
route.PrintError(err)
close()
wg.Done()
}()
go func() {
_, err := io.Copy(dest, src)
go route.PrintError(err)
route.PrintError(err)
close()
wg.Done()
}()
wg.Wait()
}

View file

@ -1,125 +1,231 @@
package main
import (
"fmt"
"io"
"log"
"net"
"sync"
"sync/atomic"
"time"
)
const udpBufferSize = 1500
const udpMaxQueueSizePerStream = 100
const udpListenTimeout = 100 * time.Second
const udpConnectionTimeout = 30 * time.Second
// const udpListenTimeout = 100 * time.Second
// const udpConnectionTimeout = 30 * time.Second
type UDPRoute struct {
StreamRoute
*StreamRouteBase
ConnMap map[net.Addr]*net.UDPConn
ConnMapMutex sync.Mutex
QueueSize atomic.Int32
SourceUDPAddr *net.UDPAddr
TargetUDPAddr *net.UDPAddr
connMap map[net.Addr]net.Conn
connMapMutex sync.Mutex
listeningConn *net.UDPConn
targetConn *net.UDPConn
connChan chan *UDPConn
}
func listenUDP(route *UDPRoute) {
source, err := net.ListenUDP(route.ListeningScheme, route.SourceUDPAddr)
type UDPConn struct {
remoteAddr net.Addr
buffer []byte
bytesReceived []byte
nReceived int
}
func NewUDPRoute(config *ProxyConfig) (StreamRoute, error) {
base, err := newStreamRouteBase(config)
if err != nil {
return nil, err
}
if base.TargetScheme != UDPStreamType {
return nil, fmt.Errorf("udp to %s not yet supported", base.TargetScheme)
}
return &UDPRoute{
StreamRouteBase: base,
connMap: make(map[net.Addr]net.Conn),
connChan: make(chan *UDPConn),
}, nil
}
func (route *UDPRoute) Listen() {
source, err := net.ListenPacket(route.ListeningScheme, fmt.Sprintf(":%s", route.ListeningPort))
if err != nil {
route.PrintError(err)
return
}
target, err := net.DialUDP(route.TargetScheme, nil, route.TargetUDPAddr)
target, err := net.Dial(route.TargetScheme, fmt.Sprintf("%s:%s", route.TargetHost, route.TargetPort))
if err != nil {
route.PrintError(err)
source.Close()
return
}
var wg sync.WaitGroup
defer wg.Wait()
defer source.Close()
defer target.Close()
route.listeningConn = source.(*net.UDPConn)
route.targetConn = target.(*net.UDPConn)
var udpBuffers = [udpMaxQueueSizePerStream][udpBufferSize]byte{}
route.wg.Add(2)
go route.grAcceptConnections()
go route.grHandleConnections()
}
func (route *UDPRoute) StopListening() {
stopListening(route)
}
func (route *UDPRoute) closeListeners() {
if route.listeningConn != nil {
route.listeningConn.Close()
}
if route.targetConn != nil {
route.targetConn.Close()
}
route.listeningConn = nil
route.targetConn = nil
for _, conn := range route.connMap {
conn.(*net.UDPConn).Close() // TODO: change on non udp target
}
}
func (route *UDPRoute) grAcceptConnections() {
defer route.wg.Done()
for {
select {
case <-route.Context.Done():
case <-route.stopChann:
return
default:
if route.QueueSize.Load() >= udpMaxQueueSizePerStream {
wg.Wait()
conn, err := route.accept()
if err != nil {
route.PrintError(err)
continue
}
go udpLoop(
route,
source,
target,
udpBuffers[route.QueueSize.Load()][:],
&wg,
)
route.connChan <- conn
}
}
}
func udpLoop(route *UDPRoute, in *net.UDPConn, out *net.UDPConn, buffer []byte, wg *sync.WaitGroup) {
wg.Add(1)
route.QueueSize.Add(1)
defer route.QueueSize.Add(-1)
defer wg.Done()
func (route *UDPRoute) grHandleConnections() {
defer route.wg.Done()
var nRead int
var nWritten int
in.SetReadDeadline(time.Now().Add(udpListenTimeout))
nRead, srcAddr, err := in.ReadFromUDP(buffer)
if err != nil {
return
}
log.Printf("[Stream] received %d bytes from %s, forwarding to %s", nRead, srcAddr.String(), out.RemoteAddr().String())
out.SetWriteDeadline(time.Now().Add(udpConnectionTimeout))
nWritten, err = out.Write(buffer[:nRead])
if nWritten != nRead {
err = io.ErrShortWrite
}
if err != nil {
go route.PrintError(err)
return
}
err = udpPipe(route, out, srcAddr, buffer)
if err != nil {
go route.PrintError(err)
for {
select {
case <-route.stopChann:
return
case conn := <-route.connChan:
go func() {
err := route.handleConnection(conn)
if err != nil {
route.PrintError(err)
}
}()
}
}
}
func udpPipe(route *UDPRoute, src *net.UDPConn, destAddr *net.UDPAddr, buffer []byte) error {
src.SetReadDeadline(time.Now().Add(udpConnectionTimeout))
nRead, err := src.Read(buffer)
if err != nil || nRead == 0 {
return err
}
log.Printf("[Stream] received %d bytes from %s, forwarding to %s", nRead, src.RemoteAddr().String(), destAddr.String())
dest, ok := route.ConnMap[destAddr]
func (route *UDPRoute) handleConnection(conn *UDPConn) error {
var err error
srcConn, ok := route.connMap[conn.remoteAddr]
if !ok {
dest, err = net.DialUDP(src.LocalAddr().Network(), nil, destAddr)
route.connMapMutex.Lock()
srcConn, err = net.DialUDP("udp", nil, conn.remoteAddr.(*net.UDPAddr))
if err != nil {
return err
}
route.ConnMapMutex.Lock()
route.ConnMap[destAddr] = dest
route.ConnMapMutex.Unlock()
route.connMap[conn.remoteAddr] = srcConn
route.connMapMutex.Unlock()
}
dest.SetWriteDeadline(time.Now().Add(udpConnectionTimeout))
nWritten, err := dest.Write(buffer[:nRead])
// initiate connection to target
err = route.forwardReceived(conn, route.targetConn)
if err != nil {
return err
}
if nWritten != nRead {
return io.ErrShortWrite
for {
select {
case <-route.stopChann:
return nil
default:
// receive from target
conn, err = route.readFrom(route.targetConn, conn.buffer)
if err != nil {
return err
}
// forward to source
err = route.forwardReceived(conn, srcConn)
if err != nil {
return err
}
// read from source
conn, err = route.readFrom(srcConn, conn.buffer)
if err != nil {
continue
}
// forward to target
err = route.forwardReceived(conn, route.targetConn)
if err != nil {
return err
}
}
}
return nil
}
func (route *UDPRoute) accept() (*UDPConn, error) {
in := route.listeningConn
buffer := make([]byte, udpBufferSize)
nRead, srcAddr, err := in.ReadFromUDP(buffer)
if err != nil {
return nil, err
}
if nRead == 0 {
return nil, io.ErrShortBuffer
}
return &UDPConn{
remoteAddr: srcAddr,
buffer: buffer,
bytesReceived: buffer[:nRead],
nReceived: nRead},
nil
}
func (route *UDPRoute) readFrom(src net.Conn, buffer []byte) (*UDPConn, error) {
nRead, err := src.Read(buffer)
if err != nil {
return nil, err
}
if nRead == 0 {
return nil, io.ErrShortBuffer
}
return &UDPConn{
remoteAddr: src.RemoteAddr(),
buffer: buffer,
bytesReceived: buffer[:nRead],
nReceived: nRead,
}, nil
}
func (route *UDPRoute) forwardReceived(receivedConn *UDPConn, dest net.Conn) error {
route.Logf(
"forwarding %d bytes %s -> %s",
receivedConn.nReceived,
receivedConn.remoteAddr.String(),
dest.RemoteAddr().String(),
)
nWritten, err := dest.Write(receivedConn.bytesReceived)
if nWritten != receivedConn.nReceived {
err = io.ErrShortWrite
}
return err
}

View file

@ -9,15 +9,18 @@ import (
)
type Utils struct {
PortsInUse map[int]bool
PortsInUse map[int]bool
portsInUseMutex sync.Mutex
}
var utils = &Utils{
PortsInUse: make(map[int]bool),
PortsInUse: make(map[int]bool),
portsInUseMutex: sync.Mutex{},
}
func (u *Utils) findFreePort(startingPort int) (int, error) {
func (u *Utils) findUseFreePort(startingPort int) (int, error) {
u.portsInUseMutex.Lock()
defer u.portsInUseMutex.Unlock()
for port := startingPort; port <= startingPort+100 && port <= 65535; port++ {
if u.PortsInUse[port] {
continue
@ -25,31 +28,34 @@ func (u *Utils) findFreePort(startingPort int) (int, error) {
addr := fmt.Sprintf(":%d", port)
l, err := net.Listen("tcp", addr)
if err == nil {
u.PortsInUse[port] = true
l.Close()
return port, nil
}
}
l, err := net.Listen("tcp", ":0")
if err == nil {
l.Close()
// NOTE: may not be after 20000
return l.Addr().(*net.TCPAddr).Port, nil
port := l.Addr().(*net.TCPAddr).Port
u.PortsInUse[port] = true
l.Close()
return port, nil
}
return -1, fmt.Errorf("unable to find free port: %v", err)
}
func (u *Utils) resetPortsInUse() {
u.portsInUseMutex.Lock()
defer u.portsInUseMutex.Unlock()
for port := range u.PortsInUse {
u.PortsInUse[port] = false
}
u.portsInUseMutex.Unlock()
}
func (u* Utils) markPortInUse(port int) {
func (u *Utils) markPortInUse(port int) {
u.portsInUseMutex.Lock()
defer u.portsInUseMutex.Unlock()
u.PortsInUse[port] = true
u.portsInUseMutex.Unlock()
}
func (*Utils) healthCheckHttp(targetUrl string) error {
@ -57,13 +63,13 @@ func (*Utils) healthCheckHttp(targetUrl string) error {
// if HEAD is not allowed, try GET
resp, err := healthCheckHttpClient.Head(targetUrl)
if resp != nil {
defer resp.Body.Close()
resp.Body.Close()
}
if err != nil && resp != nil && resp.StatusCode == http.StatusMethodNotAllowed {
_, err = healthCheckHttpClient.Get(targetUrl)
}
if resp != nil {
defer resp.Body.Close()
resp.Body.Close()
}
return err
}
@ -73,6 +79,6 @@ func (*Utils) healthCheckStream(scheme string, host string) error {
if err != nil {
return err
}
defer conn.Close()
conn.Close()
return nil
}

View file

@ -105,7 +105,7 @@
</tr>
</thead>
<tbody>
{{range $alias, $httpRoutes := .HTTPRoutes}}
{{range $alias, $httpRoutes := .HTTPRoutes.Iterator}}
{{range $route := $httpRoutes}}
<tr>
<td>{{$alias}}</td>
@ -132,7 +132,7 @@
</tr>
</thead>
<tbody>
{{range $_, $route := .StreamRoutes}}
{{range $_, $route := .StreamRoutes.Iterator}}
<tr>
<td>{{$route.Alias}}</td>
<td>{{$route.ListeningUrl}}</td>

View file

@ -0,0 +1,10 @@
FROM debian:stable-slim
RUN apt update && \
apt install -y netcat-openbsd && \
rm -rf /var/lib/apt/lists/*
RUN printf '#!/bin/bash\nclear; echo "Netcat UDP server started"; nc -u -l 9999; exit' >> /entrypoint.sh
RUN chmod +x /entrypoint.sh
ENTRYPOINT ["/entrypoint.sh"]