mirror of
https://github.com/yusing/godoxy.git
synced 2025-05-19 20:32:35 +02:00
added support for a few middlewares, added match_domain
option, changed index reference prefix from $ to #, etc.
This commit is contained in:
parent
345a4417a6
commit
f474ae4f75
47 changed files with 1523 additions and 446 deletions
15
.github/workflows/docker-image.yml
vendored
15
.github/workflows/docker-image.yml
vendored
|
@ -130,3 +130,18 @@ jobs:
|
|||
run: |
|
||||
docker tag ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:${{ steps.meta.outputs.version }} ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:latest
|
||||
docker push ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:latest
|
||||
scan:
|
||||
runs-on: ubuntu-latest
|
||||
needs:
|
||||
- merge
|
||||
steps:
|
||||
- name: Scan Image with Trivy
|
||||
uses: aquasecurity/trivy-action@0.20.0
|
||||
with:
|
||||
image-ref: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:latest
|
||||
format: "sarif"
|
||||
output: "trivy-results.sarif"
|
||||
- name: Upload Trivy SARIF Report
|
||||
uses: github/codeql-action/upload-sarif@v2
|
||||
with:
|
||||
sarif_file: "trivy-results.sarif"
|
||||
|
|
10
Dockerfile
10
Dockerfile
|
@ -9,15 +9,15 @@ COPY src/go.mod src/go.sum ./
|
|||
|
||||
# Utilize build cache
|
||||
RUN --mount=type=cache,target="/go/pkg/mod" \
|
||||
go mod download
|
||||
go mod graph | awk '{if ($1 !~ "@") print $2}' | xargs go get
|
||||
|
||||
# Now copy the remaining files
|
||||
COPY src/ ./
|
||||
ENV GOCACHE=/root/.cache/go-build
|
||||
|
||||
# Build the application with better caching
|
||||
RUN --mount=type=cache,target="/go/pkg/mod" \
|
||||
--mount=type=cache,target="/root/.cache/go-build" \
|
||||
CGO_ENABLED=0 GOOS=linux go build -pgo=auto -o go-proxy ./
|
||||
--mount=type=bind,src=src,dst=/src \
|
||||
CGO_ENABLED=0 GOOS=linux go build -ldflags '-w -s' -pgo=auto -o /go-proxy .
|
||||
|
||||
# Stage 2: Final image
|
||||
FROM scratch
|
||||
|
@ -28,7 +28,7 @@ LABEL maintainer="yusing@6uo.me"
|
|||
COPY --from=builder /usr/share/zoneinfo /usr/share/zoneinfo
|
||||
|
||||
# copy binary
|
||||
COPY --from=builder /src/go-proxy /app/
|
||||
COPY --from=builder /go-proxy /app/
|
||||
|
||||
# copy schema directory
|
||||
COPY schema/ /app/schema/
|
||||
|
|
10
Makefile
10
Makefile
|
@ -1,4 +1,4 @@
|
|||
.PHONY: all build up quick-restart restart logs get udp-server
|
||||
.PHONY: all setup build test up restart logs get debug run archive repush rapid-crash debug-list-containers
|
||||
|
||||
all: debug
|
||||
|
||||
|
@ -9,7 +9,8 @@ setup:
|
|||
|
||||
build:
|
||||
mkdir -p bin
|
||||
CGO_ENABLED=0 GOOS=linux go build -pgo=auto -o bin/go-proxy github.com/yusing/go-proxy
|
||||
CGO_ENABLED=0 GOOS=linux \
|
||||
go build -ldflags '${BUILD_FLAG}' -pgo=auto -o bin/go-proxy github.com/yusing/go-proxy
|
||||
|
||||
test:
|
||||
go test ./src/...
|
||||
|
@ -29,6 +30,9 @@ get:
|
|||
debug:
|
||||
make build && sudo GOPROXY_DEBUG=1 bin/go-proxy
|
||||
|
||||
run:
|
||||
BUILD_FLAG="-s -w" make build && sudo bin/go-proxy
|
||||
|
||||
archive:
|
||||
git archive HEAD -o ../go-proxy-$$(date +"%Y%m%d%H%M").zip
|
||||
|
||||
|
@ -44,4 +48,4 @@ rapid-crash:
|
|||
sudo docker rm -f test_crash
|
||||
|
||||
debug-list-containers:
|
||||
bash -c 'echo -e "GET /containers/json HTTP/1.0\r\n" | sudo netcat -U /var/run/docker.sock | tail -n +9 | jq'
|
||||
bash -c 'echo -e "GET /containers/json HTTP/1.0\r\n" | sudo netcat -U /var/run/docker.sock | tail -n +9 | jq'
|
||||
|
|
|
@ -1,36 +1,64 @@
|
|||
# Autocert (choose one below and uncomment to enable)
|
||||
|
||||
#
|
||||
# 1. use existing cert
|
||||
#
|
||||
# autocert:
|
||||
# provider: local
|
||||
# cert_path: certs/cert.crt # optional, uncomment only if you need to change it
|
||||
# key_path: certs/priv.key # optional, uncomment only if you need to change it
|
||||
|
||||
#
|
||||
# cert_path: certs/cert.crt # optional, uncomment only if you need to change it
|
||||
# key_path: certs/priv.key # optional, uncomment only if you need to change it
|
||||
#
|
||||
# 2. cloudflare
|
||||
#
|
||||
# autocert:
|
||||
# provider: cloudflare
|
||||
# email: # ACME Email
|
||||
# domains: # a list of domains for cert registration
|
||||
# - x.y.z
|
||||
# email: abc@gmail.com # ACME Email
|
||||
# domains: # a list of domains for cert registration
|
||||
# - "*.y.z" # remember to use double quotes to surround wildcard domain
|
||||
# options:
|
||||
# auth_token: c1234565789-abcdefghijklmnopqrst # your zone API token
|
||||
|
||||
# auth_token: c1234565789-abcdefghijklmnopqrst # your zone API token
|
||||
#
|
||||
# 3. other providers, check docs/dns_providers.md for more
|
||||
|
||||
providers:
|
||||
# include files are standalone yaml files under `config/` directory
|
||||
#
|
||||
# include:
|
||||
# - providers.yml # config/providers.yml
|
||||
# # add some more below if you want
|
||||
# - file1.yml # config/file_1.yml
|
||||
# - file1.yml
|
||||
# - file2.yml
|
||||
|
||||
docker:
|
||||
# for value format, see https://docs.docker.com/reference/cli/dockerd/
|
||||
# $DOCKER_HOST implies unix:///var/run/docker.sock by default
|
||||
# $DOCKER_HOST implies environment variable `DOCKER_HOST` or unix:///var/run/docker.sock by default
|
||||
local: $DOCKER_HOST
|
||||
|
||||
# add more docker providers if needed
|
||||
# for value format, see https://docs.docker.com/reference/cli/dockerd/
|
||||
#
|
||||
# remote-1: tcp://10.0.2.1:2375
|
||||
# remote-2: ssh://root:1234@10.0.2.2
|
||||
# Fixed options (optional, non hot-reloadable)
|
||||
# if match_domains not defined
|
||||
# any host = alias+[any domain] will match
|
||||
# i.e. https://app1.y.z will match alias app1 for any domain y.z
|
||||
# but https://app1.node1.y.z will only match alias "app.node1"
|
||||
#
|
||||
# if match_domains defined
|
||||
# only host = alias+[one of match_domains] will match
|
||||
# i.e. match_domains = [node1.my.app, my.site]
|
||||
# https://app1.my.app, https://app1.my.net, etc. will not match even if app1 exists
|
||||
# only https://*.node1.my.app and https://*.my.site will match
|
||||
#
|
||||
#
|
||||
# match_domains:
|
||||
# - my.site
|
||||
# - node1.my.app
|
||||
|
||||
# Below are fixed options (non hot-reloadable)
|
||||
|
||||
# timeout for shutdown (in seconds)
|
||||
#
|
||||
# timeout_shutdown: 5
|
||||
# redirect_to_https: false # redirect http requests to https (if enabled)
|
||||
|
||||
# global setting redirect http requests to https (if https available, otherwise this will be ignored)
|
||||
# proxy.<alias>.middlewares.redirect_http will override this
|
||||
#
|
||||
# redirect_to_https: false
|
||||
|
|
|
@ -74,7 +74,7 @@
|
|||
| `proxy.stop_timeout` | time to wait for stop command | | `10s` | `number[unit]...` |
|
||||
| `proxy.stop_signal` | signal sent to container for `stop` and `kill` methods | | docker's default | `SIGINT`, `SIGTERM`, `SIGHUP`, `SIGQUIT` and those without **SIG** prefix |
|
||||
| `proxy.<alias>.<field>` | set field for specific alias | `proxy.gitlab-ssh.scheme` | N/A | N/A |
|
||||
| `proxy.$<index>.<field>` | set field for specific alias at index (starting from **1**) | `proxy.$3.port` | N/A | N/A |
|
||||
| `proxy.#<index>.<field>` | set field for specific alias at index (starting from **1**) | `proxy.#3.port` | N/A | N/A |
|
||||
| `proxy.*.<field>` | set field for all aliases | `proxy.*.set_headers` | N/A | N/A |
|
||||
|
||||
### Fields
|
||||
|
|
|
@ -37,7 +37,13 @@
|
|||
"title": "DNS Challenge Provider",
|
||||
"default": "local",
|
||||
"type": "string",
|
||||
"enum": ["local", "cloudflare", "clouddns", "duckdns", "ovh"]
|
||||
"enum": [
|
||||
"local",
|
||||
"cloudflare",
|
||||
"clouddns",
|
||||
"duckdns",
|
||||
"ovh"
|
||||
]
|
||||
},
|
||||
"options": {
|
||||
"title": "Provider specific options",
|
||||
|
@ -56,7 +62,12 @@
|
|||
}
|
||||
},
|
||||
"then": {
|
||||
"required": ["email", "domains", "provider", "options"]
|
||||
"required": [
|
||||
"email",
|
||||
"domains",
|
||||
"provider",
|
||||
"options"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
|
@ -70,7 +81,9 @@
|
|||
"then": {
|
||||
"properties": {
|
||||
"options": {
|
||||
"required": ["auth_token"],
|
||||
"required": [
|
||||
"auth_token"
|
||||
],
|
||||
"additionalProperties": false,
|
||||
"properties": {
|
||||
"auth_token": {
|
||||
|
@ -93,7 +106,11 @@
|
|||
"then": {
|
||||
"properties": {
|
||||
"options": {
|
||||
"required": ["client_id", "email", "password"],
|
||||
"required": [
|
||||
"client_id",
|
||||
"email",
|
||||
"password"
|
||||
],
|
||||
"additionalProperties": false,
|
||||
"properties": {
|
||||
"client_id": {
|
||||
|
@ -124,7 +141,9 @@
|
|||
"then": {
|
||||
"properties": {
|
||||
"options": {
|
||||
"required": ["token"],
|
||||
"required": [
|
||||
"token"
|
||||
],
|
||||
"additionalProperties": false,
|
||||
"properties": {
|
||||
"token": {
|
||||
|
@ -147,14 +166,21 @@
|
|||
"then": {
|
||||
"properties": {
|
||||
"options": {
|
||||
"required": ["application_secret", "consumer_key"],
|
||||
"required": [
|
||||
"application_secret",
|
||||
"consumer_key"
|
||||
],
|
||||
"additionalProperties": false,
|
||||
"oneOf": [
|
||||
{
|
||||
"required": ["application_key"]
|
||||
"required": [
|
||||
"application_key"
|
||||
]
|
||||
},
|
||||
{
|
||||
"required": ["oauth2_config"]
|
||||
"required": [
|
||||
"oauth2_config"
|
||||
]
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
|
@ -205,7 +231,10 @@
|
|||
"type": "string"
|
||||
}
|
||||
},
|
||||
"required": ["client_id", "client_secret"]
|
||||
"required": [
|
||||
"client_id",
|
||||
"client_secret"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -268,6 +297,14 @@
|
|||
}
|
||||
}
|
||||
},
|
||||
"match_domains": {
|
||||
"title": "Domains to match",
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
},
|
||||
"minItems": 1
|
||||
},
|
||||
"timeout_shutdown": {
|
||||
"title": "Shutdown timeout (in seconds)",
|
||||
"type": "integer",
|
||||
|
@ -279,5 +316,7 @@
|
|||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": ["providers"]
|
||||
}
|
||||
"required": [
|
||||
"providers"
|
||||
]
|
||||
}
|
|
@ -5,6 +5,7 @@ import (
|
|||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"os"
|
||||
"path"
|
||||
"reflect"
|
||||
"sort"
|
||||
"time"
|
||||
|
@ -59,8 +60,7 @@ func (p *Provider) ObtainCert() (res E.NestedError) {
|
|||
defer b.To(&res)
|
||||
|
||||
if p.cfg.Provider == ProviderLocal {
|
||||
b.Addf("provider is set to %q", ProviderLocal).WithSeverity(E.SeverityWarning)
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
if p.client == nil {
|
||||
|
@ -191,7 +191,19 @@ func (p *Provider) registerACME() E.NestedError {
|
|||
}
|
||||
|
||||
func (p *Provider) saveCert(cert *certificate.Resource) E.NestedError {
|
||||
err := os.WriteFile(p.cfg.KeyPath, cert.PrivateKey, 0o600) // -rw-------
|
||||
//* This should have been done in setup
|
||||
//* but double check is always a good choice
|
||||
_, err := os.Stat(path.Dir(p.cfg.CertPath))
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
if err = os.MkdirAll(path.Dir(p.cfg.CertPath), 0o755); err != nil {
|
||||
return E.FailWith("create cert directory", err)
|
||||
}
|
||||
} else {
|
||||
return E.FailWith("stat cert directory", err)
|
||||
}
|
||||
}
|
||||
err = os.WriteFile(p.cfg.KeyPath, cert.PrivateKey, 0o600) // -rw-------
|
||||
if err != nil {
|
||||
return E.FailWith("write key file", err)
|
||||
}
|
||||
|
@ -227,6 +239,10 @@ func (p *Provider) certState() CertState {
|
|||
}
|
||||
|
||||
func (p *Provider) renewIfNeeded() E.NestedError {
|
||||
if p.cfg.Provider == ProviderLocal {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch p.certState() {
|
||||
case CertStateExpired:
|
||||
logger.Info("certs expired, renewing")
|
||||
|
|
|
@ -14,7 +14,7 @@ func (p *Provider) Setup(ctx context.Context) (err E.NestedError) {
|
|||
}
|
||||
logger.Debug("obtaining cert due to error loading cert")
|
||||
if err = p.ObtainCert(); err != nil {
|
||||
return err.Warn()
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
|
|
25
src/common/http.go
Normal file
25
src/common/http.go
Normal file
|
@ -0,0 +1,25 @@
|
|||
package common
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
defaultDialer = net.Dialer{
|
||||
Timeout: 60 * time.Second,
|
||||
KeepAlive: 60 * time.Second,
|
||||
}
|
||||
DefaultTransport = &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: defaultDialer.DialContext,
|
||||
MaxIdleConnsPerHost: 1000,
|
||||
}
|
||||
DefaultTransportNoTLS = func() *http.Transport {
|
||||
var clone = DefaultTransport.Clone()
|
||||
clone.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
|
||||
return clone
|
||||
}()
|
||||
)
|
|
@ -31,25 +31,48 @@ type Config struct {
|
|||
reloadReq chan struct{}
|
||||
}
|
||||
|
||||
func Load() (*Config, E.NestedError) {
|
||||
cfg := &Config{
|
||||
var instance *Config
|
||||
|
||||
func GetConfig() *Config {
|
||||
return instance
|
||||
}
|
||||
|
||||
func Load() E.NestedError {
|
||||
if instance != nil {
|
||||
return nil
|
||||
}
|
||||
instance = &Config{
|
||||
value: M.DefaultConfig(),
|
||||
proxyProviders: F.NewMapOf[string, *PR.Provider](),
|
||||
l: logrus.WithField("module", "config"),
|
||||
watcher: W.NewFileWatcher(common.ConfigFileName),
|
||||
reloadReq: make(chan struct{}, 1),
|
||||
}
|
||||
return cfg, cfg.load()
|
||||
return instance.load()
|
||||
}
|
||||
|
||||
func Validate(data []byte) E.NestedError {
|
||||
return U.ValidateYaml(U.GetSchema(common.ConfigSchemaPath), data)
|
||||
}
|
||||
|
||||
func MatchDomains() []string {
|
||||
if instance == nil {
|
||||
logrus.Panic("config has not been loaded, please check if there is any errors")
|
||||
}
|
||||
return instance.value.MatchDomains
|
||||
}
|
||||
|
||||
func (cfg *Config) Value() M.Config {
|
||||
if cfg == nil {
|
||||
logrus.Panic("config has not been loaded, please check if there is any errors")
|
||||
}
|
||||
return *cfg.value
|
||||
}
|
||||
|
||||
func (cfg *Config) GetAutoCertProvider() *autocert.Provider {
|
||||
if instance == nil {
|
||||
logrus.Panic("config has not been loaded, please check if there is any errors")
|
||||
}
|
||||
return cfg.autocertProvider
|
||||
}
|
||||
|
||||
|
@ -61,13 +84,11 @@ func (cfg *Config) Dispose() {
|
|||
cfg.stopProviders()
|
||||
}
|
||||
|
||||
func (cfg *Config) Reload() E.NestedError {
|
||||
func (cfg *Config) Reload() (err E.NestedError) {
|
||||
cfg.stopProviders()
|
||||
if err := cfg.load(); err.HasError() {
|
||||
return err
|
||||
}
|
||||
err = cfg.load()
|
||||
cfg.StartProxyProviders()
|
||||
return nil
|
||||
return
|
||||
}
|
||||
|
||||
func (cfg *Config) StartProxyProviders() {
|
||||
|
@ -126,28 +147,28 @@ func (cfg *Config) load() (res E.NestedError) {
|
|||
data, err := E.Check(os.ReadFile(common.ConfigPath))
|
||||
if err.HasError() {
|
||||
b.Add(E.FailWith("read config", err))
|
||||
return
|
||||
logrus.Fatal(b.Build())
|
||||
}
|
||||
|
||||
if !common.NoSchemaValidation {
|
||||
if err = Validate(data); err.HasError() {
|
||||
b.Add(E.FailWith("schema validation", err))
|
||||
return
|
||||
logrus.Fatal(b.Build())
|
||||
}
|
||||
}
|
||||
|
||||
model := M.DefaultConfig()
|
||||
if err := E.From(yaml.Unmarshal(data, model)); err.HasError() {
|
||||
b.Add(E.FailWith("parse config", err))
|
||||
return
|
||||
logrus.Fatal(b.Build())
|
||||
}
|
||||
|
||||
// errors are non fatal below
|
||||
b.WithSeverity(E.SeverityWarning)
|
||||
b.Add(cfg.initAutoCert(&model.AutoCert))
|
||||
b.Add(cfg.loadProviders(&model.Providers))
|
||||
|
||||
cfg.value = model
|
||||
R.SetFindMuxDomains(model.MatchDomains)
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -1,17 +1,37 @@
|
|||
package docker
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
E "github.com/yusing/go-proxy/error"
|
||||
U "github.com/yusing/go-proxy/utils"
|
||||
F "github.com/yusing/go-proxy/utils/functional"
|
||||
)
|
||||
|
||||
type Label struct {
|
||||
Namespace string
|
||||
Target string
|
||||
Attribute string
|
||||
Value any
|
||||
/*
|
||||
Formats:
|
||||
- namespace.attribute
|
||||
- namespace.target.attribute
|
||||
- namespace.target.attribute.namespace2.attribute
|
||||
*/
|
||||
type (
|
||||
Label struct {
|
||||
Namespace string
|
||||
Target string
|
||||
Attribute string
|
||||
Value any
|
||||
}
|
||||
NestedLabelMap map[string]U.SerializedObject
|
||||
ValueParser func(string) (any, E.NestedError)
|
||||
ValueParserMap map[string]ValueParser
|
||||
)
|
||||
|
||||
func (l *Label) String() string {
|
||||
if l.Attribute == "" {
|
||||
return l.Namespace + "." + l.Target
|
||||
}
|
||||
return l.Namespace + "." + l.Target + "." + l.Attribute
|
||||
}
|
||||
|
||||
// Apply applies the value of a Label to the corresponding field in the given object.
|
||||
|
@ -23,12 +43,40 @@ type Label struct {
|
|||
// Returns:
|
||||
// - error: an error if the field does not exist.
|
||||
func ApplyLabel[T any](obj *T, l *Label) E.NestedError {
|
||||
return U.Deserialize(map[string]any{l.Attribute: l.Value}, obj)
|
||||
if obj == nil {
|
||||
return E.Invalid("nil object", l)
|
||||
}
|
||||
switch nestedLabel := l.Value.(type) {
|
||||
case *Label:
|
||||
var field reflect.Value
|
||||
objType := reflect.TypeFor[T]()
|
||||
for i := 0; i < reflect.TypeFor[T]().NumField(); i++ {
|
||||
if objType.Field(i).Tag.Get("yaml") == l.Attribute {
|
||||
field = reflect.ValueOf(obj).Elem().Field(i)
|
||||
break
|
||||
}
|
||||
}
|
||||
if !field.IsValid() {
|
||||
return E.NotExist("field", l.Attribute)
|
||||
}
|
||||
dst, ok := field.Interface().(NestedLabelMap)
|
||||
if !ok {
|
||||
return E.Invalid("type", field.Type())
|
||||
}
|
||||
if dst == nil {
|
||||
field.Set(reflect.MakeMap(reflect.TypeFor[NestedLabelMap]()))
|
||||
dst = field.Interface().(NestedLabelMap)
|
||||
}
|
||||
if dst[nestedLabel.Namespace] == nil {
|
||||
dst[nestedLabel.Namespace] = make(U.SerializedObject)
|
||||
}
|
||||
dst[nestedLabel.Namespace][nestedLabel.Attribute] = nestedLabel.Value
|
||||
return nil
|
||||
default:
|
||||
return U.Deserialize(U.SerializedObject{l.Attribute: l.Value}, obj)
|
||||
}
|
||||
}
|
||||
|
||||
type ValueParser func(string) (any, E.NestedError)
|
||||
type ValueParserMap map[string]ValueParser
|
||||
|
||||
func ParseLabel(label string, value string) (*Label, E.NestedError) {
|
||||
parts := strings.Split(label, ".")
|
||||
|
||||
|
@ -45,14 +93,22 @@ func ParseLabel(label string, value string) (*Label, E.NestedError) {
|
|||
Value: value,
|
||||
}
|
||||
|
||||
if len(parts) == 3 {
|
||||
l.Attribute = parts[2]
|
||||
} else {
|
||||
switch len(parts) {
|
||||
case 2:
|
||||
l.Attribute = l.Target
|
||||
case 3:
|
||||
l.Attribute = parts[2]
|
||||
default:
|
||||
l.Attribute = parts[2]
|
||||
nestedLabel, err := ParseLabel(strings.Join(parts[3:], "."), value)
|
||||
if err.HasError() {
|
||||
return nil, err
|
||||
}
|
||||
l.Value = nestedLabel
|
||||
}
|
||||
|
||||
// find if namespace has value parser
|
||||
pm, ok := labelValueParserMap[l.Namespace]
|
||||
pm, ok := valueParserMap.Load(l.Namespace)
|
||||
if !ok {
|
||||
return l, nil
|
||||
}
|
||||
|
@ -64,15 +120,28 @@ func ParseLabel(label string, value string) (*Label, E.NestedError) {
|
|||
// try to parse value
|
||||
v, err := p(value)
|
||||
if err.HasError() {
|
||||
return nil, err
|
||||
return nil, err.Subject(label)
|
||||
}
|
||||
l.Value = v
|
||||
return l, nil
|
||||
}
|
||||
|
||||
func RegisterNamespace(namespace string, pm ValueParserMap) {
|
||||
labelValueParserMap[namespace] = pm
|
||||
valueParserMap.Store(namespace, pm)
|
||||
}
|
||||
|
||||
func GetRegisteredNamespaces() map[string][]string {
|
||||
r := make(map[string][]string)
|
||||
|
||||
valueParserMap.RangeAll(func(ns string, vpm ValueParserMap) {
|
||||
r[ns] = make([]string, 0, len(vpm))
|
||||
for attr := range vpm {
|
||||
r[ns] = append(r[ns], attr)
|
||||
}
|
||||
})
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
// namespace:target.attribute -> func(string) (any, error)
|
||||
var labelValueParserMap = make(map[string]ValueParserMap)
|
||||
var valueParserMap = F.NewMapOf[string, ValueParserMap]()
|
||||
|
|
|
@ -7,7 +7,27 @@ import (
|
|||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
func yamlListParser(value string) (any, E.NestedError) {
|
||||
const (
|
||||
NSProxy = "proxy"
|
||||
ProxyAttributePathPatterns = "path_patterns"
|
||||
ProxyAttributeNoTLSVerify = "no_tls_verify"
|
||||
ProxyAttributeMiddlewares = "middlewares"
|
||||
)
|
||||
|
||||
var _ = func() int {
|
||||
RegisterNamespace(NSProxy, ValueParserMap{
|
||||
ProxyAttributePathPatterns: YamlStringListParser,
|
||||
ProxyAttributeNoTLSVerify: BoolParser,
|
||||
})
|
||||
return 0
|
||||
}()
|
||||
|
||||
func YamlStringListParser(value string) (any, E.NestedError) {
|
||||
/*
|
||||
- foo
|
||||
- bar
|
||||
- baz
|
||||
*/
|
||||
value = strings.TrimSpace(value)
|
||||
if value == "" {
|
||||
return []string{}, nil
|
||||
|
@ -17,27 +37,36 @@ func yamlListParser(value string) (any, E.NestedError) {
|
|||
return data, err
|
||||
}
|
||||
|
||||
func yamlStringMappingParser(value string) (any, E.NestedError) {
|
||||
value = strings.TrimSpace(value)
|
||||
lines := strings.Split(value, "\n")
|
||||
h := make(map[string]string)
|
||||
for _, line := range lines {
|
||||
parts := strings.SplitN(line, ":", 2)
|
||||
if len(parts) != 2 {
|
||||
return nil, E.Invalid("set header statement", line)
|
||||
}
|
||||
key := strings.TrimSpace(parts[0])
|
||||
val := strings.TrimSpace(parts[1])
|
||||
if existing, ok := h[key]; ok {
|
||||
h[key] = existing + ", " + val
|
||||
} else {
|
||||
h[key] = val
|
||||
func YamlLikeMappingParser(allowDuplicate bool) func(string) (any, E.NestedError) {
|
||||
return func(value string) (any, E.NestedError) {
|
||||
/*
|
||||
foo: bar
|
||||
boo: baz
|
||||
*/
|
||||
value = strings.TrimSpace(value)
|
||||
lines := strings.Split(value, "\n")
|
||||
h := make(map[string]string)
|
||||
for _, line := range lines {
|
||||
parts := strings.SplitN(line, ":", 2)
|
||||
if len(parts) != 2 {
|
||||
return nil, E.Invalid("syntax", line)
|
||||
}
|
||||
key := strings.TrimSpace(parts[0])
|
||||
val := strings.TrimSpace(parts[1])
|
||||
if existing, ok := h[key]; ok {
|
||||
if !allowDuplicate {
|
||||
return nil, E.Duplicated("key", key)
|
||||
}
|
||||
h[key] = existing + ", " + val
|
||||
} else {
|
||||
h[key] = val
|
||||
}
|
||||
}
|
||||
return h, nil
|
||||
}
|
||||
return h, nil
|
||||
}
|
||||
|
||||
func boolParser(value string) (any, E.NestedError) {
|
||||
func BoolParser(value string) (any, E.NestedError) {
|
||||
switch strings.ToLower(value) {
|
||||
case "true", "yes", "1":
|
||||
return true, nil
|
||||
|
@ -47,15 +76,3 @@ func boolParser(value string) (any, E.NestedError) {
|
|||
return nil, E.Invalid("boolean value", value)
|
||||
}
|
||||
}
|
||||
|
||||
const NSProxy = "proxy"
|
||||
|
||||
var _ = func() int {
|
||||
RegisterNamespace(NSProxy, ValueParserMap{
|
||||
"path_patterns": yamlListParser,
|
||||
"set_headers": yamlStringMappingParser,
|
||||
"hide_headers": yamlListParser,
|
||||
"no_tls_verify": boolParser,
|
||||
})
|
||||
return 0
|
||||
}()
|
||||
|
|
|
@ -2,8 +2,6 @@ package docker
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
E "github.com/yusing/go-proxy/error"
|
||||
|
@ -14,21 +12,16 @@ func makeLabel(namespace string, alias string, field string) string {
|
|||
return fmt.Sprintf("%s.%s.%s", namespace, alias, field)
|
||||
}
|
||||
|
||||
func TestHomePageLabel(t *testing.T) {
|
||||
func TestParseLabel(t *testing.T) {
|
||||
alias := "foo"
|
||||
field := "ip"
|
||||
v := "bar"
|
||||
pl, err := ParseLabel(makeLabel(NSHomePage, alias, field), v)
|
||||
ExpectNoError(t, err.Error())
|
||||
if pl.Target != alias {
|
||||
t.Errorf("Expected alias=%s, got %s", alias, pl.Target)
|
||||
}
|
||||
if pl.Attribute != field {
|
||||
t.Errorf("Expected field=%s, got %s", field, pl.Target)
|
||||
}
|
||||
if pl.Value != v {
|
||||
t.Errorf("Expected value=%q, got %s", v, pl.Value)
|
||||
}
|
||||
ExpectEqual(t, pl.Namespace, NSHomePage)
|
||||
ExpectEqual(t, pl.Target, alias)
|
||||
ExpectEqual(t, pl.Attribute, field)
|
||||
ExpectEqual(t, pl.Value.(string), v)
|
||||
}
|
||||
|
||||
func TestStringProxyLabel(t *testing.T) {
|
||||
|
@ -51,90 +44,63 @@ func TestBoolProxyLabelValid(t *testing.T) {
|
|||
}
|
||||
|
||||
for k, v := range tests {
|
||||
pl, err := ParseLabel(makeLabel(NSProxy, "foo", "no_tls_verify"), k)
|
||||
pl, err := ParseLabel(makeLabel(NSProxy, "foo", ProxyAttributeNoTLSVerify), k)
|
||||
ExpectNoError(t, err.Error())
|
||||
ExpectEqual(t, pl.Value.(bool), v)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBoolProxyLabelInvalid(t *testing.T) {
|
||||
alias := "foo"
|
||||
field := "no_tls_verify"
|
||||
_, err := ParseLabel(makeLabel(NSProxy, alias, field), "invalid")
|
||||
_, err := ParseLabel(makeLabel(NSProxy, "foo", ProxyAttributeNoTLSVerify), "invalid")
|
||||
if !err.Is(E.ErrInvalid) {
|
||||
t.Errorf("Expected err InvalidProxyLabel, got %s", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetHeaderProxyLabelValid(t *testing.T) {
|
||||
v := `
|
||||
X-Custom-Header1: foo, bar
|
||||
X-Custom-Header1: baz
|
||||
X-Custom-Header2: boo`
|
||||
v = strings.TrimPrefix(v, "\n")
|
||||
h := map[string]string{
|
||||
"X-Custom-Header1": "foo, bar, baz",
|
||||
"X-Custom-Header2": "boo",
|
||||
}
|
||||
// func TestSetHeaderProxyLabelValid(t *testing.T) {
|
||||
// v := `
|
||||
// X-Custom-Header1: foo, bar
|
||||
// X-Custom-Header1: baz
|
||||
// X-Custom-Header2: boo`
|
||||
// v = strings.TrimPrefix(v, "\n")
|
||||
// h := map[string]string{
|
||||
// "X-Custom-Header1": "foo, bar, baz",
|
||||
// "X-Custom-Header2": "boo",
|
||||
// }
|
||||
|
||||
pl, err := ParseLabel(makeLabel(NSProxy, "foo", "set_headers"), v)
|
||||
ExpectNoError(t, err.Error())
|
||||
hGot := ExpectType[map[string]string](t, pl.Value)
|
||||
if hGot != nil && !reflect.DeepEqual(h, hGot) {
|
||||
t.Errorf("Expected %v, got %v", h, hGot)
|
||||
}
|
||||
// pl, err := ParseLabel(makeLabel(NSProxy, "foo", ProxyAttributeSetHeaders), v)
|
||||
// ExpectNoError(t, err.Error())
|
||||
// hGot := ExpectType[map[string]string](t, pl.Value)
|
||||
// ExpectFalse(t, hGot == nil)
|
||||
// ExpectDeepEqual(t, h, hGot)
|
||||
// }
|
||||
|
||||
}
|
||||
// func TestSetHeaderProxyLabelInvalid(t *testing.T) {
|
||||
// tests := []string{
|
||||
// "X-Custom-Header1 = bar",
|
||||
// "X-Custom-Header1",
|
||||
// "- X-Custom-Header1",
|
||||
// }
|
||||
|
||||
func TestSetHeaderProxyLabelInvalid(t *testing.T) {
|
||||
tests := []string{
|
||||
"X-Custom-Header1 = bar",
|
||||
"X-Custom-Header1",
|
||||
"- X-Custom-Header1",
|
||||
}
|
||||
|
||||
for _, v := range tests {
|
||||
_, err := ParseLabel(makeLabel(NSProxy, "foo", "set_headers"), v)
|
||||
if !err.Is(E.ErrInvalid) {
|
||||
t.Errorf("Expected invalid err for %q, got %s", v, err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHideHeadersProxyLabel(t *testing.T) {
|
||||
v := `
|
||||
- X-Custom-Header1
|
||||
- X-Custom-Header2
|
||||
- X-Custom-Header3
|
||||
`
|
||||
v = strings.TrimPrefix(v, "\n")
|
||||
pl, err := ParseLabel(makeLabel(NSProxy, "foo", "hide_headers"), v)
|
||||
ExpectNoError(t, err.Error())
|
||||
sGot := ExpectType[[]string](t, pl.Value)
|
||||
sWant := []string{"X-Custom-Header1", "X-Custom-Header2", "X-Custom-Header3"}
|
||||
if sGot != nil {
|
||||
ExpectDeepEqual(t, sGot, sWant)
|
||||
}
|
||||
}
|
||||
|
||||
// func TestCommaSepProxyLabelSingle(t *testing.T) {
|
||||
// v := "a"
|
||||
// pl, err := ParseLabel("proxy.aliases", v)
|
||||
// ExpectNoError(t, err)
|
||||
// sGot := ExpectType[[]string](t, pl.Value)
|
||||
// sWant := []string{"a"}
|
||||
// if sGot != nil {
|
||||
// ExpectEqual(t, sGot, sWant)
|
||||
// for _, v := range tests {
|
||||
// _, err := ParseLabel(makeLabel(NSProxy, "foo", ProxyAttributeSetHeaders), v)
|
||||
// if !err.Is(E.ErrInvalid) {
|
||||
// t.Errorf("Expected invalid err for %q, got %s", v, err.Error())
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
// func TestCommaSepProxyLabelMulti(t *testing.T) {
|
||||
// v := "X-Custom-Header1, X-Custom-Header2,X-Custom-Header3"
|
||||
// pl, err := ParseLabel("proxy.aliases", v)
|
||||
// ExpectNoError(t, err)
|
||||
// func TestHideHeadersProxyLabel(t *testing.T) {
|
||||
// v := `
|
||||
// - X-Custom-Header1
|
||||
// - X-Custom-Header2
|
||||
// - X-Custom-Header3
|
||||
// `
|
||||
// v = strings.TrimPrefix(v, "\n")
|
||||
// pl, err := ParseLabel(makeLabel(NSProxy, "foo", ProxyAttributeHideHeaders), v)
|
||||
// ExpectNoError(t, err.Error())
|
||||
// sGot := ExpectType[[]string](t, pl.Value)
|
||||
// sWant := []string{"X-Custom-Header1", "X-Custom-Header2", "X-Custom-Header3"}
|
||||
// if sGot != nil {
|
||||
// ExpectEqual(t, sGot, sWant)
|
||||
// }
|
||||
// ExpectFalse(t, sGot == nil)
|
||||
// ExpectDeepEqual(t, sGot, sWant)
|
||||
// }
|
||||
|
|
85
src/docker/label_test.go
Normal file
85
src/docker/label_test.go
Normal file
|
@ -0,0 +1,85 @@
|
|||
package docker
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
U "github.com/yusing/go-proxy/utils"
|
||||
. "github.com/yusing/go-proxy/utils/testing"
|
||||
)
|
||||
|
||||
func TestNestedLabel(t *testing.T) {
|
||||
mName := "middleware1"
|
||||
mAttr := "prop1"
|
||||
v := "value1"
|
||||
pl, err := ParseLabel(makeLabel(NSProxy, "foo", fmt.Sprintf("%s.%s.%s", ProxyAttributeMiddlewares, mName, mAttr)), v)
|
||||
ExpectNoError(t, err.Error())
|
||||
sGot := ExpectType[*Label](t, pl.Value)
|
||||
ExpectFalse(t, sGot == nil)
|
||||
ExpectEqual(t, sGot.Namespace, mName)
|
||||
ExpectEqual(t, sGot.Attribute, mAttr)
|
||||
}
|
||||
|
||||
func TestApplyNestedLabel(t *testing.T) {
|
||||
entry := new(struct {
|
||||
Middlewares NestedLabelMap `yaml:"middlewares"`
|
||||
})
|
||||
mName := "middleware1"
|
||||
mAttr := "prop1"
|
||||
v := "value1"
|
||||
pl, err := ParseLabel(makeLabel(NSProxy, "foo", fmt.Sprintf("%s.%s.%s", ProxyAttributeMiddlewares, mName, mAttr)), v)
|
||||
ExpectNoError(t, err.Error())
|
||||
err = ApplyLabel(entry, pl)
|
||||
ExpectNoError(t, err.Error())
|
||||
middleware1, ok := entry.Middlewares[mName]
|
||||
ExpectTrue(t, ok)
|
||||
got := ExpectType[string](t, middleware1[mAttr])
|
||||
ExpectEqual(t, got, v)
|
||||
}
|
||||
|
||||
func TestApplyNestedLabelExisting(t *testing.T) {
|
||||
mName := "middleware1"
|
||||
mAttr := "prop1"
|
||||
v := "value1"
|
||||
|
||||
checkAttr := "prop2"
|
||||
checkV := "value2"
|
||||
entry := new(struct {
|
||||
Middlewares NestedLabelMap `yaml:"middlewares"`
|
||||
})
|
||||
entry.Middlewares = make(NestedLabelMap)
|
||||
entry.Middlewares[mName] = make(U.SerializedObject)
|
||||
entry.Middlewares[mName][checkAttr] = checkV
|
||||
|
||||
pl, err := ParseLabel(makeLabel(NSProxy, "foo", fmt.Sprintf("%s.%s.%s", ProxyAttributeMiddlewares, mName, mAttr)), v)
|
||||
ExpectNoError(t, err.Error())
|
||||
err = ApplyLabel(entry, pl)
|
||||
ExpectNoError(t, err.Error())
|
||||
middleware1, ok := entry.Middlewares[mName]
|
||||
ExpectTrue(t, ok)
|
||||
got := ExpectType[string](t, middleware1[mAttr])
|
||||
ExpectEqual(t, got, v)
|
||||
|
||||
// check if prop2 is affected
|
||||
ExpectFalse(t, middleware1[checkAttr] == nil)
|
||||
got = ExpectType[string](t, middleware1[checkAttr])
|
||||
ExpectEqual(t, got, checkV)
|
||||
}
|
||||
|
||||
func TestApplyNestedLabelNoAttr(t *testing.T) {
|
||||
mName := "middleware1"
|
||||
v := "value1"
|
||||
|
||||
entry := new(struct {
|
||||
Middlewares NestedLabelMap `yaml:"middlewares"`
|
||||
})
|
||||
entry.Middlewares = make(NestedLabelMap)
|
||||
entry.Middlewares[mName] = make(U.SerializedObject)
|
||||
|
||||
pl, err := ParseLabel(makeLabel(NSProxy, "foo", fmt.Sprintf("%s.%s", ProxyAttributeMiddlewares, mName)), v)
|
||||
ExpectNoError(t, err.Error())
|
||||
err = ApplyLabel(entry, pl)
|
||||
ExpectNoError(t, err.Error())
|
||||
_, ok := entry.Middlewares[mName]
|
||||
ExpectTrue(t, ok)
|
||||
}
|
|
@ -10,9 +10,8 @@ type Builder struct {
|
|||
}
|
||||
|
||||
type builder struct {
|
||||
message string
|
||||
errors []NestedError
|
||||
severity Severity
|
||||
message string
|
||||
errors []NestedError
|
||||
sync.Mutex
|
||||
}
|
||||
|
||||
|
@ -40,11 +39,6 @@ func (b Builder) Addf(format string, args ...any) Builder {
|
|||
return b.Add(errorf(format, args...))
|
||||
}
|
||||
|
||||
func (b Builder) WithSeverity(s Severity) Builder {
|
||||
b.severity = s
|
||||
return b
|
||||
}
|
||||
|
||||
// Build builds a NestedError based on the errors collected in the Builder.
|
||||
//
|
||||
// If there are no errors in the Builder, it returns a Nil() NestedError.
|
||||
|
@ -58,7 +52,7 @@ func (b Builder) Build() NestedError {
|
|||
} else if len(b.errors) == 1 {
|
||||
return b.errors[0]
|
||||
}
|
||||
return Join(b.message, b.errors...).Severity(b.severity)
|
||||
return Join(b.message, b.errors...)
|
||||
}
|
||||
|
||||
func (b Builder) To(ptr *NestedError) {
|
||||
|
|
|
@ -9,17 +9,10 @@ import (
|
|||
type (
|
||||
NestedError = *nestedError
|
||||
nestedError struct {
|
||||
subject string
|
||||
err error
|
||||
extras []nestedError
|
||||
severity Severity
|
||||
subject string
|
||||
err error
|
||||
extras []nestedError
|
||||
}
|
||||
Severity uint8
|
||||
)
|
||||
|
||||
const (
|
||||
SeverityWarning Severity = iota
|
||||
SeverityFatal
|
||||
)
|
||||
|
||||
func From(err error) NestedError {
|
||||
|
@ -164,22 +157,6 @@ func (ne NestedError) Subjectf(format string, args ...any) NestedError {
|
|||
return ne
|
||||
}
|
||||
|
||||
func (ne NestedError) Severity(s Severity) NestedError {
|
||||
if ne == nil {
|
||||
return ne
|
||||
}
|
||||
ne.severity = s
|
||||
return ne
|
||||
}
|
||||
|
||||
func (ne NestedError) Warn() NestedError {
|
||||
if ne == nil {
|
||||
return ne
|
||||
}
|
||||
ne.severity = SeverityWarning
|
||||
return ne
|
||||
}
|
||||
|
||||
func (ne NestedError) NoError() bool {
|
||||
return ne == nil
|
||||
}
|
||||
|
@ -188,14 +165,6 @@ func (ne NestedError) HasError() bool {
|
|||
return ne != nil
|
||||
}
|
||||
|
||||
func (ne NestedError) IsFatal() bool {
|
||||
return ne != nil && ne.severity == SeverityFatal
|
||||
}
|
||||
|
||||
func (ne NestedError) IsWarning() bool {
|
||||
return ne != nil && ne.severity == SeverityWarning
|
||||
}
|
||||
|
||||
func errorf(format string, args ...any) NestedError {
|
||||
return From(fmt.Errorf(format, args...))
|
||||
}
|
||||
|
|
|
@ -31,11 +31,11 @@ func TestErrorNestedIs(t *testing.T) {
|
|||
|
||||
err = Failure("some reason")
|
||||
ExpectTrue(t, err.Is(ErrFailure))
|
||||
ExpectFalse(t, err.Is(ErrAlreadyExist))
|
||||
ExpectFalse(t, err.Is(ErrDuplicated))
|
||||
|
||||
err.With(AlreadyExist("something", ""))
|
||||
err.With(Duplicated("something", ""))
|
||||
ExpectTrue(t, err.Is(ErrFailure))
|
||||
ExpectTrue(t, err.Is(ErrAlreadyExist))
|
||||
ExpectTrue(t, err.Is(ErrDuplicated))
|
||||
ExpectFalse(t, err.Is(ErrInvalid))
|
||||
}
|
||||
|
||||
|
|
|
@ -5,14 +5,14 @@ import (
|
|||
)
|
||||
|
||||
var (
|
||||
ErrFailure = stderrors.New("failed")
|
||||
ErrInvalid = stderrors.New("invalid")
|
||||
ErrUnsupported = stderrors.New("unsupported")
|
||||
ErrUnexpected = stderrors.New("unexpected")
|
||||
ErrNotExists = stderrors.New("does not exist")
|
||||
ErrMissing = stderrors.New("missing")
|
||||
ErrAlreadyExist = stderrors.New("already exist")
|
||||
ErrOutOfRange = stderrors.New("out of range")
|
||||
ErrFailure = stderrors.New("failed")
|
||||
ErrInvalid = stderrors.New("invalid")
|
||||
ErrUnsupported = stderrors.New("unsupported")
|
||||
ErrUnexpected = stderrors.New("unexpected")
|
||||
ErrNotExists = stderrors.New("does not exist")
|
||||
ErrMissing = stderrors.New("missing")
|
||||
ErrDuplicated = stderrors.New("duplicated")
|
||||
ErrOutOfRange = stderrors.New("out of range")
|
||||
)
|
||||
|
||||
const fmtSubjectWhat = "%w %v: %q"
|
||||
|
@ -53,8 +53,8 @@ func Missing(subject any) NestedError {
|
|||
return errorf("%w %v", ErrMissing, subject)
|
||||
}
|
||||
|
||||
func AlreadyExist(subject, what any) NestedError {
|
||||
return errorf("%v %w: %v", subject, ErrAlreadyExist, what)
|
||||
func Duplicated(subject, what any) NestedError {
|
||||
return errorf("%w %v: %v", ErrDuplicated, subject, what)
|
||||
}
|
||||
|
||||
func OutOfRange(subject string, value any) NestedError {
|
||||
|
|
|
@ -17,7 +17,7 @@ require (
|
|||
require (
|
||||
github.com/Microsoft/go-winio v0.6.2 // indirect
|
||||
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
|
||||
github.com/cloudflare/cloudflare-go v0.104.0 // indirect
|
||||
github.com/cloudflare/cloudflare-go v0.105.0 // indirect
|
||||
github.com/containerd/log v0.1.0 // indirect
|
||||
github.com/distribution/reference v0.6.0 // indirect
|
||||
github.com/docker/go-connections v0.5.0 // indirect
|
||||
|
|
|
@ -4,8 +4,8 @@ github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERo
|
|||
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
|
||||
github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8=
|
||||
github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
|
||||
github.com/cloudflare/cloudflare-go v0.104.0 h1:R/lB0dZupaZbOgibAH/BRrkFbZ6Acn/WsKg2iX2xXuY=
|
||||
github.com/cloudflare/cloudflare-go v0.104.0/go.mod h1:pfUQ4PIG4ISI0/Mmc21Bp86UnFU0ktmPf3iTgbSL+cM=
|
||||
github.com/cloudflare/cloudflare-go v0.105.0 h1:yu2IatITLZ4dw7/byzRrlE5DfUvtub0k9CHZ5zBlj90=
|
||||
github.com/cloudflare/cloudflare-go v0.105.0/go.mod h1:pfUQ4PIG4ISI0/Mmc21Bp86UnFU0ktmPf3iTgbSL+cM=
|
||||
github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I=
|
||||
github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
|
|
96
src/http/modify_response_writer.go
Normal file
96
src/http/modify_response_writer.go
Normal file
|
@ -0,0 +1,96 @@
|
|||
// Modified from Traefik Labs's MIT-licensed code (https://github.com/traefik/traefik/blob/master/pkg/middlewares/response_modifier.go)
|
||||
// Copyright (c) 2020-2024 Traefik Labs
|
||||
|
||||
package http
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type ModifyResponseFunc func(*http.Response) error
|
||||
type ModifyResponseWriter struct {
|
||||
w http.ResponseWriter
|
||||
r *http.Request
|
||||
|
||||
headerSent bool
|
||||
code int
|
||||
|
||||
modifier ModifyResponseFunc
|
||||
modified bool
|
||||
modifierErr error
|
||||
}
|
||||
|
||||
func NewModifyResponseWriter(w http.ResponseWriter, r *http.Request, f ModifyResponseFunc) *ModifyResponseWriter {
|
||||
return &ModifyResponseWriter{
|
||||
w: w,
|
||||
r: r,
|
||||
modifier: f,
|
||||
code: http.StatusOK,
|
||||
}
|
||||
}
|
||||
|
||||
func (w *ModifyResponseWriter) WriteHeader(code int) {
|
||||
if w.headerSent {
|
||||
return
|
||||
}
|
||||
|
||||
if code >= http.StatusContinue && code < http.StatusOK {
|
||||
w.w.WriteHeader(code)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
w.headerSent = true
|
||||
w.code = code
|
||||
}()
|
||||
|
||||
if w.modifier == nil || w.modified {
|
||||
w.w.WriteHeader(code)
|
||||
return
|
||||
}
|
||||
|
||||
resp := http.Response{
|
||||
Header: w.w.Header(),
|
||||
Request: w.r,
|
||||
}
|
||||
|
||||
if err := w.modifier(&resp); err != nil {
|
||||
w.modifierErr = err
|
||||
logger.Errorf("error modifying response: %s", err)
|
||||
w.w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.modified = true
|
||||
w.w.WriteHeader(code)
|
||||
}
|
||||
|
||||
func (w *ModifyResponseWriter) Header() http.Header {
|
||||
return w.w.Header()
|
||||
}
|
||||
|
||||
func (w *ModifyResponseWriter) Write(b []byte) (int, error) {
|
||||
w.WriteHeader(w.code)
|
||||
if w.modifierErr != nil {
|
||||
return 0, w.modifierErr
|
||||
}
|
||||
return w.w.Write(b)
|
||||
}
|
||||
|
||||
// Hijack hijacks the connection.
|
||||
func (w *ModifyResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
if h, ok := w.w.(http.Hijacker); ok {
|
||||
return h.Hijack()
|
||||
}
|
||||
|
||||
return nil, nil, fmt.Errorf("not a hijacker: %T", w.w)
|
||||
}
|
||||
|
||||
// Flush sends any buffered data to the client.
|
||||
func (w *ModifyResponseWriter) Flush() {
|
||||
if flusher, ok := w.w.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
|
@ -1,7 +1,13 @@
|
|||
package proxy
|
||||
// Copyright 2011 The Go Authors.
|
||||
// Modified from the Go project under the a BSD-style License (https://cs.opensource.google/go/go/+/refs/tags/go1.23.1:src/net/http/httputil/reverseproxy.go)
|
||||
// https://cs.opensource.google/go/go/+/master:LICENSE
|
||||
|
||||
// A small mod on net/http/httputil/reverseproxy.go
|
||||
// that doubled the performance
|
||||
package http
|
||||
|
||||
// This is a small mod on net/http/httputil/reverseproxy.go
|
||||
// that boosts performance in some cases
|
||||
// and compatible to other modules of this project
|
||||
// Copyright (c) 2024 yusing
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
@ -52,6 +58,21 @@ type ProxyRequest struct {
|
|||
// r.SetXForwarded()
|
||||
// }
|
||||
func (r *ProxyRequest) SetXForwarded() {
|
||||
clientIP, _, err := net.SplitHostPort(r.In.RemoteAddr)
|
||||
if err == nil {
|
||||
r.Out.Header.Set("X-Forwarded-For", clientIP)
|
||||
} else {
|
||||
r.Out.Header.Del("X-Forwarded-For")
|
||||
}
|
||||
r.Out.Header.Set("X-Forwarded-Host", r.In.Host)
|
||||
if r.In.TLS == nil {
|
||||
r.Out.Header.Set("X-Forwarded-Proto", "http")
|
||||
} else {
|
||||
r.Out.Header.Set("X-Forwarded-Proto", "https")
|
||||
}
|
||||
}
|
||||
|
||||
func (r *ProxyRequest) AddXForwarded() {
|
||||
clientIP, _, err := net.SplitHostPort(r.In.RemoteAddr)
|
||||
if err == nil {
|
||||
prior := r.Out.Header["X-Forwarded-For"]
|
||||
|
@ -104,28 +125,6 @@ type ReverseProxy struct {
|
|||
// If nil, http.DefaultTransport is used.
|
||||
Transport http.RoundTripper
|
||||
|
||||
// FlushInterval specifies the flush interval
|
||||
// to flush to the client while copying the
|
||||
// response body.
|
||||
// If zero, no periodic flushing is done.
|
||||
// A negative value means to flush immediately
|
||||
// after each write to the client.
|
||||
// The FlushInterval is ignored when ReverseProxy
|
||||
// recognizes a response as a streaming response, or
|
||||
// if its ContentLength is -1; for such responses, writes
|
||||
// are flushed to the client immediately.
|
||||
// FlushInterval time.Duration
|
||||
|
||||
// ErrorLog specifies an optional logger for errors
|
||||
// that occur when attempting to proxy the request.
|
||||
// If nil, logging is done via the log package's standard logger.
|
||||
// ErrorLog *log.Logger
|
||||
|
||||
// BufferPool optionally specifies a buffer pool to
|
||||
// get byte slices for use by io.CopyBuffer when
|
||||
// copying HTTP response bodies.
|
||||
// BufferPool BufferPool
|
||||
|
||||
// ModifyResponse is an optional function that modifies the
|
||||
// Response from the backend. It is called if the backend
|
||||
// returns a response at all, with any HTTP status code.
|
||||
|
@ -208,36 +207,11 @@ func joinURLPath(a, b *url.URL) (path, rawpath string) {
|
|||
// },
|
||||
// }
|
||||
//
|
||||
// TODO: headers in ModifyResponse
|
||||
func NewReverseProxy(target *url.URL, transport http.RoundTripper, entry *ReverseProxyEntry) *ReverseProxy {
|
||||
// check on init rather than on request
|
||||
var setHeaders = func(r *http.Request) {}
|
||||
var hideHeaders = func(r *http.Request) {}
|
||||
if len(entry.SetHeaders) > 0 {
|
||||
setHeaders = func(r *http.Request) {
|
||||
h := entry.SetHeaders.Clone()
|
||||
for k, vv := range h {
|
||||
if k == "Host" {
|
||||
r.Host = vv[0]
|
||||
} else {
|
||||
r.Header[k] = vv
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(entry.HideHeaders) > 0 {
|
||||
hideHeaders = func(r *http.Request) {
|
||||
for _, k := range entry.HideHeaders {
|
||||
r.Header.Del(k)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func NewReverseProxy(target *url.URL, transport http.RoundTripper) *ReverseProxy {
|
||||
rp := &ReverseProxy{
|
||||
Rewrite: func(pr *ProxyRequest) {
|
||||
rewriteRequestURL(pr.Out, target)
|
||||
// pr.SetXForwarded()
|
||||
setHeaders(pr.Out)
|
||||
hideHeaders(pr.Out)
|
||||
}, Transport: transport,
|
||||
}
|
||||
rp.ServeHTTP = rp.serveHTTP
|
||||
|
@ -256,6 +230,23 @@ func rewriteRequestURL(req *http.Request, target *url.URL) {
|
|||
}
|
||||
}
|
||||
|
||||
// Hop-by-hop headers. These are removed when sent to the backend.
|
||||
// As of RFC 7230, hop-by-hop headers are required to appear in the
|
||||
// Connection header field. These are the headers defined by the
|
||||
// obsoleted RFC 2616 (section 13.5.1) and are used for backward
|
||||
// compatibility.
|
||||
var hopHeaders = []string{
|
||||
"Connection",
|
||||
"Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google
|
||||
"Keep-Alive",
|
||||
"Proxy-Authenticate",
|
||||
"Proxy-Authorization",
|
||||
"Te", // canonicalized version of "TE"
|
||||
"Trailer", // not Trailers per URL above; https://www.rfc-editor.org/errata_search.php?eid=4522
|
||||
"Transfer-Encoding",
|
||||
"Upgrade",
|
||||
}
|
||||
|
||||
func copyHeader(dst, src http.Header) {
|
||||
for k, vv := range src {
|
||||
for _, v := range vv {
|
||||
|
@ -331,12 +322,14 @@ func (p *ReverseProxy) serveHTTP(rw http.ResponseWriter, req *http.Request) {
|
|||
|
||||
outreq.Close = false
|
||||
|
||||
reqUpType := upgradeType(outreq.Header)
|
||||
reqUpType := UpgradeType(outreq.Header)
|
||||
if !IsPrint(reqUpType) {
|
||||
p.errorHandler(rw, req, fmt.Errorf("client tried to switch to invalid protocol %q", reqUpType))
|
||||
return
|
||||
}
|
||||
|
||||
RemoveHopByHopHeaders(outreq.Header)
|
||||
|
||||
// Issue 21096: tell backend applications that care about trailer support
|
||||
// that we support trailers. (We do, but we don't go out of our way to
|
||||
// advertise that unless the incoming client request thought it was worth
|
||||
|
@ -458,16 +451,34 @@ func (p *ReverseProxy) serveHTTP(rw http.ResponseWriter, req *http.Request) {
|
|||
}
|
||||
}
|
||||
|
||||
func upgradeType(h http.Header) string {
|
||||
func UpgradeType(h http.Header) string {
|
||||
if !httpguts.HeaderValuesContainsToken(h["Connection"], "Upgrade") {
|
||||
return ""
|
||||
}
|
||||
return h.Get("Upgrade")
|
||||
}
|
||||
|
||||
// RemoveHopByHopHeaders removes hop-by-hop headers.
|
||||
func RemoveHopByHopHeaders(h http.Header) {
|
||||
// RFC 7230, section 6.1: Remove headers listed in the "Connection" header.
|
||||
for _, f := range h["Connection"] {
|
||||
for _, sf := range strings.Split(f, ",") {
|
||||
if sf = textproto.TrimString(sf); sf != "" {
|
||||
h.Del(sf)
|
||||
}
|
||||
}
|
||||
}
|
||||
// RFC 2616, section 13.5.1: Remove a set of known hop-by-hop headers.
|
||||
// This behavior is superseded by the RFC 7230 Connection header, but
|
||||
// preserve it for backwards compatibility.
|
||||
for _, f := range hopHeaders {
|
||||
h.Del(f)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.Request, res *http.Response) {
|
||||
reqUpType := upgradeType(req.Header)
|
||||
resUpType := upgradeType(res.Header)
|
||||
reqUpType := UpgradeType(req.Header)
|
||||
resUpType := UpgradeType(res.Header)
|
||||
if !IsPrint(resUpType) { // We know reqUpType is ASCII, it's checked by the caller.
|
||||
p.errorHandler(rw, req, fmt.Errorf("backend tried to switch to invalid protocol %q", resUpType))
|
||||
}
|
27
src/main.go
27
src/main.go
|
@ -47,11 +47,10 @@ func main() {
|
|||
logrus.SetOutput(io.Discard)
|
||||
} else {
|
||||
logrus.SetFormatter(&logrus.TextFormatter{
|
||||
DisableSorting: true,
|
||||
DisableLevelTruncation: true,
|
||||
FullTimestamp: true,
|
||||
ForceColors: true,
|
||||
TimestampFormat: "01-02 15:04:05",
|
||||
DisableSorting: true,
|
||||
FullTimestamp: true,
|
||||
ForceColors: true,
|
||||
TimestampFormat: "01-02 15:04:05",
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -76,10 +75,11 @@ func main() {
|
|||
return
|
||||
}
|
||||
|
||||
cfg, err := config.Load()
|
||||
if err.IsFatal() {
|
||||
log.Fatal(err)
|
||||
err := config.Load()
|
||||
if err != nil {
|
||||
logrus.Warn(err)
|
||||
}
|
||||
cfg := config.GetConfig()
|
||||
|
||||
switch args.Command {
|
||||
case common.CommandListConfigs:
|
||||
|
@ -96,6 +96,10 @@ func main() {
|
|||
return
|
||||
}
|
||||
|
||||
if common.IsDebug {
|
||||
printJSON(docker.GetRegisteredNamespaces())
|
||||
}
|
||||
|
||||
cfg.StartProxyProviders()
|
||||
|
||||
if err.HasError() {
|
||||
|
@ -116,10 +120,7 @@ func main() {
|
|||
|
||||
if autocert != nil {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
if err = autocert.Setup(ctx); err != nil && err.IsWarning() {
|
||||
cancel()
|
||||
l.Warn(err)
|
||||
} else if err.IsFatal() {
|
||||
if err = autocert.Setup(ctx); err != nil {
|
||||
l.Fatal(err)
|
||||
} else {
|
||||
onShutdown.Add(cancel)
|
||||
|
@ -192,7 +193,7 @@ func funcName(f func()) string {
|
|||
}
|
||||
|
||||
func printJSON(obj any) {
|
||||
j, err := E.Check(json.Marshal(obj))
|
||||
j, err := E.Check(json.MarshalIndent(obj, "", " "))
|
||||
if err.HasError() {
|
||||
logrus.Fatal(err)
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ package model
|
|||
type Config struct {
|
||||
Providers ProxyProviders `yaml:",flow" json:"providers"`
|
||||
AutoCert AutoCertConfig `yaml:",flow" json:"autocert"`
|
||||
MatchDomains []string `yaml:"match_domains" json:"match_domains"`
|
||||
TimeoutShutdown int `yaml:"timeout_shutdown" json:"timeout_shutdown"`
|
||||
RedirectToHTTPS bool `yaml:"redirect_to_https" json:"redirect_to_https"`
|
||||
}
|
||||
|
@ -11,6 +12,6 @@ func DefaultConfig() *Config {
|
|||
return &Config{
|
||||
Providers: ProxyProviders{},
|
||||
TimeoutShutdown: 3,
|
||||
RedirectToHTTPS: true,
|
||||
RedirectToHTTPS: false,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -14,14 +14,13 @@ type (
|
|||
RawEntry struct {
|
||||
// raw entry object before validation
|
||||
// loaded from docker labels or yaml file
|
||||
Alias string `yaml:"-" json:"-"`
|
||||
Scheme string `yaml:"scheme" json:"scheme"`
|
||||
Host string `yaml:"host" json:"host"`
|
||||
Port string `yaml:"port" json:"port"`
|
||||
NoTLSVerify bool `yaml:"no_tls_verify" json:"no_tls_verify"` // https proxy only
|
||||
PathPatterns []string `yaml:"path_patterns" json:"path_patterns"` // http(s) proxy only
|
||||
SetHeaders map[string]string `yaml:"set_headers" json:"set_headers"` // http(s) proxy only
|
||||
HideHeaders []string `yaml:"hide_headers" json:"hide_headers"` // http(s) proxy only
|
||||
Alias string `yaml:"-" json:"-"`
|
||||
Scheme string `yaml:"scheme" json:"scheme"`
|
||||
Host string `yaml:"host" json:"host"`
|
||||
Port string `yaml:"port" json:"port"`
|
||||
NoTLSVerify bool `yaml:"no_tls_verify" json:"no_tls_verify"` // https proxy only
|
||||
PathPatterns []string `yaml:"path_patterns" json:"path_patterns"` // http(s) proxy only
|
||||
Middlewares D.NestedLabelMap `yaml:"middlewares" json:"middlewares"`
|
||||
|
||||
/* Docker only */
|
||||
*D.ProxyProperties `yaml:"-" json:"proxy_properties"`
|
||||
|
@ -44,12 +43,16 @@ func (e *RawEntry) FillMissingFields() bool {
|
|||
if pp == "" {
|
||||
pp = strconv.Itoa(port)
|
||||
}
|
||||
e.Scheme = "tcp"
|
||||
if e.Scheme == "" {
|
||||
e.Scheme = "tcp"
|
||||
}
|
||||
} else if port, ok := ImageNamePortMap[e.ImageName]; ok {
|
||||
if pp == "" {
|
||||
pp = strconv.Itoa(port)
|
||||
}
|
||||
e.Scheme = "http"
|
||||
if e.Scheme == "" {
|
||||
e.Scheme = "http"
|
||||
}
|
||||
} else if pp == "" && e.Scheme == "https" {
|
||||
pp = "443"
|
||||
} else if pp == "" {
|
||||
|
|
|
@ -2,10 +2,10 @@ package proxy
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
D "github.com/yusing/go-proxy/docker"
|
||||
E "github.com/yusing/go-proxy/error"
|
||||
M "github.com/yusing/go-proxy/models"
|
||||
T "github.com/yusing/go-proxy/proxy/fields"
|
||||
|
@ -18,8 +18,7 @@ type (
|
|||
URL *url.URL
|
||||
NoTLSVerify bool
|
||||
PathPatterns T.PathPatterns
|
||||
SetHeaders http.Header
|
||||
HideHeaders []string
|
||||
Middlewares D.NestedLabelMap
|
||||
|
||||
/* Docker only */
|
||||
IdleTimeout time.Duration
|
||||
|
@ -78,9 +77,6 @@ func validateRPEntry(m *M.RawEntry, s T.Scheme, b E.Builder) *ReverseProxyEntry
|
|||
pathPatterns, err := T.ValidatePathPatterns(m.PathPatterns)
|
||||
b.Add(err)
|
||||
|
||||
setHeaders, err := T.ValidateHTTPHeaders(m.SetHeaders)
|
||||
b.Add(err)
|
||||
|
||||
url, err := E.Check(url.Parse(fmt.Sprintf("%s://%s:%d", s, host, port)))
|
||||
b.Add(err)
|
||||
|
||||
|
@ -111,8 +107,7 @@ func validateRPEntry(m *M.RawEntry, s T.Scheme, b E.Builder) *ReverseProxyEntry
|
|||
URL: url,
|
||||
NoTLSVerify: m.NoTLSVerify,
|
||||
PathPatterns: pathPatterns,
|
||||
SetHeaders: setHeaders,
|
||||
HideHeaders: m.HideHeaders,
|
||||
Middlewares: m.Middlewares,
|
||||
IdleTimeout: idleTimeout,
|
||||
WakeTimeout: wakeTimeout,
|
||||
StopMethod: stopMethod,
|
||||
|
|
|
@ -4,7 +4,9 @@ import (
|
|||
"fmt"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
D "github.com/yusing/go-proxy/docker"
|
||||
E "github.com/yusing/go-proxy/error"
|
||||
M "github.com/yusing/go-proxy/models"
|
||||
|
@ -17,6 +19,7 @@ type DockerProvider struct {
|
|||
}
|
||||
|
||||
var AliasRefRegex = regexp.MustCompile(`#\d+`)
|
||||
var AliasRefRegexOld = regexp.MustCompile(`\$\d+`)
|
||||
|
||||
func DockerProviderImpl(dockerHost string) (ProviderImpl, E.NestedError) {
|
||||
hostname, err := D.ParseDockerHostname(dockerHost)
|
||||
|
@ -152,6 +155,20 @@ func (p *DockerProvider) applyLabel(container D.Container, entries M.RawEntries,
|
|||
b := E.NewBuilder("errors in label %s", key)
|
||||
defer b.To(&res)
|
||||
|
||||
refErr := E.NewBuilder("errors parsing alias references")
|
||||
replaceIndexRef := func(ref string) string {
|
||||
index, err := strconv.Atoi(ref[1:])
|
||||
if err != nil {
|
||||
refErr.Add(E.Invalid("integer", ref))
|
||||
return ref
|
||||
}
|
||||
if index < 1 || index > len(container.Aliases) {
|
||||
refErr.Add(E.OutOfRange("index", ref))
|
||||
return ref
|
||||
}
|
||||
return container.Aliases[index-1]
|
||||
}
|
||||
|
||||
lbl, err := D.ParseLabel(key, val)
|
||||
if err.HasError() {
|
||||
b.Add(err.Subject(key))
|
||||
|
@ -163,22 +180,14 @@ func (p *DockerProvider) applyLabel(container D.Container, entries M.RawEntries,
|
|||
// apply label for all aliases
|
||||
entries.RangeAll(func(a string, e *M.RawEntry) {
|
||||
if err = D.ApplyLabel(e, lbl); err.HasError() {
|
||||
b.Add(err.Subject(lbl.Target))
|
||||
b.Add(err.Subjectf("alias %s", lbl.Target))
|
||||
}
|
||||
})
|
||||
} else {
|
||||
refErr := E.NewBuilder("errors parsing alias references")
|
||||
lbl.Target = AliasRefRegex.ReplaceAllStringFunc(lbl.Target, func(ref string) string {
|
||||
index, err := strconv.Atoi(ref[1:])
|
||||
if err != nil {
|
||||
refErr.Add(E.Invalid("integer", ref))
|
||||
return ref
|
||||
}
|
||||
if index < 1 || index > len(container.Aliases) {
|
||||
refErr.Add(E.OutOfRange("index", ref))
|
||||
return ref
|
||||
}
|
||||
return container.Aliases[index-1]
|
||||
lbl.Target = AliasRefRegex.ReplaceAllStringFunc(lbl.Target, replaceIndexRef)
|
||||
lbl.Target = AliasRefRegexOld.ReplaceAllStringFunc(lbl.Target, func(s string) string {
|
||||
logrus.Warnf("%q should now be %q, old syntax will be removed in a future version", lbl, strings.ReplaceAll(lbl.String(), "$", "#"))
|
||||
return replaceIndexRef(s)
|
||||
})
|
||||
if refErr.HasError() {
|
||||
b.Add(refErr.Build())
|
||||
|
@ -190,7 +199,7 @@ func (p *DockerProvider) applyLabel(container D.Container, entries M.RawEntries,
|
|||
return
|
||||
}
|
||||
if err = D.ApplyLabel(config, lbl); err.HasError() {
|
||||
b.Add(err.Subject(lbl.Target))
|
||||
b.Add(err.Subjectf("alias %s", lbl.Target))
|
||||
}
|
||||
}
|
||||
return
|
||||
|
|
|
@ -132,7 +132,8 @@ func TestApplyLabel(t *testing.T) {
|
|||
ExpectEqual(t, b.Scheme, "http")
|
||||
ExpectEqual(t, b.Port, "1234")
|
||||
ExpectEqual(t, c.Scheme, "https")
|
||||
ExpectEqual(t, c.Port, "1111")
|
||||
// map does not necessary follow the order above
|
||||
ExpectEqualAny(t, c.Port, []string{"1111", "1234"})
|
||||
}
|
||||
|
||||
func TestApplyLabelWithRef(t *testing.T) {
|
||||
|
@ -142,9 +143,9 @@ func TestApplyLabelWithRef(t *testing.T) {
|
|||
Labels: map[string]string{
|
||||
D.LabelAliases: "a,b,c",
|
||||
"proxy.#1.host": "localhost",
|
||||
"proxy.*.port": "1111",
|
||||
"proxy.#1.port": "4444",
|
||||
"proxy.#2.port": "9999",
|
||||
"proxy.#3.port": "1111",
|
||||
"proxy.#3.scheme": "https",
|
||||
},
|
||||
Ports: []types.Port{
|
||||
|
|
|
@ -1,20 +1,20 @@
|
|||
package route
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/yusing/go-proxy/common"
|
||||
"github.com/yusing/go-proxy/docker/idlewatcher"
|
||||
E "github.com/yusing/go-proxy/error"
|
||||
. "github.com/yusing/go-proxy/http"
|
||||
P "github.com/yusing/go-proxy/proxy"
|
||||
PT "github.com/yusing/go-proxy/proxy/fields"
|
||||
"github.com/yusing/go-proxy/route/middleware"
|
||||
F "github.com/yusing/go-proxy/utils/functional"
|
||||
)
|
||||
|
||||
|
@ -26,7 +26,7 @@ type (
|
|||
|
||||
entry *P.ReverseProxyEntry
|
||||
mux *http.ServeMux
|
||||
handler *P.ReverseProxy
|
||||
handler *ReverseProxy
|
||||
|
||||
regIdleWatcher func() E.NestedError
|
||||
unregIdleWatcher func()
|
||||
|
@ -36,18 +36,41 @@ type (
|
|||
SubdomainKey = PT.Alias
|
||||
)
|
||||
|
||||
var (
|
||||
findMuxFunc = findMuxAnyDomain
|
||||
|
||||
httpRoutes = F.NewMapOf[SubdomainKey, *HTTPRoute]()
|
||||
httpRoutesMu sync.Mutex
|
||||
globalMux = http.NewServeMux() // TODO: support regex subdomain matching
|
||||
)
|
||||
|
||||
func SetFindMuxDomains(domains []string) {
|
||||
if len(domains) == 0 {
|
||||
findMuxFunc = findMuxAnyDomain
|
||||
} else {
|
||||
findMuxFunc = findMuxByDomain(domains)
|
||||
}
|
||||
}
|
||||
|
||||
func NewHTTPRoute(entry *P.ReverseProxyEntry) (*HTTPRoute, E.NestedError) {
|
||||
var trans *http.Transport
|
||||
var regIdleWatcher func() E.NestedError
|
||||
var unregIdleWatcher func()
|
||||
|
||||
if entry.NoTLSVerify {
|
||||
trans = transportNoTLS.Clone()
|
||||
trans = common.DefaultTransportNoTLS.Clone()
|
||||
} else {
|
||||
trans = transport.Clone()
|
||||
trans = common.DefaultTransport.Clone()
|
||||
}
|
||||
|
||||
rp := P.NewReverseProxy(entry.URL, trans, entry)
|
||||
rp := NewReverseProxy(entry.URL, trans)
|
||||
|
||||
if len(entry.Middlewares) > 0 {
|
||||
err := middleware.PatchReverseProxy(rp, entry.Middlewares)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if entry.UseIdleWatcher() {
|
||||
// allow time for response header up to `WakeTimeout`
|
||||
|
@ -74,7 +97,7 @@ func NewHTTPRoute(entry *P.ReverseProxyEntry) (*HTTPRoute, E.NestedError) {
|
|||
|
||||
_, exists := httpRoutes.Load(entry.Alias)
|
||||
if exists {
|
||||
return nil, E.AlreadyExist("HTTPRoute alias", entry.Alias)
|
||||
return nil, E.Duplicated("HTTPRoute alias", entry.Alias)
|
||||
}
|
||||
|
||||
r := &HTTPRoute{
|
||||
|
@ -94,11 +117,16 @@ func (r *HTTPRoute) String() string {
|
|||
}
|
||||
|
||||
func (r *HTTPRoute) Start() E.NestedError {
|
||||
if r.mux != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
httpRoutesMu.Lock()
|
||||
defer httpRoutesMu.Unlock()
|
||||
|
||||
if r.regIdleWatcher != nil {
|
||||
if err := r.regIdleWatcher(); err.HasError() {
|
||||
r.unregIdleWatcher = nil
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
@ -113,6 +141,10 @@ func (r *HTTPRoute) Start() E.NestedError {
|
|||
}
|
||||
|
||||
func (r *HTTPRoute) Stop() E.NestedError {
|
||||
if r.mux == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
httpRoutesMu.Lock()
|
||||
defer httpRoutesMu.Unlock()
|
||||
|
||||
|
@ -135,7 +167,7 @@ func (u *URL) MarshalText() (text []byte, err error) {
|
|||
}
|
||||
|
||||
func ProxyHandler(w http.ResponseWriter, r *http.Request) {
|
||||
mux, err := findMux(r.Host)
|
||||
mux, err := findMuxFunc(r.Host)
|
||||
if err != nil {
|
||||
err = E.Failure("request").
|
||||
Subjectf("%s %s%s", r.Method, r.Host, r.URL.Path).
|
||||
|
@ -147,7 +179,7 @@ func ProxyHandler(w http.ResponseWriter, r *http.Request) {
|
|||
mux.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
func findMux(host string) (*http.ServeMux, E.NestedError) {
|
||||
func findMuxAnyDomain(host string) (*http.ServeMux, E.NestedError) {
|
||||
hostSplit := strings.Split(host, ".")
|
||||
n := len(hostSplit)
|
||||
if n <= 2 {
|
||||
|
@ -160,23 +192,21 @@ func findMux(host string) (*http.ServeMux, E.NestedError) {
|
|||
return nil, E.NotExist("route", sd)
|
||||
}
|
||||
|
||||
var (
|
||||
defaultDialer = net.Dialer{
|
||||
Timeout: 60 * time.Second,
|
||||
KeepAlive: 60 * time.Second,
|
||||
func findMuxByDomain(domains []string) func(host string) (*http.ServeMux, E.NestedError) {
|
||||
return func(host string) (*http.ServeMux, E.NestedError) {
|
||||
var subdomain string
|
||||
for _, domain := range domains {
|
||||
subdomain = strings.TrimSuffix(subdomain, domain)
|
||||
if subdomain != domain {
|
||||
break
|
||||
}
|
||||
}
|
||||
if subdomain == "" { // not matched
|
||||
return nil, E.Invalid("host", host)
|
||||
}
|
||||
if r, ok := httpRoutes.Load(PT.Alias(subdomain)); ok {
|
||||
return r.mux, nil
|
||||
}
|
||||
return nil, E.NotExist("route", subdomain)
|
||||
}
|
||||
transport = &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: defaultDialer.DialContext,
|
||||
MaxIdleConnsPerHost: 1000,
|
||||
}
|
||||
transportNoTLS = func() *http.Transport {
|
||||
var clone = transport.Clone()
|
||||
clone.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
|
||||
return clone
|
||||
}()
|
||||
|
||||
httpRoutes = F.NewMapOf[SubdomainKey, *HTTPRoute]()
|
||||
httpRoutesMu sync.Mutex
|
||||
globalMux = http.NewServeMux()
|
||||
)
|
||||
}
|
||||
|
|
|
@ -1,7 +0,0 @@
|
|||
package middleware
|
||||
|
||||
var AddXForwarded = &Middleware{
|
||||
rewrite: func(r *ProxyRequest) {
|
||||
r.SetXForwarded()
|
||||
},
|
||||
}
|
249
src/route/middleware/forward_auth.go
Normal file
249
src/route/middleware/forward_auth.go
Normal file
|
@ -0,0 +1,249 @@
|
|||
// Modified from Traefik Labs's MIT-licensed code (https://github.com/traefik/traefik/blob/master/pkg/middlewares/auth/forward.go)
|
||||
// Copyright (c) 2020-2024 Traefik Labs
|
||||
// Copyright (c) 2024 yusing
|
||||
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/yusing/go-proxy/common"
|
||||
D "github.com/yusing/go-proxy/docker"
|
||||
E "github.com/yusing/go-proxy/error"
|
||||
gpHTTP "github.com/yusing/go-proxy/http"
|
||||
U "github.com/yusing/go-proxy/utils"
|
||||
)
|
||||
|
||||
type (
|
||||
forwardAuth struct {
|
||||
*forwardAuthOpts
|
||||
m *Middleware
|
||||
client http.Client
|
||||
}
|
||||
forwardAuthOpts struct {
|
||||
Address string
|
||||
TrustForwardHeader bool
|
||||
AuthResponseHeaders []string
|
||||
AddAuthCookiesToResponse []string
|
||||
}
|
||||
)
|
||||
|
||||
const (
|
||||
xForwardedFor = "X-Forwarded-For"
|
||||
xForwardedMethod = "X-Forwarded-Method"
|
||||
xForwardedHost = "X-Forwarded-Host"
|
||||
xForwardedProto = "X-Forwarded-Proto"
|
||||
xForwardedURI = "X-Forwarded-Uri"
|
||||
xForwardedPort = "X-Forwarded-Port"
|
||||
)
|
||||
|
||||
var ForwardAuth = newForwardAuth()
|
||||
var faLogger = logrus.WithField("middleware", "ForwardAuth")
|
||||
|
||||
func newForwardAuth() (fa *forwardAuth) {
|
||||
fa = new(forwardAuth)
|
||||
fa.m = new(Middleware)
|
||||
fa.m.labelParserMap = D.ValueParserMap{
|
||||
"trust_forward_header": D.BoolParser,
|
||||
"auth_response_headers": D.YamlStringListParser,
|
||||
"add_auth_cookies_to_response": D.YamlStringListParser,
|
||||
}
|
||||
fa.m.withOptions = func(optsRaw OptionsRaw, rp *ReverseProxy) (*Middleware, E.NestedError) {
|
||||
tr, ok := rp.Transport.(*http.Transport)
|
||||
if ok {
|
||||
tr = tr.Clone()
|
||||
} else {
|
||||
tr = common.DefaultTransport.Clone()
|
||||
}
|
||||
|
||||
faWithOpts := new(forwardAuth)
|
||||
faWithOpts.forwardAuthOpts = new(forwardAuthOpts)
|
||||
faWithOpts.client = http.Client{
|
||||
CheckRedirect: func(r *Request, via []*Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
},
|
||||
Timeout: 30 * time.Second,
|
||||
Transport: tr,
|
||||
}
|
||||
faWithOpts.m = &Middleware{
|
||||
impl: faWithOpts,
|
||||
before: fa.forward,
|
||||
}
|
||||
|
||||
err := U.Deserialize(optsRaw, faWithOpts.forwardAuthOpts)
|
||||
if err != nil {
|
||||
return nil, E.FailWith("set options", err)
|
||||
}
|
||||
_, err = E.Check(url.Parse(faWithOpts.Address))
|
||||
if err != nil {
|
||||
return nil, E.Invalid("address", faWithOpts.Address)
|
||||
}
|
||||
return faWithOpts.m, nil
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (fa *forwardAuth) forward(next http.Handler, w ResponseWriter, req *Request) {
|
||||
removeHop(req.Header)
|
||||
|
||||
faReq, err := http.NewRequestWithContext(
|
||||
req.Context(),
|
||||
http.MethodGet,
|
||||
fa.Address,
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
faLogger.Debugf("new request err to %s: %s", fa.Address, err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
copyHeader(faReq.Header, req.Header)
|
||||
removeHop(faReq.Header)
|
||||
|
||||
filterHeaders(faReq.Header, fa.AuthResponseHeaders)
|
||||
fa.setAuthHeaders(req, faReq)
|
||||
|
||||
faResp, err := fa.client.Do(faReq)
|
||||
if err != nil {
|
||||
faLogger.Debugf("failed to call %s: %s", fa.Address, err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
defer faResp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(faResp.Body)
|
||||
if err != nil {
|
||||
faLogger.Debugf("failed to read response body from %s: %s", fa.Address, err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if faResp.StatusCode < http.StatusOK || faResp.StatusCode >= http.StatusMultipleChoices {
|
||||
copyHeader(w.Header(), faResp.Header)
|
||||
removeHop(w.Header())
|
||||
|
||||
redirectURL, err := faResp.Location()
|
||||
if err != nil {
|
||||
faLogger.Debugf("failed to get location from %s: %s", fa.Address, err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
} else if redirectURL.String() != "" {
|
||||
w.Header().Set("Location", redirectURL.String())
|
||||
}
|
||||
|
||||
w.WriteHeader(faResp.StatusCode)
|
||||
|
||||
if _, err = w.Write(body); err != nil {
|
||||
faLogger.Debugf("failed to write response body from %s: %s", fa.Address, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
for _, key := range fa.AuthResponseHeaders {
|
||||
key := http.CanonicalHeaderKey(key)
|
||||
req.Header.Del(key)
|
||||
if len(faResp.Header[key]) > 0 {
|
||||
req.Header[key] = append([]string(nil), faResp.Header[key]...)
|
||||
}
|
||||
}
|
||||
|
||||
req.RequestURI = req.URL.RequestURI()
|
||||
|
||||
authCookies := faResp.Cookies()
|
||||
|
||||
if len(authCookies) == 0 {
|
||||
next.ServeHTTP(w, req)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(gpHTTP.NewModifyResponseWriter(w, req, func(resp *Response) error {
|
||||
fa.setAuthCookies(resp, authCookies)
|
||||
return nil
|
||||
}), req)
|
||||
}
|
||||
|
||||
func (fa *forwardAuth) setAuthCookies(resp *Response, authCookies []*Cookie) {
|
||||
if len(fa.AddAuthCookiesToResponse) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
cookies := resp.Cookies()
|
||||
resp.Header.Del("Set-Cookie")
|
||||
|
||||
for _, cookie := range cookies {
|
||||
if !slices.Contains(fa.AddAuthCookiesToResponse, cookie.Name) {
|
||||
// this cookie is not an auth cookie, so add it back
|
||||
resp.Header.Add("Set-Cookie", cookie.String())
|
||||
}
|
||||
}
|
||||
|
||||
for _, cookie := range authCookies {
|
||||
if slices.Contains(fa.AddAuthCookiesToResponse, cookie.Name) {
|
||||
// this cookie is an auth cookie, so add to resp
|
||||
resp.Header.Add("Set-Cookie", cookie.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (fa *forwardAuth) setAuthHeaders(req, faReq *Request) {
|
||||
if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
|
||||
if fa.TrustForwardHeader {
|
||||
if prior, ok := req.Header[xForwardedFor]; ok {
|
||||
clientIP = strings.Join(prior, ", ") + ", " + clientIP
|
||||
}
|
||||
}
|
||||
faReq.Header.Set(xForwardedFor, clientIP)
|
||||
}
|
||||
|
||||
xMethod := req.Header.Get(xForwardedMethod)
|
||||
switch {
|
||||
case xMethod != "" && fa.TrustForwardHeader:
|
||||
faReq.Header.Set(xForwardedMethod, xMethod)
|
||||
case req.Method != "":
|
||||
faReq.Header.Set(xForwardedMethod, req.Method)
|
||||
default:
|
||||
faReq.Header.Del(xForwardedMethod)
|
||||
}
|
||||
|
||||
xfp := req.Header.Get(xForwardedProto)
|
||||
switch {
|
||||
case xfp != "" && fa.TrustForwardHeader:
|
||||
faReq.Header.Set(xForwardedProto, xfp)
|
||||
case req.TLS != nil:
|
||||
faReq.Header.Set(xForwardedProto, "https")
|
||||
default:
|
||||
faReq.Header.Set(xForwardedProto, "http")
|
||||
}
|
||||
|
||||
if xfp := req.Header.Get(xForwardedPort); xfp != "" && fa.TrustForwardHeader {
|
||||
faReq.Header.Set(xForwardedPort, xfp)
|
||||
}
|
||||
|
||||
xfh := req.Header.Get(xForwardedHost)
|
||||
switch {
|
||||
case xfh != "" && fa.TrustForwardHeader:
|
||||
faReq.Header.Set(xForwardedHost, xfh)
|
||||
case req.Host != "":
|
||||
faReq.Header.Set(xForwardedHost, req.Host)
|
||||
default:
|
||||
faReq.Header.Del(xForwardedHost)
|
||||
}
|
||||
|
||||
xfURI := req.Header.Get(xForwardedURI)
|
||||
switch {
|
||||
case xfURI != "" && fa.TrustForwardHeader:
|
||||
faReq.Header.Set(xForwardedURI, xfURI)
|
||||
case req.URL.RequestURI() != "":
|
||||
faReq.Header.Set(xForwardedURI, req.URL.RequestURI())
|
||||
default:
|
||||
faReq.Header.Del(xForwardedURI)
|
||||
}
|
||||
}
|
44
src/route/middleware/headers.go
Normal file
44
src/route/middleware/headers.go
Normal file
|
@ -0,0 +1,44 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"slices"
|
||||
|
||||
gpHTTP "github.com/yusing/go-proxy/http"
|
||||
)
|
||||
|
||||
func removeHop(h Header) {
|
||||
reqUpType := gpHTTP.UpgradeType(h)
|
||||
gpHTTP.RemoveHopByHopHeaders(h)
|
||||
|
||||
if reqUpType != "" {
|
||||
h.Set("Connection", "Upgrade")
|
||||
h.Set("Upgrade", reqUpType)
|
||||
} else {
|
||||
h.Del("Connection")
|
||||
}
|
||||
}
|
||||
|
||||
func copyHeader(dst, src Header) {
|
||||
for k, vv := range src {
|
||||
for _, v := range vv {
|
||||
dst.Add(k, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func filterHeaders(h Header, allowed []string) {
|
||||
if allowed == nil {
|
||||
return
|
||||
}
|
||||
|
||||
for i := range allowed {
|
||||
allowed[i] = http.CanonicalHeaderKey(allowed[i])
|
||||
}
|
||||
|
||||
for key := range h {
|
||||
if !slices.Contains(allowed, key) {
|
||||
h.Del(key)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -3,33 +3,42 @@ package middleware
|
|||
import (
|
||||
"net/http"
|
||||
|
||||
D "github.com/yusing/go-proxy/docker"
|
||||
E "github.com/yusing/go-proxy/error"
|
||||
P "github.com/yusing/go-proxy/proxy"
|
||||
gpHTTP "github.com/yusing/go-proxy/http"
|
||||
)
|
||||
|
||||
type (
|
||||
ReverseProxy = P.ReverseProxy
|
||||
ProxyRequest = P.ProxyRequest
|
||||
Error = E.NestedError
|
||||
|
||||
ReverseProxy = gpHTTP.ReverseProxy
|
||||
ProxyRequest = gpHTTP.ProxyRequest
|
||||
Request = http.Request
|
||||
Response = http.Response
|
||||
ResponseWriter = http.ResponseWriter
|
||||
Header = http.Header
|
||||
Cookie = http.Cookie
|
||||
|
||||
BeforeFunc func(w ResponseWriter, r *Request) (continue_ bool)
|
||||
BeforeFunc func(next http.Handler, w ResponseWriter, r *Request)
|
||||
RewriteFunc func(req *ProxyRequest)
|
||||
ModifyResponseFunc func(res *Response) error
|
||||
CloneWithOptFunc func(opts OptionsRaw, rp *ReverseProxy) (*Middleware, E.NestedError)
|
||||
|
||||
MiddlewareOptionsRaw map[string]string
|
||||
MiddlewareOptions map[string]interface{}
|
||||
OptionsRaw = map[string]any
|
||||
Options any
|
||||
|
||||
Middleware struct {
|
||||
name string
|
||||
|
||||
before BeforeFunc
|
||||
rewrite RewriteFunc
|
||||
modifyResponse ModifyResponseFunc
|
||||
before BeforeFunc // runs before ReverseProxy.ServeHTTP
|
||||
rewrite RewriteFunc // runs after ReverseProxy.Rewrite
|
||||
modifyResponse ModifyResponseFunc // runs after ReverseProxy.ModifyResponse
|
||||
|
||||
options MiddlewareOptions
|
||||
validateOptions func(opts MiddlewareOptionsRaw) (MiddlewareOptions, E.NestedError)
|
||||
transport http.RoundTripper
|
||||
|
||||
withOptions CloneWithOptFunc
|
||||
labelParserMap D.ValueParserMap
|
||||
impl any
|
||||
}
|
||||
)
|
||||
|
||||
|
@ -41,41 +50,32 @@ func (m *Middleware) String() string {
|
|||
return m.name
|
||||
}
|
||||
|
||||
func (m *Middleware) WithOptions(optsRaw MiddlewareOptionsRaw) (*Middleware, E.NestedError) {
|
||||
if len(optsRaw) == 0 {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
var opts MiddlewareOptions
|
||||
var err E.NestedError
|
||||
|
||||
if m.validateOptions != nil {
|
||||
if opts, err = m.validateOptions(optsRaw); err != nil {
|
||||
func (m *Middleware) WithOptionsClone(optsRaw OptionsRaw, rp *ReverseProxy) (*Middleware, E.NestedError) {
|
||||
if len(optsRaw) != 0 && m.withOptions != nil {
|
||||
if mWithOpt, err := m.withOptions(optsRaw, rp); err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
return mWithOpt, nil
|
||||
}
|
||||
}
|
||||
|
||||
return &Middleware{
|
||||
name: m.name,
|
||||
before: m.before,
|
||||
rewrite: m.rewrite,
|
||||
modifyResponse: m.modifyResponse,
|
||||
options: opts,
|
||||
}, nil
|
||||
// WithOptionsClone is called only once
|
||||
// set withOptions and labelParser will not be used after that
|
||||
return &Middleware{m.name, m.before, m.rewrite, m.modifyResponse, m.transport, nil, nil, m.impl}, nil
|
||||
}
|
||||
|
||||
// TODO: check conflict
|
||||
func PatchReverseProxy(rp ReverseProxy, middlewares map[string]MiddlewareOptionsRaw) (out ReverseProxy, err E.NestedError) {
|
||||
out = rp
|
||||
|
||||
// TODO: check conflict or duplicates
|
||||
func PatchReverseProxy(rp *ReverseProxy, middlewares map[string]OptionsRaw) (res E.NestedError) {
|
||||
befores := make([]BeforeFunc, 0, len(middlewares))
|
||||
rewrites := make([]RewriteFunc, 0, len(middlewares))
|
||||
modifyResponses := make([]ModifyResponseFunc, 0, len(middlewares))
|
||||
|
||||
invalidM := E.NewBuilder("invalid middlewares")
|
||||
invalidOpts := E.NewBuilder("invalid options")
|
||||
defer invalidM.Add(invalidOpts.Build())
|
||||
defer invalidM.To(&err)
|
||||
defer func() {
|
||||
invalidM.Add(invalidOpts.Build())
|
||||
invalidM.To(&res)
|
||||
}()
|
||||
|
||||
for name, opts := range middlewares {
|
||||
m, ok := Get(name)
|
||||
|
@ -83,7 +83,8 @@ func PatchReverseProxy(rp ReverseProxy, middlewares map[string]MiddlewareOptions
|
|||
invalidM.Addf("%s", name)
|
||||
continue
|
||||
}
|
||||
m, err = m.WithOptions(opts)
|
||||
|
||||
m, err := m.WithOptionsClone(opts, rp)
|
||||
if err != nil {
|
||||
invalidOpts.Add(err.Subject(name))
|
||||
continue
|
||||
|
@ -103,25 +104,37 @@ func PatchReverseProxy(rp ReverseProxy, middlewares map[string]MiddlewareOptions
|
|||
return
|
||||
}
|
||||
|
||||
if len(befores) > 0 {
|
||||
rp.ServeHTTP = func(w ResponseWriter, r *Request) {
|
||||
for _, before := range befores {
|
||||
if !before(w, r) {
|
||||
return
|
||||
}
|
||||
origServeHTTP := rp.ServeHTTP
|
||||
for i, before := range befores {
|
||||
if i < len(befores)-1 {
|
||||
rp.ServeHTTP = func(w ResponseWriter, r *Request) {
|
||||
before(rp.ServeHTTP, w, r)
|
||||
}
|
||||
} else {
|
||||
rp.ServeHTTP = func(w ResponseWriter, r *Request) {
|
||||
before(origServeHTTP, w, r)
|
||||
}
|
||||
rp.ServeHTTP(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
if len(rewrites) > 0 {
|
||||
origRewrite := rp.Rewrite
|
||||
rp.Rewrite = func(req *ProxyRequest) {
|
||||
if origRewrite != nil {
|
||||
origRewrite(req)
|
||||
}
|
||||
for _, rewrite := range rewrites {
|
||||
rewrite(req)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(modifyResponses) > 0 {
|
||||
origModifyResponse := rp.ModifyResponse
|
||||
rp.ModifyResponse = func(res *Response) error {
|
||||
if origModifyResponse != nil {
|
||||
return origModifyResponse(res)
|
||||
}
|
||||
for _, modifyResponse := range modifyResponses {
|
||||
if err := modifyResponse(res); err != nil {
|
||||
return err
|
||||
|
|
|
@ -3,14 +3,11 @@ package middleware
|
|||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
D "github.com/yusing/go-proxy/docker"
|
||||
)
|
||||
|
||||
var middlewares = map[string]*Middleware{
|
||||
"set_x_forwarded": SetXForwarded, // nginx
|
||||
"add_x_forwarded": AddXForwarded, // nginx
|
||||
"trust_forward_header": AddXForwarded, // traefik alias
|
||||
"redirect_http": RedirectHTTP,
|
||||
}
|
||||
var middlewares map[string]*Middleware
|
||||
|
||||
func Get(name string) (middleware *Middleware, ok bool) {
|
||||
middleware, ok = middlewares[name]
|
||||
|
@ -18,10 +15,23 @@ func Get(name string) (middleware *Middleware, ok bool) {
|
|||
}
|
||||
|
||||
// initialize middleware names
|
||||
var _ = func() (_ bool) {
|
||||
func init() {
|
||||
middlewares = map[string]*Middleware{
|
||||
"set_x_forwarded": SetXForwarded,
|
||||
"add_x_forwarded": AddXForwarded,
|
||||
"redirect_http": RedirectHTTP,
|
||||
"forward_auth": ForwardAuth.m,
|
||||
"modify_response": ModifyResponse.m,
|
||||
"modify_request": ModifyRequest.m,
|
||||
}
|
||||
names := make(map[*Middleware][]string)
|
||||
for name, m := range middlewares {
|
||||
names[m] = append(names[m], name)
|
||||
// register middleware name to docker label parsr
|
||||
// in order to parse middleware_name.option=value into correct type
|
||||
if m.labelParserMap != nil {
|
||||
D.RegisterNamespace(name, m.labelParserMap)
|
||||
}
|
||||
}
|
||||
for m, names := range names {
|
||||
if len(names) > 1 {
|
||||
|
@ -30,5 +40,4 @@ var _ = func() (_ bool) {
|
|||
m.name = names[0]
|
||||
}
|
||||
}
|
||||
return
|
||||
}()
|
||||
}
|
||||
|
|
58
src/route/middleware/modify_request.go
Normal file
58
src/route/middleware/modify_request.go
Normal file
|
@ -0,0 +1,58 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
D "github.com/yusing/go-proxy/docker"
|
||||
E "github.com/yusing/go-proxy/error"
|
||||
U "github.com/yusing/go-proxy/utils"
|
||||
)
|
||||
|
||||
type (
|
||||
modifyRequest struct {
|
||||
*modifyRequestOpts
|
||||
m *Middleware
|
||||
}
|
||||
// order: set_headers -> add_headers -> hide_headers
|
||||
modifyRequestOpts struct {
|
||||
SetHeaders map[string]string
|
||||
AddHeaders map[string]string
|
||||
HideHeaders []string
|
||||
}
|
||||
)
|
||||
|
||||
var ModifyRequest = newModifyRequest()
|
||||
|
||||
func newModifyRequest() (mr *modifyRequest) {
|
||||
mr = new(modifyRequest)
|
||||
mr.m = new(Middleware)
|
||||
mr.m.labelParserMap = D.ValueParserMap{
|
||||
"set_headers": D.YamlLikeMappingParser(true),
|
||||
"add_headers": D.YamlLikeMappingParser(true),
|
||||
"hide_headers": D.YamlStringListParser,
|
||||
}
|
||||
mr.m.withOptions = func(optsRaw OptionsRaw, rp *ReverseProxy) (*Middleware, E.NestedError) {
|
||||
mrWithOpts := new(modifyRequest)
|
||||
mrWithOpts.m = &Middleware{
|
||||
impl: mrWithOpts,
|
||||
rewrite: mrWithOpts.modifyRequest,
|
||||
}
|
||||
mrWithOpts.modifyRequestOpts = new(modifyRequestOpts)
|
||||
err := U.Deserialize(optsRaw, mrWithOpts.modifyRequestOpts)
|
||||
if err != nil {
|
||||
return nil, E.FailWith("set options", err)
|
||||
}
|
||||
return mrWithOpts.m, nil
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (mr *modifyRequest) modifyRequest(req *ProxyRequest) {
|
||||
for k, v := range mr.SetHeaders {
|
||||
req.Out.Header.Set(k, v)
|
||||
}
|
||||
for k, v := range mr.AddHeaders {
|
||||
req.Out.Header.Add(k, v)
|
||||
}
|
||||
for _, k := range mr.HideHeaders {
|
||||
req.Out.Header.Del(k)
|
||||
}
|
||||
}
|
34
src/route/middleware/modify_request_test.go
Normal file
34
src/route/middleware/modify_request_test.go
Normal file
|
@ -0,0 +1,34 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
. "github.com/yusing/go-proxy/utils/testing"
|
||||
)
|
||||
|
||||
func TestSetModifyRequest(t *testing.T) {
|
||||
opts := OptionsRaw{
|
||||
"set_headers": map[string]string{"User-Agent": "go-proxy/v0.5.0"},
|
||||
"add_headers": map[string]string{"Accept-Encoding": "test-value"},
|
||||
"hide_headers": []string{"Accept"},
|
||||
}
|
||||
|
||||
t.Run("set_options", func(t *testing.T) {
|
||||
mr, err := ModifyRequest.m.WithOptionsClone(opts, nil)
|
||||
ExpectNoError(t, err.Error())
|
||||
ExpectDeepEqual(t, mr.impl.(*modifyRequest).SetHeaders, opts["set_headers"].(map[string]string))
|
||||
ExpectDeepEqual(t, mr.impl.(*modifyRequest).AddHeaders, opts["add_headers"].(map[string]string))
|
||||
ExpectDeepEqual(t, mr.impl.(*modifyRequest).HideHeaders, opts["hide_headers"].([]string))
|
||||
})
|
||||
|
||||
t.Run("request_headers", func(t *testing.T) {
|
||||
result, err := newMiddlewareTest(ModifyRequest.m, &testArgs{
|
||||
middlewareOpt: opts,
|
||||
})
|
||||
ExpectNoError(t, err.Error())
|
||||
ExpectEqual(t, result.RequestHeaders.Get("User-Agent"), "go-proxy/v0.5.0")
|
||||
ExpectTrue(t, slices.Contains(result.RequestHeaders.Values("Accept-Encoding"), "test-value"))
|
||||
ExpectEqual(t, result.RequestHeaders.Get("Accept"), "")
|
||||
})
|
||||
}
|
61
src/route/middleware/modify_response.go
Normal file
61
src/route/middleware/modify_response.go
Normal file
|
@ -0,0 +1,61 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
D "github.com/yusing/go-proxy/docker"
|
||||
E "github.com/yusing/go-proxy/error"
|
||||
U "github.com/yusing/go-proxy/utils"
|
||||
)
|
||||
|
||||
type (
|
||||
modifyResponse struct {
|
||||
*modifyResponseOpts
|
||||
m *Middleware
|
||||
}
|
||||
// order: set_headers -> add_headers -> hide_headers
|
||||
modifyResponseOpts struct {
|
||||
SetHeaders map[string]string
|
||||
AddHeaders map[string]string
|
||||
HideHeaders []string
|
||||
}
|
||||
)
|
||||
|
||||
var ModifyResponse = newModifyResponse()
|
||||
|
||||
func newModifyResponse() (mr *modifyResponse) {
|
||||
mr = new(modifyResponse)
|
||||
mr.m = new(Middleware)
|
||||
mr.m.labelParserMap = D.ValueParserMap{
|
||||
"set_headers": D.YamlLikeMappingParser(true),
|
||||
"add_headers": D.YamlLikeMappingParser(true),
|
||||
"hide_headers": D.YamlStringListParser,
|
||||
}
|
||||
mr.m.withOptions = func(optsRaw OptionsRaw, rp *ReverseProxy) (*Middleware, E.NestedError) {
|
||||
mrWithOpts := new(modifyResponse)
|
||||
mrWithOpts.m = &Middleware{
|
||||
impl: mrWithOpts,
|
||||
modifyResponse: mrWithOpts.modifyResponse,
|
||||
}
|
||||
mrWithOpts.modifyResponseOpts = new(modifyResponseOpts)
|
||||
err := U.Deserialize(optsRaw, mrWithOpts.modifyResponseOpts)
|
||||
if err != nil {
|
||||
return nil, E.FailWith("set options", err)
|
||||
}
|
||||
return mrWithOpts.m, nil
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (mr *modifyResponse) modifyResponse(resp *http.Response) error {
|
||||
for k, v := range mr.SetHeaders {
|
||||
resp.Header.Set(k, v)
|
||||
}
|
||||
for k, v := range mr.AddHeaders {
|
||||
resp.Header.Add(k, v)
|
||||
}
|
||||
for _, k := range mr.HideHeaders {
|
||||
resp.Header.Del(k)
|
||||
}
|
||||
return nil
|
||||
}
|
35
src/route/middleware/modify_response_test.go
Normal file
35
src/route/middleware/modify_response_test.go
Normal file
|
@ -0,0 +1,35 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
. "github.com/yusing/go-proxy/utils/testing"
|
||||
)
|
||||
|
||||
func TestSetModifyResponse(t *testing.T) {
|
||||
opts := OptionsRaw{
|
||||
"set_headers": map[string]string{"User-Agent": "go-proxy/v0.5.0"},
|
||||
"add_headers": map[string]string{"Accept-Encoding": "test-value"},
|
||||
"hide_headers": []string{"Accept"},
|
||||
}
|
||||
|
||||
t.Run("set_options", func(t *testing.T) {
|
||||
mr, err := ModifyResponse.m.WithOptionsClone(opts, nil)
|
||||
ExpectNoError(t, err.Error())
|
||||
ExpectDeepEqual(t, mr.impl.(*modifyResponse).SetHeaders, opts["set_headers"].(map[string]string))
|
||||
ExpectDeepEqual(t, mr.impl.(*modifyResponse).AddHeaders, opts["add_headers"].(map[string]string))
|
||||
ExpectDeepEqual(t, mr.impl.(*modifyResponse).HideHeaders, opts["hide_headers"].([]string))
|
||||
})
|
||||
|
||||
t.Run("request_headers", func(t *testing.T) {
|
||||
result, err := newMiddlewareTest(ModifyResponse.m, &testArgs{
|
||||
middlewareOpt: opts,
|
||||
})
|
||||
ExpectNoError(t, err.Error())
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("User-Agent"), "go-proxy/v0.5.0")
|
||||
t.Log(result.ResponseHeaders.Get("Accept-Encoding"))
|
||||
ExpectTrue(t, slices.Contains(result.ResponseHeaders.Values("Accept-Encoding"), "test-value"))
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("Accept"), "")
|
||||
})
|
||||
}
|
|
@ -7,14 +7,13 @@ import (
|
|||
)
|
||||
|
||||
var RedirectHTTP = &Middleware{
|
||||
before: func(w ResponseWriter, r *Request) (continue_ bool) {
|
||||
before: func(next http.Handler, w ResponseWriter, r *Request) {
|
||||
if r.TLS == nil {
|
||||
r.URL.Scheme = "https"
|
||||
r.URL.Host = r.URL.Hostname() + common.ProxyHTTPSPort
|
||||
r.URL.Host = r.URL.Hostname() + ":" + common.ProxyHTTPSPort
|
||||
http.Redirect(w, r, r.URL.String(), http.StatusTemporaryRedirect)
|
||||
} else {
|
||||
continue_ = true
|
||||
return
|
||||
}
|
||||
return
|
||||
next.ServeHTTP(w, r)
|
||||
},
|
||||
}
|
||||
|
|
26
src/route/middleware/redirect_http_test.go
Normal file
26
src/route/middleware/redirect_http_test.go
Normal file
|
@ -0,0 +1,26 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/yusing/go-proxy/common"
|
||||
. "github.com/yusing/go-proxy/utils/testing"
|
||||
)
|
||||
|
||||
func TestRedirectToHTTPs(t *testing.T) {
|
||||
result, err := newMiddlewareTest(RedirectHTTP, &testArgs{
|
||||
scheme: "http",
|
||||
})
|
||||
ExpectNoError(t, err.Error())
|
||||
ExpectEqual(t, result.ResponseStatus, http.StatusTemporaryRedirect)
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("Location"), "https://"+testHost+":"+common.ProxyHTTPSPort)
|
||||
}
|
||||
|
||||
func TestNoRedirect(t *testing.T) {
|
||||
result, err := newMiddlewareTest(RedirectHTTP, &testArgs{
|
||||
scheme: "https",
|
||||
})
|
||||
ExpectNoError(t, err.Error())
|
||||
ExpectEqual(t, result.ResponseStatus, http.StatusOK)
|
||||
}
|
|
@ -1,10 +0,0 @@
|
|||
package middleware
|
||||
|
||||
var SetXForwarded = &Middleware{
|
||||
rewrite: func(r *ProxyRequest) {
|
||||
r.Out.Header.Del("X-Forwarded-For")
|
||||
r.Out.Header.Del("X-Forwarded-Host")
|
||||
r.Out.Header.Del("X-Forwarded-Proto")
|
||||
r.SetXForwarded()
|
||||
},
|
||||
}
|
17
src/route/middleware/test_data/sample_headers.json
Normal file
17
src/route/middleware/test_data/sample_headers.json
Normal file
|
@ -0,0 +1,17 @@
|
|||
{
|
||||
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7",
|
||||
"Accept-Encoding": "gzip, deflate, br, zstd",
|
||||
"Accept-Language": "en,zh-HK;q=0.9,zh-TW;q=0.8,zh-CN;q=0.7,zh;q=0.6",
|
||||
"Dnt": "1",
|
||||
"Host": "localhost",
|
||||
"Priority": "u=0, i",
|
||||
"Sec-Ch-Ua": "\"Chromium\";v=\"129\", \"Not=A?Brand\";v=\"8\"",
|
||||
"Sec-Ch-Ua-Mobile": "?0",
|
||||
"Sec-Ch-Ua-Platform": "\"Windows\"",
|
||||
"Sec-Fetch-Dest": "document",
|
||||
"Sec-Fetch-Mode": "navigate",
|
||||
"Sec-Fetch-Site": "none",
|
||||
"Sec-Fetch-User": "?1",
|
||||
"Upgrade-Insecure-Requests": "1",
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/129.0.0.0 Safari/537.36"
|
||||
}
|
125
src/route/middleware/test_utils.go
Normal file
125
src/route/middleware/test_utils.go
Normal file
|
@ -0,0 +1,125 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
_ "embed"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
|
||||
E "github.com/yusing/go-proxy/error"
|
||||
gpHTTP "github.com/yusing/go-proxy/http"
|
||||
)
|
||||
|
||||
//go:embed test_data/sample_headers.json
|
||||
var testHeadersRaw []byte
|
||||
var testHeaders http.Header
|
||||
|
||||
const testHost = "example.com"
|
||||
|
||||
func init() {
|
||||
tmp := map[string]string{}
|
||||
err := json.Unmarshal(testHeadersRaw, &tmp)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
testHeaders = http.Header{}
|
||||
for k, v := range tmp {
|
||||
testHeaders.Set(k, v)
|
||||
}
|
||||
}
|
||||
|
||||
type requestHeaderRecorder struct {
|
||||
parent http.RoundTripper
|
||||
reqHeaders http.Header
|
||||
}
|
||||
|
||||
func (rt *requestHeaderRecorder) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
rt.reqHeaders = req.Header
|
||||
if rt.parent != nil {
|
||||
return rt.parent.RoundTrip(req)
|
||||
}
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: testHeaders,
|
||||
Body: io.NopCloser(bytes.NewBufferString("OK")),
|
||||
Request: req,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type TestResult struct {
|
||||
RequestHeaders http.Header
|
||||
ResponseHeaders http.Header
|
||||
ResponseStatus int
|
||||
Data []byte
|
||||
}
|
||||
|
||||
type testArgs struct {
|
||||
middlewareOpt OptionsRaw
|
||||
proxyURL string
|
||||
body []byte
|
||||
scheme string
|
||||
}
|
||||
|
||||
func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.NestedError) {
|
||||
var body io.Reader
|
||||
var rt = new(requestHeaderRecorder)
|
||||
var proxyURL *url.URL
|
||||
var requestTarget string
|
||||
var err error
|
||||
|
||||
if args == nil {
|
||||
args = new(testArgs)
|
||||
}
|
||||
|
||||
if args.body != nil {
|
||||
body = bytes.NewReader(args.body)
|
||||
}
|
||||
|
||||
if args.scheme == "" || args.scheme == "http" {
|
||||
requestTarget = "http://" + testHost
|
||||
} else if args.scheme == "https" {
|
||||
requestTarget = "https://" + testHost
|
||||
} else {
|
||||
panic("typo?")
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, requestTarget, body)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
if args.scheme == "https" && req.TLS == nil {
|
||||
panic("bug occurred")
|
||||
}
|
||||
|
||||
if args.proxyURL != "" {
|
||||
proxyURL, err = url.Parse(args.proxyURL)
|
||||
if err != nil {
|
||||
return nil, E.From(err)
|
||||
}
|
||||
rt.parent = http.DefaultTransport
|
||||
} else {
|
||||
proxyURL, _ = url.Parse("https://" + testHost) // dummy url, no actual effect
|
||||
}
|
||||
rp := gpHTTP.NewReverseProxy(proxyURL, rt)
|
||||
setOptErr := PatchReverseProxy(rp, map[string]OptionsRaw{
|
||||
middleware.name: args.middlewareOpt,
|
||||
})
|
||||
if setOptErr != nil {
|
||||
return nil, setOptErr
|
||||
}
|
||||
rp.ServeHTTP(w, req)
|
||||
resp := w.Result()
|
||||
defer resp.Body.Close()
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, E.From(err)
|
||||
}
|
||||
return &TestResult{
|
||||
RequestHeaders: rt.reqHeaders,
|
||||
ResponseHeaders: resp.Header,
|
||||
ResponseStatus: resp.StatusCode,
|
||||
Data: data,
|
||||
}, nil
|
||||
}
|
9
src/route/middleware/x_forwarded.go
Normal file
9
src/route/middleware/x_forwarded.go
Normal file
|
@ -0,0 +1,9 @@
|
|||
package middleware
|
||||
|
||||
var AddXForwarded = &Middleware{
|
||||
rewrite: (*ProxyRequest).AddXForwarded,
|
||||
}
|
||||
|
||||
var SetXForwarded = &Middleware{
|
||||
rewrite: (*ProxyRequest).SetXForwarded,
|
||||
}
|
|
@ -84,6 +84,11 @@ func NewServer(opt Options) (s *server) {
|
|||
}
|
||||
}
|
||||
|
||||
// Start will start the http and https servers.
|
||||
//
|
||||
// If both are not set, this does nothing.
|
||||
//
|
||||
// Start() is non-blocking
|
||||
func (s *server) Start() {
|
||||
if s.http == nil && s.https == nil {
|
||||
return
|
||||
|
|
|
@ -106,7 +106,7 @@ func Serialize(data any) (SerializedObject, E.NestedError) {
|
|||
return result, nil
|
||||
}
|
||||
|
||||
func Deserialize(src map[string]any, target any) E.NestedError {
|
||||
func Deserialize(src SerializedObject, target any) E.NestedError {
|
||||
// convert data fields to lower no-snake
|
||||
// convert target fields to lower
|
||||
// then check if the field of data is in the target
|
||||
|
@ -117,6 +117,10 @@ func Deserialize(src map[string]any, target any) E.NestedError {
|
|||
snakeCaseField := strings.ToLower(field.Name)
|
||||
mapping[snakeCaseField] = field.Name
|
||||
}
|
||||
tValue := reflect.ValueOf(target)
|
||||
if tValue.IsZero() {
|
||||
return E.Invalid("value", "nil")
|
||||
}
|
||||
for k, v := range src {
|
||||
kCleaned := toLowerNoSnake(k)
|
||||
if fieldName, ok := mapping[kCleaned]; ok {
|
||||
|
@ -150,13 +154,13 @@ func Deserialize(src map[string]any, target any) E.NestedError {
|
|||
}
|
||||
prop.Set(propNew)
|
||||
default:
|
||||
return E.Unsupported("field", k).Extraf("type=%s", propType)
|
||||
return E.Invalid("conversion", k).Extraf("from %s to %s", vType, propType)
|
||||
}
|
||||
} else {
|
||||
return E.Unsupported("field", k).Extraf("type=%s", propType)
|
||||
return E.Unsupported("field", k).Extraf("type %s is not settable", propType)
|
||||
}
|
||||
} else {
|
||||
return E.Failure("unknown field").With(k)
|
||||
return E.Unexpected("field", k)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -38,6 +38,17 @@ func ExpectEqual[T comparable](t *testing.T, got T, want T) {
|
|||
}
|
||||
}
|
||||
|
||||
func ExpectEqualAny[T comparable](t *testing.T, got T, wants []T) {
|
||||
t.Helper()
|
||||
for _, want := range wants {
|
||||
if got == want {
|
||||
return
|
||||
}
|
||||
}
|
||||
t.Errorf("expected any of:\n%v, got\n%v", wants, got)
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
func ExpectDeepEqual[T any](t *testing.T, got T, want T) {
|
||||
t.Helper()
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
|
|
Loading…
Add table
Reference in a new issue