added support for a few middlewares, added match_domain option, changed index reference prefix from $ to #, etc.

This commit is contained in:
yusing 2024-09-27 09:57:57 +08:00
parent 345a4417a6
commit f474ae4f75
47 changed files with 1523 additions and 446 deletions

View file

@ -130,3 +130,18 @@ jobs:
run: |
docker tag ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:${{ steps.meta.outputs.version }} ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:latest
docker push ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:latest
scan:
runs-on: ubuntu-latest
needs:
- merge
steps:
- name: Scan Image with Trivy
uses: aquasecurity/trivy-action@0.20.0
with:
image-ref: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:latest
format: "sarif"
output: "trivy-results.sarif"
- name: Upload Trivy SARIF Report
uses: github/codeql-action/upload-sarif@v2
with:
sarif_file: "trivy-results.sarif"

View file

@ -9,15 +9,15 @@ COPY src/go.mod src/go.sum ./
# Utilize build cache
RUN --mount=type=cache,target="/go/pkg/mod" \
go mod download
go mod graph | awk '{if ($1 !~ "@") print $2}' | xargs go get
# Now copy the remaining files
COPY src/ ./
ENV GOCACHE=/root/.cache/go-build
# Build the application with better caching
RUN --mount=type=cache,target="/go/pkg/mod" \
--mount=type=cache,target="/root/.cache/go-build" \
CGO_ENABLED=0 GOOS=linux go build -pgo=auto -o go-proxy ./
--mount=type=bind,src=src,dst=/src \
CGO_ENABLED=0 GOOS=linux go build -ldflags '-w -s' -pgo=auto -o /go-proxy .
# Stage 2: Final image
FROM scratch
@ -28,7 +28,7 @@ LABEL maintainer="yusing@6uo.me"
COPY --from=builder /usr/share/zoneinfo /usr/share/zoneinfo
# copy binary
COPY --from=builder /src/go-proxy /app/
COPY --from=builder /go-proxy /app/
# copy schema directory
COPY schema/ /app/schema/

View file

@ -1,4 +1,4 @@
.PHONY: all build up quick-restart restart logs get udp-server
.PHONY: all setup build test up restart logs get debug run archive repush rapid-crash debug-list-containers
all: debug
@ -9,7 +9,8 @@ setup:
build:
mkdir -p bin
CGO_ENABLED=0 GOOS=linux go build -pgo=auto -o bin/go-proxy github.com/yusing/go-proxy
CGO_ENABLED=0 GOOS=linux \
go build -ldflags '${BUILD_FLAG}' -pgo=auto -o bin/go-proxy github.com/yusing/go-proxy
test:
go test ./src/...
@ -29,6 +30,9 @@ get:
debug:
make build && sudo GOPROXY_DEBUG=1 bin/go-proxy
run:
BUILD_FLAG="-s -w" make build && sudo bin/go-proxy
archive:
git archive HEAD -o ../go-proxy-$$(date +"%Y%m%d%H%M").zip
@ -44,4 +48,4 @@ rapid-crash:
sudo docker rm -f test_crash
debug-list-containers:
bash -c 'echo -e "GET /containers/json HTTP/1.0\r\n" | sudo netcat -U /var/run/docker.sock | tail -n +9 | jq'
bash -c 'echo -e "GET /containers/json HTTP/1.0\r\n" | sudo netcat -U /var/run/docker.sock | tail -n +9 | jq'

View file

@ -1,36 +1,64 @@
# Autocert (choose one below and uncomment to enable)
#
# 1. use existing cert
#
# autocert:
# provider: local
# cert_path: certs/cert.crt # optional, uncomment only if you need to change it
# key_path: certs/priv.key # optional, uncomment only if you need to change it
#
# cert_path: certs/cert.crt # optional, uncomment only if you need to change it
# key_path: certs/priv.key # optional, uncomment only if you need to change it
#
# 2. cloudflare
#
# autocert:
# provider: cloudflare
# email: # ACME Email
# domains: # a list of domains for cert registration
# - x.y.z
# email: abc@gmail.com # ACME Email
# domains: # a list of domains for cert registration
# - "*.y.z" # remember to use double quotes to surround wildcard domain
# options:
# auth_token: c1234565789-abcdefghijklmnopqrst # your zone API token
# auth_token: c1234565789-abcdefghijklmnopqrst # your zone API token
#
# 3. other providers, check docs/dns_providers.md for more
providers:
# include files are standalone yaml files under `config/` directory
#
# include:
# - providers.yml # config/providers.yml
# # add some more below if you want
# - file1.yml # config/file_1.yml
# - file1.yml
# - file2.yml
docker:
# for value format, see https://docs.docker.com/reference/cli/dockerd/
# $DOCKER_HOST implies unix:///var/run/docker.sock by default
# $DOCKER_HOST implies environment variable `DOCKER_HOST` or unix:///var/run/docker.sock by default
local: $DOCKER_HOST
# add more docker providers if needed
# for value format, see https://docs.docker.com/reference/cli/dockerd/
#
# remote-1: tcp://10.0.2.1:2375
# remote-2: ssh://root:1234@10.0.2.2
# Fixed options (optional, non hot-reloadable)
# if match_domains not defined
# any host = alias+[any domain] will match
# i.e. https://app1.y.z will match alias app1 for any domain y.z
# but https://app1.node1.y.z will only match alias "app.node1"
#
# if match_domains defined
# only host = alias+[one of match_domains] will match
# i.e. match_domains = [node1.my.app, my.site]
# https://app1.my.app, https://app1.my.net, etc. will not match even if app1 exists
# only https://*.node1.my.app and https://*.my.site will match
#
#
# match_domains:
# - my.site
# - node1.my.app
# Below are fixed options (non hot-reloadable)
# timeout for shutdown (in seconds)
#
# timeout_shutdown: 5
# redirect_to_https: false # redirect http requests to https (if enabled)
# global setting redirect http requests to https (if https available, otherwise this will be ignored)
# proxy.<alias>.middlewares.redirect_http will override this
#
# redirect_to_https: false

View file

@ -74,7 +74,7 @@
| `proxy.stop_timeout` | time to wait for stop command | | `10s` | `number[unit]...` |
| `proxy.stop_signal` | signal sent to container for `stop` and `kill` methods | | docker's default | `SIGINT`, `SIGTERM`, `SIGHUP`, `SIGQUIT` and those without **SIG** prefix |
| `proxy.<alias>.<field>` | set field for specific alias | `proxy.gitlab-ssh.scheme` | N/A | N/A |
| `proxy.$<index>.<field>` | set field for specific alias at index (starting from **1**) | `proxy.$3.port` | N/A | N/A |
| `proxy.#<index>.<field>` | set field for specific alias at index (starting from **1**) | `proxy.#3.port` | N/A | N/A |
| `proxy.*.<field>` | set field for all aliases | `proxy.*.set_headers` | N/A | N/A |
### Fields

View file

@ -37,7 +37,13 @@
"title": "DNS Challenge Provider",
"default": "local",
"type": "string",
"enum": ["local", "cloudflare", "clouddns", "duckdns", "ovh"]
"enum": [
"local",
"cloudflare",
"clouddns",
"duckdns",
"ovh"
]
},
"options": {
"title": "Provider specific options",
@ -56,7 +62,12 @@
}
},
"then": {
"required": ["email", "domains", "provider", "options"]
"required": [
"email",
"domains",
"provider",
"options"
]
}
},
{
@ -70,7 +81,9 @@
"then": {
"properties": {
"options": {
"required": ["auth_token"],
"required": [
"auth_token"
],
"additionalProperties": false,
"properties": {
"auth_token": {
@ -93,7 +106,11 @@
"then": {
"properties": {
"options": {
"required": ["client_id", "email", "password"],
"required": [
"client_id",
"email",
"password"
],
"additionalProperties": false,
"properties": {
"client_id": {
@ -124,7 +141,9 @@
"then": {
"properties": {
"options": {
"required": ["token"],
"required": [
"token"
],
"additionalProperties": false,
"properties": {
"token": {
@ -147,14 +166,21 @@
"then": {
"properties": {
"options": {
"required": ["application_secret", "consumer_key"],
"required": [
"application_secret",
"consumer_key"
],
"additionalProperties": false,
"oneOf": [
{
"required": ["application_key"]
"required": [
"application_key"
]
},
{
"required": ["oauth2_config"]
"required": [
"oauth2_config"
]
}
],
"properties": {
@ -205,7 +231,10 @@
"type": "string"
}
},
"required": ["client_id", "client_secret"]
"required": [
"client_id",
"client_secret"
]
}
}
}
@ -268,6 +297,14 @@
}
}
},
"match_domains": {
"title": "Domains to match",
"type": "array",
"items": {
"type": "string"
},
"minItems": 1
},
"timeout_shutdown": {
"title": "Shutdown timeout (in seconds)",
"type": "integer",
@ -279,5 +316,7 @@
}
},
"additionalProperties": false,
"required": ["providers"]
}
"required": [
"providers"
]
}

View file

@ -5,6 +5,7 @@ import (
"crypto/tls"
"crypto/x509"
"os"
"path"
"reflect"
"sort"
"time"
@ -59,8 +60,7 @@ func (p *Provider) ObtainCert() (res E.NestedError) {
defer b.To(&res)
if p.cfg.Provider == ProviderLocal {
b.Addf("provider is set to %q", ProviderLocal).WithSeverity(E.SeverityWarning)
return
return nil
}
if p.client == nil {
@ -191,7 +191,19 @@ func (p *Provider) registerACME() E.NestedError {
}
func (p *Provider) saveCert(cert *certificate.Resource) E.NestedError {
err := os.WriteFile(p.cfg.KeyPath, cert.PrivateKey, 0o600) // -rw-------
//* This should have been done in setup
//* but double check is always a good choice
_, err := os.Stat(path.Dir(p.cfg.CertPath))
if err != nil {
if os.IsNotExist(err) {
if err = os.MkdirAll(path.Dir(p.cfg.CertPath), 0o755); err != nil {
return E.FailWith("create cert directory", err)
}
} else {
return E.FailWith("stat cert directory", err)
}
}
err = os.WriteFile(p.cfg.KeyPath, cert.PrivateKey, 0o600) // -rw-------
if err != nil {
return E.FailWith("write key file", err)
}
@ -227,6 +239,10 @@ func (p *Provider) certState() CertState {
}
func (p *Provider) renewIfNeeded() E.NestedError {
if p.cfg.Provider == ProviderLocal {
return nil
}
switch p.certState() {
case CertStateExpired:
logger.Info("certs expired, renewing")

View file

@ -14,7 +14,7 @@ func (p *Provider) Setup(ctx context.Context) (err E.NestedError) {
}
logger.Debug("obtaining cert due to error loading cert")
if err = p.ObtainCert(); err != nil {
return err.Warn()
return err
}
}

25
src/common/http.go Normal file
View file

@ -0,0 +1,25 @@
package common
import (
"crypto/tls"
"net"
"net/http"
"time"
)
var (
defaultDialer = net.Dialer{
Timeout: 60 * time.Second,
KeepAlive: 60 * time.Second,
}
DefaultTransport = &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: defaultDialer.DialContext,
MaxIdleConnsPerHost: 1000,
}
DefaultTransportNoTLS = func() *http.Transport {
var clone = DefaultTransport.Clone()
clone.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
return clone
}()
)

View file

@ -31,25 +31,48 @@ type Config struct {
reloadReq chan struct{}
}
func Load() (*Config, E.NestedError) {
cfg := &Config{
var instance *Config
func GetConfig() *Config {
return instance
}
func Load() E.NestedError {
if instance != nil {
return nil
}
instance = &Config{
value: M.DefaultConfig(),
proxyProviders: F.NewMapOf[string, *PR.Provider](),
l: logrus.WithField("module", "config"),
watcher: W.NewFileWatcher(common.ConfigFileName),
reloadReq: make(chan struct{}, 1),
}
return cfg, cfg.load()
return instance.load()
}
func Validate(data []byte) E.NestedError {
return U.ValidateYaml(U.GetSchema(common.ConfigSchemaPath), data)
}
func MatchDomains() []string {
if instance == nil {
logrus.Panic("config has not been loaded, please check if there is any errors")
}
return instance.value.MatchDomains
}
func (cfg *Config) Value() M.Config {
if cfg == nil {
logrus.Panic("config has not been loaded, please check if there is any errors")
}
return *cfg.value
}
func (cfg *Config) GetAutoCertProvider() *autocert.Provider {
if instance == nil {
logrus.Panic("config has not been loaded, please check if there is any errors")
}
return cfg.autocertProvider
}
@ -61,13 +84,11 @@ func (cfg *Config) Dispose() {
cfg.stopProviders()
}
func (cfg *Config) Reload() E.NestedError {
func (cfg *Config) Reload() (err E.NestedError) {
cfg.stopProviders()
if err := cfg.load(); err.HasError() {
return err
}
err = cfg.load()
cfg.StartProxyProviders()
return nil
return
}
func (cfg *Config) StartProxyProviders() {
@ -126,28 +147,28 @@ func (cfg *Config) load() (res E.NestedError) {
data, err := E.Check(os.ReadFile(common.ConfigPath))
if err.HasError() {
b.Add(E.FailWith("read config", err))
return
logrus.Fatal(b.Build())
}
if !common.NoSchemaValidation {
if err = Validate(data); err.HasError() {
b.Add(E.FailWith("schema validation", err))
return
logrus.Fatal(b.Build())
}
}
model := M.DefaultConfig()
if err := E.From(yaml.Unmarshal(data, model)); err.HasError() {
b.Add(E.FailWith("parse config", err))
return
logrus.Fatal(b.Build())
}
// errors are non fatal below
b.WithSeverity(E.SeverityWarning)
b.Add(cfg.initAutoCert(&model.AutoCert))
b.Add(cfg.loadProviders(&model.Providers))
cfg.value = model
R.SetFindMuxDomains(model.MatchDomains)
return
}

View file

@ -1,17 +1,37 @@
package docker
import (
"reflect"
"strings"
E "github.com/yusing/go-proxy/error"
U "github.com/yusing/go-proxy/utils"
F "github.com/yusing/go-proxy/utils/functional"
)
type Label struct {
Namespace string
Target string
Attribute string
Value any
/*
Formats:
- namespace.attribute
- namespace.target.attribute
- namespace.target.attribute.namespace2.attribute
*/
type (
Label struct {
Namespace string
Target string
Attribute string
Value any
}
NestedLabelMap map[string]U.SerializedObject
ValueParser func(string) (any, E.NestedError)
ValueParserMap map[string]ValueParser
)
func (l *Label) String() string {
if l.Attribute == "" {
return l.Namespace + "." + l.Target
}
return l.Namespace + "." + l.Target + "." + l.Attribute
}
// Apply applies the value of a Label to the corresponding field in the given object.
@ -23,12 +43,40 @@ type Label struct {
// Returns:
// - error: an error if the field does not exist.
func ApplyLabel[T any](obj *T, l *Label) E.NestedError {
return U.Deserialize(map[string]any{l.Attribute: l.Value}, obj)
if obj == nil {
return E.Invalid("nil object", l)
}
switch nestedLabel := l.Value.(type) {
case *Label:
var field reflect.Value
objType := reflect.TypeFor[T]()
for i := 0; i < reflect.TypeFor[T]().NumField(); i++ {
if objType.Field(i).Tag.Get("yaml") == l.Attribute {
field = reflect.ValueOf(obj).Elem().Field(i)
break
}
}
if !field.IsValid() {
return E.NotExist("field", l.Attribute)
}
dst, ok := field.Interface().(NestedLabelMap)
if !ok {
return E.Invalid("type", field.Type())
}
if dst == nil {
field.Set(reflect.MakeMap(reflect.TypeFor[NestedLabelMap]()))
dst = field.Interface().(NestedLabelMap)
}
if dst[nestedLabel.Namespace] == nil {
dst[nestedLabel.Namespace] = make(U.SerializedObject)
}
dst[nestedLabel.Namespace][nestedLabel.Attribute] = nestedLabel.Value
return nil
default:
return U.Deserialize(U.SerializedObject{l.Attribute: l.Value}, obj)
}
}
type ValueParser func(string) (any, E.NestedError)
type ValueParserMap map[string]ValueParser
func ParseLabel(label string, value string) (*Label, E.NestedError) {
parts := strings.Split(label, ".")
@ -45,14 +93,22 @@ func ParseLabel(label string, value string) (*Label, E.NestedError) {
Value: value,
}
if len(parts) == 3 {
l.Attribute = parts[2]
} else {
switch len(parts) {
case 2:
l.Attribute = l.Target
case 3:
l.Attribute = parts[2]
default:
l.Attribute = parts[2]
nestedLabel, err := ParseLabel(strings.Join(parts[3:], "."), value)
if err.HasError() {
return nil, err
}
l.Value = nestedLabel
}
// find if namespace has value parser
pm, ok := labelValueParserMap[l.Namespace]
pm, ok := valueParserMap.Load(l.Namespace)
if !ok {
return l, nil
}
@ -64,15 +120,28 @@ func ParseLabel(label string, value string) (*Label, E.NestedError) {
// try to parse value
v, err := p(value)
if err.HasError() {
return nil, err
return nil, err.Subject(label)
}
l.Value = v
return l, nil
}
func RegisterNamespace(namespace string, pm ValueParserMap) {
labelValueParserMap[namespace] = pm
valueParserMap.Store(namespace, pm)
}
func GetRegisteredNamespaces() map[string][]string {
r := make(map[string][]string)
valueParserMap.RangeAll(func(ns string, vpm ValueParserMap) {
r[ns] = make([]string, 0, len(vpm))
for attr := range vpm {
r[ns] = append(r[ns], attr)
}
})
return r
}
// namespace:target.attribute -> func(string) (any, error)
var labelValueParserMap = make(map[string]ValueParserMap)
var valueParserMap = F.NewMapOf[string, ValueParserMap]()

View file

@ -7,7 +7,27 @@ import (
"gopkg.in/yaml.v3"
)
func yamlListParser(value string) (any, E.NestedError) {
const (
NSProxy = "proxy"
ProxyAttributePathPatterns = "path_patterns"
ProxyAttributeNoTLSVerify = "no_tls_verify"
ProxyAttributeMiddlewares = "middlewares"
)
var _ = func() int {
RegisterNamespace(NSProxy, ValueParserMap{
ProxyAttributePathPatterns: YamlStringListParser,
ProxyAttributeNoTLSVerify: BoolParser,
})
return 0
}()
func YamlStringListParser(value string) (any, E.NestedError) {
/*
- foo
- bar
- baz
*/
value = strings.TrimSpace(value)
if value == "" {
return []string{}, nil
@ -17,27 +37,36 @@ func yamlListParser(value string) (any, E.NestedError) {
return data, err
}
func yamlStringMappingParser(value string) (any, E.NestedError) {
value = strings.TrimSpace(value)
lines := strings.Split(value, "\n")
h := make(map[string]string)
for _, line := range lines {
parts := strings.SplitN(line, ":", 2)
if len(parts) != 2 {
return nil, E.Invalid("set header statement", line)
}
key := strings.TrimSpace(parts[0])
val := strings.TrimSpace(parts[1])
if existing, ok := h[key]; ok {
h[key] = existing + ", " + val
} else {
h[key] = val
func YamlLikeMappingParser(allowDuplicate bool) func(string) (any, E.NestedError) {
return func(value string) (any, E.NestedError) {
/*
foo: bar
boo: baz
*/
value = strings.TrimSpace(value)
lines := strings.Split(value, "\n")
h := make(map[string]string)
for _, line := range lines {
parts := strings.SplitN(line, ":", 2)
if len(parts) != 2 {
return nil, E.Invalid("syntax", line)
}
key := strings.TrimSpace(parts[0])
val := strings.TrimSpace(parts[1])
if existing, ok := h[key]; ok {
if !allowDuplicate {
return nil, E.Duplicated("key", key)
}
h[key] = existing + ", " + val
} else {
h[key] = val
}
}
return h, nil
}
return h, nil
}
func boolParser(value string) (any, E.NestedError) {
func BoolParser(value string) (any, E.NestedError) {
switch strings.ToLower(value) {
case "true", "yes", "1":
return true, nil
@ -47,15 +76,3 @@ func boolParser(value string) (any, E.NestedError) {
return nil, E.Invalid("boolean value", value)
}
}
const NSProxy = "proxy"
var _ = func() int {
RegisterNamespace(NSProxy, ValueParserMap{
"path_patterns": yamlListParser,
"set_headers": yamlStringMappingParser,
"hide_headers": yamlListParser,
"no_tls_verify": boolParser,
})
return 0
}()

View file

@ -2,8 +2,6 @@ package docker
import (
"fmt"
"reflect"
"strings"
"testing"
E "github.com/yusing/go-proxy/error"
@ -14,21 +12,16 @@ func makeLabel(namespace string, alias string, field string) string {
return fmt.Sprintf("%s.%s.%s", namespace, alias, field)
}
func TestHomePageLabel(t *testing.T) {
func TestParseLabel(t *testing.T) {
alias := "foo"
field := "ip"
v := "bar"
pl, err := ParseLabel(makeLabel(NSHomePage, alias, field), v)
ExpectNoError(t, err.Error())
if pl.Target != alias {
t.Errorf("Expected alias=%s, got %s", alias, pl.Target)
}
if pl.Attribute != field {
t.Errorf("Expected field=%s, got %s", field, pl.Target)
}
if pl.Value != v {
t.Errorf("Expected value=%q, got %s", v, pl.Value)
}
ExpectEqual(t, pl.Namespace, NSHomePage)
ExpectEqual(t, pl.Target, alias)
ExpectEqual(t, pl.Attribute, field)
ExpectEqual(t, pl.Value.(string), v)
}
func TestStringProxyLabel(t *testing.T) {
@ -51,90 +44,63 @@ func TestBoolProxyLabelValid(t *testing.T) {
}
for k, v := range tests {
pl, err := ParseLabel(makeLabel(NSProxy, "foo", "no_tls_verify"), k)
pl, err := ParseLabel(makeLabel(NSProxy, "foo", ProxyAttributeNoTLSVerify), k)
ExpectNoError(t, err.Error())
ExpectEqual(t, pl.Value.(bool), v)
}
}
func TestBoolProxyLabelInvalid(t *testing.T) {
alias := "foo"
field := "no_tls_verify"
_, err := ParseLabel(makeLabel(NSProxy, alias, field), "invalid")
_, err := ParseLabel(makeLabel(NSProxy, "foo", ProxyAttributeNoTLSVerify), "invalid")
if !err.Is(E.ErrInvalid) {
t.Errorf("Expected err InvalidProxyLabel, got %s", err.Error())
}
}
func TestSetHeaderProxyLabelValid(t *testing.T) {
v := `
X-Custom-Header1: foo, bar
X-Custom-Header1: baz
X-Custom-Header2: boo`
v = strings.TrimPrefix(v, "\n")
h := map[string]string{
"X-Custom-Header1": "foo, bar, baz",
"X-Custom-Header2": "boo",
}
// func TestSetHeaderProxyLabelValid(t *testing.T) {
// v := `
// X-Custom-Header1: foo, bar
// X-Custom-Header1: baz
// X-Custom-Header2: boo`
// v = strings.TrimPrefix(v, "\n")
// h := map[string]string{
// "X-Custom-Header1": "foo, bar, baz",
// "X-Custom-Header2": "boo",
// }
pl, err := ParseLabel(makeLabel(NSProxy, "foo", "set_headers"), v)
ExpectNoError(t, err.Error())
hGot := ExpectType[map[string]string](t, pl.Value)
if hGot != nil && !reflect.DeepEqual(h, hGot) {
t.Errorf("Expected %v, got %v", h, hGot)
}
// pl, err := ParseLabel(makeLabel(NSProxy, "foo", ProxyAttributeSetHeaders), v)
// ExpectNoError(t, err.Error())
// hGot := ExpectType[map[string]string](t, pl.Value)
// ExpectFalse(t, hGot == nil)
// ExpectDeepEqual(t, h, hGot)
// }
}
// func TestSetHeaderProxyLabelInvalid(t *testing.T) {
// tests := []string{
// "X-Custom-Header1 = bar",
// "X-Custom-Header1",
// "- X-Custom-Header1",
// }
func TestSetHeaderProxyLabelInvalid(t *testing.T) {
tests := []string{
"X-Custom-Header1 = bar",
"X-Custom-Header1",
"- X-Custom-Header1",
}
for _, v := range tests {
_, err := ParseLabel(makeLabel(NSProxy, "foo", "set_headers"), v)
if !err.Is(E.ErrInvalid) {
t.Errorf("Expected invalid err for %q, got %s", v, err.Error())
}
}
}
func TestHideHeadersProxyLabel(t *testing.T) {
v := `
- X-Custom-Header1
- X-Custom-Header2
- X-Custom-Header3
`
v = strings.TrimPrefix(v, "\n")
pl, err := ParseLabel(makeLabel(NSProxy, "foo", "hide_headers"), v)
ExpectNoError(t, err.Error())
sGot := ExpectType[[]string](t, pl.Value)
sWant := []string{"X-Custom-Header1", "X-Custom-Header2", "X-Custom-Header3"}
if sGot != nil {
ExpectDeepEqual(t, sGot, sWant)
}
}
// func TestCommaSepProxyLabelSingle(t *testing.T) {
// v := "a"
// pl, err := ParseLabel("proxy.aliases", v)
// ExpectNoError(t, err)
// sGot := ExpectType[[]string](t, pl.Value)
// sWant := []string{"a"}
// if sGot != nil {
// ExpectEqual(t, sGot, sWant)
// for _, v := range tests {
// _, err := ParseLabel(makeLabel(NSProxy, "foo", ProxyAttributeSetHeaders), v)
// if !err.Is(E.ErrInvalid) {
// t.Errorf("Expected invalid err for %q, got %s", v, err.Error())
// }
// }
// }
// func TestCommaSepProxyLabelMulti(t *testing.T) {
// v := "X-Custom-Header1, X-Custom-Header2,X-Custom-Header3"
// pl, err := ParseLabel("proxy.aliases", v)
// ExpectNoError(t, err)
// func TestHideHeadersProxyLabel(t *testing.T) {
// v := `
// - X-Custom-Header1
// - X-Custom-Header2
// - X-Custom-Header3
// `
// v = strings.TrimPrefix(v, "\n")
// pl, err := ParseLabel(makeLabel(NSProxy, "foo", ProxyAttributeHideHeaders), v)
// ExpectNoError(t, err.Error())
// sGot := ExpectType[[]string](t, pl.Value)
// sWant := []string{"X-Custom-Header1", "X-Custom-Header2", "X-Custom-Header3"}
// if sGot != nil {
// ExpectEqual(t, sGot, sWant)
// }
// ExpectFalse(t, sGot == nil)
// ExpectDeepEqual(t, sGot, sWant)
// }

85
src/docker/label_test.go Normal file
View file

@ -0,0 +1,85 @@
package docker
import (
"fmt"
"testing"
U "github.com/yusing/go-proxy/utils"
. "github.com/yusing/go-proxy/utils/testing"
)
func TestNestedLabel(t *testing.T) {
mName := "middleware1"
mAttr := "prop1"
v := "value1"
pl, err := ParseLabel(makeLabel(NSProxy, "foo", fmt.Sprintf("%s.%s.%s", ProxyAttributeMiddlewares, mName, mAttr)), v)
ExpectNoError(t, err.Error())
sGot := ExpectType[*Label](t, pl.Value)
ExpectFalse(t, sGot == nil)
ExpectEqual(t, sGot.Namespace, mName)
ExpectEqual(t, sGot.Attribute, mAttr)
}
func TestApplyNestedLabel(t *testing.T) {
entry := new(struct {
Middlewares NestedLabelMap `yaml:"middlewares"`
})
mName := "middleware1"
mAttr := "prop1"
v := "value1"
pl, err := ParseLabel(makeLabel(NSProxy, "foo", fmt.Sprintf("%s.%s.%s", ProxyAttributeMiddlewares, mName, mAttr)), v)
ExpectNoError(t, err.Error())
err = ApplyLabel(entry, pl)
ExpectNoError(t, err.Error())
middleware1, ok := entry.Middlewares[mName]
ExpectTrue(t, ok)
got := ExpectType[string](t, middleware1[mAttr])
ExpectEqual(t, got, v)
}
func TestApplyNestedLabelExisting(t *testing.T) {
mName := "middleware1"
mAttr := "prop1"
v := "value1"
checkAttr := "prop2"
checkV := "value2"
entry := new(struct {
Middlewares NestedLabelMap `yaml:"middlewares"`
})
entry.Middlewares = make(NestedLabelMap)
entry.Middlewares[mName] = make(U.SerializedObject)
entry.Middlewares[mName][checkAttr] = checkV
pl, err := ParseLabel(makeLabel(NSProxy, "foo", fmt.Sprintf("%s.%s.%s", ProxyAttributeMiddlewares, mName, mAttr)), v)
ExpectNoError(t, err.Error())
err = ApplyLabel(entry, pl)
ExpectNoError(t, err.Error())
middleware1, ok := entry.Middlewares[mName]
ExpectTrue(t, ok)
got := ExpectType[string](t, middleware1[mAttr])
ExpectEqual(t, got, v)
// check if prop2 is affected
ExpectFalse(t, middleware1[checkAttr] == nil)
got = ExpectType[string](t, middleware1[checkAttr])
ExpectEqual(t, got, checkV)
}
func TestApplyNestedLabelNoAttr(t *testing.T) {
mName := "middleware1"
v := "value1"
entry := new(struct {
Middlewares NestedLabelMap `yaml:"middlewares"`
})
entry.Middlewares = make(NestedLabelMap)
entry.Middlewares[mName] = make(U.SerializedObject)
pl, err := ParseLabel(makeLabel(NSProxy, "foo", fmt.Sprintf("%s.%s", ProxyAttributeMiddlewares, mName)), v)
ExpectNoError(t, err.Error())
err = ApplyLabel(entry, pl)
ExpectNoError(t, err.Error())
_, ok := entry.Middlewares[mName]
ExpectTrue(t, ok)
}

View file

@ -10,9 +10,8 @@ type Builder struct {
}
type builder struct {
message string
errors []NestedError
severity Severity
message string
errors []NestedError
sync.Mutex
}
@ -40,11 +39,6 @@ func (b Builder) Addf(format string, args ...any) Builder {
return b.Add(errorf(format, args...))
}
func (b Builder) WithSeverity(s Severity) Builder {
b.severity = s
return b
}
// Build builds a NestedError based on the errors collected in the Builder.
//
// If there are no errors in the Builder, it returns a Nil() NestedError.
@ -58,7 +52,7 @@ func (b Builder) Build() NestedError {
} else if len(b.errors) == 1 {
return b.errors[0]
}
return Join(b.message, b.errors...).Severity(b.severity)
return Join(b.message, b.errors...)
}
func (b Builder) To(ptr *NestedError) {

View file

@ -9,17 +9,10 @@ import (
type (
NestedError = *nestedError
nestedError struct {
subject string
err error
extras []nestedError
severity Severity
subject string
err error
extras []nestedError
}
Severity uint8
)
const (
SeverityWarning Severity = iota
SeverityFatal
)
func From(err error) NestedError {
@ -164,22 +157,6 @@ func (ne NestedError) Subjectf(format string, args ...any) NestedError {
return ne
}
func (ne NestedError) Severity(s Severity) NestedError {
if ne == nil {
return ne
}
ne.severity = s
return ne
}
func (ne NestedError) Warn() NestedError {
if ne == nil {
return ne
}
ne.severity = SeverityWarning
return ne
}
func (ne NestedError) NoError() bool {
return ne == nil
}
@ -188,14 +165,6 @@ func (ne NestedError) HasError() bool {
return ne != nil
}
func (ne NestedError) IsFatal() bool {
return ne != nil && ne.severity == SeverityFatal
}
func (ne NestedError) IsWarning() bool {
return ne != nil && ne.severity == SeverityWarning
}
func errorf(format string, args ...any) NestedError {
return From(fmt.Errorf(format, args...))
}

View file

@ -31,11 +31,11 @@ func TestErrorNestedIs(t *testing.T) {
err = Failure("some reason")
ExpectTrue(t, err.Is(ErrFailure))
ExpectFalse(t, err.Is(ErrAlreadyExist))
ExpectFalse(t, err.Is(ErrDuplicated))
err.With(AlreadyExist("something", ""))
err.With(Duplicated("something", ""))
ExpectTrue(t, err.Is(ErrFailure))
ExpectTrue(t, err.Is(ErrAlreadyExist))
ExpectTrue(t, err.Is(ErrDuplicated))
ExpectFalse(t, err.Is(ErrInvalid))
}

View file

@ -5,14 +5,14 @@ import (
)
var (
ErrFailure = stderrors.New("failed")
ErrInvalid = stderrors.New("invalid")
ErrUnsupported = stderrors.New("unsupported")
ErrUnexpected = stderrors.New("unexpected")
ErrNotExists = stderrors.New("does not exist")
ErrMissing = stderrors.New("missing")
ErrAlreadyExist = stderrors.New("already exist")
ErrOutOfRange = stderrors.New("out of range")
ErrFailure = stderrors.New("failed")
ErrInvalid = stderrors.New("invalid")
ErrUnsupported = stderrors.New("unsupported")
ErrUnexpected = stderrors.New("unexpected")
ErrNotExists = stderrors.New("does not exist")
ErrMissing = stderrors.New("missing")
ErrDuplicated = stderrors.New("duplicated")
ErrOutOfRange = stderrors.New("out of range")
)
const fmtSubjectWhat = "%w %v: %q"
@ -53,8 +53,8 @@ func Missing(subject any) NestedError {
return errorf("%w %v", ErrMissing, subject)
}
func AlreadyExist(subject, what any) NestedError {
return errorf("%v %w: %v", subject, ErrAlreadyExist, what)
func Duplicated(subject, what any) NestedError {
return errorf("%w %v: %v", ErrDuplicated, subject, what)
}
func OutOfRange(subject string, value any) NestedError {

View file

@ -17,7 +17,7 @@ require (
require (
github.com/Microsoft/go-winio v0.6.2 // indirect
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
github.com/cloudflare/cloudflare-go v0.104.0 // indirect
github.com/cloudflare/cloudflare-go v0.105.0 // indirect
github.com/containerd/log v0.1.0 // indirect
github.com/distribution/reference v0.6.0 // indirect
github.com/docker/go-connections v0.5.0 // indirect

View file

@ -4,8 +4,8 @@ github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERo
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8=
github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
github.com/cloudflare/cloudflare-go v0.104.0 h1:R/lB0dZupaZbOgibAH/BRrkFbZ6Acn/WsKg2iX2xXuY=
github.com/cloudflare/cloudflare-go v0.104.0/go.mod h1:pfUQ4PIG4ISI0/Mmc21Bp86UnFU0ktmPf3iTgbSL+cM=
github.com/cloudflare/cloudflare-go v0.105.0 h1:yu2IatITLZ4dw7/byzRrlE5DfUvtub0k9CHZ5zBlj90=
github.com/cloudflare/cloudflare-go v0.105.0/go.mod h1:pfUQ4PIG4ISI0/Mmc21Bp86UnFU0ktmPf3iTgbSL+cM=
github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I=
github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=

View file

@ -0,0 +1,96 @@
// Modified from Traefik Labs's MIT-licensed code (https://github.com/traefik/traefik/blob/master/pkg/middlewares/response_modifier.go)
// Copyright (c) 2020-2024 Traefik Labs
package http
import (
"bufio"
"fmt"
"net"
"net/http"
)
type ModifyResponseFunc func(*http.Response) error
type ModifyResponseWriter struct {
w http.ResponseWriter
r *http.Request
headerSent bool
code int
modifier ModifyResponseFunc
modified bool
modifierErr error
}
func NewModifyResponseWriter(w http.ResponseWriter, r *http.Request, f ModifyResponseFunc) *ModifyResponseWriter {
return &ModifyResponseWriter{
w: w,
r: r,
modifier: f,
code: http.StatusOK,
}
}
func (w *ModifyResponseWriter) WriteHeader(code int) {
if w.headerSent {
return
}
if code >= http.StatusContinue && code < http.StatusOK {
w.w.WriteHeader(code)
}
defer func() {
w.headerSent = true
w.code = code
}()
if w.modifier == nil || w.modified {
w.w.WriteHeader(code)
return
}
resp := http.Response{
Header: w.w.Header(),
Request: w.r,
}
if err := w.modifier(&resp); err != nil {
w.modifierErr = err
logger.Errorf("error modifying response: %s", err)
w.w.WriteHeader(http.StatusInternalServerError)
return
}
w.modified = true
w.w.WriteHeader(code)
}
func (w *ModifyResponseWriter) Header() http.Header {
return w.w.Header()
}
func (w *ModifyResponseWriter) Write(b []byte) (int, error) {
w.WriteHeader(w.code)
if w.modifierErr != nil {
return 0, w.modifierErr
}
return w.w.Write(b)
}
// Hijack hijacks the connection.
func (w *ModifyResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if h, ok := w.w.(http.Hijacker); ok {
return h.Hijack()
}
return nil, nil, fmt.Errorf("not a hijacker: %T", w.w)
}
// Flush sends any buffered data to the client.
func (w *ModifyResponseWriter) Flush() {
if flusher, ok := w.w.(http.Flusher); ok {
flusher.Flush()
}
}

View file

@ -1,7 +1,13 @@
package proxy
// Copyright 2011 The Go Authors.
// Modified from the Go project under the a BSD-style License (https://cs.opensource.google/go/go/+/refs/tags/go1.23.1:src/net/http/httputil/reverseproxy.go)
// https://cs.opensource.google/go/go/+/master:LICENSE
// A small mod on net/http/httputil/reverseproxy.go
// that doubled the performance
package http
// This is a small mod on net/http/httputil/reverseproxy.go
// that boosts performance in some cases
// and compatible to other modules of this project
// Copyright (c) 2024 yusing
import (
"context"
@ -52,6 +58,21 @@ type ProxyRequest struct {
// r.SetXForwarded()
// }
func (r *ProxyRequest) SetXForwarded() {
clientIP, _, err := net.SplitHostPort(r.In.RemoteAddr)
if err == nil {
r.Out.Header.Set("X-Forwarded-For", clientIP)
} else {
r.Out.Header.Del("X-Forwarded-For")
}
r.Out.Header.Set("X-Forwarded-Host", r.In.Host)
if r.In.TLS == nil {
r.Out.Header.Set("X-Forwarded-Proto", "http")
} else {
r.Out.Header.Set("X-Forwarded-Proto", "https")
}
}
func (r *ProxyRequest) AddXForwarded() {
clientIP, _, err := net.SplitHostPort(r.In.RemoteAddr)
if err == nil {
prior := r.Out.Header["X-Forwarded-For"]
@ -104,28 +125,6 @@ type ReverseProxy struct {
// If nil, http.DefaultTransport is used.
Transport http.RoundTripper
// FlushInterval specifies the flush interval
// to flush to the client while copying the
// response body.
// If zero, no periodic flushing is done.
// A negative value means to flush immediately
// after each write to the client.
// The FlushInterval is ignored when ReverseProxy
// recognizes a response as a streaming response, or
// if its ContentLength is -1; for such responses, writes
// are flushed to the client immediately.
// FlushInterval time.Duration
// ErrorLog specifies an optional logger for errors
// that occur when attempting to proxy the request.
// If nil, logging is done via the log package's standard logger.
// ErrorLog *log.Logger
// BufferPool optionally specifies a buffer pool to
// get byte slices for use by io.CopyBuffer when
// copying HTTP response bodies.
// BufferPool BufferPool
// ModifyResponse is an optional function that modifies the
// Response from the backend. It is called if the backend
// returns a response at all, with any HTTP status code.
@ -208,36 +207,11 @@ func joinURLPath(a, b *url.URL) (path, rawpath string) {
// },
// }
//
// TODO: headers in ModifyResponse
func NewReverseProxy(target *url.URL, transport http.RoundTripper, entry *ReverseProxyEntry) *ReverseProxy {
// check on init rather than on request
var setHeaders = func(r *http.Request) {}
var hideHeaders = func(r *http.Request) {}
if len(entry.SetHeaders) > 0 {
setHeaders = func(r *http.Request) {
h := entry.SetHeaders.Clone()
for k, vv := range h {
if k == "Host" {
r.Host = vv[0]
} else {
r.Header[k] = vv
}
}
}
}
if len(entry.HideHeaders) > 0 {
hideHeaders = func(r *http.Request) {
for _, k := range entry.HideHeaders {
r.Header.Del(k)
}
}
}
func NewReverseProxy(target *url.URL, transport http.RoundTripper) *ReverseProxy {
rp := &ReverseProxy{
Rewrite: func(pr *ProxyRequest) {
rewriteRequestURL(pr.Out, target)
// pr.SetXForwarded()
setHeaders(pr.Out)
hideHeaders(pr.Out)
}, Transport: transport,
}
rp.ServeHTTP = rp.serveHTTP
@ -256,6 +230,23 @@ func rewriteRequestURL(req *http.Request, target *url.URL) {
}
}
// Hop-by-hop headers. These are removed when sent to the backend.
// As of RFC 7230, hop-by-hop headers are required to appear in the
// Connection header field. These are the headers defined by the
// obsoleted RFC 2616 (section 13.5.1) and are used for backward
// compatibility.
var hopHeaders = []string{
"Connection",
"Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google
"Keep-Alive",
"Proxy-Authenticate",
"Proxy-Authorization",
"Te", // canonicalized version of "TE"
"Trailer", // not Trailers per URL above; https://www.rfc-editor.org/errata_search.php?eid=4522
"Transfer-Encoding",
"Upgrade",
}
func copyHeader(dst, src http.Header) {
for k, vv := range src {
for _, v := range vv {
@ -331,12 +322,14 @@ func (p *ReverseProxy) serveHTTP(rw http.ResponseWriter, req *http.Request) {
outreq.Close = false
reqUpType := upgradeType(outreq.Header)
reqUpType := UpgradeType(outreq.Header)
if !IsPrint(reqUpType) {
p.errorHandler(rw, req, fmt.Errorf("client tried to switch to invalid protocol %q", reqUpType))
return
}
RemoveHopByHopHeaders(outreq.Header)
// Issue 21096: tell backend applications that care about trailer support
// that we support trailers. (We do, but we don't go out of our way to
// advertise that unless the incoming client request thought it was worth
@ -458,16 +451,34 @@ func (p *ReverseProxy) serveHTTP(rw http.ResponseWriter, req *http.Request) {
}
}
func upgradeType(h http.Header) string {
func UpgradeType(h http.Header) string {
if !httpguts.HeaderValuesContainsToken(h["Connection"], "Upgrade") {
return ""
}
return h.Get("Upgrade")
}
// RemoveHopByHopHeaders removes hop-by-hop headers.
func RemoveHopByHopHeaders(h http.Header) {
// RFC 7230, section 6.1: Remove headers listed in the "Connection" header.
for _, f := range h["Connection"] {
for _, sf := range strings.Split(f, ",") {
if sf = textproto.TrimString(sf); sf != "" {
h.Del(sf)
}
}
}
// RFC 2616, section 13.5.1: Remove a set of known hop-by-hop headers.
// This behavior is superseded by the RFC 7230 Connection header, but
// preserve it for backwards compatibility.
for _, f := range hopHeaders {
h.Del(f)
}
}
func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.Request, res *http.Response) {
reqUpType := upgradeType(req.Header)
resUpType := upgradeType(res.Header)
reqUpType := UpgradeType(req.Header)
resUpType := UpgradeType(res.Header)
if !IsPrint(resUpType) { // We know reqUpType is ASCII, it's checked by the caller.
p.errorHandler(rw, req, fmt.Errorf("backend tried to switch to invalid protocol %q", resUpType))
}

View file

@ -47,11 +47,10 @@ func main() {
logrus.SetOutput(io.Discard)
} else {
logrus.SetFormatter(&logrus.TextFormatter{
DisableSorting: true,
DisableLevelTruncation: true,
FullTimestamp: true,
ForceColors: true,
TimestampFormat: "01-02 15:04:05",
DisableSorting: true,
FullTimestamp: true,
ForceColors: true,
TimestampFormat: "01-02 15:04:05",
})
}
@ -76,10 +75,11 @@ func main() {
return
}
cfg, err := config.Load()
if err.IsFatal() {
log.Fatal(err)
err := config.Load()
if err != nil {
logrus.Warn(err)
}
cfg := config.GetConfig()
switch args.Command {
case common.CommandListConfigs:
@ -96,6 +96,10 @@ func main() {
return
}
if common.IsDebug {
printJSON(docker.GetRegisteredNamespaces())
}
cfg.StartProxyProviders()
if err.HasError() {
@ -116,10 +120,7 @@ func main() {
if autocert != nil {
ctx, cancel := context.WithCancel(context.Background())
if err = autocert.Setup(ctx); err != nil && err.IsWarning() {
cancel()
l.Warn(err)
} else if err.IsFatal() {
if err = autocert.Setup(ctx); err != nil {
l.Fatal(err)
} else {
onShutdown.Add(cancel)
@ -192,7 +193,7 @@ func funcName(f func()) string {
}
func printJSON(obj any) {
j, err := E.Check(json.Marshal(obj))
j, err := E.Check(json.MarshalIndent(obj, "", " "))
if err.HasError() {
logrus.Fatal(err)
}

View file

@ -3,6 +3,7 @@ package model
type Config struct {
Providers ProxyProviders `yaml:",flow" json:"providers"`
AutoCert AutoCertConfig `yaml:",flow" json:"autocert"`
MatchDomains []string `yaml:"match_domains" json:"match_domains"`
TimeoutShutdown int `yaml:"timeout_shutdown" json:"timeout_shutdown"`
RedirectToHTTPS bool `yaml:"redirect_to_https" json:"redirect_to_https"`
}
@ -11,6 +12,6 @@ func DefaultConfig() *Config {
return &Config{
Providers: ProxyProviders{},
TimeoutShutdown: 3,
RedirectToHTTPS: true,
RedirectToHTTPS: false,
}
}

View file

@ -14,14 +14,13 @@ type (
RawEntry struct {
// raw entry object before validation
// loaded from docker labels or yaml file
Alias string `yaml:"-" json:"-"`
Scheme string `yaml:"scheme" json:"scheme"`
Host string `yaml:"host" json:"host"`
Port string `yaml:"port" json:"port"`
NoTLSVerify bool `yaml:"no_tls_verify" json:"no_tls_verify"` // https proxy only
PathPatterns []string `yaml:"path_patterns" json:"path_patterns"` // http(s) proxy only
SetHeaders map[string]string `yaml:"set_headers" json:"set_headers"` // http(s) proxy only
HideHeaders []string `yaml:"hide_headers" json:"hide_headers"` // http(s) proxy only
Alias string `yaml:"-" json:"-"`
Scheme string `yaml:"scheme" json:"scheme"`
Host string `yaml:"host" json:"host"`
Port string `yaml:"port" json:"port"`
NoTLSVerify bool `yaml:"no_tls_verify" json:"no_tls_verify"` // https proxy only
PathPatterns []string `yaml:"path_patterns" json:"path_patterns"` // http(s) proxy only
Middlewares D.NestedLabelMap `yaml:"middlewares" json:"middlewares"`
/* Docker only */
*D.ProxyProperties `yaml:"-" json:"proxy_properties"`
@ -44,12 +43,16 @@ func (e *RawEntry) FillMissingFields() bool {
if pp == "" {
pp = strconv.Itoa(port)
}
e.Scheme = "tcp"
if e.Scheme == "" {
e.Scheme = "tcp"
}
} else if port, ok := ImageNamePortMap[e.ImageName]; ok {
if pp == "" {
pp = strconv.Itoa(port)
}
e.Scheme = "http"
if e.Scheme == "" {
e.Scheme = "http"
}
} else if pp == "" && e.Scheme == "https" {
pp = "443"
} else if pp == "" {

View file

@ -2,10 +2,10 @@ package proxy
import (
"fmt"
"net/http"
"net/url"
"time"
D "github.com/yusing/go-proxy/docker"
E "github.com/yusing/go-proxy/error"
M "github.com/yusing/go-proxy/models"
T "github.com/yusing/go-proxy/proxy/fields"
@ -18,8 +18,7 @@ type (
URL *url.URL
NoTLSVerify bool
PathPatterns T.PathPatterns
SetHeaders http.Header
HideHeaders []string
Middlewares D.NestedLabelMap
/* Docker only */
IdleTimeout time.Duration
@ -78,9 +77,6 @@ func validateRPEntry(m *M.RawEntry, s T.Scheme, b E.Builder) *ReverseProxyEntry
pathPatterns, err := T.ValidatePathPatterns(m.PathPatterns)
b.Add(err)
setHeaders, err := T.ValidateHTTPHeaders(m.SetHeaders)
b.Add(err)
url, err := E.Check(url.Parse(fmt.Sprintf("%s://%s:%d", s, host, port)))
b.Add(err)
@ -111,8 +107,7 @@ func validateRPEntry(m *M.RawEntry, s T.Scheme, b E.Builder) *ReverseProxyEntry
URL: url,
NoTLSVerify: m.NoTLSVerify,
PathPatterns: pathPatterns,
SetHeaders: setHeaders,
HideHeaders: m.HideHeaders,
Middlewares: m.Middlewares,
IdleTimeout: idleTimeout,
WakeTimeout: wakeTimeout,
StopMethod: stopMethod,

View file

@ -4,7 +4,9 @@ import (
"fmt"
"regexp"
"strconv"
"strings"
"github.com/sirupsen/logrus"
D "github.com/yusing/go-proxy/docker"
E "github.com/yusing/go-proxy/error"
M "github.com/yusing/go-proxy/models"
@ -17,6 +19,7 @@ type DockerProvider struct {
}
var AliasRefRegex = regexp.MustCompile(`#\d+`)
var AliasRefRegexOld = regexp.MustCompile(`\$\d+`)
func DockerProviderImpl(dockerHost string) (ProviderImpl, E.NestedError) {
hostname, err := D.ParseDockerHostname(dockerHost)
@ -152,6 +155,20 @@ func (p *DockerProvider) applyLabel(container D.Container, entries M.RawEntries,
b := E.NewBuilder("errors in label %s", key)
defer b.To(&res)
refErr := E.NewBuilder("errors parsing alias references")
replaceIndexRef := func(ref string) string {
index, err := strconv.Atoi(ref[1:])
if err != nil {
refErr.Add(E.Invalid("integer", ref))
return ref
}
if index < 1 || index > len(container.Aliases) {
refErr.Add(E.OutOfRange("index", ref))
return ref
}
return container.Aliases[index-1]
}
lbl, err := D.ParseLabel(key, val)
if err.HasError() {
b.Add(err.Subject(key))
@ -163,22 +180,14 @@ func (p *DockerProvider) applyLabel(container D.Container, entries M.RawEntries,
// apply label for all aliases
entries.RangeAll(func(a string, e *M.RawEntry) {
if err = D.ApplyLabel(e, lbl); err.HasError() {
b.Add(err.Subject(lbl.Target))
b.Add(err.Subjectf("alias %s", lbl.Target))
}
})
} else {
refErr := E.NewBuilder("errors parsing alias references")
lbl.Target = AliasRefRegex.ReplaceAllStringFunc(lbl.Target, func(ref string) string {
index, err := strconv.Atoi(ref[1:])
if err != nil {
refErr.Add(E.Invalid("integer", ref))
return ref
}
if index < 1 || index > len(container.Aliases) {
refErr.Add(E.OutOfRange("index", ref))
return ref
}
return container.Aliases[index-1]
lbl.Target = AliasRefRegex.ReplaceAllStringFunc(lbl.Target, replaceIndexRef)
lbl.Target = AliasRefRegexOld.ReplaceAllStringFunc(lbl.Target, func(s string) string {
logrus.Warnf("%q should now be %q, old syntax will be removed in a future version", lbl, strings.ReplaceAll(lbl.String(), "$", "#"))
return replaceIndexRef(s)
})
if refErr.HasError() {
b.Add(refErr.Build())
@ -190,7 +199,7 @@ func (p *DockerProvider) applyLabel(container D.Container, entries M.RawEntries,
return
}
if err = D.ApplyLabel(config, lbl); err.HasError() {
b.Add(err.Subject(lbl.Target))
b.Add(err.Subjectf("alias %s", lbl.Target))
}
}
return

View file

@ -132,7 +132,8 @@ func TestApplyLabel(t *testing.T) {
ExpectEqual(t, b.Scheme, "http")
ExpectEqual(t, b.Port, "1234")
ExpectEqual(t, c.Scheme, "https")
ExpectEqual(t, c.Port, "1111")
// map does not necessary follow the order above
ExpectEqualAny(t, c.Port, []string{"1111", "1234"})
}
func TestApplyLabelWithRef(t *testing.T) {
@ -142,9 +143,9 @@ func TestApplyLabelWithRef(t *testing.T) {
Labels: map[string]string{
D.LabelAliases: "a,b,c",
"proxy.#1.host": "localhost",
"proxy.*.port": "1111",
"proxy.#1.port": "4444",
"proxy.#2.port": "9999",
"proxy.#3.port": "1111",
"proxy.#3.scheme": "https",
},
Ports: []types.Port{

View file

@ -1,20 +1,20 @@
package route
import (
"crypto/tls"
"net"
"sync"
"time"
"net/http"
"net/url"
"strings"
"github.com/sirupsen/logrus"
"github.com/yusing/go-proxy/common"
"github.com/yusing/go-proxy/docker/idlewatcher"
E "github.com/yusing/go-proxy/error"
. "github.com/yusing/go-proxy/http"
P "github.com/yusing/go-proxy/proxy"
PT "github.com/yusing/go-proxy/proxy/fields"
"github.com/yusing/go-proxy/route/middleware"
F "github.com/yusing/go-proxy/utils/functional"
)
@ -26,7 +26,7 @@ type (
entry *P.ReverseProxyEntry
mux *http.ServeMux
handler *P.ReverseProxy
handler *ReverseProxy
regIdleWatcher func() E.NestedError
unregIdleWatcher func()
@ -36,18 +36,41 @@ type (
SubdomainKey = PT.Alias
)
var (
findMuxFunc = findMuxAnyDomain
httpRoutes = F.NewMapOf[SubdomainKey, *HTTPRoute]()
httpRoutesMu sync.Mutex
globalMux = http.NewServeMux() // TODO: support regex subdomain matching
)
func SetFindMuxDomains(domains []string) {
if len(domains) == 0 {
findMuxFunc = findMuxAnyDomain
} else {
findMuxFunc = findMuxByDomain(domains)
}
}
func NewHTTPRoute(entry *P.ReverseProxyEntry) (*HTTPRoute, E.NestedError) {
var trans *http.Transport
var regIdleWatcher func() E.NestedError
var unregIdleWatcher func()
if entry.NoTLSVerify {
trans = transportNoTLS.Clone()
trans = common.DefaultTransportNoTLS.Clone()
} else {
trans = transport.Clone()
trans = common.DefaultTransport.Clone()
}
rp := P.NewReverseProxy(entry.URL, trans, entry)
rp := NewReverseProxy(entry.URL, trans)
if len(entry.Middlewares) > 0 {
err := middleware.PatchReverseProxy(rp, entry.Middlewares)
if err != nil {
return nil, err
}
}
if entry.UseIdleWatcher() {
// allow time for response header up to `WakeTimeout`
@ -74,7 +97,7 @@ func NewHTTPRoute(entry *P.ReverseProxyEntry) (*HTTPRoute, E.NestedError) {
_, exists := httpRoutes.Load(entry.Alias)
if exists {
return nil, E.AlreadyExist("HTTPRoute alias", entry.Alias)
return nil, E.Duplicated("HTTPRoute alias", entry.Alias)
}
r := &HTTPRoute{
@ -94,11 +117,16 @@ func (r *HTTPRoute) String() string {
}
func (r *HTTPRoute) Start() E.NestedError {
if r.mux != nil {
return nil
}
httpRoutesMu.Lock()
defer httpRoutesMu.Unlock()
if r.regIdleWatcher != nil {
if err := r.regIdleWatcher(); err.HasError() {
r.unregIdleWatcher = nil
return err
}
}
@ -113,6 +141,10 @@ func (r *HTTPRoute) Start() E.NestedError {
}
func (r *HTTPRoute) Stop() E.NestedError {
if r.mux == nil {
return nil
}
httpRoutesMu.Lock()
defer httpRoutesMu.Unlock()
@ -135,7 +167,7 @@ func (u *URL) MarshalText() (text []byte, err error) {
}
func ProxyHandler(w http.ResponseWriter, r *http.Request) {
mux, err := findMux(r.Host)
mux, err := findMuxFunc(r.Host)
if err != nil {
err = E.Failure("request").
Subjectf("%s %s%s", r.Method, r.Host, r.URL.Path).
@ -147,7 +179,7 @@ func ProxyHandler(w http.ResponseWriter, r *http.Request) {
mux.ServeHTTP(w, r)
}
func findMux(host string) (*http.ServeMux, E.NestedError) {
func findMuxAnyDomain(host string) (*http.ServeMux, E.NestedError) {
hostSplit := strings.Split(host, ".")
n := len(hostSplit)
if n <= 2 {
@ -160,23 +192,21 @@ func findMux(host string) (*http.ServeMux, E.NestedError) {
return nil, E.NotExist("route", sd)
}
var (
defaultDialer = net.Dialer{
Timeout: 60 * time.Second,
KeepAlive: 60 * time.Second,
func findMuxByDomain(domains []string) func(host string) (*http.ServeMux, E.NestedError) {
return func(host string) (*http.ServeMux, E.NestedError) {
var subdomain string
for _, domain := range domains {
subdomain = strings.TrimSuffix(subdomain, domain)
if subdomain != domain {
break
}
}
if subdomain == "" { // not matched
return nil, E.Invalid("host", host)
}
if r, ok := httpRoutes.Load(PT.Alias(subdomain)); ok {
return r.mux, nil
}
return nil, E.NotExist("route", subdomain)
}
transport = &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: defaultDialer.DialContext,
MaxIdleConnsPerHost: 1000,
}
transportNoTLS = func() *http.Transport {
var clone = transport.Clone()
clone.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
return clone
}()
httpRoutes = F.NewMapOf[SubdomainKey, *HTTPRoute]()
httpRoutesMu sync.Mutex
globalMux = http.NewServeMux()
)
}

View file

@ -1,7 +0,0 @@
package middleware
var AddXForwarded = &Middleware{
rewrite: func(r *ProxyRequest) {
r.SetXForwarded()
},
}

View file

@ -0,0 +1,249 @@
// Modified from Traefik Labs's MIT-licensed code (https://github.com/traefik/traefik/blob/master/pkg/middlewares/auth/forward.go)
// Copyright (c) 2020-2024 Traefik Labs
// Copyright (c) 2024 yusing
package middleware
import (
"io"
"net"
"net/http"
"net/url"
"slices"
"strings"
"time"
"github.com/sirupsen/logrus"
"github.com/yusing/go-proxy/common"
D "github.com/yusing/go-proxy/docker"
E "github.com/yusing/go-proxy/error"
gpHTTP "github.com/yusing/go-proxy/http"
U "github.com/yusing/go-proxy/utils"
)
type (
forwardAuth struct {
*forwardAuthOpts
m *Middleware
client http.Client
}
forwardAuthOpts struct {
Address string
TrustForwardHeader bool
AuthResponseHeaders []string
AddAuthCookiesToResponse []string
}
)
const (
xForwardedFor = "X-Forwarded-For"
xForwardedMethod = "X-Forwarded-Method"
xForwardedHost = "X-Forwarded-Host"
xForwardedProto = "X-Forwarded-Proto"
xForwardedURI = "X-Forwarded-Uri"
xForwardedPort = "X-Forwarded-Port"
)
var ForwardAuth = newForwardAuth()
var faLogger = logrus.WithField("middleware", "ForwardAuth")
func newForwardAuth() (fa *forwardAuth) {
fa = new(forwardAuth)
fa.m = new(Middleware)
fa.m.labelParserMap = D.ValueParserMap{
"trust_forward_header": D.BoolParser,
"auth_response_headers": D.YamlStringListParser,
"add_auth_cookies_to_response": D.YamlStringListParser,
}
fa.m.withOptions = func(optsRaw OptionsRaw, rp *ReverseProxy) (*Middleware, E.NestedError) {
tr, ok := rp.Transport.(*http.Transport)
if ok {
tr = tr.Clone()
} else {
tr = common.DefaultTransport.Clone()
}
faWithOpts := new(forwardAuth)
faWithOpts.forwardAuthOpts = new(forwardAuthOpts)
faWithOpts.client = http.Client{
CheckRedirect: func(r *Request, via []*Request) error {
return http.ErrUseLastResponse
},
Timeout: 30 * time.Second,
Transport: tr,
}
faWithOpts.m = &Middleware{
impl: faWithOpts,
before: fa.forward,
}
err := U.Deserialize(optsRaw, faWithOpts.forwardAuthOpts)
if err != nil {
return nil, E.FailWith("set options", err)
}
_, err = E.Check(url.Parse(faWithOpts.Address))
if err != nil {
return nil, E.Invalid("address", faWithOpts.Address)
}
return faWithOpts.m, nil
}
return
}
func (fa *forwardAuth) forward(next http.Handler, w ResponseWriter, req *Request) {
removeHop(req.Header)
faReq, err := http.NewRequestWithContext(
req.Context(),
http.MethodGet,
fa.Address,
nil,
)
if err != nil {
faLogger.Debugf("new request err to %s: %s", fa.Address, err)
w.WriteHeader(http.StatusInternalServerError)
return
}
copyHeader(faReq.Header, req.Header)
removeHop(faReq.Header)
filterHeaders(faReq.Header, fa.AuthResponseHeaders)
fa.setAuthHeaders(req, faReq)
faResp, err := fa.client.Do(faReq)
if err != nil {
faLogger.Debugf("failed to call %s: %s", fa.Address, err)
w.WriteHeader(http.StatusInternalServerError)
return
}
defer faResp.Body.Close()
body, err := io.ReadAll(faResp.Body)
if err != nil {
faLogger.Debugf("failed to read response body from %s: %s", fa.Address, err)
w.WriteHeader(http.StatusInternalServerError)
return
}
if faResp.StatusCode < http.StatusOK || faResp.StatusCode >= http.StatusMultipleChoices {
copyHeader(w.Header(), faResp.Header)
removeHop(w.Header())
redirectURL, err := faResp.Location()
if err != nil {
faLogger.Debugf("failed to get location from %s: %s", fa.Address, err)
w.WriteHeader(http.StatusInternalServerError)
return
} else if redirectURL.String() != "" {
w.Header().Set("Location", redirectURL.String())
}
w.WriteHeader(faResp.StatusCode)
if _, err = w.Write(body); err != nil {
faLogger.Debugf("failed to write response body from %s: %s", fa.Address, err)
}
return
}
for _, key := range fa.AuthResponseHeaders {
key := http.CanonicalHeaderKey(key)
req.Header.Del(key)
if len(faResp.Header[key]) > 0 {
req.Header[key] = append([]string(nil), faResp.Header[key]...)
}
}
req.RequestURI = req.URL.RequestURI()
authCookies := faResp.Cookies()
if len(authCookies) == 0 {
next.ServeHTTP(w, req)
return
}
next.ServeHTTP(gpHTTP.NewModifyResponseWriter(w, req, func(resp *Response) error {
fa.setAuthCookies(resp, authCookies)
return nil
}), req)
}
func (fa *forwardAuth) setAuthCookies(resp *Response, authCookies []*Cookie) {
if len(fa.AddAuthCookiesToResponse) == 0 {
return
}
cookies := resp.Cookies()
resp.Header.Del("Set-Cookie")
for _, cookie := range cookies {
if !slices.Contains(fa.AddAuthCookiesToResponse, cookie.Name) {
// this cookie is not an auth cookie, so add it back
resp.Header.Add("Set-Cookie", cookie.String())
}
}
for _, cookie := range authCookies {
if slices.Contains(fa.AddAuthCookiesToResponse, cookie.Name) {
// this cookie is an auth cookie, so add to resp
resp.Header.Add("Set-Cookie", cookie.String())
}
}
}
func (fa *forwardAuth) setAuthHeaders(req, faReq *Request) {
if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
if fa.TrustForwardHeader {
if prior, ok := req.Header[xForwardedFor]; ok {
clientIP = strings.Join(prior, ", ") + ", " + clientIP
}
}
faReq.Header.Set(xForwardedFor, clientIP)
}
xMethod := req.Header.Get(xForwardedMethod)
switch {
case xMethod != "" && fa.TrustForwardHeader:
faReq.Header.Set(xForwardedMethod, xMethod)
case req.Method != "":
faReq.Header.Set(xForwardedMethod, req.Method)
default:
faReq.Header.Del(xForwardedMethod)
}
xfp := req.Header.Get(xForwardedProto)
switch {
case xfp != "" && fa.TrustForwardHeader:
faReq.Header.Set(xForwardedProto, xfp)
case req.TLS != nil:
faReq.Header.Set(xForwardedProto, "https")
default:
faReq.Header.Set(xForwardedProto, "http")
}
if xfp := req.Header.Get(xForwardedPort); xfp != "" && fa.TrustForwardHeader {
faReq.Header.Set(xForwardedPort, xfp)
}
xfh := req.Header.Get(xForwardedHost)
switch {
case xfh != "" && fa.TrustForwardHeader:
faReq.Header.Set(xForwardedHost, xfh)
case req.Host != "":
faReq.Header.Set(xForwardedHost, req.Host)
default:
faReq.Header.Del(xForwardedHost)
}
xfURI := req.Header.Get(xForwardedURI)
switch {
case xfURI != "" && fa.TrustForwardHeader:
faReq.Header.Set(xForwardedURI, xfURI)
case req.URL.RequestURI() != "":
faReq.Header.Set(xForwardedURI, req.URL.RequestURI())
default:
faReq.Header.Del(xForwardedURI)
}
}

View file

@ -0,0 +1,44 @@
package middleware
import (
"net/http"
"slices"
gpHTTP "github.com/yusing/go-proxy/http"
)
func removeHop(h Header) {
reqUpType := gpHTTP.UpgradeType(h)
gpHTTP.RemoveHopByHopHeaders(h)
if reqUpType != "" {
h.Set("Connection", "Upgrade")
h.Set("Upgrade", reqUpType)
} else {
h.Del("Connection")
}
}
func copyHeader(dst, src Header) {
for k, vv := range src {
for _, v := range vv {
dst.Add(k, v)
}
}
}
func filterHeaders(h Header, allowed []string) {
if allowed == nil {
return
}
for i := range allowed {
allowed[i] = http.CanonicalHeaderKey(allowed[i])
}
for key := range h {
if !slices.Contains(allowed, key) {
h.Del(key)
}
}
}

View file

@ -3,33 +3,42 @@ package middleware
import (
"net/http"
D "github.com/yusing/go-proxy/docker"
E "github.com/yusing/go-proxy/error"
P "github.com/yusing/go-proxy/proxy"
gpHTTP "github.com/yusing/go-proxy/http"
)
type (
ReverseProxy = P.ReverseProxy
ProxyRequest = P.ProxyRequest
Error = E.NestedError
ReverseProxy = gpHTTP.ReverseProxy
ProxyRequest = gpHTTP.ProxyRequest
Request = http.Request
Response = http.Response
ResponseWriter = http.ResponseWriter
Header = http.Header
Cookie = http.Cookie
BeforeFunc func(w ResponseWriter, r *Request) (continue_ bool)
BeforeFunc func(next http.Handler, w ResponseWriter, r *Request)
RewriteFunc func(req *ProxyRequest)
ModifyResponseFunc func(res *Response) error
CloneWithOptFunc func(opts OptionsRaw, rp *ReverseProxy) (*Middleware, E.NestedError)
MiddlewareOptionsRaw map[string]string
MiddlewareOptions map[string]interface{}
OptionsRaw = map[string]any
Options any
Middleware struct {
name string
before BeforeFunc
rewrite RewriteFunc
modifyResponse ModifyResponseFunc
before BeforeFunc // runs before ReverseProxy.ServeHTTP
rewrite RewriteFunc // runs after ReverseProxy.Rewrite
modifyResponse ModifyResponseFunc // runs after ReverseProxy.ModifyResponse
options MiddlewareOptions
validateOptions func(opts MiddlewareOptionsRaw) (MiddlewareOptions, E.NestedError)
transport http.RoundTripper
withOptions CloneWithOptFunc
labelParserMap D.ValueParserMap
impl any
}
)
@ -41,41 +50,32 @@ func (m *Middleware) String() string {
return m.name
}
func (m *Middleware) WithOptions(optsRaw MiddlewareOptionsRaw) (*Middleware, E.NestedError) {
if len(optsRaw) == 0 {
return m, nil
}
var opts MiddlewareOptions
var err E.NestedError
if m.validateOptions != nil {
if opts, err = m.validateOptions(optsRaw); err != nil {
func (m *Middleware) WithOptionsClone(optsRaw OptionsRaw, rp *ReverseProxy) (*Middleware, E.NestedError) {
if len(optsRaw) != 0 && m.withOptions != nil {
if mWithOpt, err := m.withOptions(optsRaw, rp); err != nil {
return nil, err
} else {
return mWithOpt, nil
}
}
return &Middleware{
name: m.name,
before: m.before,
rewrite: m.rewrite,
modifyResponse: m.modifyResponse,
options: opts,
}, nil
// WithOptionsClone is called only once
// set withOptions and labelParser will not be used after that
return &Middleware{m.name, m.before, m.rewrite, m.modifyResponse, m.transport, nil, nil, m.impl}, nil
}
// TODO: check conflict
func PatchReverseProxy(rp ReverseProxy, middlewares map[string]MiddlewareOptionsRaw) (out ReverseProxy, err E.NestedError) {
out = rp
// TODO: check conflict or duplicates
func PatchReverseProxy(rp *ReverseProxy, middlewares map[string]OptionsRaw) (res E.NestedError) {
befores := make([]BeforeFunc, 0, len(middlewares))
rewrites := make([]RewriteFunc, 0, len(middlewares))
modifyResponses := make([]ModifyResponseFunc, 0, len(middlewares))
invalidM := E.NewBuilder("invalid middlewares")
invalidOpts := E.NewBuilder("invalid options")
defer invalidM.Add(invalidOpts.Build())
defer invalidM.To(&err)
defer func() {
invalidM.Add(invalidOpts.Build())
invalidM.To(&res)
}()
for name, opts := range middlewares {
m, ok := Get(name)
@ -83,7 +83,8 @@ func PatchReverseProxy(rp ReverseProxy, middlewares map[string]MiddlewareOptions
invalidM.Addf("%s", name)
continue
}
m, err = m.WithOptions(opts)
m, err := m.WithOptionsClone(opts, rp)
if err != nil {
invalidOpts.Add(err.Subject(name))
continue
@ -103,25 +104,37 @@ func PatchReverseProxy(rp ReverseProxy, middlewares map[string]MiddlewareOptions
return
}
if len(befores) > 0 {
rp.ServeHTTP = func(w ResponseWriter, r *Request) {
for _, before := range befores {
if !before(w, r) {
return
}
origServeHTTP := rp.ServeHTTP
for i, before := range befores {
if i < len(befores)-1 {
rp.ServeHTTP = func(w ResponseWriter, r *Request) {
before(rp.ServeHTTP, w, r)
}
} else {
rp.ServeHTTP = func(w ResponseWriter, r *Request) {
before(origServeHTTP, w, r)
}
rp.ServeHTTP(w, r)
}
}
if len(rewrites) > 0 {
origRewrite := rp.Rewrite
rp.Rewrite = func(req *ProxyRequest) {
if origRewrite != nil {
origRewrite(req)
}
for _, rewrite := range rewrites {
rewrite(req)
}
}
}
if len(modifyResponses) > 0 {
origModifyResponse := rp.ModifyResponse
rp.ModifyResponse = func(res *Response) error {
if origModifyResponse != nil {
return origModifyResponse(res)
}
for _, modifyResponse := range modifyResponses {
if err := modifyResponse(res); err != nil {
return err

View file

@ -3,14 +3,11 @@ package middleware
import (
"fmt"
"strings"
D "github.com/yusing/go-proxy/docker"
)
var middlewares = map[string]*Middleware{
"set_x_forwarded": SetXForwarded, // nginx
"add_x_forwarded": AddXForwarded, // nginx
"trust_forward_header": AddXForwarded, // traefik alias
"redirect_http": RedirectHTTP,
}
var middlewares map[string]*Middleware
func Get(name string) (middleware *Middleware, ok bool) {
middleware, ok = middlewares[name]
@ -18,10 +15,23 @@ func Get(name string) (middleware *Middleware, ok bool) {
}
// initialize middleware names
var _ = func() (_ bool) {
func init() {
middlewares = map[string]*Middleware{
"set_x_forwarded": SetXForwarded,
"add_x_forwarded": AddXForwarded,
"redirect_http": RedirectHTTP,
"forward_auth": ForwardAuth.m,
"modify_response": ModifyResponse.m,
"modify_request": ModifyRequest.m,
}
names := make(map[*Middleware][]string)
for name, m := range middlewares {
names[m] = append(names[m], name)
// register middleware name to docker label parsr
// in order to parse middleware_name.option=value into correct type
if m.labelParserMap != nil {
D.RegisterNamespace(name, m.labelParserMap)
}
}
for m, names := range names {
if len(names) > 1 {
@ -30,5 +40,4 @@ var _ = func() (_ bool) {
m.name = names[0]
}
}
return
}()
}

View file

@ -0,0 +1,58 @@
package middleware
import (
D "github.com/yusing/go-proxy/docker"
E "github.com/yusing/go-proxy/error"
U "github.com/yusing/go-proxy/utils"
)
type (
modifyRequest struct {
*modifyRequestOpts
m *Middleware
}
// order: set_headers -> add_headers -> hide_headers
modifyRequestOpts struct {
SetHeaders map[string]string
AddHeaders map[string]string
HideHeaders []string
}
)
var ModifyRequest = newModifyRequest()
func newModifyRequest() (mr *modifyRequest) {
mr = new(modifyRequest)
mr.m = new(Middleware)
mr.m.labelParserMap = D.ValueParserMap{
"set_headers": D.YamlLikeMappingParser(true),
"add_headers": D.YamlLikeMappingParser(true),
"hide_headers": D.YamlStringListParser,
}
mr.m.withOptions = func(optsRaw OptionsRaw, rp *ReverseProxy) (*Middleware, E.NestedError) {
mrWithOpts := new(modifyRequest)
mrWithOpts.m = &Middleware{
impl: mrWithOpts,
rewrite: mrWithOpts.modifyRequest,
}
mrWithOpts.modifyRequestOpts = new(modifyRequestOpts)
err := U.Deserialize(optsRaw, mrWithOpts.modifyRequestOpts)
if err != nil {
return nil, E.FailWith("set options", err)
}
return mrWithOpts.m, nil
}
return
}
func (mr *modifyRequest) modifyRequest(req *ProxyRequest) {
for k, v := range mr.SetHeaders {
req.Out.Header.Set(k, v)
}
for k, v := range mr.AddHeaders {
req.Out.Header.Add(k, v)
}
for _, k := range mr.HideHeaders {
req.Out.Header.Del(k)
}
}

View file

@ -0,0 +1,34 @@
package middleware
import (
"slices"
"testing"
. "github.com/yusing/go-proxy/utils/testing"
)
func TestSetModifyRequest(t *testing.T) {
opts := OptionsRaw{
"set_headers": map[string]string{"User-Agent": "go-proxy/v0.5.0"},
"add_headers": map[string]string{"Accept-Encoding": "test-value"},
"hide_headers": []string{"Accept"},
}
t.Run("set_options", func(t *testing.T) {
mr, err := ModifyRequest.m.WithOptionsClone(opts, nil)
ExpectNoError(t, err.Error())
ExpectDeepEqual(t, mr.impl.(*modifyRequest).SetHeaders, opts["set_headers"].(map[string]string))
ExpectDeepEqual(t, mr.impl.(*modifyRequest).AddHeaders, opts["add_headers"].(map[string]string))
ExpectDeepEqual(t, mr.impl.(*modifyRequest).HideHeaders, opts["hide_headers"].([]string))
})
t.Run("request_headers", func(t *testing.T) {
result, err := newMiddlewareTest(ModifyRequest.m, &testArgs{
middlewareOpt: opts,
})
ExpectNoError(t, err.Error())
ExpectEqual(t, result.RequestHeaders.Get("User-Agent"), "go-proxy/v0.5.0")
ExpectTrue(t, slices.Contains(result.RequestHeaders.Values("Accept-Encoding"), "test-value"))
ExpectEqual(t, result.RequestHeaders.Get("Accept"), "")
})
}

View file

@ -0,0 +1,61 @@
package middleware
import (
"net/http"
D "github.com/yusing/go-proxy/docker"
E "github.com/yusing/go-proxy/error"
U "github.com/yusing/go-proxy/utils"
)
type (
modifyResponse struct {
*modifyResponseOpts
m *Middleware
}
// order: set_headers -> add_headers -> hide_headers
modifyResponseOpts struct {
SetHeaders map[string]string
AddHeaders map[string]string
HideHeaders []string
}
)
var ModifyResponse = newModifyResponse()
func newModifyResponse() (mr *modifyResponse) {
mr = new(modifyResponse)
mr.m = new(Middleware)
mr.m.labelParserMap = D.ValueParserMap{
"set_headers": D.YamlLikeMappingParser(true),
"add_headers": D.YamlLikeMappingParser(true),
"hide_headers": D.YamlStringListParser,
}
mr.m.withOptions = func(optsRaw OptionsRaw, rp *ReverseProxy) (*Middleware, E.NestedError) {
mrWithOpts := new(modifyResponse)
mrWithOpts.m = &Middleware{
impl: mrWithOpts,
modifyResponse: mrWithOpts.modifyResponse,
}
mrWithOpts.modifyResponseOpts = new(modifyResponseOpts)
err := U.Deserialize(optsRaw, mrWithOpts.modifyResponseOpts)
if err != nil {
return nil, E.FailWith("set options", err)
}
return mrWithOpts.m, nil
}
return
}
func (mr *modifyResponse) modifyResponse(resp *http.Response) error {
for k, v := range mr.SetHeaders {
resp.Header.Set(k, v)
}
for k, v := range mr.AddHeaders {
resp.Header.Add(k, v)
}
for _, k := range mr.HideHeaders {
resp.Header.Del(k)
}
return nil
}

View file

@ -0,0 +1,35 @@
package middleware
import (
"slices"
"testing"
. "github.com/yusing/go-proxy/utils/testing"
)
func TestSetModifyResponse(t *testing.T) {
opts := OptionsRaw{
"set_headers": map[string]string{"User-Agent": "go-proxy/v0.5.0"},
"add_headers": map[string]string{"Accept-Encoding": "test-value"},
"hide_headers": []string{"Accept"},
}
t.Run("set_options", func(t *testing.T) {
mr, err := ModifyResponse.m.WithOptionsClone(opts, nil)
ExpectNoError(t, err.Error())
ExpectDeepEqual(t, mr.impl.(*modifyResponse).SetHeaders, opts["set_headers"].(map[string]string))
ExpectDeepEqual(t, mr.impl.(*modifyResponse).AddHeaders, opts["add_headers"].(map[string]string))
ExpectDeepEqual(t, mr.impl.(*modifyResponse).HideHeaders, opts["hide_headers"].([]string))
})
t.Run("request_headers", func(t *testing.T) {
result, err := newMiddlewareTest(ModifyResponse.m, &testArgs{
middlewareOpt: opts,
})
ExpectNoError(t, err.Error())
ExpectEqual(t, result.ResponseHeaders.Get("User-Agent"), "go-proxy/v0.5.0")
t.Log(result.ResponseHeaders.Get("Accept-Encoding"))
ExpectTrue(t, slices.Contains(result.ResponseHeaders.Values("Accept-Encoding"), "test-value"))
ExpectEqual(t, result.ResponseHeaders.Get("Accept"), "")
})
}

View file

@ -7,14 +7,13 @@ import (
)
var RedirectHTTP = &Middleware{
before: func(w ResponseWriter, r *Request) (continue_ bool) {
before: func(next http.Handler, w ResponseWriter, r *Request) {
if r.TLS == nil {
r.URL.Scheme = "https"
r.URL.Host = r.URL.Hostname() + common.ProxyHTTPSPort
r.URL.Host = r.URL.Hostname() + ":" + common.ProxyHTTPSPort
http.Redirect(w, r, r.URL.String(), http.StatusTemporaryRedirect)
} else {
continue_ = true
return
}
return
next.ServeHTTP(w, r)
},
}

View file

@ -0,0 +1,26 @@
package middleware
import (
"net/http"
"testing"
"github.com/yusing/go-proxy/common"
. "github.com/yusing/go-proxy/utils/testing"
)
func TestRedirectToHTTPs(t *testing.T) {
result, err := newMiddlewareTest(RedirectHTTP, &testArgs{
scheme: "http",
})
ExpectNoError(t, err.Error())
ExpectEqual(t, result.ResponseStatus, http.StatusTemporaryRedirect)
ExpectEqual(t, result.ResponseHeaders.Get("Location"), "https://"+testHost+":"+common.ProxyHTTPSPort)
}
func TestNoRedirect(t *testing.T) {
result, err := newMiddlewareTest(RedirectHTTP, &testArgs{
scheme: "https",
})
ExpectNoError(t, err.Error())
ExpectEqual(t, result.ResponseStatus, http.StatusOK)
}

View file

@ -1,10 +0,0 @@
package middleware
var SetXForwarded = &Middleware{
rewrite: func(r *ProxyRequest) {
r.Out.Header.Del("X-Forwarded-For")
r.Out.Header.Del("X-Forwarded-Host")
r.Out.Header.Del("X-Forwarded-Proto")
r.SetXForwarded()
},
}

View file

@ -0,0 +1,17 @@
{
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7",
"Accept-Encoding": "gzip, deflate, br, zstd",
"Accept-Language": "en,zh-HK;q=0.9,zh-TW;q=0.8,zh-CN;q=0.7,zh;q=0.6",
"Dnt": "1",
"Host": "localhost",
"Priority": "u=0, i",
"Sec-Ch-Ua": "\"Chromium\";v=\"129\", \"Not=A?Brand\";v=\"8\"",
"Sec-Ch-Ua-Mobile": "?0",
"Sec-Ch-Ua-Platform": "\"Windows\"",
"Sec-Fetch-Dest": "document",
"Sec-Fetch-Mode": "navigate",
"Sec-Fetch-Site": "none",
"Sec-Fetch-User": "?1",
"Upgrade-Insecure-Requests": "1",
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/129.0.0.0 Safari/537.36"
}

View file

@ -0,0 +1,125 @@
package middleware
import (
"bytes"
_ "embed"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"net/url"
E "github.com/yusing/go-proxy/error"
gpHTTP "github.com/yusing/go-proxy/http"
)
//go:embed test_data/sample_headers.json
var testHeadersRaw []byte
var testHeaders http.Header
const testHost = "example.com"
func init() {
tmp := map[string]string{}
err := json.Unmarshal(testHeadersRaw, &tmp)
if err != nil {
panic(err)
}
testHeaders = http.Header{}
for k, v := range tmp {
testHeaders.Set(k, v)
}
}
type requestHeaderRecorder struct {
parent http.RoundTripper
reqHeaders http.Header
}
func (rt *requestHeaderRecorder) RoundTrip(req *http.Request) (*http.Response, error) {
rt.reqHeaders = req.Header
if rt.parent != nil {
return rt.parent.RoundTrip(req)
}
return &http.Response{
StatusCode: http.StatusOK,
Header: testHeaders,
Body: io.NopCloser(bytes.NewBufferString("OK")),
Request: req,
}, nil
}
type TestResult struct {
RequestHeaders http.Header
ResponseHeaders http.Header
ResponseStatus int
Data []byte
}
type testArgs struct {
middlewareOpt OptionsRaw
proxyURL string
body []byte
scheme string
}
func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.NestedError) {
var body io.Reader
var rt = new(requestHeaderRecorder)
var proxyURL *url.URL
var requestTarget string
var err error
if args == nil {
args = new(testArgs)
}
if args.body != nil {
body = bytes.NewReader(args.body)
}
if args.scheme == "" || args.scheme == "http" {
requestTarget = "http://" + testHost
} else if args.scheme == "https" {
requestTarget = "https://" + testHost
} else {
panic("typo?")
}
req := httptest.NewRequest(http.MethodGet, requestTarget, body)
w := httptest.NewRecorder()
if args.scheme == "https" && req.TLS == nil {
panic("bug occurred")
}
if args.proxyURL != "" {
proxyURL, err = url.Parse(args.proxyURL)
if err != nil {
return nil, E.From(err)
}
rt.parent = http.DefaultTransport
} else {
proxyURL, _ = url.Parse("https://" + testHost) // dummy url, no actual effect
}
rp := gpHTTP.NewReverseProxy(proxyURL, rt)
setOptErr := PatchReverseProxy(rp, map[string]OptionsRaw{
middleware.name: args.middlewareOpt,
})
if setOptErr != nil {
return nil, setOptErr
}
rp.ServeHTTP(w, req)
resp := w.Result()
defer resp.Body.Close()
data, err := io.ReadAll(resp.Body)
if err != nil {
return nil, E.From(err)
}
return &TestResult{
RequestHeaders: rt.reqHeaders,
ResponseHeaders: resp.Header,
ResponseStatus: resp.StatusCode,
Data: data,
}, nil
}

View file

@ -0,0 +1,9 @@
package middleware
var AddXForwarded = &Middleware{
rewrite: (*ProxyRequest).AddXForwarded,
}
var SetXForwarded = &Middleware{
rewrite: (*ProxyRequest).SetXForwarded,
}

View file

@ -84,6 +84,11 @@ func NewServer(opt Options) (s *server) {
}
}
// Start will start the http and https servers.
//
// If both are not set, this does nothing.
//
// Start() is non-blocking
func (s *server) Start() {
if s.http == nil && s.https == nil {
return

View file

@ -106,7 +106,7 @@ func Serialize(data any) (SerializedObject, E.NestedError) {
return result, nil
}
func Deserialize(src map[string]any, target any) E.NestedError {
func Deserialize(src SerializedObject, target any) E.NestedError {
// convert data fields to lower no-snake
// convert target fields to lower
// then check if the field of data is in the target
@ -117,6 +117,10 @@ func Deserialize(src map[string]any, target any) E.NestedError {
snakeCaseField := strings.ToLower(field.Name)
mapping[snakeCaseField] = field.Name
}
tValue := reflect.ValueOf(target)
if tValue.IsZero() {
return E.Invalid("value", "nil")
}
for k, v := range src {
kCleaned := toLowerNoSnake(k)
if fieldName, ok := mapping[kCleaned]; ok {
@ -150,13 +154,13 @@ func Deserialize(src map[string]any, target any) E.NestedError {
}
prop.Set(propNew)
default:
return E.Unsupported("field", k).Extraf("type=%s", propType)
return E.Invalid("conversion", k).Extraf("from %s to %s", vType, propType)
}
} else {
return E.Unsupported("field", k).Extraf("type=%s", propType)
return E.Unsupported("field", k).Extraf("type %s is not settable", propType)
}
} else {
return E.Failure("unknown field").With(k)
return E.Unexpected("field", k)
}
}

View file

@ -38,6 +38,17 @@ func ExpectEqual[T comparable](t *testing.T, got T, want T) {
}
}
func ExpectEqualAny[T comparable](t *testing.T, got T, wants []T) {
t.Helper()
for _, want := range wants {
if got == want {
return
}
}
t.Errorf("expected any of:\n%v, got\n%v", wants, got)
t.FailNow()
}
func ExpectDeepEqual[T any](t *testing.T, got T, want T) {
t.Helper()
if !reflect.DeepEqual(got, want) {