smarter scheme and port detection

This commit is contained in:
yusing 2024-09-23 20:16:38 +08:00
parent 8e2cc56afb
commit 71e8e4a462
14 changed files with 294 additions and 132 deletions

View file

@ -13,9 +13,10 @@ type ProxyProperties struct {
DockerHost string `yaml:"-" json:"docker_host"`
ContainerName string `yaml:"-" json:"container_name"`
ImageName string `yaml:"-" json:"image_name"`
PublicPortMapping PortMapping `yaml:"-" json:"public_port_mapping"` // non-zero publicPort:types.Port
PrivatePortMapping PortMapping `yaml:"-" json:"private_port_mapping"` // privatePort:types.Port
Aliases []string `yaml:"-" json:"aliases"`
IsExcluded bool `yaml:"-" json:"is_excluded"`
FirstPort string `yaml:"-" json:"first_port"`
IdleTimeout string `yaml:"-" json:"idle_timeout"`
WakeTimeout string `yaml:"-" json:"wake_timeout"`
StopMethod string `yaml:"-" json:"stop_method"`
@ -29,15 +30,18 @@ type Container struct {
*ProxyProperties
}
type PortMapping = map[string]types.Port
func FromDocker(c *types.Container, dockerHost string) (res Container) {
res.Container = c
res.ProxyProperties = &ProxyProperties{
DockerHost: dockerHost,
ContainerName: res.getName(),
ImageName: res.getImageName(),
PublicPortMapping: res.getPublicPortMapping(),
PrivatePortMapping: res.getPrivatePortMapping(),
Aliases: res.getAliases(),
IsExcluded: U.ParseBool(res.getDeleteLabel(LableExclude)),
FirstPort: res.firstPortOrEmpty(),
IsExcluded: U.ParseBool(res.getDeleteLabel(LabelExclude)),
IdleTimeout: res.getDeleteLabel(LabelIdleTimeout),
WakeTimeout: res.getDeleteLabel(LabelWakeTimeout),
StopMethod: res.getDeleteLabel(LabelStopMethod),
@ -81,7 +85,7 @@ func (c Container) getDeleteLabel(label string) string {
}
func (c Container) getAliases() []string {
if l := c.getDeleteLabel(LableAliases); l != "" {
if l := c.getDeleteLabel(LabelAliases); l != "" {
return U.CommaSeperatedList(l)
} else {
return []string{c.getName()}
@ -98,14 +102,24 @@ func (c Container) getImageName() string {
return slashSep[len(slashSep)-1]
}
func (c Container) firstPortOrEmpty() string {
if len(c.Ports) == 0 {
return ""
func (c Container) getPublicPortMapping() PortMapping {
res := make(PortMapping)
for _, v := range c.Ports {
if v.PublicPort == 0 {
continue
}
for _, p := range c.Ports {
if p.PublicPort != 0 {
return fmt.Sprint(p.PublicPort)
res[fmt.Sprint(v.PublicPort)] = v
}
return res
}
return ""
func (c Container) getPrivatePortMapping() PortMapping {
res := make(PortMapping)
for _, v := range c.Ports {
if v.PublicPort == 0 {
continue
}
res[fmt.Sprint(v.PrivatePort)] = v
}
return res
}

View file

@ -3,8 +3,8 @@ package docker
const (
WildcardAlias = "*"
LableAliases = NSProxy + ".aliases"
LableExclude = NSProxy + ".exclude"
LabelAliases = NSProxy + ".aliases"
LabelExclude = NSProxy + ".exclude"
LabelIdleTimeout = NSProxy + ".idle_timeout"
LabelWakeTimeout = NSProxy + ".wake_timeout"
LabelStopMethod = NSProxy + ".stop_method"

View file

@ -118,9 +118,9 @@ func (ne NestedError) With(s any) NestedError {
case string:
msg = ss
case fmt.Stringer:
msg = ss.String()
return ne.append(ss.String())
default:
msg = fmt.Sprint(s)
return ne.append(fmt.Sprint(s))
}
return ne.withError(From(errors.New(msg)))
}
@ -201,6 +201,14 @@ func (ne NestedError) withError(err NestedError) NestedError {
return ne
}
func (ne NestedError) append(msg string) NestedError {
if ne == nil {
return nil
}
ne.err = fmt.Errorf("%w %s", ne.err, msg)
return ne
}
func (ne NestedError) writeToSB(sb *strings.Builder, level int, prefix string) {
for i := 0; i < level; i++ {
sb.WriteString(" ")

View file

@ -1,6 +1,7 @@
package model
import (
"fmt"
"strconv"
"strings"
@ -53,11 +54,17 @@ func (e *RawEntry) FillMissingFields() bool {
}
}
if e.Port == "" {
if e.FirstPort == "" {
if e.PublicPortMapping != nil {
if _, ok := e.PublicPortMapping[e.Port]; !ok { // port is not exposed, but specified
// try to fallback to first public port
if len(e.PublicPortMapping) == 0 {
return false
}
e.Port = e.FirstPort
for _, p := range e.PublicPortMapping {
e.Port = fmt.Sprint(p.PublicPort)
break
}
}
}
if e.Scheme == "" {
@ -69,8 +76,19 @@ func (e *RawEntry) FillMissingFields() bool {
e.Scheme = "http"
} else if e.Port == "443" {
e.Scheme = "https"
} else if isDocker && e.Port == "" {
} else if isDocker {
if e.Port == "" {
return false
}
if p, ok := e.PublicPortMapping[e.Port]; ok {
if p.Type == "udp" {
e.Scheme = "udp"
} else {
e.Scheme = "http"
}
} else {
return false
}
} else {
e.Scheme = "http"
}

View file

@ -1,16 +0,0 @@
package proxy
const (
StreamType_UDP string = "udp"
StreamType_TCP string = "tcp"
// StreamType_UDP_TCP Scheme = "udp-tcp"
// StreamType_TCP_UDP Scheme = "tcp-udp"
// StreamType_TLS Scheme = "tls"
)
var (
// TODO: support "tcp-udp", "udp-tcp", etc.
StreamSchemes = []string{StreamType_TCP, StreamType_UDP}
HTTPSchemes = []string{"http", "https"}
ValidSchemes = append(StreamSchemes, HTTPSchemes...)
)

View file

@ -28,6 +28,10 @@ func (p Port) inBound() bool {
return p >= MinPort && p <= MaxPort
}
func (p Port) String() string {
return strconv.Itoa(int(p))
}
const (
MinPort = 0
MaxPort = 65535

View file

@ -37,7 +37,7 @@ func ValidateStreamPort(p string) (StreamPort, E.NestedError) {
} else if err != nil {
proxyPort, err = parseNameToPort(split[1])
if err != nil {
return ErrStreamPort, err
return ErrStreamPort, E.Invalid("stream port", p).With(proxyPort)
}
}

View file

@ -4,7 +4,7 @@ import (
"testing"
E "github.com/yusing/go-proxy/error"
U "github.com/yusing/go-proxy/utils/testing"
. "github.com/yusing/go-proxy/utils/testing"
)
var validPorts = []string{
@ -35,14 +35,14 @@ var outOfRangePorts = []string{
func TestStreamPort(t *testing.T) {
for _, port := range validPorts {
_, err := ValidateStreamPort(port)
U.ExpectNoError(t, err.Error())
ExpectNoError(t, err.Error())
}
for _, port := range invalidPorts {
_, err := ValidateStreamPort(port)
U.ExpectError2(t, port, E.ErrInvalid, err.Error())
ExpectError2(t, port, E.ErrInvalid, err.Error())
}
for _, port := range outOfRangePorts {
_, err := ValidateStreamPort(port)
U.ExpectError2(t, port, E.ErrOutOfRange, err.Error())
ExpectError2(t, port, E.ErrOutOfRange, err.Error())
}
}

View file

@ -137,31 +137,29 @@ func (p *DockerProvider) entriesFromContainerLabels(container D.Container) (M.Ra
errors.Add(p.applyLabel(container, entries, key, val))
}
// selecting correct host port
replacePrivPorts := func() {
if container.HostConfig.NetworkMode != "host" {
entries.RangeAll(func(_ string, entry *M.RawEntry) {
entryPortSplit := strings.Split(entry.Port, ":")
n := len(entryPortSplit)
// if the port matches the proxy port, replace it with the public port
if p, ok := container.PrivatePortMapping[entryPortSplit[n-1]]; ok {
entryPortSplit[n-1] = fmt.Sprint(p.PublicPort)
entry.Port = strings.Join(entryPortSplit, ":")
}
})
}
}
replacePrivPorts()
// remove all entries that failed to fill in missing fields
entries.RemoveAll(func(re *M.RawEntry) bool {
return !re.FillMissingFields()
})
// selecting correct host port
if container.HostConfig.NetworkMode != "host" {
for _, a := range container.Aliases {
entry, ok := entries.Load(a)
if !ok {
continue
}
for _, p := range container.Ports {
containerPort := strconv.Itoa(int(p.PrivatePort))
publicPort := strconv.Itoa(int(p.PublicPort))
entryPortSplit := strings.Split(entry.Port, ":")
n := len(entryPortSplit)
if entryPortSplit[n-1] == containerPort {
entryPortSplit[n-1] = publicPort
entry.Port = strings.Join(entryPortSplit, ":")
break
}
}
}
}
// do it again since the port may got filled in
replacePrivPorts()
return entries, errors.Build().Subject(container.ContainerName)
}
@ -193,7 +191,7 @@ func (p *DockerProvider) applyLabel(container D.Container, entries M.RawEntries,
return ref
}
if index < 1 || index > len(container.Aliases) {
refErr.Add(E.Invalid("index", ref).Extraf("index out of range"))
refErr.Add(E.OutOfRange("index", ref))
return ref
}
return container.Aliases[index-1]

View file

@ -8,15 +8,12 @@ import (
"github.com/yusing/go-proxy/common"
D "github.com/yusing/go-proxy/docker"
E "github.com/yusing/go-proxy/error"
F "github.com/yusing/go-proxy/utils/functional"
P "github.com/yusing/go-proxy/proxy"
T "github.com/yusing/go-proxy/proxy/fields"
. "github.com/yusing/go-proxy/utils/testing"
)
func get[KT comparable, VT any](m F.Map[KT, VT], key KT) VT {
v, _ := m.Load(key)
return v
}
var dummyNames = []string{"/a"}
func TestApplyLabelFieldValidity(t *testing.T) {
@ -48,10 +45,10 @@ X_Custom_Header2: value3
"X-Custom-Header2",
}
var p DockerProvider
var c = D.FromDocker(&types.Container{
entries, err := p.entriesFromContainerLabels(D.FromDocker(&types.Container{
Names: dummyNames,
Labels: map[string]string{
D.LableAliases: "a,b",
D.LabelAliases: "a,b",
D.LabelIdleTimeout: common.IdleTimeoutDefault,
D.LabelStopMethod: common.StopMethodDefault,
D.LabelStopSignal: "SIGTERM",
@ -65,11 +62,16 @@ X_Custom_Header2: value3
"proxy.a.path_patterns": pathPatterns,
"proxy.a.set_headers": setHeaders,
"proxy.a.hide_headers": hideHeaders,
}}, "")
entries, err := p.entriesFromContainerLabels(c)
},
Ports: []types.Port{
{Type: "tcp", PrivatePort: 4567, PublicPort: 8888},
}}, ""))
ExpectNoError(t, err.Error())
a := get(entries, "a")
b := get(entries, "b")
a, ok := entries.Load("a")
ExpectTrue(t, ok)
b, ok := entries.Load("b")
ExpectTrue(t, ok)
ExpectEqual(t, a.Scheme, "https")
ExpectEqual(t, b.Scheme, "https")
@ -77,8 +79,8 @@ X_Custom_Header2: value3
ExpectEqual(t, a.Host, "app")
ExpectEqual(t, b.Host, "app")
ExpectEqual(t, a.Port, "4567")
ExpectEqual(t, b.Port, "4567")
ExpectEqual(t, a.Port, "8888")
ExpectEqual(t, b.Port, "8888")
ExpectTrue(t, a.NoTLSVerify)
ExpectTrue(t, b.NoTLSVerify)
@ -110,38 +112,68 @@ X_Custom_Header2: value3
func TestApplyLabel(t *testing.T) {
var p DockerProvider
var c = D.FromDocker(&types.Container{
entries, err := p.entriesFromContainerLabels(D.FromDocker(&types.Container{
Names: dummyNames,
Labels: map[string]string{
D.LableAliases: "a,b,c",
D.LabelAliases: "a,b,c",
"proxy.a.no_tls_verify": "true",
"proxy.a.port": "3333",
"proxy.b.port": "1234",
"proxy.c.scheme": "https",
}}, "")
entries, err := p.entriesFromContainerLabels(c)
},
Ports: []types.Port{
{Type: "tcp", PrivatePort: 3333, PublicPort: 1111},
{Type: "tcp", PrivatePort: 4444, PublicPort: 1234},
}}, "",
))
a, ok := entries.Load("a")
ExpectTrue(t, ok)
b, ok := entries.Load("b")
ExpectTrue(t, ok)
c, ok := entries.Load("c")
ExpectTrue(t, ok)
ExpectNoError(t, err.Error())
ExpectEqual(t, get(entries, "a").NoTLSVerify, true)
ExpectEqual(t, get(entries, "b").Port, "1234")
ExpectEqual(t, get(entries, "c").Scheme, "https")
ExpectEqual(t, a.Scheme, "http")
ExpectEqual(t, a.Port, "1111")
ExpectEqual(t, a.NoTLSVerify, true)
ExpectEqual(t, b.Scheme, "http")
ExpectEqual(t, b.Port, "1234")
ExpectEqual(t, c.Scheme, "https")
ExpectEqual(t, c.Port, "1111")
}
func TestApplyLabelWithRef(t *testing.T) {
var p DockerProvider
var c = D.FromDocker(&types.Container{
entries, err := p.entriesFromContainerLabels(D.FromDocker(&types.Container{
Names: dummyNames,
Labels: map[string]string{
D.LableAliases: "a,b,c",
D.LabelAliases: "a,b,c",
"proxy.$1.host": "localhost",
"proxy.*.port": "1111",
"proxy.$1.port": "4444",
"proxy.$2.port": "1234",
"proxy.$2.port": "9999",
"proxy.$3.scheme": "https",
}}, "")
entries, err := p.entriesFromContainerLabels(c)
},
Ports: []types.Port{
{Type: "tcp", PrivatePort: 3333, PublicPort: 9999},
{Type: "tcp", PrivatePort: 4444, PublicPort: 5555},
{Type: "tcp", PrivatePort: 1111, PublicPort: 2222},
}}, ""))
a, ok := entries.Load("a")
ExpectTrue(t, ok)
b, ok := entries.Load("b")
ExpectTrue(t, ok)
c, ok := entries.Load("c")
ExpectTrue(t, ok)
ExpectNoError(t, err.Error())
ExpectEqual(t, get(entries, "a").Host, "localhost")
ExpectEqual(t, get(entries, "b").Port, "1234")
ExpectEqual(t, get(entries, "c").Scheme, "https")
ExpectEqual(t, a.Scheme, "http")
ExpectEqual(t, a.Host, "localhost")
ExpectEqual(t, a.Port, "5555")
ExpectEqual(t, b.Port, "9999")
ExpectEqual(t, c.Scheme, "https")
ExpectEqual(t, c.Port, "2222")
}
func TestApplyLabelWithRefIndexError(t *testing.T) {
@ -149,21 +181,110 @@ func TestApplyLabelWithRefIndexError(t *testing.T) {
var c = D.FromDocker(&types.Container{
Names: dummyNames,
Labels: map[string]string{
D.LableAliases: "a,b",
D.LabelAliases: "a,b",
"proxy.$1.host": "localhost",
"proxy.$4.scheme": "https",
}}, "")
_, err := p.entriesFromContainerLabels(c)
ExpectError(t, E.ErrInvalid, err.Error())
ExpectError(t, E.ErrOutOfRange, err.Error())
ExpectTrue(t, strings.Contains(err.String(), "index out of range"))
c = D.FromDocker(&types.Container{
_, err = p.entriesFromContainerLabels(D.FromDocker(&types.Container{
Names: dummyNames,
Labels: map[string]string{
D.LableAliases: "a,b",
D.LabelAliases: "a,b",
"proxy.$0.host": "localhost",
}}, "")
_, err = p.entriesFromContainerLabels(c)
ExpectError(t, E.ErrInvalid, err.Error())
}}, ""))
ExpectError(t, E.ErrOutOfRange, err.Error())
ExpectTrue(t, strings.Contains(err.String(), "index out of range"))
}
func TestStreamDefaultValues(t *testing.T) {
var p DockerProvider
var c = D.FromDocker(&types.Container{
Names: dummyNames,
Labels: map[string]string{
D.LabelAliases: "a",
"proxy.*.no_tls_verify": "true",
},
Ports: []types.Port{
{Type: "udp", PrivatePort: 1234, PublicPort: 5678},
}}, "",
)
entries, err := p.entriesFromContainerLabels(c)
ExpectNoError(t, err.Error())
raw, ok := entries.Load("a")
ExpectTrue(t, ok)
entry, err := P.ValidateEntry(raw)
ExpectNoError(t, err.Error())
a := ExpectType[*P.StreamEntry](t, entry)
ExpectEqual(t, a.Scheme.ListeningScheme, T.Scheme("udp"))
ExpectEqual(t, a.Scheme.ProxyScheme, T.Scheme("udp"))
ExpectEqual(t, a.Port.ListeningPort, 0)
ExpectEqual(t, a.Port.ProxyPort, 5678)
}
func TestExplicitExclude(t *testing.T) {
var p DockerProvider
entries, err := p.entriesFromContainerLabels(D.FromDocker(&types.Container{
Names: dummyNames,
Labels: map[string]string{
D.LabelAliases: "a",
D.LabelExclude: "true",
"proxy.a.no_tls_verify": "true",
}}, ""))
ExpectNoError(t, err.Error())
_, ok := entries.Load("a")
ExpectFalse(t, ok)
}
func TestImplicitExclude(t *testing.T) {
var p DockerProvider
entries, err := p.entriesFromContainerLabels(D.FromDocker(&types.Container{
Names: dummyNames,
Labels: map[string]string{
D.LabelAliases: "a",
"proxy.a.no_tls_verify": "true",
}}, ""))
ExpectNoError(t, err.Error())
_, ok := entries.Load("a")
ExpectFalse(t, ok)
}
func TestImplicitExcludeNoExposedPort(t *testing.T) {
var p DockerProvider
entries, err := p.entriesFromContainerLabels(D.FromDocker(&types.Container{
Image: "redis",
Names: []string{"redis"},
Ports: []types.Port{
{Type: "tcp", PrivatePort: 6379, PublicPort: 0}, // not exposed
},
}, ""))
ExpectNoError(t, err.Error())
_, ok := entries.Load("redis")
ExpectFalse(t, ok)
}
func TestExcludeNonExposedPort(t *testing.T) {
var p DockerProvider
entries, err := p.entriesFromContainerLabels(D.FromDocker(&types.Container{
Image: "redis",
Names: []string{"redis"},
Ports: []types.Port{
{Type: "tcp", PrivatePort: 6379, PublicPort: 0}, // not exposed
},
Labels: map[string]string{
"proxy.redis.port": "6379:6379", // should be excluded even specified
},
}, ""))
ExpectNoError(t, err.Error())
_, ok := entries.Load("redis")
ExpectFalse(t, ok)
}

View file

@ -18,13 +18,13 @@ type (
URL() *url.URL
}
Routes = F.Map[string, Route]
RouteType string
RouteImpl interface {
Start() E.NestedError
Stop() E.NestedError
String() string
}
RouteType string
route struct {
RouteImpl
type_ RouteType
@ -58,7 +58,10 @@ func NewRoute(en *M.RawEntry) (Route, E.NestedError) {
default:
panic("bug: should not reach here")
}
return &route{RouteImpl: rt.(RouteImpl), entry: en, type_: t}, err
if err != nil {
return nil, err
}
return &route{RouteImpl: rt.(RouteImpl), entry: en, type_: t}, nil
}
func (rt *route) Entry() *M.RawEntry {

View file

@ -14,7 +14,7 @@ import (
)
type StreamRoute struct {
P.StreamEntry
*P.StreamEntry
StreamImpl `json:"-"`
wg sync.WaitGroup
@ -31,6 +31,7 @@ type StreamImpl interface {
Accept() (any, error)
Handle(any) error
CloseListeners()
String() string
}
func NewStreamRoute(entry *P.StreamEntry) (*StreamRoute, E.NestedError) {
@ -39,7 +40,7 @@ func NewStreamRoute(entry *P.StreamEntry) (*StreamRoute, E.NestedError) {
return nil, E.Unsupported("scheme", fmt.Sprintf("%v -> %v", entry.Scheme.ListeningScheme, entry.Scheme.ProxyScheme))
}
base := &StreamRoute{
StreamEntry: *entry,
StreamEntry: entry,
connCh: make(chan any, 100),
}
if entry.Scheme.ListeningScheme.IsTCP() {
@ -64,6 +65,7 @@ func (r *StreamRoute) Start() E.NestedError {
if err := r.Setup(); err != nil {
return E.FailWith("setup", err)
}
r.l.Infof("listening on port %d", r.Port.ListeningPort)
r.started.Store(true)
r.wg.Add(2)
go r.grAcceptConnections()

View file

@ -52,10 +52,11 @@ func (route *UDPRoute) Setup() error {
}
//! this read the allocated listeningPort from orginal ':0'
route.Port.ListeningPort = T.Port(laddr.Port)
route.Port.ListeningPort = T.Port(source.LocalAddr().(*net.UDPAddr).Port)
route.listeningConn = source
route.targetAddr = raddr
return nil
}

View file

@ -10,6 +10,7 @@ func ExpectNoError(t *testing.T, err error) {
t.Helper()
if err != nil && !reflect.ValueOf(err).IsNil() {
t.Errorf("expected err=nil, got %s", err.Error())
t.FailNow()
}
}
@ -17,6 +18,7 @@ func ExpectError(t *testing.T, expected error, err error) {
t.Helper()
if !errors.Is(err, expected) {
t.Errorf("expected err %s, got %s", expected.Error(), err.Error())
t.FailNow()
}
}
@ -24,6 +26,7 @@ func ExpectError2(t *testing.T, input any, expected error, err error) {
t.Helper()
if !errors.Is(err, expected) {
t.Errorf("%v: expected err %s, got %s", input, expected.Error(), err.Error())
t.FailNow()
}
}
@ -31,6 +34,7 @@ func ExpectEqual[T comparable](t *testing.T, got T, want T) {
t.Helper()
if got != want {
t.Errorf("expected:\n%v, got\n%v", want, got)
t.FailNow()
}
}
@ -38,29 +42,34 @@ func ExpectDeepEqual[T any](t *testing.T, got T, want T) {
t.Helper()
if !reflect.DeepEqual(got, want) {
t.Errorf("expected:\n%v, got\n%v", want, got)
t.FailNow()
}
}
func ExpectTrue(t *testing.T, got bool) {
t.Helper()
if !got {
t.Errorf("expected true, got false")
t.Error("expected true")
t.FailNow()
}
}
func ExpectFalse(t *testing.T, got bool) {
t.Helper()
if got {
t.Errorf("expected false, got true")
t.Error("expected false")
t.FailNow()
}
}
func ExpectType[T any](t *testing.T, got any) T {
func ExpectType[T any](t *testing.T, got any) (_ T) {
t.Helper()
tExpect := reflect.TypeFor[T]()
_, ok := got.(T)
if !ok {
t.Errorf("expected type %s, got %T", tExpect, got)
t.Fatalf("expected type %s, got %s", tExpect, reflect.TypeOf(got).Elem())
t.FailNow()
return
}
return got.(T)
}