diff --git a/src/docker/container.go b/src/docker/container.go index accd211..eabf328 100644 --- a/src/docker/container.go +++ b/src/docker/container.go @@ -10,18 +10,19 @@ import ( ) type ProxyProperties struct { - DockerHost string `yaml:"-" json:"docker_host"` - ContainerName string `yaml:"-" json:"container_name"` - ImageName string `yaml:"-" json:"image_name"` - Aliases []string `yaml:"-" json:"aliases"` - IsExcluded bool `yaml:"-" json:"is_excluded"` - FirstPort string `yaml:"-" json:"first_port"` - IdleTimeout string `yaml:"-" json:"idle_timeout"` - WakeTimeout string `yaml:"-" json:"wake_timeout"` - StopMethod string `yaml:"-" json:"stop_method"` - StopTimeout string `yaml:"-" json:"stop_timeout"` // stop_method = "stop" only - StopSignal string `yaml:"-" json:"stop_signal"` // stop_method = "stop" | "kill" only - Running bool `yaml:"-" json:"running"` + DockerHost string `yaml:"-" json:"docker_host"` + ContainerName string `yaml:"-" json:"container_name"` + ImageName string `yaml:"-" json:"image_name"` + PublicPortMapping PortMapping `yaml:"-" json:"public_port_mapping"` // non-zero publicPort:types.Port + PrivatePortMapping PortMapping `yaml:"-" json:"private_port_mapping"` // privatePort:types.Port + Aliases []string `yaml:"-" json:"aliases"` + IsExcluded bool `yaml:"-" json:"is_excluded"` + IdleTimeout string `yaml:"-" json:"idle_timeout"` + WakeTimeout string `yaml:"-" json:"wake_timeout"` + StopMethod string `yaml:"-" json:"stop_method"` + StopTimeout string `yaml:"-" json:"stop_timeout"` // stop_method = "stop" only + StopSignal string `yaml:"-" json:"stop_signal"` // stop_method = "stop" | "kill" only + Running bool `yaml:"-" json:"running"` } type Container struct { @@ -29,21 +30,24 @@ type Container struct { *ProxyProperties } +type PortMapping = map[string]types.Port + func FromDocker(c *types.Container, dockerHost string) (res Container) { res.Container = c res.ProxyProperties = &ProxyProperties{ - DockerHost: dockerHost, - ContainerName: res.getName(), - ImageName: res.getImageName(), - Aliases: res.getAliases(), - IsExcluded: U.ParseBool(res.getDeleteLabel(LableExclude)), - FirstPort: res.firstPortOrEmpty(), - IdleTimeout: res.getDeleteLabel(LabelIdleTimeout), - WakeTimeout: res.getDeleteLabel(LabelWakeTimeout), - StopMethod: res.getDeleteLabel(LabelStopMethod), - StopTimeout: res.getDeleteLabel(LabelStopTimeout), - StopSignal: res.getDeleteLabel(LabelStopSignal), - Running: c.Status == "running" || c.State == "running", + DockerHost: dockerHost, + ContainerName: res.getName(), + ImageName: res.getImageName(), + PublicPortMapping: res.getPublicPortMapping(), + PrivatePortMapping: res.getPrivatePortMapping(), + Aliases: res.getAliases(), + IsExcluded: U.ParseBool(res.getDeleteLabel(LabelExclude)), + IdleTimeout: res.getDeleteLabel(LabelIdleTimeout), + WakeTimeout: res.getDeleteLabel(LabelWakeTimeout), + StopMethod: res.getDeleteLabel(LabelStopMethod), + StopTimeout: res.getDeleteLabel(LabelStopTimeout), + StopSignal: res.getDeleteLabel(LabelStopSignal), + Running: c.Status == "running" || c.State == "running", } return } @@ -81,7 +85,7 @@ func (c Container) getDeleteLabel(label string) string { } func (c Container) getAliases() []string { - if l := c.getDeleteLabel(LableAliases); l != "" { + if l := c.getDeleteLabel(LabelAliases); l != "" { return U.CommaSeperatedList(l) } else { return []string{c.getName()} @@ -98,14 +102,24 @@ func (c Container) getImageName() string { return slashSep[len(slashSep)-1] } -func (c Container) firstPortOrEmpty() string { - if len(c.Ports) == 0 { - return "" - } - for _, p := range c.Ports { - if p.PublicPort != 0 { - return fmt.Sprint(p.PublicPort) +func (c Container) getPublicPortMapping() PortMapping { + res := make(PortMapping) + for _, v := range c.Ports { + if v.PublicPort == 0 { + continue } + res[fmt.Sprint(v.PublicPort)] = v } - return "" + return res +} + +func (c Container) getPrivatePortMapping() PortMapping { + res := make(PortMapping) + for _, v := range c.Ports { + if v.PublicPort == 0 { + continue + } + res[fmt.Sprint(v.PrivatePort)] = v + } + return res } diff --git a/src/docker/labels.go b/src/docker/labels.go index 6b26197..444db36 100644 --- a/src/docker/labels.go +++ b/src/docker/labels.go @@ -3,8 +3,8 @@ package docker const ( WildcardAlias = "*" - LableAliases = NSProxy + ".aliases" - LableExclude = NSProxy + ".exclude" + LabelAliases = NSProxy + ".aliases" + LabelExclude = NSProxy + ".exclude" LabelIdleTimeout = NSProxy + ".idle_timeout" LabelWakeTimeout = NSProxy + ".wake_timeout" LabelStopMethod = NSProxy + ".stop_method" diff --git a/src/error/error.go b/src/error/error.go index 32ad66d..f5206e1 100644 --- a/src/error/error.go +++ b/src/error/error.go @@ -118,9 +118,9 @@ func (ne NestedError) With(s any) NestedError { case string: msg = ss case fmt.Stringer: - msg = ss.String() + return ne.append(ss.String()) default: - msg = fmt.Sprint(s) + return ne.append(fmt.Sprint(s)) } return ne.withError(From(errors.New(msg))) } @@ -201,6 +201,14 @@ func (ne NestedError) withError(err NestedError) NestedError { return ne } +func (ne NestedError) append(msg string) NestedError { + if ne == nil { + return nil + } + ne.err = fmt.Errorf("%w %s", ne.err, msg) + return ne +} + func (ne NestedError) writeToSB(sb *strings.Builder, level int, prefix string) { for i := 0; i < level; i++ { sb.WriteString(" ") diff --git a/src/models/raw_entry.go b/src/models/raw_entry.go index ea91c1f..00f3cdb 100644 --- a/src/models/raw_entry.go +++ b/src/models/raw_entry.go @@ -1,6 +1,7 @@ package model import ( + "fmt" "strconv" "strings" @@ -53,11 +54,17 @@ func (e *RawEntry) FillMissingFields() bool { } } - if e.Port == "" { - if e.FirstPort == "" { - return false + if e.PublicPortMapping != nil { + if _, ok := e.PublicPortMapping[e.Port]; !ok { // port is not exposed, but specified + // try to fallback to first public port + if len(e.PublicPortMapping) == 0 { + return false + } + for _, p := range e.PublicPortMapping { + e.Port = fmt.Sprint(p.PublicPort) + break + } } - e.Port = e.FirstPort } if e.Scheme == "" { @@ -69,8 +76,19 @@ func (e *RawEntry) FillMissingFields() bool { e.Scheme = "http" } else if e.Port == "443" { e.Scheme = "https" - } else if isDocker && e.Port == "" { - return false + } else if isDocker { + if e.Port == "" { + return false + } + if p, ok := e.PublicPortMapping[e.Port]; ok { + if p.Type == "udp" { + e.Scheme = "udp" + } else { + e.Scheme = "http" + } + } else { + return false + } } else { e.Scheme = "http" } diff --git a/src/proxy/constants.go b/src/proxy/constants.go deleted file mode 100644 index 1278322..0000000 --- a/src/proxy/constants.go +++ /dev/null @@ -1,16 +0,0 @@ -package proxy - -const ( - StreamType_UDP string = "udp" - StreamType_TCP string = "tcp" - // StreamType_UDP_TCP Scheme = "udp-tcp" - // StreamType_TCP_UDP Scheme = "tcp-udp" - // StreamType_TLS Scheme = "tls" -) - -var ( - // TODO: support "tcp-udp", "udp-tcp", etc. - StreamSchemes = []string{StreamType_TCP, StreamType_UDP} - HTTPSchemes = []string{"http", "https"} - ValidSchemes = append(StreamSchemes, HTTPSchemes...) -) diff --git a/src/proxy/fields/port.go b/src/proxy/fields/port.go index 5708c25..7756492 100644 --- a/src/proxy/fields/port.go +++ b/src/proxy/fields/port.go @@ -28,6 +28,10 @@ func (p Port) inBound() bool { return p >= MinPort && p <= MaxPort } +func (p Port) String() string { + return strconv.Itoa(int(p)) +} + const ( MinPort = 0 MaxPort = 65535 diff --git a/src/proxy/fields/stream_port.go b/src/proxy/fields/stream_port.go index 9ab3941..0c8d674 100644 --- a/src/proxy/fields/stream_port.go +++ b/src/proxy/fields/stream_port.go @@ -37,7 +37,7 @@ func ValidateStreamPort(p string) (StreamPort, E.NestedError) { } else if err != nil { proxyPort, err = parseNameToPort(split[1]) if err != nil { - return ErrStreamPort, err + return ErrStreamPort, E.Invalid("stream port", p).With(proxyPort) } } diff --git a/src/proxy/fields/stream_port_test.go b/src/proxy/fields/stream_port_test.go index c16c9b9..39def28 100644 --- a/src/proxy/fields/stream_port_test.go +++ b/src/proxy/fields/stream_port_test.go @@ -4,7 +4,7 @@ import ( "testing" E "github.com/yusing/go-proxy/error" - U "github.com/yusing/go-proxy/utils/testing" + . "github.com/yusing/go-proxy/utils/testing" ) var validPorts = []string{ @@ -35,14 +35,14 @@ var outOfRangePorts = []string{ func TestStreamPort(t *testing.T) { for _, port := range validPorts { _, err := ValidateStreamPort(port) - U.ExpectNoError(t, err.Error()) + ExpectNoError(t, err.Error()) } for _, port := range invalidPorts { _, err := ValidateStreamPort(port) - U.ExpectError2(t, port, E.ErrInvalid, err.Error()) + ExpectError2(t, port, E.ErrInvalid, err.Error()) } for _, port := range outOfRangePorts { _, err := ValidateStreamPort(port) - U.ExpectError2(t, port, E.ErrOutOfRange, err.Error()) + ExpectError2(t, port, E.ErrOutOfRange, err.Error()) } } diff --git a/src/proxy/provider/docker_provider.go b/src/proxy/provider/docker_provider.go index 6ea412f..0f37066 100755 --- a/src/proxy/provider/docker_provider.go +++ b/src/proxy/provider/docker_provider.go @@ -137,31 +137,29 @@ func (p *DockerProvider) entriesFromContainerLabels(container D.Container) (M.Ra errors.Add(p.applyLabel(container, entries, key, val)) } + // selecting correct host port + replacePrivPorts := func() { + if container.HostConfig.NetworkMode != "host" { + entries.RangeAll(func(_ string, entry *M.RawEntry) { + entryPortSplit := strings.Split(entry.Port, ":") + n := len(entryPortSplit) + // if the port matches the proxy port, replace it with the public port + if p, ok := container.PrivatePortMapping[entryPortSplit[n-1]]; ok { + entryPortSplit[n-1] = fmt.Sprint(p.PublicPort) + entry.Port = strings.Join(entryPortSplit, ":") + } + }) + } + } + replacePrivPorts() + // remove all entries that failed to fill in missing fields entries.RemoveAll(func(re *M.RawEntry) bool { return !re.FillMissingFields() }) - // selecting correct host port - if container.HostConfig.NetworkMode != "host" { - for _, a := range container.Aliases { - entry, ok := entries.Load(a) - if !ok { - continue - } - for _, p := range container.Ports { - containerPort := strconv.Itoa(int(p.PrivatePort)) - publicPort := strconv.Itoa(int(p.PublicPort)) - entryPortSplit := strings.Split(entry.Port, ":") - n := len(entryPortSplit) - if entryPortSplit[n-1] == containerPort { - entryPortSplit[n-1] = publicPort - entry.Port = strings.Join(entryPortSplit, ":") - break - } - } - } - } + // do it again since the port may got filled in + replacePrivPorts() return entries, errors.Build().Subject(container.ContainerName) } @@ -193,7 +191,7 @@ func (p *DockerProvider) applyLabel(container D.Container, entries M.RawEntries, return ref } if index < 1 || index > len(container.Aliases) { - refErr.Add(E.Invalid("index", ref).Extraf("index out of range")) + refErr.Add(E.OutOfRange("index", ref)) return ref } return container.Aliases[index-1] diff --git a/src/proxy/provider/docker_provider_test.go b/src/proxy/provider/docker_provider_test.go index e468a54..06bdd30 100644 --- a/src/proxy/provider/docker_provider_test.go +++ b/src/proxy/provider/docker_provider_test.go @@ -8,15 +8,12 @@ import ( "github.com/yusing/go-proxy/common" D "github.com/yusing/go-proxy/docker" E "github.com/yusing/go-proxy/error" - F "github.com/yusing/go-proxy/utils/functional" + P "github.com/yusing/go-proxy/proxy" + T "github.com/yusing/go-proxy/proxy/fields" + . "github.com/yusing/go-proxy/utils/testing" ) -func get[KT comparable, VT any](m F.Map[KT, VT], key KT) VT { - v, _ := m.Load(key) - return v -} - var dummyNames = []string{"/a"} func TestApplyLabelFieldValidity(t *testing.T) { @@ -48,10 +45,10 @@ X_Custom_Header2: value3 "X-Custom-Header2", } var p DockerProvider - var c = D.FromDocker(&types.Container{ + entries, err := p.entriesFromContainerLabels(D.FromDocker(&types.Container{ Names: dummyNames, Labels: map[string]string{ - D.LableAliases: "a,b", + D.LabelAliases: "a,b", D.LabelIdleTimeout: common.IdleTimeoutDefault, D.LabelStopMethod: common.StopMethodDefault, D.LabelStopSignal: "SIGTERM", @@ -65,11 +62,16 @@ X_Custom_Header2: value3 "proxy.a.path_patterns": pathPatterns, "proxy.a.set_headers": setHeaders, "proxy.a.hide_headers": hideHeaders, - }}, "") - entries, err := p.entriesFromContainerLabels(c) + }, + Ports: []types.Port{ + {Type: "tcp", PrivatePort: 4567, PublicPort: 8888}, + }}, "")) ExpectNoError(t, err.Error()) - a := get(entries, "a") - b := get(entries, "b") + + a, ok := entries.Load("a") + ExpectTrue(t, ok) + b, ok := entries.Load("b") + ExpectTrue(t, ok) ExpectEqual(t, a.Scheme, "https") ExpectEqual(t, b.Scheme, "https") @@ -77,8 +79,8 @@ X_Custom_Header2: value3 ExpectEqual(t, a.Host, "app") ExpectEqual(t, b.Host, "app") - ExpectEqual(t, a.Port, "4567") - ExpectEqual(t, b.Port, "4567") + ExpectEqual(t, a.Port, "8888") + ExpectEqual(t, b.Port, "8888") ExpectTrue(t, a.NoTLSVerify) ExpectTrue(t, b.NoTLSVerify) @@ -110,38 +112,68 @@ X_Custom_Header2: value3 func TestApplyLabel(t *testing.T) { var p DockerProvider - var c = D.FromDocker(&types.Container{ + entries, err := p.entriesFromContainerLabels(D.FromDocker(&types.Container{ Names: dummyNames, Labels: map[string]string{ - D.LableAliases: "a,b,c", + D.LabelAliases: "a,b,c", "proxy.a.no_tls_verify": "true", "proxy.a.port": "3333", "proxy.b.port": "1234", "proxy.c.scheme": "https", - }}, "") - entries, err := p.entriesFromContainerLabels(c) + }, + Ports: []types.Port{ + {Type: "tcp", PrivatePort: 3333, PublicPort: 1111}, + {Type: "tcp", PrivatePort: 4444, PublicPort: 1234}, + }}, "", + )) + a, ok := entries.Load("a") + ExpectTrue(t, ok) + b, ok := entries.Load("b") + ExpectTrue(t, ok) + c, ok := entries.Load("c") + ExpectTrue(t, ok) + ExpectNoError(t, err.Error()) - ExpectEqual(t, get(entries, "a").NoTLSVerify, true) - ExpectEqual(t, get(entries, "b").Port, "1234") - ExpectEqual(t, get(entries, "c").Scheme, "https") + ExpectEqual(t, a.Scheme, "http") + ExpectEqual(t, a.Port, "1111") + ExpectEqual(t, a.NoTLSVerify, true) + ExpectEqual(t, b.Scheme, "http") + ExpectEqual(t, b.Port, "1234") + ExpectEqual(t, c.Scheme, "https") + ExpectEqual(t, c.Port, "1111") } func TestApplyLabelWithRef(t *testing.T) { var p DockerProvider - var c = D.FromDocker(&types.Container{ + entries, err := p.entriesFromContainerLabels(D.FromDocker(&types.Container{ Names: dummyNames, Labels: map[string]string{ - D.LableAliases: "a,b,c", + D.LabelAliases: "a,b,c", "proxy.$1.host": "localhost", + "proxy.*.port": "1111", "proxy.$1.port": "4444", - "proxy.$2.port": "1234", + "proxy.$2.port": "9999", "proxy.$3.scheme": "https", - }}, "") - entries, err := p.entriesFromContainerLabels(c) + }, + Ports: []types.Port{ + {Type: "tcp", PrivatePort: 3333, PublicPort: 9999}, + {Type: "tcp", PrivatePort: 4444, PublicPort: 5555}, + {Type: "tcp", PrivatePort: 1111, PublicPort: 2222}, + }}, "")) + a, ok := entries.Load("a") + ExpectTrue(t, ok) + b, ok := entries.Load("b") + ExpectTrue(t, ok) + c, ok := entries.Load("c") + ExpectTrue(t, ok) + ExpectNoError(t, err.Error()) - ExpectEqual(t, get(entries, "a").Host, "localhost") - ExpectEqual(t, get(entries, "b").Port, "1234") - ExpectEqual(t, get(entries, "c").Scheme, "https") + ExpectEqual(t, a.Scheme, "http") + ExpectEqual(t, a.Host, "localhost") + ExpectEqual(t, a.Port, "5555") + ExpectEqual(t, b.Port, "9999") + ExpectEqual(t, c.Scheme, "https") + ExpectEqual(t, c.Port, "2222") } func TestApplyLabelWithRefIndexError(t *testing.T) { @@ -149,21 +181,110 @@ func TestApplyLabelWithRefIndexError(t *testing.T) { var c = D.FromDocker(&types.Container{ Names: dummyNames, Labels: map[string]string{ - D.LableAliases: "a,b", + D.LabelAliases: "a,b", "proxy.$1.host": "localhost", "proxy.$4.scheme": "https", }}, "") _, err := p.entriesFromContainerLabels(c) - ExpectError(t, E.ErrInvalid, err.Error()) + ExpectError(t, E.ErrOutOfRange, err.Error()) ExpectTrue(t, strings.Contains(err.String(), "index out of range")) - c = D.FromDocker(&types.Container{ + _, err = p.entriesFromContainerLabels(D.FromDocker(&types.Container{ Names: dummyNames, Labels: map[string]string{ - D.LableAliases: "a,b", + D.LabelAliases: "a,b", "proxy.$0.host": "localhost", - }}, "") - _, err = p.entriesFromContainerLabels(c) - ExpectError(t, E.ErrInvalid, err.Error()) + }}, "")) + ExpectError(t, E.ErrOutOfRange, err.Error()) ExpectTrue(t, strings.Contains(err.String(), "index out of range")) } + +func TestStreamDefaultValues(t *testing.T) { + var p DockerProvider + var c = D.FromDocker(&types.Container{ + Names: dummyNames, + Labels: map[string]string{ + D.LabelAliases: "a", + "proxy.*.no_tls_verify": "true", + }, + Ports: []types.Port{ + {Type: "udp", PrivatePort: 1234, PublicPort: 5678}, + }}, "", + ) + entries, err := p.entriesFromContainerLabels(c) + ExpectNoError(t, err.Error()) + + raw, ok := entries.Load("a") + ExpectTrue(t, ok) + + entry, err := P.ValidateEntry(raw) + ExpectNoError(t, err.Error()) + + a := ExpectType[*P.StreamEntry](t, entry) + ExpectEqual(t, a.Scheme.ListeningScheme, T.Scheme("udp")) + ExpectEqual(t, a.Scheme.ProxyScheme, T.Scheme("udp")) + ExpectEqual(t, a.Port.ListeningPort, 0) + ExpectEqual(t, a.Port.ProxyPort, 5678) +} + +func TestExplicitExclude(t *testing.T) { + var p DockerProvider + entries, err := p.entriesFromContainerLabels(D.FromDocker(&types.Container{ + Names: dummyNames, + Labels: map[string]string{ + D.LabelAliases: "a", + D.LabelExclude: "true", + "proxy.a.no_tls_verify": "true", + }}, "")) + ExpectNoError(t, err.Error()) + + _, ok := entries.Load("a") + ExpectFalse(t, ok) +} + +func TestImplicitExclude(t *testing.T) { + var p DockerProvider + entries, err := p.entriesFromContainerLabels(D.FromDocker(&types.Container{ + Names: dummyNames, + Labels: map[string]string{ + D.LabelAliases: "a", + "proxy.a.no_tls_verify": "true", + }}, "")) + ExpectNoError(t, err.Error()) + + _, ok := entries.Load("a") + ExpectFalse(t, ok) +} + +func TestImplicitExcludeNoExposedPort(t *testing.T) { + var p DockerProvider + entries, err := p.entriesFromContainerLabels(D.FromDocker(&types.Container{ + Image: "redis", + Names: []string{"redis"}, + Ports: []types.Port{ + {Type: "tcp", PrivatePort: 6379, PublicPort: 0}, // not exposed + }, + }, "")) + ExpectNoError(t, err.Error()) + + _, ok := entries.Load("redis") + ExpectFalse(t, ok) +} + +func TestExcludeNonExposedPort(t *testing.T) { + var p DockerProvider + entries, err := p.entriesFromContainerLabels(D.FromDocker(&types.Container{ + Image: "redis", + Names: []string{"redis"}, + Ports: []types.Port{ + {Type: "tcp", PrivatePort: 6379, PublicPort: 0}, // not exposed + }, + Labels: map[string]string{ + "proxy.redis.port": "6379:6379", // should be excluded even specified + }, + }, "")) + ExpectNoError(t, err.Error()) + + _, ok := entries.Load("redis") + ExpectFalse(t, ok) +} diff --git a/src/route/route.go b/src/route/route.go index 4ef4a13..0140289 100755 --- a/src/route/route.go +++ b/src/route/route.go @@ -17,15 +17,15 @@ type ( Type() RouteType URL() *url.URL } - Routes = F.Map[string, Route] - RouteType string + Routes = F.Map[string, Route] RouteImpl interface { Start() E.NestedError Stop() E.NestedError String() string } - route struct { + RouteType string + route struct { RouteImpl type_ RouteType entry *M.RawEntry @@ -58,7 +58,10 @@ func NewRoute(en *M.RawEntry) (Route, E.NestedError) { default: panic("bug: should not reach here") } - return &route{RouteImpl: rt.(RouteImpl), entry: en, type_: t}, err + if err != nil { + return nil, err + } + return &route{RouteImpl: rt.(RouteImpl), entry: en, type_: t}, nil } func (rt *route) Entry() *M.RawEntry { diff --git a/src/route/stream_route.go b/src/route/stream_route.go index f160108..8fe6668 100755 --- a/src/route/stream_route.go +++ b/src/route/stream_route.go @@ -14,7 +14,7 @@ import ( ) type StreamRoute struct { - P.StreamEntry + *P.StreamEntry StreamImpl `json:"-"` wg sync.WaitGroup @@ -31,6 +31,7 @@ type StreamImpl interface { Accept() (any, error) Handle(any) error CloseListeners() + String() string } func NewStreamRoute(entry *P.StreamEntry) (*StreamRoute, E.NestedError) { @@ -39,7 +40,7 @@ func NewStreamRoute(entry *P.StreamEntry) (*StreamRoute, E.NestedError) { return nil, E.Unsupported("scheme", fmt.Sprintf("%v -> %v", entry.Scheme.ListeningScheme, entry.Scheme.ProxyScheme)) } base := &StreamRoute{ - StreamEntry: *entry, + StreamEntry: entry, connCh: make(chan any, 100), } if entry.Scheme.ListeningScheme.IsTCP() { @@ -64,6 +65,7 @@ func (r *StreamRoute) Start() E.NestedError { if err := r.Setup(); err != nil { return E.FailWith("setup", err) } + r.l.Infof("listening on port %d", r.Port.ListeningPort) r.started.Store(true) r.wg.Add(2) go r.grAcceptConnections() diff --git a/src/route/udp_route.go b/src/route/udp_route.go index 3961fb0..41b11e2 100755 --- a/src/route/udp_route.go +++ b/src/route/udp_route.go @@ -52,10 +52,11 @@ func (route *UDPRoute) Setup() error { } //! this read the allocated listeningPort from orginal ':0' - route.Port.ListeningPort = T.Port(laddr.Port) + route.Port.ListeningPort = T.Port(source.LocalAddr().(*net.UDPAddr).Port) route.listeningConn = source route.targetAddr = raddr + return nil } diff --git a/src/utils/testing/testing.go b/src/utils/testing/testing.go index 69b9334..23f8242 100644 --- a/src/utils/testing/testing.go +++ b/src/utils/testing/testing.go @@ -10,6 +10,7 @@ func ExpectNoError(t *testing.T, err error) { t.Helper() if err != nil && !reflect.ValueOf(err).IsNil() { t.Errorf("expected err=nil, got %s", err.Error()) + t.FailNow() } } @@ -17,6 +18,7 @@ func ExpectError(t *testing.T, expected error, err error) { t.Helper() if !errors.Is(err, expected) { t.Errorf("expected err %s, got %s", expected.Error(), err.Error()) + t.FailNow() } } @@ -24,6 +26,7 @@ func ExpectError2(t *testing.T, input any, expected error, err error) { t.Helper() if !errors.Is(err, expected) { t.Errorf("%v: expected err %s, got %s", input, expected.Error(), err.Error()) + t.FailNow() } } @@ -31,6 +34,7 @@ func ExpectEqual[T comparable](t *testing.T, got T, want T) { t.Helper() if got != want { t.Errorf("expected:\n%v, got\n%v", want, got) + t.FailNow() } } @@ -38,29 +42,34 @@ func ExpectDeepEqual[T any](t *testing.T, got T, want T) { t.Helper() if !reflect.DeepEqual(got, want) { t.Errorf("expected:\n%v, got\n%v", want, got) + t.FailNow() } } func ExpectTrue(t *testing.T, got bool) { t.Helper() if !got { - t.Errorf("expected true, got false") + t.Error("expected true") + t.FailNow() } } func ExpectFalse(t *testing.T, got bool) { t.Helper() if got { - t.Errorf("expected false, got true") + t.Error("expected false") + t.FailNow() } } -func ExpectType[T any](t *testing.T, got any) T { +func ExpectType[T any](t *testing.T, got any) (_ T) { t.Helper() tExpect := reflect.TypeFor[T]() _, ok := got.(T) if !ok { - t.Errorf("expected type %s, got %T", tExpect, got) + t.Fatalf("expected type %s, got %s", tExpect, reflect.TypeOf(got).Elem()) + t.FailNow() + return } return got.(T) }