tcp/udp fix

This commit is contained in:
yusing 2024-03-31 11:26:39 +00:00
parent cbe23d2ed1
commit 351bf84559
8 changed files with 152 additions and 179 deletions

View file

@ -1,20 +1,30 @@
FROM golang:1.22.1 as builder
FROM alpine:latest AS codemirror
RUN apk add --no-cache unzip wget make
COPY Makefile .
RUN make setup-codemirror
COPY go.mod /app/go.mod
COPY src/ /app/src
COPY Makefile /app
WORKDIR /app
RUN make get
RUN make build
FROM golang:1.22.1-alpine as builder
COPY src/ /src
COPY go.mod go.sum /src/go-proxy
WORKDIR /src/go-proxy
RUN --mount=type=cache,target="/go/pkg/mod" \
go mod download
ENV GOCACHE=/root/.cache/go-build
RUN --mount=type=cache,target="/go/pkg/mod" \
--mount=type=cache,target="/root/.cache/go-build" \
CGO_ENABLED=0 GOOS=linux go build -pgo=auto -o go-proxy
FROM alpine:latest
LABEL maintainer="yusing@6uo.me"
RUN apk add --no-cache tzdata
COPY --from=builder /app/bin/go-proxy /app/
RUN mkdir -p /app/templates
COPY --from=codemirror templates/codemirror/ /app/templates/codemirror
COPY templates/ /app/templates
COPY schema/ /app/schema
COPY --from=builder /src/go-proxy /app/
RUN chmod +x /app/go-proxy
ENV DOCKER_HOST unix:///var/run/docker.sock

View file

@ -11,6 +11,7 @@ setup-codemirror:
wget https://codemirror.net/5/codemirror.zip
unzip codemirror.zip
rm codemirror.zip
mkdir -p templates
mv codemirror-* templates/codemirror
build:
@ -35,6 +36,6 @@ udp-server:
-p 9999:9999/udp \
--label proxy.test-udp.scheme=udp \
--label proxy.test-udp.port=20003:9999 \
--network data_default \
--network host \
--name test-udp \
$$(docker build -q -f udp-test-server.Dockerfile .)

View file

@ -1,23 +0,0 @@
package main
import "os"
type Reader interface {
Read() ([]byte, error)
}
type FileReader struct {
Path string
}
func (r *FileReader) Read() ([]byte, error) {
return os.ReadFile(r.Path)
}
type ByteReader struct {
Data []byte
}
func (r *ByteReader) Read() ([]byte, error) {
return r.Data, nil
}

View file

@ -2,13 +2,37 @@ package main
import (
"context"
"errors"
"fmt"
"io"
"sync"
"os"
"sync/atomic"
)
type Reader interface {
Read() ([]byte, error)
}
type FileReader struct {
Path string
}
func (r *FileReader) Read() ([]byte, error) {
return os.ReadFile(r.Path)
}
type ByteReader struct {
Data []byte
}
func (r *ByteReader) Read() ([]byte, error) {
return r.Data, nil
}
type ReadCloser struct {
ctx context.Context
r io.ReadCloser
ctx context.Context
r io.ReadCloser
closed atomic.Bool
}
func (r *ReadCloser) Read(p []byte) (int, error) {
@ -21,13 +45,16 @@ func (r *ReadCloser) Read(p []byte) (int, error) {
}
func (r *ReadCloser) Close() error {
if r.closed.Load() {
return nil
}
r.closed.Store(true)
return r.r.Close()
}
type Pipe struct {
r ReadCloser
w io.WriteCloser
wg sync.WaitGroup
ctx context.Context
cancel context.CancelFunc
}
@ -35,32 +62,24 @@ type Pipe struct {
func NewPipe(ctx context.Context, r io.ReadCloser, w io.WriteCloser) *Pipe {
ctx, cancel := context.WithCancel(ctx)
return &Pipe{
r: ReadCloser{ctx, r},
r: ReadCloser{ctx: ctx, r: r},
w: w,
ctx: ctx,
cancel: cancel,
}
}
func (p *Pipe) Start() {
p.wg.Add(1)
go func() {
Copy(p.ctx, p.w, &p.r)
p.wg.Done()
}()
func (p *Pipe) Start() error {
return Copy(p.ctx, p.w, &p.r)
}
func (p *Pipe) Stop() {
func (p *Pipe) Stop() error {
p.cancel()
p.wg.Wait()
return errors.Join(fmt.Errorf("read: %w", p.r.Close()), fmt.Errorf("write: %w", p.w.Close()))
}
func (p *Pipe) Close() (error, error) {
return p.r.Close(), p.w.Close()
}
func (p *Pipe) Wait() {
p.wg.Wait()
func (p *Pipe) Write(b []byte) (int, error) {
return p.w.Write(b)
}
type BidirectionalPipe struct {
@ -75,26 +94,34 @@ func NewBidirectionalPipe(ctx context.Context, rw1 io.ReadWriteCloser, rw2 io.Re
}
}
func (p *BidirectionalPipe) Start() {
p.pSrcDst.Start()
p.pDstSrc.Start()
func NewBidirectionalPipeIntermediate(ctx context.Context, listener io.ReadCloser, client io.ReadWriteCloser, target io.ReadWriteCloser) *BidirectionalPipe {
return &BidirectionalPipe{
pSrcDst: *NewPipe(ctx, listener, client),
pDstSrc: *NewPipe(ctx, client, target),
}
}
func (p *BidirectionalPipe) Stop() {
p.pSrcDst.Stop()
p.pDstSrc.Stop()
func (p *BidirectionalPipe) Start() error {
errCh := make(chan error, 2)
go func() {
errCh <- p.pSrcDst.Start()
}()
go func() {
errCh <- p.pDstSrc.Start()
}()
for err := range errCh {
if err != nil {
return err
}
}
return nil
}
func (p *BidirectionalPipe) Close() (error, error) {
return p.pSrcDst.Close()
}
func (p *BidirectionalPipe) Wait() {
p.pSrcDst.Wait()
p.pDstSrc.Wait()
func (p *BidirectionalPipe) Stop() error {
return errors.Join(p.pSrcDst.Stop(), p.pDstSrc.Stop())
}
func Copy(ctx context.Context, dst io.WriteCloser, src io.ReadCloser) error {
_, err := io.Copy(dst, &ReadCloser{ctx, src})
_, err := io.Copy(dst, &ReadCloser{ctx: ctx, r: src})
return err
}

View file

@ -15,7 +15,6 @@ func NewRoute(cfg *ProxyConfig) (Route, error) {
if err != nil {
return nil, NewNestedErrorFrom(err).Subject(cfg.Alias)
}
streamRoutes.Set(id, route)
return route, nil
} else {
httpRoutes.Ensure(cfg.Alias)

View file

@ -143,6 +143,7 @@ func (route *StreamRouteBase) Start() {
route.l.Errorf("failed to setup: %v", err)
return
}
streamRoutes.Set(route.id, route)
route.started = true
route.wg.Add(2)
go route.grAcceptConnections()

View file

@ -4,6 +4,7 @@ import (
"context"
"fmt"
"net"
"sync"
"time"
)
@ -14,12 +15,15 @@ type Pipes []*BidirectionalPipe
type TCPRoute struct {
*StreamRouteBase
listener net.Listener
pipe Pipes
mu sync.Mutex
}
func NewTCPRoute(base *StreamRouteBase) StreamImpl {
return &TCPRoute{
StreamRouteBase: base,
listener: nil,
pipe: make(Pipes, 0),
}
}
@ -40,7 +44,6 @@ func (route *TCPRoute) Handle(c interface{}) error {
clientConn := c.(net.Conn)
defer clientConn.Close()
defer route.wg.Done()
ctx, cancel := context.WithTimeout(context.Background(), tcpDialTimeout)
defer cancel()
@ -58,11 +61,12 @@ func (route *TCPRoute) Handle(c interface{}) error {
<-route.stopCh
pipeCancel()
}()
route.mu.Lock()
pipe := NewBidirectionalPipe(pipeCtx, clientConn, serverConn)
pipe.Start()
pipe.Wait()
pipe.Close()
return nil
route.pipe = append(route.pipe, pipe)
route.mu.Unlock()
return pipe.Start()
}
func (route *TCPRoute) CloseListeners() {
@ -71,4 +75,9 @@ func (route *TCPRoute) CloseListeners() {
}
route.listener.Close()
route.listener = nil
for _, pipe := range route.pipe {
if err := pipe.Stop(); err != nil {
route.l.Error(err)
}
}
}

View file

@ -1,52 +1,55 @@
package main
import (
"context"
"fmt"
"io"
"net"
"sync"
"github.com/sirupsen/logrus"
)
type UDPRoute struct {
*StreamRouteBase
connMap map[net.Addr]net.Conn
connMap UDPConnMap
connMapMutex sync.Mutex
listeningConn *net.UDPConn
targetConn *net.UDPConn
targetAddr *net.UDPAddr
}
type UDPConn struct {
remoteAddr net.Addr
buffer []byte
bytesReceived []byte
nReceived int
src *net.UDPConn
dst *net.UDPConn
*BidirectionalPipe
}
type UDPConnMap map[net.Addr]*UDPConn
func NewUDPRoute(base *StreamRouteBase) StreamImpl {
return &UDPRoute{
StreamRouteBase: base,
connMap: make(map[net.Addr]net.Conn),
connMap: make(UDPConnMap),
}
}
func (route *UDPRoute) Setup() error {
source, err := net.ListenPacket(route.ListeningScheme, fmt.Sprintf(":%v", route.ListeningPort))
laddr, err := net.ResolveUDPAddr(route.ListeningScheme, fmt.Sprintf(":%v", route.ListeningPort))
if err != nil {
return err
}
target, err := net.Dial(route.TargetScheme, fmt.Sprintf("%s:%v", route.TargetHost, route.TargetPort))
source, err := net.ListenUDP(route.ListeningScheme, laddr)
if err != nil {
return err
}
raddr, err := net.ResolveUDPAddr(route.TargetScheme, fmt.Sprintf("%s:%v", route.TargetHost, route.TargetPort))
if err != nil {
source.Close()
return err
}
route.listeningConn = source.(*net.UDPConn)
route.targetConn = target.(*net.UDPConn)
route.listeningConn = source
route.targetAddr = raddr
return nil
}
@ -64,71 +67,39 @@ func (route *UDPRoute) Accept() (interface{}, error) {
return nil, io.ErrShortBuffer
}
conn := &UDPConn{
remoteAddr: srcAddr,
buffer: buffer,
bytesReceived: buffer[:nRead],
nReceived: nRead,
}
return conn, nil
}
conn, ok := route.connMap[srcAddr]
func (route *UDPRoute) Handle(c interface{}) error {
var err error
conn := c.(*UDPConn)
srcConn, ok := route.connMap[conn.remoteAddr]
if !ok {
route.connMapMutex.Lock()
srcConn, err = net.DialUDP("udp", nil, conn.remoteAddr.(*net.UDPAddr))
srcConn, err := net.DialUDP("udp", nil, srcAddr)
if err != nil {
return err
return nil, err
}
route.connMap[conn.remoteAddr] = srcConn
dstConn, err := net.DialUDP("udp", nil, route.targetAddr)
if err != nil {
srcConn.Close()
return nil, err
}
pipeCtx, pipeCancel := context.WithCancel(context.Background())
go func() {
<-route.stopCh
pipeCancel()
}()
conn = &UDPConn{
srcConn,
dstConn,
NewBidirectionalPipe(pipeCtx, sourceRWCloser{in, dstConn}, sourceRWCloser{in, srcConn}),
}
route.connMap[srcAddr] = conn
route.connMapMutex.Unlock()
}
var forwarder func(*UDPConn, net.Conn) error
_, err = conn.dst.Write(buffer[:nRead])
return conn, err
}
if logLevel == logrus.DebugLevel {
forwarder = route.forwardReceivedDebug
} else {
forwarder = route.forwardReceivedReal
}
// initiate connection to target
err = forwarder(conn, route.targetConn)
if err != nil {
return err
}
for {
select {
case <-route.stopCh:
return nil
default:
// receive from target
conn, err = route.readFrom(route.targetConn, conn.buffer)
if err != nil {
return err
}
// forward to source
err = forwarder(conn, srcConn)
if err != nil {
return err
}
// read from source
conn, err = route.readFrom(srcConn, conn.buffer)
if err != nil {
continue
}
// forward to target
err = forwarder(conn, route.targetConn)
if err != nil {
return err
}
}
}
func (route *UDPRoute) Handle(c interface{}) error {
return c.(*UDPConn).Start()
}
func (route *UDPRoute) CloseListeners() {
@ -136,50 +107,28 @@ func (route *UDPRoute) CloseListeners() {
route.listeningConn.Close()
route.listeningConn = nil
}
if route.targetConn != nil {
route.targetConn.Close()
route.targetConn = nil
}
for _, conn := range route.connMap {
conn.(*net.UDPConn).Close() // TODO: change on non udp target
if err := conn.dst.Close(); err != nil {
route.l.Error(err)
}
}
route.connMap = make(map[net.Addr]net.Conn)
route.connMap = make(UDPConnMap)
}
func (route *UDPRoute) readFrom(src net.Conn, buffer []byte) (*UDPConn, error) {
nRead, err := src.Read(buffer)
if err != nil {
return nil, err
}
if nRead == 0 {
return nil, io.ErrShortBuffer
}
return &UDPConn{
remoteAddr: src.RemoteAddr(),
buffer: buffer,
bytesReceived: buffer[:nRead],
nReceived: nRead,
}, nil
type sourceRWCloser struct {
server *net.UDPConn
target *net.UDPConn
}
func (route *UDPRoute) forwardReceivedReal(receivedConn *UDPConn, dest net.Conn) error {
nWritten, err := dest.Write(receivedConn.bytesReceived)
if nWritten != receivedConn.nReceived {
err = io.ErrShortWrite
}
return err
func (w sourceRWCloser) Read(p []byte) (int, error) {
n, _, err := w.target.ReadFrom(p)
return n, err
}
func (route *UDPRoute) forwardReceivedDebug(receivedConn *UDPConn, dest net.Conn) error {
route.l.WithField("size", receivedConn.nReceived).Debugf(
"forwarding from %s to %s",
receivedConn.remoteAddr.String(),
dest.RemoteAddr().String(),
)
return route.forwardReceivedReal(receivedConn, dest)
func (w sourceRWCloser) Write(p []byte) (int, error) {
return w.server.WriteToUDP(p, w.target.RemoteAddr().(*net.UDPAddr)) // TODO: support non udp
}
func (w sourceRWCloser) Close() error {
return w.target.Close()
}