mirror of
https://github.com/yusing/godoxy.git
synced 2025-07-21 20:04:03 +02:00
added support for a few middlewares, added match_domain
option, changed index reference prefix from $ to #, etc.
This commit is contained in:
parent
345a4417a6
commit
f474ae4f75
47 changed files with 1523 additions and 446 deletions
15
.github/workflows/docker-image.yml
vendored
15
.github/workflows/docker-image.yml
vendored
|
@ -130,3 +130,18 @@ jobs:
|
||||||
run: |
|
run: |
|
||||||
docker tag ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:${{ steps.meta.outputs.version }} ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:latest
|
docker tag ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:${{ steps.meta.outputs.version }} ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:latest
|
||||||
docker push ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:latest
|
docker push ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:latest
|
||||||
|
scan:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
needs:
|
||||||
|
- merge
|
||||||
|
steps:
|
||||||
|
- name: Scan Image with Trivy
|
||||||
|
uses: aquasecurity/trivy-action@0.20.0
|
||||||
|
with:
|
||||||
|
image-ref: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:latest
|
||||||
|
format: "sarif"
|
||||||
|
output: "trivy-results.sarif"
|
||||||
|
- name: Upload Trivy SARIF Report
|
||||||
|
uses: github/codeql-action/upload-sarif@v2
|
||||||
|
with:
|
||||||
|
sarif_file: "trivy-results.sarif"
|
||||||
|
|
10
Dockerfile
10
Dockerfile
|
@ -9,15 +9,15 @@ COPY src/go.mod src/go.sum ./
|
||||||
|
|
||||||
# Utilize build cache
|
# Utilize build cache
|
||||||
RUN --mount=type=cache,target="/go/pkg/mod" \
|
RUN --mount=type=cache,target="/go/pkg/mod" \
|
||||||
go mod download
|
go mod graph | awk '{if ($1 !~ "@") print $2}' | xargs go get
|
||||||
|
|
||||||
# Now copy the remaining files
|
ENV GOCACHE=/root/.cache/go-build
|
||||||
COPY src/ ./
|
|
||||||
|
|
||||||
# Build the application with better caching
|
# Build the application with better caching
|
||||||
RUN --mount=type=cache,target="/go/pkg/mod" \
|
RUN --mount=type=cache,target="/go/pkg/mod" \
|
||||||
--mount=type=cache,target="/root/.cache/go-build" \
|
--mount=type=cache,target="/root/.cache/go-build" \
|
||||||
CGO_ENABLED=0 GOOS=linux go build -pgo=auto -o go-proxy ./
|
--mount=type=bind,src=src,dst=/src \
|
||||||
|
CGO_ENABLED=0 GOOS=linux go build -ldflags '-w -s' -pgo=auto -o /go-proxy .
|
||||||
|
|
||||||
# Stage 2: Final image
|
# Stage 2: Final image
|
||||||
FROM scratch
|
FROM scratch
|
||||||
|
@ -28,7 +28,7 @@ LABEL maintainer="yusing@6uo.me"
|
||||||
COPY --from=builder /usr/share/zoneinfo /usr/share/zoneinfo
|
COPY --from=builder /usr/share/zoneinfo /usr/share/zoneinfo
|
||||||
|
|
||||||
# copy binary
|
# copy binary
|
||||||
COPY --from=builder /src/go-proxy /app/
|
COPY --from=builder /go-proxy /app/
|
||||||
|
|
||||||
# copy schema directory
|
# copy schema directory
|
||||||
COPY schema/ /app/schema/
|
COPY schema/ /app/schema/
|
||||||
|
|
8
Makefile
8
Makefile
|
@ -1,4 +1,4 @@
|
||||||
.PHONY: all build up quick-restart restart logs get udp-server
|
.PHONY: all setup build test up restart logs get debug run archive repush rapid-crash debug-list-containers
|
||||||
|
|
||||||
all: debug
|
all: debug
|
||||||
|
|
||||||
|
@ -9,7 +9,8 @@ setup:
|
||||||
|
|
||||||
build:
|
build:
|
||||||
mkdir -p bin
|
mkdir -p bin
|
||||||
CGO_ENABLED=0 GOOS=linux go build -pgo=auto -o bin/go-proxy github.com/yusing/go-proxy
|
CGO_ENABLED=0 GOOS=linux \
|
||||||
|
go build -ldflags '${BUILD_FLAG}' -pgo=auto -o bin/go-proxy github.com/yusing/go-proxy
|
||||||
|
|
||||||
test:
|
test:
|
||||||
go test ./src/...
|
go test ./src/...
|
||||||
|
@ -29,6 +30,9 @@ get:
|
||||||
debug:
|
debug:
|
||||||
make build && sudo GOPROXY_DEBUG=1 bin/go-proxy
|
make build && sudo GOPROXY_DEBUG=1 bin/go-proxy
|
||||||
|
|
||||||
|
run:
|
||||||
|
BUILD_FLAG="-s -w" make build && sudo bin/go-proxy
|
||||||
|
|
||||||
archive:
|
archive:
|
||||||
git archive HEAD -o ../go-proxy-$$(date +"%Y%m%d%H%M").zip
|
git archive HEAD -o ../go-proxy-$$(date +"%Y%m%d%H%M").zip
|
||||||
|
|
||||||
|
|
|
@ -1,36 +1,64 @@
|
||||||
# Autocert (choose one below and uncomment to enable)
|
# Autocert (choose one below and uncomment to enable)
|
||||||
|
#
|
||||||
# 1. use existing cert
|
# 1. use existing cert
|
||||||
|
#
|
||||||
# autocert:
|
# autocert:
|
||||||
# provider: local
|
# provider: local
|
||||||
# cert_path: certs/cert.crt # optional, uncomment only if you need to change it
|
#
|
||||||
# key_path: certs/priv.key # optional, uncomment only if you need to change it
|
# cert_path: certs/cert.crt # optional, uncomment only if you need to change it
|
||||||
|
# key_path: certs/priv.key # optional, uncomment only if you need to change it
|
||||||
|
#
|
||||||
# 2. cloudflare
|
# 2. cloudflare
|
||||||
|
#
|
||||||
# autocert:
|
# autocert:
|
||||||
# provider: cloudflare
|
# provider: cloudflare
|
||||||
# email: # ACME Email
|
# email: abc@gmail.com # ACME Email
|
||||||
# domains: # a list of domains for cert registration
|
# domains: # a list of domains for cert registration
|
||||||
# - x.y.z
|
# - "*.y.z" # remember to use double quotes to surround wildcard domain
|
||||||
# options:
|
# options:
|
||||||
# auth_token: c1234565789-abcdefghijklmnopqrst # your zone API token
|
# auth_token: c1234565789-abcdefghijklmnopqrst # your zone API token
|
||||||
|
#
|
||||||
# 3. other providers, check docs/dns_providers.md for more
|
# 3. other providers, check docs/dns_providers.md for more
|
||||||
|
|
||||||
providers:
|
providers:
|
||||||
|
# include files are standalone yaml files under `config/` directory
|
||||||
|
#
|
||||||
# include:
|
# include:
|
||||||
# - providers.yml # config/providers.yml
|
# - file1.yml
|
||||||
# # add some more below if you want
|
|
||||||
# - file1.yml # config/file_1.yml
|
|
||||||
# - file2.yml
|
# - file2.yml
|
||||||
|
|
||||||
docker:
|
docker:
|
||||||
# for value format, see https://docs.docker.com/reference/cli/dockerd/
|
# $DOCKER_HOST implies environment variable `DOCKER_HOST` or unix:///var/run/docker.sock by default
|
||||||
# $DOCKER_HOST implies unix:///var/run/docker.sock by default
|
|
||||||
local: $DOCKER_HOST
|
local: $DOCKER_HOST
|
||||||
|
|
||||||
# add more docker providers if needed
|
# add more docker providers if needed
|
||||||
|
# for value format, see https://docs.docker.com/reference/cli/dockerd/
|
||||||
|
#
|
||||||
# remote-1: tcp://10.0.2.1:2375
|
# remote-1: tcp://10.0.2.1:2375
|
||||||
# remote-2: ssh://root:1234@10.0.2.2
|
# remote-2: ssh://root:1234@10.0.2.2
|
||||||
# Fixed options (optional, non hot-reloadable)
|
# if match_domains not defined
|
||||||
|
# any host = alias+[any domain] will match
|
||||||
|
# i.e. https://app1.y.z will match alias app1 for any domain y.z
|
||||||
|
# but https://app1.node1.y.z will only match alias "app.node1"
|
||||||
|
#
|
||||||
|
# if match_domains defined
|
||||||
|
# only host = alias+[one of match_domains] will match
|
||||||
|
# i.e. match_domains = [node1.my.app, my.site]
|
||||||
|
# https://app1.my.app, https://app1.my.net, etc. will not match even if app1 exists
|
||||||
|
# only https://*.node1.my.app and https://*.my.site will match
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# match_domains:
|
||||||
|
# - my.site
|
||||||
|
# - node1.my.app
|
||||||
|
|
||||||
|
# Below are fixed options (non hot-reloadable)
|
||||||
|
|
||||||
|
# timeout for shutdown (in seconds)
|
||||||
|
#
|
||||||
# timeout_shutdown: 5
|
# timeout_shutdown: 5
|
||||||
# redirect_to_https: false # redirect http requests to https (if enabled)
|
|
||||||
|
# global setting redirect http requests to https (if https available, otherwise this will be ignored)
|
||||||
|
# proxy.<alias>.middlewares.redirect_http will override this
|
||||||
|
#
|
||||||
|
# redirect_to_https: false
|
||||||
|
|
|
@ -74,7 +74,7 @@
|
||||||
| `proxy.stop_timeout` | time to wait for stop command | | `10s` | `number[unit]...` |
|
| `proxy.stop_timeout` | time to wait for stop command | | `10s` | `number[unit]...` |
|
||||||
| `proxy.stop_signal` | signal sent to container for `stop` and `kill` methods | | docker's default | `SIGINT`, `SIGTERM`, `SIGHUP`, `SIGQUIT` and those without **SIG** prefix |
|
| `proxy.stop_signal` | signal sent to container for `stop` and `kill` methods | | docker's default | `SIGINT`, `SIGTERM`, `SIGHUP`, `SIGQUIT` and those without **SIG** prefix |
|
||||||
| `proxy.<alias>.<field>` | set field for specific alias | `proxy.gitlab-ssh.scheme` | N/A | N/A |
|
| `proxy.<alias>.<field>` | set field for specific alias | `proxy.gitlab-ssh.scheme` | N/A | N/A |
|
||||||
| `proxy.$<index>.<field>` | set field for specific alias at index (starting from **1**) | `proxy.$3.port` | N/A | N/A |
|
| `proxy.#<index>.<field>` | set field for specific alias at index (starting from **1**) | `proxy.#3.port` | N/A | N/A |
|
||||||
| `proxy.*.<field>` | set field for all aliases | `proxy.*.set_headers` | N/A | N/A |
|
| `proxy.*.<field>` | set field for all aliases | `proxy.*.set_headers` | N/A | N/A |
|
||||||
|
|
||||||
### Fields
|
### Fields
|
||||||
|
|
|
@ -37,7 +37,13 @@
|
||||||
"title": "DNS Challenge Provider",
|
"title": "DNS Challenge Provider",
|
||||||
"default": "local",
|
"default": "local",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"enum": ["local", "cloudflare", "clouddns", "duckdns", "ovh"]
|
"enum": [
|
||||||
|
"local",
|
||||||
|
"cloudflare",
|
||||||
|
"clouddns",
|
||||||
|
"duckdns",
|
||||||
|
"ovh"
|
||||||
|
]
|
||||||
},
|
},
|
||||||
"options": {
|
"options": {
|
||||||
"title": "Provider specific options",
|
"title": "Provider specific options",
|
||||||
|
@ -56,7 +62,12 @@
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"then": {
|
"then": {
|
||||||
"required": ["email", "domains", "provider", "options"]
|
"required": [
|
||||||
|
"email",
|
||||||
|
"domains",
|
||||||
|
"provider",
|
||||||
|
"options"
|
||||||
|
]
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -70,7 +81,9 @@
|
||||||
"then": {
|
"then": {
|
||||||
"properties": {
|
"properties": {
|
||||||
"options": {
|
"options": {
|
||||||
"required": ["auth_token"],
|
"required": [
|
||||||
|
"auth_token"
|
||||||
|
],
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
"properties": {
|
"properties": {
|
||||||
"auth_token": {
|
"auth_token": {
|
||||||
|
@ -93,7 +106,11 @@
|
||||||
"then": {
|
"then": {
|
||||||
"properties": {
|
"properties": {
|
||||||
"options": {
|
"options": {
|
||||||
"required": ["client_id", "email", "password"],
|
"required": [
|
||||||
|
"client_id",
|
||||||
|
"email",
|
||||||
|
"password"
|
||||||
|
],
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
"properties": {
|
"properties": {
|
||||||
"client_id": {
|
"client_id": {
|
||||||
|
@ -124,7 +141,9 @@
|
||||||
"then": {
|
"then": {
|
||||||
"properties": {
|
"properties": {
|
||||||
"options": {
|
"options": {
|
||||||
"required": ["token"],
|
"required": [
|
||||||
|
"token"
|
||||||
|
],
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
"properties": {
|
"properties": {
|
||||||
"token": {
|
"token": {
|
||||||
|
@ -147,14 +166,21 @@
|
||||||
"then": {
|
"then": {
|
||||||
"properties": {
|
"properties": {
|
||||||
"options": {
|
"options": {
|
||||||
"required": ["application_secret", "consumer_key"],
|
"required": [
|
||||||
|
"application_secret",
|
||||||
|
"consumer_key"
|
||||||
|
],
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
"oneOf": [
|
"oneOf": [
|
||||||
{
|
{
|
||||||
"required": ["application_key"]
|
"required": [
|
||||||
|
"application_key"
|
||||||
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"required": ["oauth2_config"]
|
"required": [
|
||||||
|
"oauth2_config"
|
||||||
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"properties": {
|
"properties": {
|
||||||
|
@ -205,7 +231,10 @@
|
||||||
"type": "string"
|
"type": "string"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"required": ["client_id", "client_secret"]
|
"required": [
|
||||||
|
"client_id",
|
||||||
|
"client_secret"
|
||||||
|
]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -268,6 +297,14 @@
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"match_domains": {
|
||||||
|
"title": "Domains to match",
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"minItems": 1
|
||||||
|
},
|
||||||
"timeout_shutdown": {
|
"timeout_shutdown": {
|
||||||
"title": "Shutdown timeout (in seconds)",
|
"title": "Shutdown timeout (in seconds)",
|
||||||
"type": "integer",
|
"type": "integer",
|
||||||
|
@ -279,5 +316,7 @@
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
"required": ["providers"]
|
"required": [
|
||||||
|
"providers"
|
||||||
|
]
|
||||||
}
|
}
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"os"
|
"os"
|
||||||
|
"path"
|
||||||
"reflect"
|
"reflect"
|
||||||
"sort"
|
"sort"
|
||||||
"time"
|
"time"
|
||||||
|
@ -59,8 +60,7 @@ func (p *Provider) ObtainCert() (res E.NestedError) {
|
||||||
defer b.To(&res)
|
defer b.To(&res)
|
||||||
|
|
||||||
if p.cfg.Provider == ProviderLocal {
|
if p.cfg.Provider == ProviderLocal {
|
||||||
b.Addf("provider is set to %q", ProviderLocal).WithSeverity(E.SeverityWarning)
|
return nil
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if p.client == nil {
|
if p.client == nil {
|
||||||
|
@ -191,7 +191,19 @@ func (p *Provider) registerACME() E.NestedError {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Provider) saveCert(cert *certificate.Resource) E.NestedError {
|
func (p *Provider) saveCert(cert *certificate.Resource) E.NestedError {
|
||||||
err := os.WriteFile(p.cfg.KeyPath, cert.PrivateKey, 0o600) // -rw-------
|
//* This should have been done in setup
|
||||||
|
//* but double check is always a good choice
|
||||||
|
_, err := os.Stat(path.Dir(p.cfg.CertPath))
|
||||||
|
if err != nil {
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
if err = os.MkdirAll(path.Dir(p.cfg.CertPath), 0o755); err != nil {
|
||||||
|
return E.FailWith("create cert directory", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return E.FailWith("stat cert directory", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
err = os.WriteFile(p.cfg.KeyPath, cert.PrivateKey, 0o600) // -rw-------
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return E.FailWith("write key file", err)
|
return E.FailWith("write key file", err)
|
||||||
}
|
}
|
||||||
|
@ -227,6 +239,10 @@ func (p *Provider) certState() CertState {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Provider) renewIfNeeded() E.NestedError {
|
func (p *Provider) renewIfNeeded() E.NestedError {
|
||||||
|
if p.cfg.Provider == ProviderLocal {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
switch p.certState() {
|
switch p.certState() {
|
||||||
case CertStateExpired:
|
case CertStateExpired:
|
||||||
logger.Info("certs expired, renewing")
|
logger.Info("certs expired, renewing")
|
||||||
|
|
|
@ -14,7 +14,7 @@ func (p *Provider) Setup(ctx context.Context) (err E.NestedError) {
|
||||||
}
|
}
|
||||||
logger.Debug("obtaining cert due to error loading cert")
|
logger.Debug("obtaining cert due to error loading cert")
|
||||||
if err = p.ObtainCert(); err != nil {
|
if err = p.ObtainCert(); err != nil {
|
||||||
return err.Warn()
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
25
src/common/http.go
Normal file
25
src/common/http.go
Normal file
|
@ -0,0 +1,25 @@
|
||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
defaultDialer = net.Dialer{
|
||||||
|
Timeout: 60 * time.Second,
|
||||||
|
KeepAlive: 60 * time.Second,
|
||||||
|
}
|
||||||
|
DefaultTransport = &http.Transport{
|
||||||
|
Proxy: http.ProxyFromEnvironment,
|
||||||
|
DialContext: defaultDialer.DialContext,
|
||||||
|
MaxIdleConnsPerHost: 1000,
|
||||||
|
}
|
||||||
|
DefaultTransportNoTLS = func() *http.Transport {
|
||||||
|
var clone = DefaultTransport.Clone()
|
||||||
|
clone.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
|
||||||
|
return clone
|
||||||
|
}()
|
||||||
|
)
|
|
@ -31,25 +31,48 @@ type Config struct {
|
||||||
reloadReq chan struct{}
|
reloadReq chan struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Load() (*Config, E.NestedError) {
|
var instance *Config
|
||||||
cfg := &Config{
|
|
||||||
|
func GetConfig() *Config {
|
||||||
|
return instance
|
||||||
|
}
|
||||||
|
|
||||||
|
func Load() E.NestedError {
|
||||||
|
if instance != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
instance = &Config{
|
||||||
|
value: M.DefaultConfig(),
|
||||||
proxyProviders: F.NewMapOf[string, *PR.Provider](),
|
proxyProviders: F.NewMapOf[string, *PR.Provider](),
|
||||||
l: logrus.WithField("module", "config"),
|
l: logrus.WithField("module", "config"),
|
||||||
watcher: W.NewFileWatcher(common.ConfigFileName),
|
watcher: W.NewFileWatcher(common.ConfigFileName),
|
||||||
reloadReq: make(chan struct{}, 1),
|
reloadReq: make(chan struct{}, 1),
|
||||||
}
|
}
|
||||||
return cfg, cfg.load()
|
return instance.load()
|
||||||
}
|
}
|
||||||
|
|
||||||
func Validate(data []byte) E.NestedError {
|
func Validate(data []byte) E.NestedError {
|
||||||
return U.ValidateYaml(U.GetSchema(common.ConfigSchemaPath), data)
|
return U.ValidateYaml(U.GetSchema(common.ConfigSchemaPath), data)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func MatchDomains() []string {
|
||||||
|
if instance == nil {
|
||||||
|
logrus.Panic("config has not been loaded, please check if there is any errors")
|
||||||
|
}
|
||||||
|
return instance.value.MatchDomains
|
||||||
|
}
|
||||||
|
|
||||||
func (cfg *Config) Value() M.Config {
|
func (cfg *Config) Value() M.Config {
|
||||||
|
if cfg == nil {
|
||||||
|
logrus.Panic("config has not been loaded, please check if there is any errors")
|
||||||
|
}
|
||||||
return *cfg.value
|
return *cfg.value
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cfg *Config) GetAutoCertProvider() *autocert.Provider {
|
func (cfg *Config) GetAutoCertProvider() *autocert.Provider {
|
||||||
|
if instance == nil {
|
||||||
|
logrus.Panic("config has not been loaded, please check if there is any errors")
|
||||||
|
}
|
||||||
return cfg.autocertProvider
|
return cfg.autocertProvider
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -61,13 +84,11 @@ func (cfg *Config) Dispose() {
|
||||||
cfg.stopProviders()
|
cfg.stopProviders()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cfg *Config) Reload() E.NestedError {
|
func (cfg *Config) Reload() (err E.NestedError) {
|
||||||
cfg.stopProviders()
|
cfg.stopProviders()
|
||||||
if err := cfg.load(); err.HasError() {
|
err = cfg.load()
|
||||||
return err
|
|
||||||
}
|
|
||||||
cfg.StartProxyProviders()
|
cfg.StartProxyProviders()
|
||||||
return nil
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cfg *Config) StartProxyProviders() {
|
func (cfg *Config) StartProxyProviders() {
|
||||||
|
@ -126,28 +147,28 @@ func (cfg *Config) load() (res E.NestedError) {
|
||||||
data, err := E.Check(os.ReadFile(common.ConfigPath))
|
data, err := E.Check(os.ReadFile(common.ConfigPath))
|
||||||
if err.HasError() {
|
if err.HasError() {
|
||||||
b.Add(E.FailWith("read config", err))
|
b.Add(E.FailWith("read config", err))
|
||||||
return
|
logrus.Fatal(b.Build())
|
||||||
}
|
}
|
||||||
|
|
||||||
if !common.NoSchemaValidation {
|
if !common.NoSchemaValidation {
|
||||||
if err = Validate(data); err.HasError() {
|
if err = Validate(data); err.HasError() {
|
||||||
b.Add(E.FailWith("schema validation", err))
|
b.Add(E.FailWith("schema validation", err))
|
||||||
return
|
logrus.Fatal(b.Build())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
model := M.DefaultConfig()
|
model := M.DefaultConfig()
|
||||||
if err := E.From(yaml.Unmarshal(data, model)); err.HasError() {
|
if err := E.From(yaml.Unmarshal(data, model)); err.HasError() {
|
||||||
b.Add(E.FailWith("parse config", err))
|
b.Add(E.FailWith("parse config", err))
|
||||||
return
|
logrus.Fatal(b.Build())
|
||||||
}
|
}
|
||||||
|
|
||||||
// errors are non fatal below
|
// errors are non fatal below
|
||||||
b.WithSeverity(E.SeverityWarning)
|
|
||||||
b.Add(cfg.initAutoCert(&model.AutoCert))
|
b.Add(cfg.initAutoCert(&model.AutoCert))
|
||||||
b.Add(cfg.loadProviders(&model.Providers))
|
b.Add(cfg.loadProviders(&model.Providers))
|
||||||
|
|
||||||
cfg.value = model
|
cfg.value = model
|
||||||
|
R.SetFindMuxDomains(model.MatchDomains)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,17 +1,37 @@
|
||||||
package docker
|
package docker
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
E "github.com/yusing/go-proxy/error"
|
E "github.com/yusing/go-proxy/error"
|
||||||
U "github.com/yusing/go-proxy/utils"
|
U "github.com/yusing/go-proxy/utils"
|
||||||
|
F "github.com/yusing/go-proxy/utils/functional"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Label struct {
|
/*
|
||||||
Namespace string
|
Formats:
|
||||||
Target string
|
- namespace.attribute
|
||||||
Attribute string
|
- namespace.target.attribute
|
||||||
Value any
|
- namespace.target.attribute.namespace2.attribute
|
||||||
|
*/
|
||||||
|
type (
|
||||||
|
Label struct {
|
||||||
|
Namespace string
|
||||||
|
Target string
|
||||||
|
Attribute string
|
||||||
|
Value any
|
||||||
|
}
|
||||||
|
NestedLabelMap map[string]U.SerializedObject
|
||||||
|
ValueParser func(string) (any, E.NestedError)
|
||||||
|
ValueParserMap map[string]ValueParser
|
||||||
|
)
|
||||||
|
|
||||||
|
func (l *Label) String() string {
|
||||||
|
if l.Attribute == "" {
|
||||||
|
return l.Namespace + "." + l.Target
|
||||||
|
}
|
||||||
|
return l.Namespace + "." + l.Target + "." + l.Attribute
|
||||||
}
|
}
|
||||||
|
|
||||||
// Apply applies the value of a Label to the corresponding field in the given object.
|
// Apply applies the value of a Label to the corresponding field in the given object.
|
||||||
|
@ -23,12 +43,40 @@ type Label struct {
|
||||||
// Returns:
|
// Returns:
|
||||||
// - error: an error if the field does not exist.
|
// - error: an error if the field does not exist.
|
||||||
func ApplyLabel[T any](obj *T, l *Label) E.NestedError {
|
func ApplyLabel[T any](obj *T, l *Label) E.NestedError {
|
||||||
return U.Deserialize(map[string]any{l.Attribute: l.Value}, obj)
|
if obj == nil {
|
||||||
|
return E.Invalid("nil object", l)
|
||||||
|
}
|
||||||
|
switch nestedLabel := l.Value.(type) {
|
||||||
|
case *Label:
|
||||||
|
var field reflect.Value
|
||||||
|
objType := reflect.TypeFor[T]()
|
||||||
|
for i := 0; i < reflect.TypeFor[T]().NumField(); i++ {
|
||||||
|
if objType.Field(i).Tag.Get("yaml") == l.Attribute {
|
||||||
|
field = reflect.ValueOf(obj).Elem().Field(i)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !field.IsValid() {
|
||||||
|
return E.NotExist("field", l.Attribute)
|
||||||
|
}
|
||||||
|
dst, ok := field.Interface().(NestedLabelMap)
|
||||||
|
if !ok {
|
||||||
|
return E.Invalid("type", field.Type())
|
||||||
|
}
|
||||||
|
if dst == nil {
|
||||||
|
field.Set(reflect.MakeMap(reflect.TypeFor[NestedLabelMap]()))
|
||||||
|
dst = field.Interface().(NestedLabelMap)
|
||||||
|
}
|
||||||
|
if dst[nestedLabel.Namespace] == nil {
|
||||||
|
dst[nestedLabel.Namespace] = make(U.SerializedObject)
|
||||||
|
}
|
||||||
|
dst[nestedLabel.Namespace][nestedLabel.Attribute] = nestedLabel.Value
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
return U.Deserialize(U.SerializedObject{l.Attribute: l.Value}, obj)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type ValueParser func(string) (any, E.NestedError)
|
|
||||||
type ValueParserMap map[string]ValueParser
|
|
||||||
|
|
||||||
func ParseLabel(label string, value string) (*Label, E.NestedError) {
|
func ParseLabel(label string, value string) (*Label, E.NestedError) {
|
||||||
parts := strings.Split(label, ".")
|
parts := strings.Split(label, ".")
|
||||||
|
|
||||||
|
@ -45,14 +93,22 @@ func ParseLabel(label string, value string) (*Label, E.NestedError) {
|
||||||
Value: value,
|
Value: value,
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(parts) == 3 {
|
switch len(parts) {
|
||||||
l.Attribute = parts[2]
|
case 2:
|
||||||
} else {
|
|
||||||
l.Attribute = l.Target
|
l.Attribute = l.Target
|
||||||
|
case 3:
|
||||||
|
l.Attribute = parts[2]
|
||||||
|
default:
|
||||||
|
l.Attribute = parts[2]
|
||||||
|
nestedLabel, err := ParseLabel(strings.Join(parts[3:], "."), value)
|
||||||
|
if err.HasError() {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
l.Value = nestedLabel
|
||||||
}
|
}
|
||||||
|
|
||||||
// find if namespace has value parser
|
// find if namespace has value parser
|
||||||
pm, ok := labelValueParserMap[l.Namespace]
|
pm, ok := valueParserMap.Load(l.Namespace)
|
||||||
if !ok {
|
if !ok {
|
||||||
return l, nil
|
return l, nil
|
||||||
}
|
}
|
||||||
|
@ -64,15 +120,28 @@ func ParseLabel(label string, value string) (*Label, E.NestedError) {
|
||||||
// try to parse value
|
// try to parse value
|
||||||
v, err := p(value)
|
v, err := p(value)
|
||||||
if err.HasError() {
|
if err.HasError() {
|
||||||
return nil, err
|
return nil, err.Subject(label)
|
||||||
}
|
}
|
||||||
l.Value = v
|
l.Value = v
|
||||||
return l, nil
|
return l, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func RegisterNamespace(namespace string, pm ValueParserMap) {
|
func RegisterNamespace(namespace string, pm ValueParserMap) {
|
||||||
labelValueParserMap[namespace] = pm
|
valueParserMap.Store(namespace, pm)
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetRegisteredNamespaces() map[string][]string {
|
||||||
|
r := make(map[string][]string)
|
||||||
|
|
||||||
|
valueParserMap.RangeAll(func(ns string, vpm ValueParserMap) {
|
||||||
|
r[ns] = make([]string, 0, len(vpm))
|
||||||
|
for attr := range vpm {
|
||||||
|
r[ns] = append(r[ns], attr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
// namespace:target.attribute -> func(string) (any, error)
|
// namespace:target.attribute -> func(string) (any, error)
|
||||||
var labelValueParserMap = make(map[string]ValueParserMap)
|
var valueParserMap = F.NewMapOf[string, ValueParserMap]()
|
||||||
|
|
|
@ -7,7 +7,27 @@ import (
|
||||||
"gopkg.in/yaml.v3"
|
"gopkg.in/yaml.v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
func yamlListParser(value string) (any, E.NestedError) {
|
const (
|
||||||
|
NSProxy = "proxy"
|
||||||
|
ProxyAttributePathPatterns = "path_patterns"
|
||||||
|
ProxyAttributeNoTLSVerify = "no_tls_verify"
|
||||||
|
ProxyAttributeMiddlewares = "middlewares"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ = func() int {
|
||||||
|
RegisterNamespace(NSProxy, ValueParserMap{
|
||||||
|
ProxyAttributePathPatterns: YamlStringListParser,
|
||||||
|
ProxyAttributeNoTLSVerify: BoolParser,
|
||||||
|
})
|
||||||
|
return 0
|
||||||
|
}()
|
||||||
|
|
||||||
|
func YamlStringListParser(value string) (any, E.NestedError) {
|
||||||
|
/*
|
||||||
|
- foo
|
||||||
|
- bar
|
||||||
|
- baz
|
||||||
|
*/
|
||||||
value = strings.TrimSpace(value)
|
value = strings.TrimSpace(value)
|
||||||
if value == "" {
|
if value == "" {
|
||||||
return []string{}, nil
|
return []string{}, nil
|
||||||
|
@ -17,27 +37,36 @@ func yamlListParser(value string) (any, E.NestedError) {
|
||||||
return data, err
|
return data, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func yamlStringMappingParser(value string) (any, E.NestedError) {
|
func YamlLikeMappingParser(allowDuplicate bool) func(string) (any, E.NestedError) {
|
||||||
value = strings.TrimSpace(value)
|
return func(value string) (any, E.NestedError) {
|
||||||
lines := strings.Split(value, "\n")
|
/*
|
||||||
h := make(map[string]string)
|
foo: bar
|
||||||
for _, line := range lines {
|
boo: baz
|
||||||
parts := strings.SplitN(line, ":", 2)
|
*/
|
||||||
if len(parts) != 2 {
|
value = strings.TrimSpace(value)
|
||||||
return nil, E.Invalid("set header statement", line)
|
lines := strings.Split(value, "\n")
|
||||||
}
|
h := make(map[string]string)
|
||||||
key := strings.TrimSpace(parts[0])
|
for _, line := range lines {
|
||||||
val := strings.TrimSpace(parts[1])
|
parts := strings.SplitN(line, ":", 2)
|
||||||
if existing, ok := h[key]; ok {
|
if len(parts) != 2 {
|
||||||
h[key] = existing + ", " + val
|
return nil, E.Invalid("syntax", line)
|
||||||
} else {
|
}
|
||||||
h[key] = val
|
key := strings.TrimSpace(parts[0])
|
||||||
|
val := strings.TrimSpace(parts[1])
|
||||||
|
if existing, ok := h[key]; ok {
|
||||||
|
if !allowDuplicate {
|
||||||
|
return nil, E.Duplicated("key", key)
|
||||||
|
}
|
||||||
|
h[key] = existing + ", " + val
|
||||||
|
} else {
|
||||||
|
h[key] = val
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
return h, nil
|
||||||
}
|
}
|
||||||
return h, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func boolParser(value string) (any, E.NestedError) {
|
func BoolParser(value string) (any, E.NestedError) {
|
||||||
switch strings.ToLower(value) {
|
switch strings.ToLower(value) {
|
||||||
case "true", "yes", "1":
|
case "true", "yes", "1":
|
||||||
return true, nil
|
return true, nil
|
||||||
|
@ -47,15 +76,3 @@ func boolParser(value string) (any, E.NestedError) {
|
||||||
return nil, E.Invalid("boolean value", value)
|
return nil, E.Invalid("boolean value", value)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const NSProxy = "proxy"
|
|
||||||
|
|
||||||
var _ = func() int {
|
|
||||||
RegisterNamespace(NSProxy, ValueParserMap{
|
|
||||||
"path_patterns": yamlListParser,
|
|
||||||
"set_headers": yamlStringMappingParser,
|
|
||||||
"hide_headers": yamlListParser,
|
|
||||||
"no_tls_verify": boolParser,
|
|
||||||
})
|
|
||||||
return 0
|
|
||||||
}()
|
|
||||||
|
|
|
@ -2,8 +2,6 @@ package docker
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
E "github.com/yusing/go-proxy/error"
|
E "github.com/yusing/go-proxy/error"
|
||||||
|
@ -14,21 +12,16 @@ func makeLabel(namespace string, alias string, field string) string {
|
||||||
return fmt.Sprintf("%s.%s.%s", namespace, alias, field)
|
return fmt.Sprintf("%s.%s.%s", namespace, alias, field)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHomePageLabel(t *testing.T) {
|
func TestParseLabel(t *testing.T) {
|
||||||
alias := "foo"
|
alias := "foo"
|
||||||
field := "ip"
|
field := "ip"
|
||||||
v := "bar"
|
v := "bar"
|
||||||
pl, err := ParseLabel(makeLabel(NSHomePage, alias, field), v)
|
pl, err := ParseLabel(makeLabel(NSHomePage, alias, field), v)
|
||||||
ExpectNoError(t, err.Error())
|
ExpectNoError(t, err.Error())
|
||||||
if pl.Target != alias {
|
ExpectEqual(t, pl.Namespace, NSHomePage)
|
||||||
t.Errorf("Expected alias=%s, got %s", alias, pl.Target)
|
ExpectEqual(t, pl.Target, alias)
|
||||||
}
|
ExpectEqual(t, pl.Attribute, field)
|
||||||
if pl.Attribute != field {
|
ExpectEqual(t, pl.Value.(string), v)
|
||||||
t.Errorf("Expected field=%s, got %s", field, pl.Target)
|
|
||||||
}
|
|
||||||
if pl.Value != v {
|
|
||||||
t.Errorf("Expected value=%q, got %s", v, pl.Value)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestStringProxyLabel(t *testing.T) {
|
func TestStringProxyLabel(t *testing.T) {
|
||||||
|
@ -51,90 +44,63 @@ func TestBoolProxyLabelValid(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
for k, v := range tests {
|
for k, v := range tests {
|
||||||
pl, err := ParseLabel(makeLabel(NSProxy, "foo", "no_tls_verify"), k)
|
pl, err := ParseLabel(makeLabel(NSProxy, "foo", ProxyAttributeNoTLSVerify), k)
|
||||||
ExpectNoError(t, err.Error())
|
ExpectNoError(t, err.Error())
|
||||||
ExpectEqual(t, pl.Value.(bool), v)
|
ExpectEqual(t, pl.Value.(bool), v)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestBoolProxyLabelInvalid(t *testing.T) {
|
func TestBoolProxyLabelInvalid(t *testing.T) {
|
||||||
alias := "foo"
|
_, err := ParseLabel(makeLabel(NSProxy, "foo", ProxyAttributeNoTLSVerify), "invalid")
|
||||||
field := "no_tls_verify"
|
|
||||||
_, err := ParseLabel(makeLabel(NSProxy, alias, field), "invalid")
|
|
||||||
if !err.Is(E.ErrInvalid) {
|
if !err.Is(E.ErrInvalid) {
|
||||||
t.Errorf("Expected err InvalidProxyLabel, got %s", err.Error())
|
t.Errorf("Expected err InvalidProxyLabel, got %s", err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSetHeaderProxyLabelValid(t *testing.T) {
|
// func TestSetHeaderProxyLabelValid(t *testing.T) {
|
||||||
v := `
|
// v := `
|
||||||
X-Custom-Header1: foo, bar
|
// X-Custom-Header1: foo, bar
|
||||||
X-Custom-Header1: baz
|
// X-Custom-Header1: baz
|
||||||
X-Custom-Header2: boo`
|
// X-Custom-Header2: boo`
|
||||||
v = strings.TrimPrefix(v, "\n")
|
// v = strings.TrimPrefix(v, "\n")
|
||||||
h := map[string]string{
|
// h := map[string]string{
|
||||||
"X-Custom-Header1": "foo, bar, baz",
|
// "X-Custom-Header1": "foo, bar, baz",
|
||||||
"X-Custom-Header2": "boo",
|
// "X-Custom-Header2": "boo",
|
||||||
}
|
// }
|
||||||
|
|
||||||
pl, err := ParseLabel(makeLabel(NSProxy, "foo", "set_headers"), v)
|
// pl, err := ParseLabel(makeLabel(NSProxy, "foo", ProxyAttributeSetHeaders), v)
|
||||||
ExpectNoError(t, err.Error())
|
// ExpectNoError(t, err.Error())
|
||||||
hGot := ExpectType[map[string]string](t, pl.Value)
|
// hGot := ExpectType[map[string]string](t, pl.Value)
|
||||||
if hGot != nil && !reflect.DeepEqual(h, hGot) {
|
// ExpectFalse(t, hGot == nil)
|
||||||
t.Errorf("Expected %v, got %v", h, hGot)
|
// ExpectDeepEqual(t, h, hGot)
|
||||||
}
|
// }
|
||||||
|
|
||||||
}
|
// func TestSetHeaderProxyLabelInvalid(t *testing.T) {
|
||||||
|
// tests := []string{
|
||||||
|
// "X-Custom-Header1 = bar",
|
||||||
|
// "X-Custom-Header1",
|
||||||
|
// "- X-Custom-Header1",
|
||||||
|
// }
|
||||||
|
|
||||||
func TestSetHeaderProxyLabelInvalid(t *testing.T) {
|
// for _, v := range tests {
|
||||||
tests := []string{
|
// _, err := ParseLabel(makeLabel(NSProxy, "foo", ProxyAttributeSetHeaders), v)
|
||||||
"X-Custom-Header1 = bar",
|
// if !err.Is(E.ErrInvalid) {
|
||||||
"X-Custom-Header1",
|
// t.Errorf("Expected invalid err for %q, got %s", v, err.Error())
|
||||||
"- X-Custom-Header1",
|
// }
|
||||||
}
|
|
||||||
|
|
||||||
for _, v := range tests {
|
|
||||||
_, err := ParseLabel(makeLabel(NSProxy, "foo", "set_headers"), v)
|
|
||||||
if !err.Is(E.ErrInvalid) {
|
|
||||||
t.Errorf("Expected invalid err for %q, got %s", v, err.Error())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHideHeadersProxyLabel(t *testing.T) {
|
|
||||||
v := `
|
|
||||||
- X-Custom-Header1
|
|
||||||
- X-Custom-Header2
|
|
||||||
- X-Custom-Header3
|
|
||||||
`
|
|
||||||
v = strings.TrimPrefix(v, "\n")
|
|
||||||
pl, err := ParseLabel(makeLabel(NSProxy, "foo", "hide_headers"), v)
|
|
||||||
ExpectNoError(t, err.Error())
|
|
||||||
sGot := ExpectType[[]string](t, pl.Value)
|
|
||||||
sWant := []string{"X-Custom-Header1", "X-Custom-Header2", "X-Custom-Header3"}
|
|
||||||
if sGot != nil {
|
|
||||||
ExpectDeepEqual(t, sGot, sWant)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// func TestCommaSepProxyLabelSingle(t *testing.T) {
|
|
||||||
// v := "a"
|
|
||||||
// pl, err := ParseLabel("proxy.aliases", v)
|
|
||||||
// ExpectNoError(t, err)
|
|
||||||
// sGot := ExpectType[[]string](t, pl.Value)
|
|
||||||
// sWant := []string{"a"}
|
|
||||||
// if sGot != nil {
|
|
||||||
// ExpectEqual(t, sGot, sWant)
|
|
||||||
// }
|
// }
|
||||||
// }
|
// }
|
||||||
|
|
||||||
// func TestCommaSepProxyLabelMulti(t *testing.T) {
|
// func TestHideHeadersProxyLabel(t *testing.T) {
|
||||||
// v := "X-Custom-Header1, X-Custom-Header2,X-Custom-Header3"
|
// v := `
|
||||||
// pl, err := ParseLabel("proxy.aliases", v)
|
// - X-Custom-Header1
|
||||||
// ExpectNoError(t, err)
|
// - X-Custom-Header2
|
||||||
|
// - X-Custom-Header3
|
||||||
|
// `
|
||||||
|
// v = strings.TrimPrefix(v, "\n")
|
||||||
|
// pl, err := ParseLabel(makeLabel(NSProxy, "foo", ProxyAttributeHideHeaders), v)
|
||||||
|
// ExpectNoError(t, err.Error())
|
||||||
// sGot := ExpectType[[]string](t, pl.Value)
|
// sGot := ExpectType[[]string](t, pl.Value)
|
||||||
// sWant := []string{"X-Custom-Header1", "X-Custom-Header2", "X-Custom-Header3"}
|
// sWant := []string{"X-Custom-Header1", "X-Custom-Header2", "X-Custom-Header3"}
|
||||||
// if sGot != nil {
|
// ExpectFalse(t, sGot == nil)
|
||||||
// ExpectEqual(t, sGot, sWant)
|
// ExpectDeepEqual(t, sGot, sWant)
|
||||||
// }
|
|
||||||
// }
|
// }
|
||||||
|
|
85
src/docker/label_test.go
Normal file
85
src/docker/label_test.go
Normal file
|
@ -0,0 +1,85 @@
|
||||||
|
package docker
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
U "github.com/yusing/go-proxy/utils"
|
||||||
|
. "github.com/yusing/go-proxy/utils/testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNestedLabel(t *testing.T) {
|
||||||
|
mName := "middleware1"
|
||||||
|
mAttr := "prop1"
|
||||||
|
v := "value1"
|
||||||
|
pl, err := ParseLabel(makeLabel(NSProxy, "foo", fmt.Sprintf("%s.%s.%s", ProxyAttributeMiddlewares, mName, mAttr)), v)
|
||||||
|
ExpectNoError(t, err.Error())
|
||||||
|
sGot := ExpectType[*Label](t, pl.Value)
|
||||||
|
ExpectFalse(t, sGot == nil)
|
||||||
|
ExpectEqual(t, sGot.Namespace, mName)
|
||||||
|
ExpectEqual(t, sGot.Attribute, mAttr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyNestedLabel(t *testing.T) {
|
||||||
|
entry := new(struct {
|
||||||
|
Middlewares NestedLabelMap `yaml:"middlewares"`
|
||||||
|
})
|
||||||
|
mName := "middleware1"
|
||||||
|
mAttr := "prop1"
|
||||||
|
v := "value1"
|
||||||
|
pl, err := ParseLabel(makeLabel(NSProxy, "foo", fmt.Sprintf("%s.%s.%s", ProxyAttributeMiddlewares, mName, mAttr)), v)
|
||||||
|
ExpectNoError(t, err.Error())
|
||||||
|
err = ApplyLabel(entry, pl)
|
||||||
|
ExpectNoError(t, err.Error())
|
||||||
|
middleware1, ok := entry.Middlewares[mName]
|
||||||
|
ExpectTrue(t, ok)
|
||||||
|
got := ExpectType[string](t, middleware1[mAttr])
|
||||||
|
ExpectEqual(t, got, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyNestedLabelExisting(t *testing.T) {
|
||||||
|
mName := "middleware1"
|
||||||
|
mAttr := "prop1"
|
||||||
|
v := "value1"
|
||||||
|
|
||||||
|
checkAttr := "prop2"
|
||||||
|
checkV := "value2"
|
||||||
|
entry := new(struct {
|
||||||
|
Middlewares NestedLabelMap `yaml:"middlewares"`
|
||||||
|
})
|
||||||
|
entry.Middlewares = make(NestedLabelMap)
|
||||||
|
entry.Middlewares[mName] = make(U.SerializedObject)
|
||||||
|
entry.Middlewares[mName][checkAttr] = checkV
|
||||||
|
|
||||||
|
pl, err := ParseLabel(makeLabel(NSProxy, "foo", fmt.Sprintf("%s.%s.%s", ProxyAttributeMiddlewares, mName, mAttr)), v)
|
||||||
|
ExpectNoError(t, err.Error())
|
||||||
|
err = ApplyLabel(entry, pl)
|
||||||
|
ExpectNoError(t, err.Error())
|
||||||
|
middleware1, ok := entry.Middlewares[mName]
|
||||||
|
ExpectTrue(t, ok)
|
||||||
|
got := ExpectType[string](t, middleware1[mAttr])
|
||||||
|
ExpectEqual(t, got, v)
|
||||||
|
|
||||||
|
// check if prop2 is affected
|
||||||
|
ExpectFalse(t, middleware1[checkAttr] == nil)
|
||||||
|
got = ExpectType[string](t, middleware1[checkAttr])
|
||||||
|
ExpectEqual(t, got, checkV)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyNestedLabelNoAttr(t *testing.T) {
|
||||||
|
mName := "middleware1"
|
||||||
|
v := "value1"
|
||||||
|
|
||||||
|
entry := new(struct {
|
||||||
|
Middlewares NestedLabelMap `yaml:"middlewares"`
|
||||||
|
})
|
||||||
|
entry.Middlewares = make(NestedLabelMap)
|
||||||
|
entry.Middlewares[mName] = make(U.SerializedObject)
|
||||||
|
|
||||||
|
pl, err := ParseLabel(makeLabel(NSProxy, "foo", fmt.Sprintf("%s.%s", ProxyAttributeMiddlewares, mName)), v)
|
||||||
|
ExpectNoError(t, err.Error())
|
||||||
|
err = ApplyLabel(entry, pl)
|
||||||
|
ExpectNoError(t, err.Error())
|
||||||
|
_, ok := entry.Middlewares[mName]
|
||||||
|
ExpectTrue(t, ok)
|
||||||
|
}
|
|
@ -10,9 +10,8 @@ type Builder struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type builder struct {
|
type builder struct {
|
||||||
message string
|
message string
|
||||||
errors []NestedError
|
errors []NestedError
|
||||||
severity Severity
|
|
||||||
sync.Mutex
|
sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -40,11 +39,6 @@ func (b Builder) Addf(format string, args ...any) Builder {
|
||||||
return b.Add(errorf(format, args...))
|
return b.Add(errorf(format, args...))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b Builder) WithSeverity(s Severity) Builder {
|
|
||||||
b.severity = s
|
|
||||||
return b
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build builds a NestedError based on the errors collected in the Builder.
|
// Build builds a NestedError based on the errors collected in the Builder.
|
||||||
//
|
//
|
||||||
// If there are no errors in the Builder, it returns a Nil() NestedError.
|
// If there are no errors in the Builder, it returns a Nil() NestedError.
|
||||||
|
@ -58,7 +52,7 @@ func (b Builder) Build() NestedError {
|
||||||
} else if len(b.errors) == 1 {
|
} else if len(b.errors) == 1 {
|
||||||
return b.errors[0]
|
return b.errors[0]
|
||||||
}
|
}
|
||||||
return Join(b.message, b.errors...).Severity(b.severity)
|
return Join(b.message, b.errors...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b Builder) To(ptr *NestedError) {
|
func (b Builder) To(ptr *NestedError) {
|
||||||
|
|
|
@ -9,17 +9,10 @@ import (
|
||||||
type (
|
type (
|
||||||
NestedError = *nestedError
|
NestedError = *nestedError
|
||||||
nestedError struct {
|
nestedError struct {
|
||||||
subject string
|
subject string
|
||||||
err error
|
err error
|
||||||
extras []nestedError
|
extras []nestedError
|
||||||
severity Severity
|
|
||||||
}
|
}
|
||||||
Severity uint8
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
SeverityWarning Severity = iota
|
|
||||||
SeverityFatal
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func From(err error) NestedError {
|
func From(err error) NestedError {
|
||||||
|
@ -164,22 +157,6 @@ func (ne NestedError) Subjectf(format string, args ...any) NestedError {
|
||||||
return ne
|
return ne
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ne NestedError) Severity(s Severity) NestedError {
|
|
||||||
if ne == nil {
|
|
||||||
return ne
|
|
||||||
}
|
|
||||||
ne.severity = s
|
|
||||||
return ne
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ne NestedError) Warn() NestedError {
|
|
||||||
if ne == nil {
|
|
||||||
return ne
|
|
||||||
}
|
|
||||||
ne.severity = SeverityWarning
|
|
||||||
return ne
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ne NestedError) NoError() bool {
|
func (ne NestedError) NoError() bool {
|
||||||
return ne == nil
|
return ne == nil
|
||||||
}
|
}
|
||||||
|
@ -188,14 +165,6 @@ func (ne NestedError) HasError() bool {
|
||||||
return ne != nil
|
return ne != nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ne NestedError) IsFatal() bool {
|
|
||||||
return ne != nil && ne.severity == SeverityFatal
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ne NestedError) IsWarning() bool {
|
|
||||||
return ne != nil && ne.severity == SeverityWarning
|
|
||||||
}
|
|
||||||
|
|
||||||
func errorf(format string, args ...any) NestedError {
|
func errorf(format string, args ...any) NestedError {
|
||||||
return From(fmt.Errorf(format, args...))
|
return From(fmt.Errorf(format, args...))
|
||||||
}
|
}
|
||||||
|
|
|
@ -31,11 +31,11 @@ func TestErrorNestedIs(t *testing.T) {
|
||||||
|
|
||||||
err = Failure("some reason")
|
err = Failure("some reason")
|
||||||
ExpectTrue(t, err.Is(ErrFailure))
|
ExpectTrue(t, err.Is(ErrFailure))
|
||||||
ExpectFalse(t, err.Is(ErrAlreadyExist))
|
ExpectFalse(t, err.Is(ErrDuplicated))
|
||||||
|
|
||||||
err.With(AlreadyExist("something", ""))
|
err.With(Duplicated("something", ""))
|
||||||
ExpectTrue(t, err.Is(ErrFailure))
|
ExpectTrue(t, err.Is(ErrFailure))
|
||||||
ExpectTrue(t, err.Is(ErrAlreadyExist))
|
ExpectTrue(t, err.Is(ErrDuplicated))
|
||||||
ExpectFalse(t, err.Is(ErrInvalid))
|
ExpectFalse(t, err.Is(ErrInvalid))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -5,14 +5,14 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrFailure = stderrors.New("failed")
|
ErrFailure = stderrors.New("failed")
|
||||||
ErrInvalid = stderrors.New("invalid")
|
ErrInvalid = stderrors.New("invalid")
|
||||||
ErrUnsupported = stderrors.New("unsupported")
|
ErrUnsupported = stderrors.New("unsupported")
|
||||||
ErrUnexpected = stderrors.New("unexpected")
|
ErrUnexpected = stderrors.New("unexpected")
|
||||||
ErrNotExists = stderrors.New("does not exist")
|
ErrNotExists = stderrors.New("does not exist")
|
||||||
ErrMissing = stderrors.New("missing")
|
ErrMissing = stderrors.New("missing")
|
||||||
ErrAlreadyExist = stderrors.New("already exist")
|
ErrDuplicated = stderrors.New("duplicated")
|
||||||
ErrOutOfRange = stderrors.New("out of range")
|
ErrOutOfRange = stderrors.New("out of range")
|
||||||
)
|
)
|
||||||
|
|
||||||
const fmtSubjectWhat = "%w %v: %q"
|
const fmtSubjectWhat = "%w %v: %q"
|
||||||
|
@ -53,8 +53,8 @@ func Missing(subject any) NestedError {
|
||||||
return errorf("%w %v", ErrMissing, subject)
|
return errorf("%w %v", ErrMissing, subject)
|
||||||
}
|
}
|
||||||
|
|
||||||
func AlreadyExist(subject, what any) NestedError {
|
func Duplicated(subject, what any) NestedError {
|
||||||
return errorf("%v %w: %v", subject, ErrAlreadyExist, what)
|
return errorf("%w %v: %v", ErrDuplicated, subject, what)
|
||||||
}
|
}
|
||||||
|
|
||||||
func OutOfRange(subject string, value any) NestedError {
|
func OutOfRange(subject string, value any) NestedError {
|
||||||
|
|
|
@ -17,7 +17,7 @@ require (
|
||||||
require (
|
require (
|
||||||
github.com/Microsoft/go-winio v0.6.2 // indirect
|
github.com/Microsoft/go-winio v0.6.2 // indirect
|
||||||
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
|
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
|
||||||
github.com/cloudflare/cloudflare-go v0.104.0 // indirect
|
github.com/cloudflare/cloudflare-go v0.105.0 // indirect
|
||||||
github.com/containerd/log v0.1.0 // indirect
|
github.com/containerd/log v0.1.0 // indirect
|
||||||
github.com/distribution/reference v0.6.0 // indirect
|
github.com/distribution/reference v0.6.0 // indirect
|
||||||
github.com/docker/go-connections v0.5.0 // indirect
|
github.com/docker/go-connections v0.5.0 // indirect
|
||||||
|
|
|
@ -4,8 +4,8 @@ github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERo
|
||||||
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
|
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
|
||||||
github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8=
|
github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8=
|
||||||
github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
|
github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
|
||||||
github.com/cloudflare/cloudflare-go v0.104.0 h1:R/lB0dZupaZbOgibAH/BRrkFbZ6Acn/WsKg2iX2xXuY=
|
github.com/cloudflare/cloudflare-go v0.105.0 h1:yu2IatITLZ4dw7/byzRrlE5DfUvtub0k9CHZ5zBlj90=
|
||||||
github.com/cloudflare/cloudflare-go v0.104.0/go.mod h1:pfUQ4PIG4ISI0/Mmc21Bp86UnFU0ktmPf3iTgbSL+cM=
|
github.com/cloudflare/cloudflare-go v0.105.0/go.mod h1:pfUQ4PIG4ISI0/Mmc21Bp86UnFU0ktmPf3iTgbSL+cM=
|
||||||
github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I=
|
github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I=
|
||||||
github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo=
|
github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo=
|
||||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
|
96
src/http/modify_response_writer.go
Normal file
96
src/http/modify_response_writer.go
Normal file
|
@ -0,0 +1,96 @@
|
||||||
|
// Modified from Traefik Labs's MIT-licensed code (https://github.com/traefik/traefik/blob/master/pkg/middlewares/response_modifier.go)
|
||||||
|
// Copyright (c) 2020-2024 Traefik Labs
|
||||||
|
|
||||||
|
package http
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ModifyResponseFunc func(*http.Response) error
|
||||||
|
type ModifyResponseWriter struct {
|
||||||
|
w http.ResponseWriter
|
||||||
|
r *http.Request
|
||||||
|
|
||||||
|
headerSent bool
|
||||||
|
code int
|
||||||
|
|
||||||
|
modifier ModifyResponseFunc
|
||||||
|
modified bool
|
||||||
|
modifierErr error
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewModifyResponseWriter(w http.ResponseWriter, r *http.Request, f ModifyResponseFunc) *ModifyResponseWriter {
|
||||||
|
return &ModifyResponseWriter{
|
||||||
|
w: w,
|
||||||
|
r: r,
|
||||||
|
modifier: f,
|
||||||
|
code: http.StatusOK,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *ModifyResponseWriter) WriteHeader(code int) {
|
||||||
|
if w.headerSent {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if code >= http.StatusContinue && code < http.StatusOK {
|
||||||
|
w.w.WriteHeader(code)
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
w.headerSent = true
|
||||||
|
w.code = code
|
||||||
|
}()
|
||||||
|
|
||||||
|
if w.modifier == nil || w.modified {
|
||||||
|
w.w.WriteHeader(code)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := http.Response{
|
||||||
|
Header: w.w.Header(),
|
||||||
|
Request: w.r,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := w.modifier(&resp); err != nil {
|
||||||
|
w.modifierErr = err
|
||||||
|
logger.Errorf("error modifying response: %s", err)
|
||||||
|
w.w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.modified = true
|
||||||
|
w.w.WriteHeader(code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *ModifyResponseWriter) Header() http.Header {
|
||||||
|
return w.w.Header()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *ModifyResponseWriter) Write(b []byte) (int, error) {
|
||||||
|
w.WriteHeader(w.code)
|
||||||
|
if w.modifierErr != nil {
|
||||||
|
return 0, w.modifierErr
|
||||||
|
}
|
||||||
|
return w.w.Write(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hijack hijacks the connection.
|
||||||
|
func (w *ModifyResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||||
|
if h, ok := w.w.(http.Hijacker); ok {
|
||||||
|
return h.Hijack()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, nil, fmt.Errorf("not a hijacker: %T", w.w)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Flush sends any buffered data to the client.
|
||||||
|
func (w *ModifyResponseWriter) Flush() {
|
||||||
|
if flusher, ok := w.w.(http.Flusher); ok {
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,7 +1,13 @@
|
||||||
package proxy
|
// Copyright 2011 The Go Authors.
|
||||||
|
// Modified from the Go project under the a BSD-style License (https://cs.opensource.google/go/go/+/refs/tags/go1.23.1:src/net/http/httputil/reverseproxy.go)
|
||||||
|
// https://cs.opensource.google/go/go/+/master:LICENSE
|
||||||
|
|
||||||
// A small mod on net/http/httputil/reverseproxy.go
|
package http
|
||||||
// that doubled the performance
|
|
||||||
|
// This is a small mod on net/http/httputil/reverseproxy.go
|
||||||
|
// that boosts performance in some cases
|
||||||
|
// and compatible to other modules of this project
|
||||||
|
// Copyright (c) 2024 yusing
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
@ -52,6 +58,21 @@ type ProxyRequest struct {
|
||||||
// r.SetXForwarded()
|
// r.SetXForwarded()
|
||||||
// }
|
// }
|
||||||
func (r *ProxyRequest) SetXForwarded() {
|
func (r *ProxyRequest) SetXForwarded() {
|
||||||
|
clientIP, _, err := net.SplitHostPort(r.In.RemoteAddr)
|
||||||
|
if err == nil {
|
||||||
|
r.Out.Header.Set("X-Forwarded-For", clientIP)
|
||||||
|
} else {
|
||||||
|
r.Out.Header.Del("X-Forwarded-For")
|
||||||
|
}
|
||||||
|
r.Out.Header.Set("X-Forwarded-Host", r.In.Host)
|
||||||
|
if r.In.TLS == nil {
|
||||||
|
r.Out.Header.Set("X-Forwarded-Proto", "http")
|
||||||
|
} else {
|
||||||
|
r.Out.Header.Set("X-Forwarded-Proto", "https")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ProxyRequest) AddXForwarded() {
|
||||||
clientIP, _, err := net.SplitHostPort(r.In.RemoteAddr)
|
clientIP, _, err := net.SplitHostPort(r.In.RemoteAddr)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
prior := r.Out.Header["X-Forwarded-For"]
|
prior := r.Out.Header["X-Forwarded-For"]
|
||||||
|
@ -104,28 +125,6 @@ type ReverseProxy struct {
|
||||||
// If nil, http.DefaultTransport is used.
|
// If nil, http.DefaultTransport is used.
|
||||||
Transport http.RoundTripper
|
Transport http.RoundTripper
|
||||||
|
|
||||||
// FlushInterval specifies the flush interval
|
|
||||||
// to flush to the client while copying the
|
|
||||||
// response body.
|
|
||||||
// If zero, no periodic flushing is done.
|
|
||||||
// A negative value means to flush immediately
|
|
||||||
// after each write to the client.
|
|
||||||
// The FlushInterval is ignored when ReverseProxy
|
|
||||||
// recognizes a response as a streaming response, or
|
|
||||||
// if its ContentLength is -1; for such responses, writes
|
|
||||||
// are flushed to the client immediately.
|
|
||||||
// FlushInterval time.Duration
|
|
||||||
|
|
||||||
// ErrorLog specifies an optional logger for errors
|
|
||||||
// that occur when attempting to proxy the request.
|
|
||||||
// If nil, logging is done via the log package's standard logger.
|
|
||||||
// ErrorLog *log.Logger
|
|
||||||
|
|
||||||
// BufferPool optionally specifies a buffer pool to
|
|
||||||
// get byte slices for use by io.CopyBuffer when
|
|
||||||
// copying HTTP response bodies.
|
|
||||||
// BufferPool BufferPool
|
|
||||||
|
|
||||||
// ModifyResponse is an optional function that modifies the
|
// ModifyResponse is an optional function that modifies the
|
||||||
// Response from the backend. It is called if the backend
|
// Response from the backend. It is called if the backend
|
||||||
// returns a response at all, with any HTTP status code.
|
// returns a response at all, with any HTTP status code.
|
||||||
|
@ -208,36 +207,11 @@ func joinURLPath(a, b *url.URL) (path, rawpath string) {
|
||||||
// },
|
// },
|
||||||
// }
|
// }
|
||||||
//
|
//
|
||||||
// TODO: headers in ModifyResponse
|
|
||||||
func NewReverseProxy(target *url.URL, transport http.RoundTripper, entry *ReverseProxyEntry) *ReverseProxy {
|
func NewReverseProxy(target *url.URL, transport http.RoundTripper) *ReverseProxy {
|
||||||
// check on init rather than on request
|
|
||||||
var setHeaders = func(r *http.Request) {}
|
|
||||||
var hideHeaders = func(r *http.Request) {}
|
|
||||||
if len(entry.SetHeaders) > 0 {
|
|
||||||
setHeaders = func(r *http.Request) {
|
|
||||||
h := entry.SetHeaders.Clone()
|
|
||||||
for k, vv := range h {
|
|
||||||
if k == "Host" {
|
|
||||||
r.Host = vv[0]
|
|
||||||
} else {
|
|
||||||
r.Header[k] = vv
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(entry.HideHeaders) > 0 {
|
|
||||||
hideHeaders = func(r *http.Request) {
|
|
||||||
for _, k := range entry.HideHeaders {
|
|
||||||
r.Header.Del(k)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
rp := &ReverseProxy{
|
rp := &ReverseProxy{
|
||||||
Rewrite: func(pr *ProxyRequest) {
|
Rewrite: func(pr *ProxyRequest) {
|
||||||
rewriteRequestURL(pr.Out, target)
|
rewriteRequestURL(pr.Out, target)
|
||||||
// pr.SetXForwarded()
|
|
||||||
setHeaders(pr.Out)
|
|
||||||
hideHeaders(pr.Out)
|
|
||||||
}, Transport: transport,
|
}, Transport: transport,
|
||||||
}
|
}
|
||||||
rp.ServeHTTP = rp.serveHTTP
|
rp.ServeHTTP = rp.serveHTTP
|
||||||
|
@ -256,6 +230,23 @@ func rewriteRequestURL(req *http.Request, target *url.URL) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Hop-by-hop headers. These are removed when sent to the backend.
|
||||||
|
// As of RFC 7230, hop-by-hop headers are required to appear in the
|
||||||
|
// Connection header field. These are the headers defined by the
|
||||||
|
// obsoleted RFC 2616 (section 13.5.1) and are used for backward
|
||||||
|
// compatibility.
|
||||||
|
var hopHeaders = []string{
|
||||||
|
"Connection",
|
||||||
|
"Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google
|
||||||
|
"Keep-Alive",
|
||||||
|
"Proxy-Authenticate",
|
||||||
|
"Proxy-Authorization",
|
||||||
|
"Te", // canonicalized version of "TE"
|
||||||
|
"Trailer", // not Trailers per URL above; https://www.rfc-editor.org/errata_search.php?eid=4522
|
||||||
|
"Transfer-Encoding",
|
||||||
|
"Upgrade",
|
||||||
|
}
|
||||||
|
|
||||||
func copyHeader(dst, src http.Header) {
|
func copyHeader(dst, src http.Header) {
|
||||||
for k, vv := range src {
|
for k, vv := range src {
|
||||||
for _, v := range vv {
|
for _, v := range vv {
|
||||||
|
@ -331,12 +322,14 @@ func (p *ReverseProxy) serveHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
|
||||||
outreq.Close = false
|
outreq.Close = false
|
||||||
|
|
||||||
reqUpType := upgradeType(outreq.Header)
|
reqUpType := UpgradeType(outreq.Header)
|
||||||
if !IsPrint(reqUpType) {
|
if !IsPrint(reqUpType) {
|
||||||
p.errorHandler(rw, req, fmt.Errorf("client tried to switch to invalid protocol %q", reqUpType))
|
p.errorHandler(rw, req, fmt.Errorf("client tried to switch to invalid protocol %q", reqUpType))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
RemoveHopByHopHeaders(outreq.Header)
|
||||||
|
|
||||||
// Issue 21096: tell backend applications that care about trailer support
|
// Issue 21096: tell backend applications that care about trailer support
|
||||||
// that we support trailers. (We do, but we don't go out of our way to
|
// that we support trailers. (We do, but we don't go out of our way to
|
||||||
// advertise that unless the incoming client request thought it was worth
|
// advertise that unless the incoming client request thought it was worth
|
||||||
|
@ -458,16 +451,34 @@ func (p *ReverseProxy) serveHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func upgradeType(h http.Header) string {
|
func UpgradeType(h http.Header) string {
|
||||||
if !httpguts.HeaderValuesContainsToken(h["Connection"], "Upgrade") {
|
if !httpguts.HeaderValuesContainsToken(h["Connection"], "Upgrade") {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
return h.Get("Upgrade")
|
return h.Get("Upgrade")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RemoveHopByHopHeaders removes hop-by-hop headers.
|
||||||
|
func RemoveHopByHopHeaders(h http.Header) {
|
||||||
|
// RFC 7230, section 6.1: Remove headers listed in the "Connection" header.
|
||||||
|
for _, f := range h["Connection"] {
|
||||||
|
for _, sf := range strings.Split(f, ",") {
|
||||||
|
if sf = textproto.TrimString(sf); sf != "" {
|
||||||
|
h.Del(sf)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// RFC 2616, section 13.5.1: Remove a set of known hop-by-hop headers.
|
||||||
|
// This behavior is superseded by the RFC 7230 Connection header, but
|
||||||
|
// preserve it for backwards compatibility.
|
||||||
|
for _, f := range hopHeaders {
|
||||||
|
h.Del(f)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.Request, res *http.Response) {
|
func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.Request, res *http.Response) {
|
||||||
reqUpType := upgradeType(req.Header)
|
reqUpType := UpgradeType(req.Header)
|
||||||
resUpType := upgradeType(res.Header)
|
resUpType := UpgradeType(res.Header)
|
||||||
if !IsPrint(resUpType) { // We know reqUpType is ASCII, it's checked by the caller.
|
if !IsPrint(resUpType) { // We know reqUpType is ASCII, it's checked by the caller.
|
||||||
p.errorHandler(rw, req, fmt.Errorf("backend tried to switch to invalid protocol %q", resUpType))
|
p.errorHandler(rw, req, fmt.Errorf("backend tried to switch to invalid protocol %q", resUpType))
|
||||||
}
|
}
|
27
src/main.go
27
src/main.go
|
@ -47,11 +47,10 @@ func main() {
|
||||||
logrus.SetOutput(io.Discard)
|
logrus.SetOutput(io.Discard)
|
||||||
} else {
|
} else {
|
||||||
logrus.SetFormatter(&logrus.TextFormatter{
|
logrus.SetFormatter(&logrus.TextFormatter{
|
||||||
DisableSorting: true,
|
DisableSorting: true,
|
||||||
DisableLevelTruncation: true,
|
FullTimestamp: true,
|
||||||
FullTimestamp: true,
|
ForceColors: true,
|
||||||
ForceColors: true,
|
TimestampFormat: "01-02 15:04:05",
|
||||||
TimestampFormat: "01-02 15:04:05",
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -76,10 +75,11 @@ func main() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
cfg, err := config.Load()
|
err := config.Load()
|
||||||
if err.IsFatal() {
|
if err != nil {
|
||||||
log.Fatal(err)
|
logrus.Warn(err)
|
||||||
}
|
}
|
||||||
|
cfg := config.GetConfig()
|
||||||
|
|
||||||
switch args.Command {
|
switch args.Command {
|
||||||
case common.CommandListConfigs:
|
case common.CommandListConfigs:
|
||||||
|
@ -96,6 +96,10 @@ func main() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if common.IsDebug {
|
||||||
|
printJSON(docker.GetRegisteredNamespaces())
|
||||||
|
}
|
||||||
|
|
||||||
cfg.StartProxyProviders()
|
cfg.StartProxyProviders()
|
||||||
|
|
||||||
if err.HasError() {
|
if err.HasError() {
|
||||||
|
@ -116,10 +120,7 @@ func main() {
|
||||||
|
|
||||||
if autocert != nil {
|
if autocert != nil {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
if err = autocert.Setup(ctx); err != nil && err.IsWarning() {
|
if err = autocert.Setup(ctx); err != nil {
|
||||||
cancel()
|
|
||||||
l.Warn(err)
|
|
||||||
} else if err.IsFatal() {
|
|
||||||
l.Fatal(err)
|
l.Fatal(err)
|
||||||
} else {
|
} else {
|
||||||
onShutdown.Add(cancel)
|
onShutdown.Add(cancel)
|
||||||
|
@ -192,7 +193,7 @@ func funcName(f func()) string {
|
||||||
}
|
}
|
||||||
|
|
||||||
func printJSON(obj any) {
|
func printJSON(obj any) {
|
||||||
j, err := E.Check(json.Marshal(obj))
|
j, err := E.Check(json.MarshalIndent(obj, "", " "))
|
||||||
if err.HasError() {
|
if err.HasError() {
|
||||||
logrus.Fatal(err)
|
logrus.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,6 +3,7 @@ package model
|
||||||
type Config struct {
|
type Config struct {
|
||||||
Providers ProxyProviders `yaml:",flow" json:"providers"`
|
Providers ProxyProviders `yaml:",flow" json:"providers"`
|
||||||
AutoCert AutoCertConfig `yaml:",flow" json:"autocert"`
|
AutoCert AutoCertConfig `yaml:",flow" json:"autocert"`
|
||||||
|
MatchDomains []string `yaml:"match_domains" json:"match_domains"`
|
||||||
TimeoutShutdown int `yaml:"timeout_shutdown" json:"timeout_shutdown"`
|
TimeoutShutdown int `yaml:"timeout_shutdown" json:"timeout_shutdown"`
|
||||||
RedirectToHTTPS bool `yaml:"redirect_to_https" json:"redirect_to_https"`
|
RedirectToHTTPS bool `yaml:"redirect_to_https" json:"redirect_to_https"`
|
||||||
}
|
}
|
||||||
|
@ -11,6 +12,6 @@ func DefaultConfig() *Config {
|
||||||
return &Config{
|
return &Config{
|
||||||
Providers: ProxyProviders{},
|
Providers: ProxyProviders{},
|
||||||
TimeoutShutdown: 3,
|
TimeoutShutdown: 3,
|
||||||
RedirectToHTTPS: true,
|
RedirectToHTTPS: false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,14 +14,13 @@ type (
|
||||||
RawEntry struct {
|
RawEntry struct {
|
||||||
// raw entry object before validation
|
// raw entry object before validation
|
||||||
// loaded from docker labels or yaml file
|
// loaded from docker labels or yaml file
|
||||||
Alias string `yaml:"-" json:"-"`
|
Alias string `yaml:"-" json:"-"`
|
||||||
Scheme string `yaml:"scheme" json:"scheme"`
|
Scheme string `yaml:"scheme" json:"scheme"`
|
||||||
Host string `yaml:"host" json:"host"`
|
Host string `yaml:"host" json:"host"`
|
||||||
Port string `yaml:"port" json:"port"`
|
Port string `yaml:"port" json:"port"`
|
||||||
NoTLSVerify bool `yaml:"no_tls_verify" json:"no_tls_verify"` // https proxy only
|
NoTLSVerify bool `yaml:"no_tls_verify" json:"no_tls_verify"` // https proxy only
|
||||||
PathPatterns []string `yaml:"path_patterns" json:"path_patterns"` // http(s) proxy only
|
PathPatterns []string `yaml:"path_patterns" json:"path_patterns"` // http(s) proxy only
|
||||||
SetHeaders map[string]string `yaml:"set_headers" json:"set_headers"` // http(s) proxy only
|
Middlewares D.NestedLabelMap `yaml:"middlewares" json:"middlewares"`
|
||||||
HideHeaders []string `yaml:"hide_headers" json:"hide_headers"` // http(s) proxy only
|
|
||||||
|
|
||||||
/* Docker only */
|
/* Docker only */
|
||||||
*D.ProxyProperties `yaml:"-" json:"proxy_properties"`
|
*D.ProxyProperties `yaml:"-" json:"proxy_properties"`
|
||||||
|
@ -44,12 +43,16 @@ func (e *RawEntry) FillMissingFields() bool {
|
||||||
if pp == "" {
|
if pp == "" {
|
||||||
pp = strconv.Itoa(port)
|
pp = strconv.Itoa(port)
|
||||||
}
|
}
|
||||||
e.Scheme = "tcp"
|
if e.Scheme == "" {
|
||||||
|
e.Scheme = "tcp"
|
||||||
|
}
|
||||||
} else if port, ok := ImageNamePortMap[e.ImageName]; ok {
|
} else if port, ok := ImageNamePortMap[e.ImageName]; ok {
|
||||||
if pp == "" {
|
if pp == "" {
|
||||||
pp = strconv.Itoa(port)
|
pp = strconv.Itoa(port)
|
||||||
}
|
}
|
||||||
e.Scheme = "http"
|
if e.Scheme == "" {
|
||||||
|
e.Scheme = "http"
|
||||||
|
}
|
||||||
} else if pp == "" && e.Scheme == "https" {
|
} else if pp == "" && e.Scheme == "https" {
|
||||||
pp = "443"
|
pp = "443"
|
||||||
} else if pp == "" {
|
} else if pp == "" {
|
||||||
|
|
|
@ -2,10 +2,10 @@ package proxy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
|
||||||
"net/url"
|
"net/url"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
D "github.com/yusing/go-proxy/docker"
|
||||||
E "github.com/yusing/go-proxy/error"
|
E "github.com/yusing/go-proxy/error"
|
||||||
M "github.com/yusing/go-proxy/models"
|
M "github.com/yusing/go-proxy/models"
|
||||||
T "github.com/yusing/go-proxy/proxy/fields"
|
T "github.com/yusing/go-proxy/proxy/fields"
|
||||||
|
@ -18,8 +18,7 @@ type (
|
||||||
URL *url.URL
|
URL *url.URL
|
||||||
NoTLSVerify bool
|
NoTLSVerify bool
|
||||||
PathPatterns T.PathPatterns
|
PathPatterns T.PathPatterns
|
||||||
SetHeaders http.Header
|
Middlewares D.NestedLabelMap
|
||||||
HideHeaders []string
|
|
||||||
|
|
||||||
/* Docker only */
|
/* Docker only */
|
||||||
IdleTimeout time.Duration
|
IdleTimeout time.Duration
|
||||||
|
@ -78,9 +77,6 @@ func validateRPEntry(m *M.RawEntry, s T.Scheme, b E.Builder) *ReverseProxyEntry
|
||||||
pathPatterns, err := T.ValidatePathPatterns(m.PathPatterns)
|
pathPatterns, err := T.ValidatePathPatterns(m.PathPatterns)
|
||||||
b.Add(err)
|
b.Add(err)
|
||||||
|
|
||||||
setHeaders, err := T.ValidateHTTPHeaders(m.SetHeaders)
|
|
||||||
b.Add(err)
|
|
||||||
|
|
||||||
url, err := E.Check(url.Parse(fmt.Sprintf("%s://%s:%d", s, host, port)))
|
url, err := E.Check(url.Parse(fmt.Sprintf("%s://%s:%d", s, host, port)))
|
||||||
b.Add(err)
|
b.Add(err)
|
||||||
|
|
||||||
|
@ -111,8 +107,7 @@ func validateRPEntry(m *M.RawEntry, s T.Scheme, b E.Builder) *ReverseProxyEntry
|
||||||
URL: url,
|
URL: url,
|
||||||
NoTLSVerify: m.NoTLSVerify,
|
NoTLSVerify: m.NoTLSVerify,
|
||||||
PathPatterns: pathPatterns,
|
PathPatterns: pathPatterns,
|
||||||
SetHeaders: setHeaders,
|
Middlewares: m.Middlewares,
|
||||||
HideHeaders: m.HideHeaders,
|
|
||||||
IdleTimeout: idleTimeout,
|
IdleTimeout: idleTimeout,
|
||||||
WakeTimeout: wakeTimeout,
|
WakeTimeout: wakeTimeout,
|
||||||
StopMethod: stopMethod,
|
StopMethod: stopMethod,
|
||||||
|
|
|
@ -4,7 +4,9 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
D "github.com/yusing/go-proxy/docker"
|
D "github.com/yusing/go-proxy/docker"
|
||||||
E "github.com/yusing/go-proxy/error"
|
E "github.com/yusing/go-proxy/error"
|
||||||
M "github.com/yusing/go-proxy/models"
|
M "github.com/yusing/go-proxy/models"
|
||||||
|
@ -17,6 +19,7 @@ type DockerProvider struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
var AliasRefRegex = regexp.MustCompile(`#\d+`)
|
var AliasRefRegex = regexp.MustCompile(`#\d+`)
|
||||||
|
var AliasRefRegexOld = regexp.MustCompile(`\$\d+`)
|
||||||
|
|
||||||
func DockerProviderImpl(dockerHost string) (ProviderImpl, E.NestedError) {
|
func DockerProviderImpl(dockerHost string) (ProviderImpl, E.NestedError) {
|
||||||
hostname, err := D.ParseDockerHostname(dockerHost)
|
hostname, err := D.ParseDockerHostname(dockerHost)
|
||||||
|
@ -152,6 +155,20 @@ func (p *DockerProvider) applyLabel(container D.Container, entries M.RawEntries,
|
||||||
b := E.NewBuilder("errors in label %s", key)
|
b := E.NewBuilder("errors in label %s", key)
|
||||||
defer b.To(&res)
|
defer b.To(&res)
|
||||||
|
|
||||||
|
refErr := E.NewBuilder("errors parsing alias references")
|
||||||
|
replaceIndexRef := func(ref string) string {
|
||||||
|
index, err := strconv.Atoi(ref[1:])
|
||||||
|
if err != nil {
|
||||||
|
refErr.Add(E.Invalid("integer", ref))
|
||||||
|
return ref
|
||||||
|
}
|
||||||
|
if index < 1 || index > len(container.Aliases) {
|
||||||
|
refErr.Add(E.OutOfRange("index", ref))
|
||||||
|
return ref
|
||||||
|
}
|
||||||
|
return container.Aliases[index-1]
|
||||||
|
}
|
||||||
|
|
||||||
lbl, err := D.ParseLabel(key, val)
|
lbl, err := D.ParseLabel(key, val)
|
||||||
if err.HasError() {
|
if err.HasError() {
|
||||||
b.Add(err.Subject(key))
|
b.Add(err.Subject(key))
|
||||||
|
@ -163,22 +180,14 @@ func (p *DockerProvider) applyLabel(container D.Container, entries M.RawEntries,
|
||||||
// apply label for all aliases
|
// apply label for all aliases
|
||||||
entries.RangeAll(func(a string, e *M.RawEntry) {
|
entries.RangeAll(func(a string, e *M.RawEntry) {
|
||||||
if err = D.ApplyLabel(e, lbl); err.HasError() {
|
if err = D.ApplyLabel(e, lbl); err.HasError() {
|
||||||
b.Add(err.Subject(lbl.Target))
|
b.Add(err.Subjectf("alias %s", lbl.Target))
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
refErr := E.NewBuilder("errors parsing alias references")
|
lbl.Target = AliasRefRegex.ReplaceAllStringFunc(lbl.Target, replaceIndexRef)
|
||||||
lbl.Target = AliasRefRegex.ReplaceAllStringFunc(lbl.Target, func(ref string) string {
|
lbl.Target = AliasRefRegexOld.ReplaceAllStringFunc(lbl.Target, func(s string) string {
|
||||||
index, err := strconv.Atoi(ref[1:])
|
logrus.Warnf("%q should now be %q, old syntax will be removed in a future version", lbl, strings.ReplaceAll(lbl.String(), "$", "#"))
|
||||||
if err != nil {
|
return replaceIndexRef(s)
|
||||||
refErr.Add(E.Invalid("integer", ref))
|
|
||||||
return ref
|
|
||||||
}
|
|
||||||
if index < 1 || index > len(container.Aliases) {
|
|
||||||
refErr.Add(E.OutOfRange("index", ref))
|
|
||||||
return ref
|
|
||||||
}
|
|
||||||
return container.Aliases[index-1]
|
|
||||||
})
|
})
|
||||||
if refErr.HasError() {
|
if refErr.HasError() {
|
||||||
b.Add(refErr.Build())
|
b.Add(refErr.Build())
|
||||||
|
@ -190,7 +199,7 @@ func (p *DockerProvider) applyLabel(container D.Container, entries M.RawEntries,
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err = D.ApplyLabel(config, lbl); err.HasError() {
|
if err = D.ApplyLabel(config, lbl); err.HasError() {
|
||||||
b.Add(err.Subject(lbl.Target))
|
b.Add(err.Subjectf("alias %s", lbl.Target))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
|
|
@ -132,7 +132,8 @@ func TestApplyLabel(t *testing.T) {
|
||||||
ExpectEqual(t, b.Scheme, "http")
|
ExpectEqual(t, b.Scheme, "http")
|
||||||
ExpectEqual(t, b.Port, "1234")
|
ExpectEqual(t, b.Port, "1234")
|
||||||
ExpectEqual(t, c.Scheme, "https")
|
ExpectEqual(t, c.Scheme, "https")
|
||||||
ExpectEqual(t, c.Port, "1111")
|
// map does not necessary follow the order above
|
||||||
|
ExpectEqualAny(t, c.Port, []string{"1111", "1234"})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestApplyLabelWithRef(t *testing.T) {
|
func TestApplyLabelWithRef(t *testing.T) {
|
||||||
|
@ -142,9 +143,9 @@ func TestApplyLabelWithRef(t *testing.T) {
|
||||||
Labels: map[string]string{
|
Labels: map[string]string{
|
||||||
D.LabelAliases: "a,b,c",
|
D.LabelAliases: "a,b,c",
|
||||||
"proxy.#1.host": "localhost",
|
"proxy.#1.host": "localhost",
|
||||||
"proxy.*.port": "1111",
|
|
||||||
"proxy.#1.port": "4444",
|
"proxy.#1.port": "4444",
|
||||||
"proxy.#2.port": "9999",
|
"proxy.#2.port": "9999",
|
||||||
|
"proxy.#3.port": "1111",
|
||||||
"proxy.#3.scheme": "https",
|
"proxy.#3.scheme": "https",
|
||||||
},
|
},
|
||||||
Ports: []types.Port{
|
Ports: []types.Port{
|
||||||
|
|
|
@ -1,20 +1,20 @@
|
||||||
package route
|
package route
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/tls"
|
|
||||||
"net"
|
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
|
||||||
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
|
"github.com/yusing/go-proxy/common"
|
||||||
"github.com/yusing/go-proxy/docker/idlewatcher"
|
"github.com/yusing/go-proxy/docker/idlewatcher"
|
||||||
E "github.com/yusing/go-proxy/error"
|
E "github.com/yusing/go-proxy/error"
|
||||||
|
. "github.com/yusing/go-proxy/http"
|
||||||
P "github.com/yusing/go-proxy/proxy"
|
P "github.com/yusing/go-proxy/proxy"
|
||||||
PT "github.com/yusing/go-proxy/proxy/fields"
|
PT "github.com/yusing/go-proxy/proxy/fields"
|
||||||
|
"github.com/yusing/go-proxy/route/middleware"
|
||||||
F "github.com/yusing/go-proxy/utils/functional"
|
F "github.com/yusing/go-proxy/utils/functional"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -26,7 +26,7 @@ type (
|
||||||
|
|
||||||
entry *P.ReverseProxyEntry
|
entry *P.ReverseProxyEntry
|
||||||
mux *http.ServeMux
|
mux *http.ServeMux
|
||||||
handler *P.ReverseProxy
|
handler *ReverseProxy
|
||||||
|
|
||||||
regIdleWatcher func() E.NestedError
|
regIdleWatcher func() E.NestedError
|
||||||
unregIdleWatcher func()
|
unregIdleWatcher func()
|
||||||
|
@ -36,18 +36,41 @@ type (
|
||||||
SubdomainKey = PT.Alias
|
SubdomainKey = PT.Alias
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
findMuxFunc = findMuxAnyDomain
|
||||||
|
|
||||||
|
httpRoutes = F.NewMapOf[SubdomainKey, *HTTPRoute]()
|
||||||
|
httpRoutesMu sync.Mutex
|
||||||
|
globalMux = http.NewServeMux() // TODO: support regex subdomain matching
|
||||||
|
)
|
||||||
|
|
||||||
|
func SetFindMuxDomains(domains []string) {
|
||||||
|
if len(domains) == 0 {
|
||||||
|
findMuxFunc = findMuxAnyDomain
|
||||||
|
} else {
|
||||||
|
findMuxFunc = findMuxByDomain(domains)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func NewHTTPRoute(entry *P.ReverseProxyEntry) (*HTTPRoute, E.NestedError) {
|
func NewHTTPRoute(entry *P.ReverseProxyEntry) (*HTTPRoute, E.NestedError) {
|
||||||
var trans *http.Transport
|
var trans *http.Transport
|
||||||
var regIdleWatcher func() E.NestedError
|
var regIdleWatcher func() E.NestedError
|
||||||
var unregIdleWatcher func()
|
var unregIdleWatcher func()
|
||||||
|
|
||||||
if entry.NoTLSVerify {
|
if entry.NoTLSVerify {
|
||||||
trans = transportNoTLS.Clone()
|
trans = common.DefaultTransportNoTLS.Clone()
|
||||||
} else {
|
} else {
|
||||||
trans = transport.Clone()
|
trans = common.DefaultTransport.Clone()
|
||||||
}
|
}
|
||||||
|
|
||||||
rp := P.NewReverseProxy(entry.URL, trans, entry)
|
rp := NewReverseProxy(entry.URL, trans)
|
||||||
|
|
||||||
|
if len(entry.Middlewares) > 0 {
|
||||||
|
err := middleware.PatchReverseProxy(rp, entry.Middlewares)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if entry.UseIdleWatcher() {
|
if entry.UseIdleWatcher() {
|
||||||
// allow time for response header up to `WakeTimeout`
|
// allow time for response header up to `WakeTimeout`
|
||||||
|
@ -74,7 +97,7 @@ func NewHTTPRoute(entry *P.ReverseProxyEntry) (*HTTPRoute, E.NestedError) {
|
||||||
|
|
||||||
_, exists := httpRoutes.Load(entry.Alias)
|
_, exists := httpRoutes.Load(entry.Alias)
|
||||||
if exists {
|
if exists {
|
||||||
return nil, E.AlreadyExist("HTTPRoute alias", entry.Alias)
|
return nil, E.Duplicated("HTTPRoute alias", entry.Alias)
|
||||||
}
|
}
|
||||||
|
|
||||||
r := &HTTPRoute{
|
r := &HTTPRoute{
|
||||||
|
@ -94,11 +117,16 @@ func (r *HTTPRoute) String() string {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *HTTPRoute) Start() E.NestedError {
|
func (r *HTTPRoute) Start() E.NestedError {
|
||||||
|
if r.mux != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
httpRoutesMu.Lock()
|
httpRoutesMu.Lock()
|
||||||
defer httpRoutesMu.Unlock()
|
defer httpRoutesMu.Unlock()
|
||||||
|
|
||||||
if r.regIdleWatcher != nil {
|
if r.regIdleWatcher != nil {
|
||||||
if err := r.regIdleWatcher(); err.HasError() {
|
if err := r.regIdleWatcher(); err.HasError() {
|
||||||
|
r.unregIdleWatcher = nil
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -113,6 +141,10 @@ func (r *HTTPRoute) Start() E.NestedError {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *HTTPRoute) Stop() E.NestedError {
|
func (r *HTTPRoute) Stop() E.NestedError {
|
||||||
|
if r.mux == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
httpRoutesMu.Lock()
|
httpRoutesMu.Lock()
|
||||||
defer httpRoutesMu.Unlock()
|
defer httpRoutesMu.Unlock()
|
||||||
|
|
||||||
|
@ -135,7 +167,7 @@ func (u *URL) MarshalText() (text []byte, err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func ProxyHandler(w http.ResponseWriter, r *http.Request) {
|
func ProxyHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
mux, err := findMux(r.Host)
|
mux, err := findMuxFunc(r.Host)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = E.Failure("request").
|
err = E.Failure("request").
|
||||||
Subjectf("%s %s%s", r.Method, r.Host, r.URL.Path).
|
Subjectf("%s %s%s", r.Method, r.Host, r.URL.Path).
|
||||||
|
@ -147,7 +179,7 @@ func ProxyHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
mux.ServeHTTP(w, r)
|
mux.ServeHTTP(w, r)
|
||||||
}
|
}
|
||||||
|
|
||||||
func findMux(host string) (*http.ServeMux, E.NestedError) {
|
func findMuxAnyDomain(host string) (*http.ServeMux, E.NestedError) {
|
||||||
hostSplit := strings.Split(host, ".")
|
hostSplit := strings.Split(host, ".")
|
||||||
n := len(hostSplit)
|
n := len(hostSplit)
|
||||||
if n <= 2 {
|
if n <= 2 {
|
||||||
|
@ -160,23 +192,21 @@ func findMux(host string) (*http.ServeMux, E.NestedError) {
|
||||||
return nil, E.NotExist("route", sd)
|
return nil, E.NotExist("route", sd)
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
func findMuxByDomain(domains []string) func(host string) (*http.ServeMux, E.NestedError) {
|
||||||
defaultDialer = net.Dialer{
|
return func(host string) (*http.ServeMux, E.NestedError) {
|
||||||
Timeout: 60 * time.Second,
|
var subdomain string
|
||||||
KeepAlive: 60 * time.Second,
|
for _, domain := range domains {
|
||||||
|
subdomain = strings.TrimSuffix(subdomain, domain)
|
||||||
|
if subdomain != domain {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if subdomain == "" { // not matched
|
||||||
|
return nil, E.Invalid("host", host)
|
||||||
|
}
|
||||||
|
if r, ok := httpRoutes.Load(PT.Alias(subdomain)); ok {
|
||||||
|
return r.mux, nil
|
||||||
|
}
|
||||||
|
return nil, E.NotExist("route", subdomain)
|
||||||
}
|
}
|
||||||
transport = &http.Transport{
|
}
|
||||||
Proxy: http.ProxyFromEnvironment,
|
|
||||||
DialContext: defaultDialer.DialContext,
|
|
||||||
MaxIdleConnsPerHost: 1000,
|
|
||||||
}
|
|
||||||
transportNoTLS = func() *http.Transport {
|
|
||||||
var clone = transport.Clone()
|
|
||||||
clone.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
|
|
||||||
return clone
|
|
||||||
}()
|
|
||||||
|
|
||||||
httpRoutes = F.NewMapOf[SubdomainKey, *HTTPRoute]()
|
|
||||||
httpRoutesMu sync.Mutex
|
|
||||||
globalMux = http.NewServeMux()
|
|
||||||
)
|
|
||||||
|
|
|
@ -1,7 +0,0 @@
|
||||||
package middleware
|
|
||||||
|
|
||||||
var AddXForwarded = &Middleware{
|
|
||||||
rewrite: func(r *ProxyRequest) {
|
|
||||||
r.SetXForwarded()
|
|
||||||
},
|
|
||||||
}
|
|
249
src/route/middleware/forward_auth.go
Normal file
249
src/route/middleware/forward_auth.go
Normal file
|
@ -0,0 +1,249 @@
|
||||||
|
// Modified from Traefik Labs's MIT-licensed code (https://github.com/traefik/traefik/blob/master/pkg/middlewares/auth/forward.go)
|
||||||
|
// Copyright (c) 2020-2024 Traefik Labs
|
||||||
|
// Copyright (c) 2024 yusing
|
||||||
|
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
"github.com/yusing/go-proxy/common"
|
||||||
|
D "github.com/yusing/go-proxy/docker"
|
||||||
|
E "github.com/yusing/go-proxy/error"
|
||||||
|
gpHTTP "github.com/yusing/go-proxy/http"
|
||||||
|
U "github.com/yusing/go-proxy/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
type (
|
||||||
|
forwardAuth struct {
|
||||||
|
*forwardAuthOpts
|
||||||
|
m *Middleware
|
||||||
|
client http.Client
|
||||||
|
}
|
||||||
|
forwardAuthOpts struct {
|
||||||
|
Address string
|
||||||
|
TrustForwardHeader bool
|
||||||
|
AuthResponseHeaders []string
|
||||||
|
AddAuthCookiesToResponse []string
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
xForwardedFor = "X-Forwarded-For"
|
||||||
|
xForwardedMethod = "X-Forwarded-Method"
|
||||||
|
xForwardedHost = "X-Forwarded-Host"
|
||||||
|
xForwardedProto = "X-Forwarded-Proto"
|
||||||
|
xForwardedURI = "X-Forwarded-Uri"
|
||||||
|
xForwardedPort = "X-Forwarded-Port"
|
||||||
|
)
|
||||||
|
|
||||||
|
var ForwardAuth = newForwardAuth()
|
||||||
|
var faLogger = logrus.WithField("middleware", "ForwardAuth")
|
||||||
|
|
||||||
|
func newForwardAuth() (fa *forwardAuth) {
|
||||||
|
fa = new(forwardAuth)
|
||||||
|
fa.m = new(Middleware)
|
||||||
|
fa.m.labelParserMap = D.ValueParserMap{
|
||||||
|
"trust_forward_header": D.BoolParser,
|
||||||
|
"auth_response_headers": D.YamlStringListParser,
|
||||||
|
"add_auth_cookies_to_response": D.YamlStringListParser,
|
||||||
|
}
|
||||||
|
fa.m.withOptions = func(optsRaw OptionsRaw, rp *ReverseProxy) (*Middleware, E.NestedError) {
|
||||||
|
tr, ok := rp.Transport.(*http.Transport)
|
||||||
|
if ok {
|
||||||
|
tr = tr.Clone()
|
||||||
|
} else {
|
||||||
|
tr = common.DefaultTransport.Clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
faWithOpts := new(forwardAuth)
|
||||||
|
faWithOpts.forwardAuthOpts = new(forwardAuthOpts)
|
||||||
|
faWithOpts.client = http.Client{
|
||||||
|
CheckRedirect: func(r *Request, via []*Request) error {
|
||||||
|
return http.ErrUseLastResponse
|
||||||
|
},
|
||||||
|
Timeout: 30 * time.Second,
|
||||||
|
Transport: tr,
|
||||||
|
}
|
||||||
|
faWithOpts.m = &Middleware{
|
||||||
|
impl: faWithOpts,
|
||||||
|
before: fa.forward,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := U.Deserialize(optsRaw, faWithOpts.forwardAuthOpts)
|
||||||
|
if err != nil {
|
||||||
|
return nil, E.FailWith("set options", err)
|
||||||
|
}
|
||||||
|
_, err = E.Check(url.Parse(faWithOpts.Address))
|
||||||
|
if err != nil {
|
||||||
|
return nil, E.Invalid("address", faWithOpts.Address)
|
||||||
|
}
|
||||||
|
return faWithOpts.m, nil
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fa *forwardAuth) forward(next http.Handler, w ResponseWriter, req *Request) {
|
||||||
|
removeHop(req.Header)
|
||||||
|
|
||||||
|
faReq, err := http.NewRequestWithContext(
|
||||||
|
req.Context(),
|
||||||
|
http.MethodGet,
|
||||||
|
fa.Address,
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
faLogger.Debugf("new request err to %s: %s", fa.Address, err)
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
copyHeader(faReq.Header, req.Header)
|
||||||
|
removeHop(faReq.Header)
|
||||||
|
|
||||||
|
filterHeaders(faReq.Header, fa.AuthResponseHeaders)
|
||||||
|
fa.setAuthHeaders(req, faReq)
|
||||||
|
|
||||||
|
faResp, err := fa.client.Do(faReq)
|
||||||
|
if err != nil {
|
||||||
|
faLogger.Debugf("failed to call %s: %s", fa.Address, err)
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer faResp.Body.Close()
|
||||||
|
|
||||||
|
body, err := io.ReadAll(faResp.Body)
|
||||||
|
if err != nil {
|
||||||
|
faLogger.Debugf("failed to read response body from %s: %s", fa.Address, err)
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if faResp.StatusCode < http.StatusOK || faResp.StatusCode >= http.StatusMultipleChoices {
|
||||||
|
copyHeader(w.Header(), faResp.Header)
|
||||||
|
removeHop(w.Header())
|
||||||
|
|
||||||
|
redirectURL, err := faResp.Location()
|
||||||
|
if err != nil {
|
||||||
|
faLogger.Debugf("failed to get location from %s: %s", fa.Address, err)
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
} else if redirectURL.String() != "" {
|
||||||
|
w.Header().Set("Location", redirectURL.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
w.WriteHeader(faResp.StatusCode)
|
||||||
|
|
||||||
|
if _, err = w.Write(body); err != nil {
|
||||||
|
faLogger.Debugf("failed to write response body from %s: %s", fa.Address, err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, key := range fa.AuthResponseHeaders {
|
||||||
|
key := http.CanonicalHeaderKey(key)
|
||||||
|
req.Header.Del(key)
|
||||||
|
if len(faResp.Header[key]) > 0 {
|
||||||
|
req.Header[key] = append([]string(nil), faResp.Header[key]...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
req.RequestURI = req.URL.RequestURI()
|
||||||
|
|
||||||
|
authCookies := faResp.Cookies()
|
||||||
|
|
||||||
|
if len(authCookies) == 0 {
|
||||||
|
next.ServeHTTP(w, req)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
next.ServeHTTP(gpHTTP.NewModifyResponseWriter(w, req, func(resp *Response) error {
|
||||||
|
fa.setAuthCookies(resp, authCookies)
|
||||||
|
return nil
|
||||||
|
}), req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fa *forwardAuth) setAuthCookies(resp *Response, authCookies []*Cookie) {
|
||||||
|
if len(fa.AddAuthCookiesToResponse) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
cookies := resp.Cookies()
|
||||||
|
resp.Header.Del("Set-Cookie")
|
||||||
|
|
||||||
|
for _, cookie := range cookies {
|
||||||
|
if !slices.Contains(fa.AddAuthCookiesToResponse, cookie.Name) {
|
||||||
|
// this cookie is not an auth cookie, so add it back
|
||||||
|
resp.Header.Add("Set-Cookie", cookie.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, cookie := range authCookies {
|
||||||
|
if slices.Contains(fa.AddAuthCookiesToResponse, cookie.Name) {
|
||||||
|
// this cookie is an auth cookie, so add to resp
|
||||||
|
resp.Header.Add("Set-Cookie", cookie.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fa *forwardAuth) setAuthHeaders(req, faReq *Request) {
|
||||||
|
if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
|
||||||
|
if fa.TrustForwardHeader {
|
||||||
|
if prior, ok := req.Header[xForwardedFor]; ok {
|
||||||
|
clientIP = strings.Join(prior, ", ") + ", " + clientIP
|
||||||
|
}
|
||||||
|
}
|
||||||
|
faReq.Header.Set(xForwardedFor, clientIP)
|
||||||
|
}
|
||||||
|
|
||||||
|
xMethod := req.Header.Get(xForwardedMethod)
|
||||||
|
switch {
|
||||||
|
case xMethod != "" && fa.TrustForwardHeader:
|
||||||
|
faReq.Header.Set(xForwardedMethod, xMethod)
|
||||||
|
case req.Method != "":
|
||||||
|
faReq.Header.Set(xForwardedMethod, req.Method)
|
||||||
|
default:
|
||||||
|
faReq.Header.Del(xForwardedMethod)
|
||||||
|
}
|
||||||
|
|
||||||
|
xfp := req.Header.Get(xForwardedProto)
|
||||||
|
switch {
|
||||||
|
case xfp != "" && fa.TrustForwardHeader:
|
||||||
|
faReq.Header.Set(xForwardedProto, xfp)
|
||||||
|
case req.TLS != nil:
|
||||||
|
faReq.Header.Set(xForwardedProto, "https")
|
||||||
|
default:
|
||||||
|
faReq.Header.Set(xForwardedProto, "http")
|
||||||
|
}
|
||||||
|
|
||||||
|
if xfp := req.Header.Get(xForwardedPort); xfp != "" && fa.TrustForwardHeader {
|
||||||
|
faReq.Header.Set(xForwardedPort, xfp)
|
||||||
|
}
|
||||||
|
|
||||||
|
xfh := req.Header.Get(xForwardedHost)
|
||||||
|
switch {
|
||||||
|
case xfh != "" && fa.TrustForwardHeader:
|
||||||
|
faReq.Header.Set(xForwardedHost, xfh)
|
||||||
|
case req.Host != "":
|
||||||
|
faReq.Header.Set(xForwardedHost, req.Host)
|
||||||
|
default:
|
||||||
|
faReq.Header.Del(xForwardedHost)
|
||||||
|
}
|
||||||
|
|
||||||
|
xfURI := req.Header.Get(xForwardedURI)
|
||||||
|
switch {
|
||||||
|
case xfURI != "" && fa.TrustForwardHeader:
|
||||||
|
faReq.Header.Set(xForwardedURI, xfURI)
|
||||||
|
case req.URL.RequestURI() != "":
|
||||||
|
faReq.Header.Set(xForwardedURI, req.URL.RequestURI())
|
||||||
|
default:
|
||||||
|
faReq.Header.Del(xForwardedURI)
|
||||||
|
}
|
||||||
|
}
|
44
src/route/middleware/headers.go
Normal file
44
src/route/middleware/headers.go
Normal file
|
@ -0,0 +1,44 @@
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"slices"
|
||||||
|
|
||||||
|
gpHTTP "github.com/yusing/go-proxy/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
func removeHop(h Header) {
|
||||||
|
reqUpType := gpHTTP.UpgradeType(h)
|
||||||
|
gpHTTP.RemoveHopByHopHeaders(h)
|
||||||
|
|
||||||
|
if reqUpType != "" {
|
||||||
|
h.Set("Connection", "Upgrade")
|
||||||
|
h.Set("Upgrade", reqUpType)
|
||||||
|
} else {
|
||||||
|
h.Del("Connection")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func copyHeader(dst, src Header) {
|
||||||
|
for k, vv := range src {
|
||||||
|
for _, v := range vv {
|
||||||
|
dst.Add(k, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func filterHeaders(h Header, allowed []string) {
|
||||||
|
if allowed == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range allowed {
|
||||||
|
allowed[i] = http.CanonicalHeaderKey(allowed[i])
|
||||||
|
}
|
||||||
|
|
||||||
|
for key := range h {
|
||||||
|
if !slices.Contains(allowed, key) {
|
||||||
|
h.Del(key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -3,33 +3,42 @@ package middleware
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
|
D "github.com/yusing/go-proxy/docker"
|
||||||
E "github.com/yusing/go-proxy/error"
|
E "github.com/yusing/go-proxy/error"
|
||||||
P "github.com/yusing/go-proxy/proxy"
|
gpHTTP "github.com/yusing/go-proxy/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
type (
|
type (
|
||||||
ReverseProxy = P.ReverseProxy
|
Error = E.NestedError
|
||||||
ProxyRequest = P.ProxyRequest
|
|
||||||
|
ReverseProxy = gpHTTP.ReverseProxy
|
||||||
|
ProxyRequest = gpHTTP.ProxyRequest
|
||||||
Request = http.Request
|
Request = http.Request
|
||||||
Response = http.Response
|
Response = http.Response
|
||||||
ResponseWriter = http.ResponseWriter
|
ResponseWriter = http.ResponseWriter
|
||||||
|
Header = http.Header
|
||||||
|
Cookie = http.Cookie
|
||||||
|
|
||||||
BeforeFunc func(w ResponseWriter, r *Request) (continue_ bool)
|
BeforeFunc func(next http.Handler, w ResponseWriter, r *Request)
|
||||||
RewriteFunc func(req *ProxyRequest)
|
RewriteFunc func(req *ProxyRequest)
|
||||||
ModifyResponseFunc func(res *Response) error
|
ModifyResponseFunc func(res *Response) error
|
||||||
|
CloneWithOptFunc func(opts OptionsRaw, rp *ReverseProxy) (*Middleware, E.NestedError)
|
||||||
|
|
||||||
MiddlewareOptionsRaw map[string]string
|
OptionsRaw = map[string]any
|
||||||
MiddlewareOptions map[string]interface{}
|
Options any
|
||||||
|
|
||||||
Middleware struct {
|
Middleware struct {
|
||||||
name string
|
name string
|
||||||
|
|
||||||
before BeforeFunc
|
before BeforeFunc // runs before ReverseProxy.ServeHTTP
|
||||||
rewrite RewriteFunc
|
rewrite RewriteFunc // runs after ReverseProxy.Rewrite
|
||||||
modifyResponse ModifyResponseFunc
|
modifyResponse ModifyResponseFunc // runs after ReverseProxy.ModifyResponse
|
||||||
|
|
||||||
options MiddlewareOptions
|
transport http.RoundTripper
|
||||||
validateOptions func(opts MiddlewareOptionsRaw) (MiddlewareOptions, E.NestedError)
|
|
||||||
|
withOptions CloneWithOptFunc
|
||||||
|
labelParserMap D.ValueParserMap
|
||||||
|
impl any
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -41,41 +50,32 @@ func (m *Middleware) String() string {
|
||||||
return m.name
|
return m.name
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Middleware) WithOptions(optsRaw MiddlewareOptionsRaw) (*Middleware, E.NestedError) {
|
func (m *Middleware) WithOptionsClone(optsRaw OptionsRaw, rp *ReverseProxy) (*Middleware, E.NestedError) {
|
||||||
if len(optsRaw) == 0 {
|
if len(optsRaw) != 0 && m.withOptions != nil {
|
||||||
return m, nil
|
if mWithOpt, err := m.withOptions(optsRaw, rp); err != nil {
|
||||||
}
|
|
||||||
|
|
||||||
var opts MiddlewareOptions
|
|
||||||
var err E.NestedError
|
|
||||||
|
|
||||||
if m.validateOptions != nil {
|
|
||||||
if opts, err = m.validateOptions(optsRaw); err != nil {
|
|
||||||
return nil, err
|
return nil, err
|
||||||
|
} else {
|
||||||
|
return mWithOpt, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Middleware{
|
// WithOptionsClone is called only once
|
||||||
name: m.name,
|
// set withOptions and labelParser will not be used after that
|
||||||
before: m.before,
|
return &Middleware{m.name, m.before, m.rewrite, m.modifyResponse, m.transport, nil, nil, m.impl}, nil
|
||||||
rewrite: m.rewrite,
|
|
||||||
modifyResponse: m.modifyResponse,
|
|
||||||
options: opts,
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: check conflict
|
// TODO: check conflict or duplicates
|
||||||
func PatchReverseProxy(rp ReverseProxy, middlewares map[string]MiddlewareOptionsRaw) (out ReverseProxy, err E.NestedError) {
|
func PatchReverseProxy(rp *ReverseProxy, middlewares map[string]OptionsRaw) (res E.NestedError) {
|
||||||
out = rp
|
|
||||||
|
|
||||||
befores := make([]BeforeFunc, 0, len(middlewares))
|
befores := make([]BeforeFunc, 0, len(middlewares))
|
||||||
rewrites := make([]RewriteFunc, 0, len(middlewares))
|
rewrites := make([]RewriteFunc, 0, len(middlewares))
|
||||||
modifyResponses := make([]ModifyResponseFunc, 0, len(middlewares))
|
modifyResponses := make([]ModifyResponseFunc, 0, len(middlewares))
|
||||||
|
|
||||||
invalidM := E.NewBuilder("invalid middlewares")
|
invalidM := E.NewBuilder("invalid middlewares")
|
||||||
invalidOpts := E.NewBuilder("invalid options")
|
invalidOpts := E.NewBuilder("invalid options")
|
||||||
defer invalidM.Add(invalidOpts.Build())
|
defer func() {
|
||||||
defer invalidM.To(&err)
|
invalidM.Add(invalidOpts.Build())
|
||||||
|
invalidM.To(&res)
|
||||||
|
}()
|
||||||
|
|
||||||
for name, opts := range middlewares {
|
for name, opts := range middlewares {
|
||||||
m, ok := Get(name)
|
m, ok := Get(name)
|
||||||
|
@ -83,7 +83,8 @@ func PatchReverseProxy(rp ReverseProxy, middlewares map[string]MiddlewareOptions
|
||||||
invalidM.Addf("%s", name)
|
invalidM.Addf("%s", name)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
m, err = m.WithOptions(opts)
|
|
||||||
|
m, err := m.WithOptionsClone(opts, rp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
invalidOpts.Add(err.Subject(name))
|
invalidOpts.Add(err.Subject(name))
|
||||||
continue
|
continue
|
||||||
|
@ -103,25 +104,37 @@ func PatchReverseProxy(rp ReverseProxy, middlewares map[string]MiddlewareOptions
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(befores) > 0 {
|
origServeHTTP := rp.ServeHTTP
|
||||||
rp.ServeHTTP = func(w ResponseWriter, r *Request) {
|
for i, before := range befores {
|
||||||
for _, before := range befores {
|
if i < len(befores)-1 {
|
||||||
if !before(w, r) {
|
rp.ServeHTTP = func(w ResponseWriter, r *Request) {
|
||||||
return
|
before(rp.ServeHTTP, w, r)
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
rp.ServeHTTP = func(w ResponseWriter, r *Request) {
|
||||||
|
before(origServeHTTP, w, r)
|
||||||
}
|
}
|
||||||
rp.ServeHTTP(w, r)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(rewrites) > 0 {
|
if len(rewrites) > 0 {
|
||||||
|
origRewrite := rp.Rewrite
|
||||||
rp.Rewrite = func(req *ProxyRequest) {
|
rp.Rewrite = func(req *ProxyRequest) {
|
||||||
|
if origRewrite != nil {
|
||||||
|
origRewrite(req)
|
||||||
|
}
|
||||||
for _, rewrite := range rewrites {
|
for _, rewrite := range rewrites {
|
||||||
rewrite(req)
|
rewrite(req)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(modifyResponses) > 0 {
|
if len(modifyResponses) > 0 {
|
||||||
|
origModifyResponse := rp.ModifyResponse
|
||||||
rp.ModifyResponse = func(res *Response) error {
|
rp.ModifyResponse = func(res *Response) error {
|
||||||
|
if origModifyResponse != nil {
|
||||||
|
return origModifyResponse(res)
|
||||||
|
}
|
||||||
for _, modifyResponse := range modifyResponses {
|
for _, modifyResponse := range modifyResponses {
|
||||||
if err := modifyResponse(res); err != nil {
|
if err := modifyResponse(res); err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
|
@ -3,14 +3,11 @@ package middleware
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
D "github.com/yusing/go-proxy/docker"
|
||||||
)
|
)
|
||||||
|
|
||||||
var middlewares = map[string]*Middleware{
|
var middlewares map[string]*Middleware
|
||||||
"set_x_forwarded": SetXForwarded, // nginx
|
|
||||||
"add_x_forwarded": AddXForwarded, // nginx
|
|
||||||
"trust_forward_header": AddXForwarded, // traefik alias
|
|
||||||
"redirect_http": RedirectHTTP,
|
|
||||||
}
|
|
||||||
|
|
||||||
func Get(name string) (middleware *Middleware, ok bool) {
|
func Get(name string) (middleware *Middleware, ok bool) {
|
||||||
middleware, ok = middlewares[name]
|
middleware, ok = middlewares[name]
|
||||||
|
@ -18,10 +15,23 @@ func Get(name string) (middleware *Middleware, ok bool) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// initialize middleware names
|
// initialize middleware names
|
||||||
var _ = func() (_ bool) {
|
func init() {
|
||||||
|
middlewares = map[string]*Middleware{
|
||||||
|
"set_x_forwarded": SetXForwarded,
|
||||||
|
"add_x_forwarded": AddXForwarded,
|
||||||
|
"redirect_http": RedirectHTTP,
|
||||||
|
"forward_auth": ForwardAuth.m,
|
||||||
|
"modify_response": ModifyResponse.m,
|
||||||
|
"modify_request": ModifyRequest.m,
|
||||||
|
}
|
||||||
names := make(map[*Middleware][]string)
|
names := make(map[*Middleware][]string)
|
||||||
for name, m := range middlewares {
|
for name, m := range middlewares {
|
||||||
names[m] = append(names[m], name)
|
names[m] = append(names[m], name)
|
||||||
|
// register middleware name to docker label parsr
|
||||||
|
// in order to parse middleware_name.option=value into correct type
|
||||||
|
if m.labelParserMap != nil {
|
||||||
|
D.RegisterNamespace(name, m.labelParserMap)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
for m, names := range names {
|
for m, names := range names {
|
||||||
if len(names) > 1 {
|
if len(names) > 1 {
|
||||||
|
@ -30,5 +40,4 @@ var _ = func() (_ bool) {
|
||||||
m.name = names[0]
|
m.name = names[0]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return
|
}
|
||||||
}()
|
|
||||||
|
|
58
src/route/middleware/modify_request.go
Normal file
58
src/route/middleware/modify_request.go
Normal file
|
@ -0,0 +1,58 @@
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
D "github.com/yusing/go-proxy/docker"
|
||||||
|
E "github.com/yusing/go-proxy/error"
|
||||||
|
U "github.com/yusing/go-proxy/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
type (
|
||||||
|
modifyRequest struct {
|
||||||
|
*modifyRequestOpts
|
||||||
|
m *Middleware
|
||||||
|
}
|
||||||
|
// order: set_headers -> add_headers -> hide_headers
|
||||||
|
modifyRequestOpts struct {
|
||||||
|
SetHeaders map[string]string
|
||||||
|
AddHeaders map[string]string
|
||||||
|
HideHeaders []string
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
var ModifyRequest = newModifyRequest()
|
||||||
|
|
||||||
|
func newModifyRequest() (mr *modifyRequest) {
|
||||||
|
mr = new(modifyRequest)
|
||||||
|
mr.m = new(Middleware)
|
||||||
|
mr.m.labelParserMap = D.ValueParserMap{
|
||||||
|
"set_headers": D.YamlLikeMappingParser(true),
|
||||||
|
"add_headers": D.YamlLikeMappingParser(true),
|
||||||
|
"hide_headers": D.YamlStringListParser,
|
||||||
|
}
|
||||||
|
mr.m.withOptions = func(optsRaw OptionsRaw, rp *ReverseProxy) (*Middleware, E.NestedError) {
|
||||||
|
mrWithOpts := new(modifyRequest)
|
||||||
|
mrWithOpts.m = &Middleware{
|
||||||
|
impl: mrWithOpts,
|
||||||
|
rewrite: mrWithOpts.modifyRequest,
|
||||||
|
}
|
||||||
|
mrWithOpts.modifyRequestOpts = new(modifyRequestOpts)
|
||||||
|
err := U.Deserialize(optsRaw, mrWithOpts.modifyRequestOpts)
|
||||||
|
if err != nil {
|
||||||
|
return nil, E.FailWith("set options", err)
|
||||||
|
}
|
||||||
|
return mrWithOpts.m, nil
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mr *modifyRequest) modifyRequest(req *ProxyRequest) {
|
||||||
|
for k, v := range mr.SetHeaders {
|
||||||
|
req.Out.Header.Set(k, v)
|
||||||
|
}
|
||||||
|
for k, v := range mr.AddHeaders {
|
||||||
|
req.Out.Header.Add(k, v)
|
||||||
|
}
|
||||||
|
for _, k := range mr.HideHeaders {
|
||||||
|
req.Out.Header.Del(k)
|
||||||
|
}
|
||||||
|
}
|
34
src/route/middleware/modify_request_test.go
Normal file
34
src/route/middleware/modify_request_test.go
Normal file
|
@ -0,0 +1,34 @@
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"slices"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
. "github.com/yusing/go-proxy/utils/testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSetModifyRequest(t *testing.T) {
|
||||||
|
opts := OptionsRaw{
|
||||||
|
"set_headers": map[string]string{"User-Agent": "go-proxy/v0.5.0"},
|
||||||
|
"add_headers": map[string]string{"Accept-Encoding": "test-value"},
|
||||||
|
"hide_headers": []string{"Accept"},
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("set_options", func(t *testing.T) {
|
||||||
|
mr, err := ModifyRequest.m.WithOptionsClone(opts, nil)
|
||||||
|
ExpectNoError(t, err.Error())
|
||||||
|
ExpectDeepEqual(t, mr.impl.(*modifyRequest).SetHeaders, opts["set_headers"].(map[string]string))
|
||||||
|
ExpectDeepEqual(t, mr.impl.(*modifyRequest).AddHeaders, opts["add_headers"].(map[string]string))
|
||||||
|
ExpectDeepEqual(t, mr.impl.(*modifyRequest).HideHeaders, opts["hide_headers"].([]string))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("request_headers", func(t *testing.T) {
|
||||||
|
result, err := newMiddlewareTest(ModifyRequest.m, &testArgs{
|
||||||
|
middlewareOpt: opts,
|
||||||
|
})
|
||||||
|
ExpectNoError(t, err.Error())
|
||||||
|
ExpectEqual(t, result.RequestHeaders.Get("User-Agent"), "go-proxy/v0.5.0")
|
||||||
|
ExpectTrue(t, slices.Contains(result.RequestHeaders.Values("Accept-Encoding"), "test-value"))
|
||||||
|
ExpectEqual(t, result.RequestHeaders.Get("Accept"), "")
|
||||||
|
})
|
||||||
|
}
|
61
src/route/middleware/modify_response.go
Normal file
61
src/route/middleware/modify_response.go
Normal file
|
@ -0,0 +1,61 @@
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
D "github.com/yusing/go-proxy/docker"
|
||||||
|
E "github.com/yusing/go-proxy/error"
|
||||||
|
U "github.com/yusing/go-proxy/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
type (
|
||||||
|
modifyResponse struct {
|
||||||
|
*modifyResponseOpts
|
||||||
|
m *Middleware
|
||||||
|
}
|
||||||
|
// order: set_headers -> add_headers -> hide_headers
|
||||||
|
modifyResponseOpts struct {
|
||||||
|
SetHeaders map[string]string
|
||||||
|
AddHeaders map[string]string
|
||||||
|
HideHeaders []string
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
var ModifyResponse = newModifyResponse()
|
||||||
|
|
||||||
|
func newModifyResponse() (mr *modifyResponse) {
|
||||||
|
mr = new(modifyResponse)
|
||||||
|
mr.m = new(Middleware)
|
||||||
|
mr.m.labelParserMap = D.ValueParserMap{
|
||||||
|
"set_headers": D.YamlLikeMappingParser(true),
|
||||||
|
"add_headers": D.YamlLikeMappingParser(true),
|
||||||
|
"hide_headers": D.YamlStringListParser,
|
||||||
|
}
|
||||||
|
mr.m.withOptions = func(optsRaw OptionsRaw, rp *ReverseProxy) (*Middleware, E.NestedError) {
|
||||||
|
mrWithOpts := new(modifyResponse)
|
||||||
|
mrWithOpts.m = &Middleware{
|
||||||
|
impl: mrWithOpts,
|
||||||
|
modifyResponse: mrWithOpts.modifyResponse,
|
||||||
|
}
|
||||||
|
mrWithOpts.modifyResponseOpts = new(modifyResponseOpts)
|
||||||
|
err := U.Deserialize(optsRaw, mrWithOpts.modifyResponseOpts)
|
||||||
|
if err != nil {
|
||||||
|
return nil, E.FailWith("set options", err)
|
||||||
|
}
|
||||||
|
return mrWithOpts.m, nil
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mr *modifyResponse) modifyResponse(resp *http.Response) error {
|
||||||
|
for k, v := range mr.SetHeaders {
|
||||||
|
resp.Header.Set(k, v)
|
||||||
|
}
|
||||||
|
for k, v := range mr.AddHeaders {
|
||||||
|
resp.Header.Add(k, v)
|
||||||
|
}
|
||||||
|
for _, k := range mr.HideHeaders {
|
||||||
|
resp.Header.Del(k)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
35
src/route/middleware/modify_response_test.go
Normal file
35
src/route/middleware/modify_response_test.go
Normal file
|
@ -0,0 +1,35 @@
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"slices"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
. "github.com/yusing/go-proxy/utils/testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSetModifyResponse(t *testing.T) {
|
||||||
|
opts := OptionsRaw{
|
||||||
|
"set_headers": map[string]string{"User-Agent": "go-proxy/v0.5.0"},
|
||||||
|
"add_headers": map[string]string{"Accept-Encoding": "test-value"},
|
||||||
|
"hide_headers": []string{"Accept"},
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("set_options", func(t *testing.T) {
|
||||||
|
mr, err := ModifyResponse.m.WithOptionsClone(opts, nil)
|
||||||
|
ExpectNoError(t, err.Error())
|
||||||
|
ExpectDeepEqual(t, mr.impl.(*modifyResponse).SetHeaders, opts["set_headers"].(map[string]string))
|
||||||
|
ExpectDeepEqual(t, mr.impl.(*modifyResponse).AddHeaders, opts["add_headers"].(map[string]string))
|
||||||
|
ExpectDeepEqual(t, mr.impl.(*modifyResponse).HideHeaders, opts["hide_headers"].([]string))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("request_headers", func(t *testing.T) {
|
||||||
|
result, err := newMiddlewareTest(ModifyResponse.m, &testArgs{
|
||||||
|
middlewareOpt: opts,
|
||||||
|
})
|
||||||
|
ExpectNoError(t, err.Error())
|
||||||
|
ExpectEqual(t, result.ResponseHeaders.Get("User-Agent"), "go-proxy/v0.5.0")
|
||||||
|
t.Log(result.ResponseHeaders.Get("Accept-Encoding"))
|
||||||
|
ExpectTrue(t, slices.Contains(result.ResponseHeaders.Values("Accept-Encoding"), "test-value"))
|
||||||
|
ExpectEqual(t, result.ResponseHeaders.Get("Accept"), "")
|
||||||
|
})
|
||||||
|
}
|
|
@ -7,14 +7,13 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
var RedirectHTTP = &Middleware{
|
var RedirectHTTP = &Middleware{
|
||||||
before: func(w ResponseWriter, r *Request) (continue_ bool) {
|
before: func(next http.Handler, w ResponseWriter, r *Request) {
|
||||||
if r.TLS == nil {
|
if r.TLS == nil {
|
||||||
r.URL.Scheme = "https"
|
r.URL.Scheme = "https"
|
||||||
r.URL.Host = r.URL.Hostname() + common.ProxyHTTPSPort
|
r.URL.Host = r.URL.Hostname() + ":" + common.ProxyHTTPSPort
|
||||||
http.Redirect(w, r, r.URL.String(), http.StatusTemporaryRedirect)
|
http.Redirect(w, r, r.URL.String(), http.StatusTemporaryRedirect)
|
||||||
} else {
|
return
|
||||||
continue_ = true
|
|
||||||
}
|
}
|
||||||
return
|
next.ServeHTTP(w, r)
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
26
src/route/middleware/redirect_http_test.go
Normal file
26
src/route/middleware/redirect_http_test.go
Normal file
|
@ -0,0 +1,26 @@
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/yusing/go-proxy/common"
|
||||||
|
. "github.com/yusing/go-proxy/utils/testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRedirectToHTTPs(t *testing.T) {
|
||||||
|
result, err := newMiddlewareTest(RedirectHTTP, &testArgs{
|
||||||
|
scheme: "http",
|
||||||
|
})
|
||||||
|
ExpectNoError(t, err.Error())
|
||||||
|
ExpectEqual(t, result.ResponseStatus, http.StatusTemporaryRedirect)
|
||||||
|
ExpectEqual(t, result.ResponseHeaders.Get("Location"), "https://"+testHost+":"+common.ProxyHTTPSPort)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNoRedirect(t *testing.T) {
|
||||||
|
result, err := newMiddlewareTest(RedirectHTTP, &testArgs{
|
||||||
|
scheme: "https",
|
||||||
|
})
|
||||||
|
ExpectNoError(t, err.Error())
|
||||||
|
ExpectEqual(t, result.ResponseStatus, http.StatusOK)
|
||||||
|
}
|
|
@ -1,10 +0,0 @@
|
||||||
package middleware
|
|
||||||
|
|
||||||
var SetXForwarded = &Middleware{
|
|
||||||
rewrite: func(r *ProxyRequest) {
|
|
||||||
r.Out.Header.Del("X-Forwarded-For")
|
|
||||||
r.Out.Header.Del("X-Forwarded-Host")
|
|
||||||
r.Out.Header.Del("X-Forwarded-Proto")
|
|
||||||
r.SetXForwarded()
|
|
||||||
},
|
|
||||||
}
|
|
17
src/route/middleware/test_data/sample_headers.json
Normal file
17
src/route/middleware/test_data/sample_headers.json
Normal file
|
@ -0,0 +1,17 @@
|
||||||
|
{
|
||||||
|
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7",
|
||||||
|
"Accept-Encoding": "gzip, deflate, br, zstd",
|
||||||
|
"Accept-Language": "en,zh-HK;q=0.9,zh-TW;q=0.8,zh-CN;q=0.7,zh;q=0.6",
|
||||||
|
"Dnt": "1",
|
||||||
|
"Host": "localhost",
|
||||||
|
"Priority": "u=0, i",
|
||||||
|
"Sec-Ch-Ua": "\"Chromium\";v=\"129\", \"Not=A?Brand\";v=\"8\"",
|
||||||
|
"Sec-Ch-Ua-Mobile": "?0",
|
||||||
|
"Sec-Ch-Ua-Platform": "\"Windows\"",
|
||||||
|
"Sec-Fetch-Dest": "document",
|
||||||
|
"Sec-Fetch-Mode": "navigate",
|
||||||
|
"Sec-Fetch-Site": "none",
|
||||||
|
"Sec-Fetch-User": "?1",
|
||||||
|
"Upgrade-Insecure-Requests": "1",
|
||||||
|
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/129.0.0.0 Safari/537.36"
|
||||||
|
}
|
125
src/route/middleware/test_utils.go
Normal file
125
src/route/middleware/test_utils.go
Normal file
|
@ -0,0 +1,125 @@
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
_ "embed"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
|
||||||
|
E "github.com/yusing/go-proxy/error"
|
||||||
|
gpHTTP "github.com/yusing/go-proxy/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
//go:embed test_data/sample_headers.json
|
||||||
|
var testHeadersRaw []byte
|
||||||
|
var testHeaders http.Header
|
||||||
|
|
||||||
|
const testHost = "example.com"
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
tmp := map[string]string{}
|
||||||
|
err := json.Unmarshal(testHeadersRaw, &tmp)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
testHeaders = http.Header{}
|
||||||
|
for k, v := range tmp {
|
||||||
|
testHeaders.Set(k, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type requestHeaderRecorder struct {
|
||||||
|
parent http.RoundTripper
|
||||||
|
reqHeaders http.Header
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rt *requestHeaderRecorder) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
|
rt.reqHeaders = req.Header
|
||||||
|
if rt.parent != nil {
|
||||||
|
return rt.parent.RoundTrip(req)
|
||||||
|
}
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: testHeaders,
|
||||||
|
Body: io.NopCloser(bytes.NewBufferString("OK")),
|
||||||
|
Request: req,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type TestResult struct {
|
||||||
|
RequestHeaders http.Header
|
||||||
|
ResponseHeaders http.Header
|
||||||
|
ResponseStatus int
|
||||||
|
Data []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
type testArgs struct {
|
||||||
|
middlewareOpt OptionsRaw
|
||||||
|
proxyURL string
|
||||||
|
body []byte
|
||||||
|
scheme string
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.NestedError) {
|
||||||
|
var body io.Reader
|
||||||
|
var rt = new(requestHeaderRecorder)
|
||||||
|
var proxyURL *url.URL
|
||||||
|
var requestTarget string
|
||||||
|
var err error
|
||||||
|
|
||||||
|
if args == nil {
|
||||||
|
args = new(testArgs)
|
||||||
|
}
|
||||||
|
|
||||||
|
if args.body != nil {
|
||||||
|
body = bytes.NewReader(args.body)
|
||||||
|
}
|
||||||
|
|
||||||
|
if args.scheme == "" || args.scheme == "http" {
|
||||||
|
requestTarget = "http://" + testHost
|
||||||
|
} else if args.scheme == "https" {
|
||||||
|
requestTarget = "https://" + testHost
|
||||||
|
} else {
|
||||||
|
panic("typo?")
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, requestTarget, body)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
if args.scheme == "https" && req.TLS == nil {
|
||||||
|
panic("bug occurred")
|
||||||
|
}
|
||||||
|
|
||||||
|
if args.proxyURL != "" {
|
||||||
|
proxyURL, err = url.Parse(args.proxyURL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, E.From(err)
|
||||||
|
}
|
||||||
|
rt.parent = http.DefaultTransport
|
||||||
|
} else {
|
||||||
|
proxyURL, _ = url.Parse("https://" + testHost) // dummy url, no actual effect
|
||||||
|
}
|
||||||
|
rp := gpHTTP.NewReverseProxy(proxyURL, rt)
|
||||||
|
setOptErr := PatchReverseProxy(rp, map[string]OptionsRaw{
|
||||||
|
middleware.name: args.middlewareOpt,
|
||||||
|
})
|
||||||
|
if setOptErr != nil {
|
||||||
|
return nil, setOptErr
|
||||||
|
}
|
||||||
|
rp.ServeHTTP(w, req)
|
||||||
|
resp := w.Result()
|
||||||
|
defer resp.Body.Close()
|
||||||
|
data, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, E.From(err)
|
||||||
|
}
|
||||||
|
return &TestResult{
|
||||||
|
RequestHeaders: rt.reqHeaders,
|
||||||
|
ResponseHeaders: resp.Header,
|
||||||
|
ResponseStatus: resp.StatusCode,
|
||||||
|
Data: data,
|
||||||
|
}, nil
|
||||||
|
}
|
9
src/route/middleware/x_forwarded.go
Normal file
9
src/route/middleware/x_forwarded.go
Normal file
|
@ -0,0 +1,9 @@
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
var AddXForwarded = &Middleware{
|
||||||
|
rewrite: (*ProxyRequest).AddXForwarded,
|
||||||
|
}
|
||||||
|
|
||||||
|
var SetXForwarded = &Middleware{
|
||||||
|
rewrite: (*ProxyRequest).SetXForwarded,
|
||||||
|
}
|
|
@ -84,6 +84,11 @@ func NewServer(opt Options) (s *server) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Start will start the http and https servers.
|
||||||
|
//
|
||||||
|
// If both are not set, this does nothing.
|
||||||
|
//
|
||||||
|
// Start() is non-blocking
|
||||||
func (s *server) Start() {
|
func (s *server) Start() {
|
||||||
if s.http == nil && s.https == nil {
|
if s.http == nil && s.https == nil {
|
||||||
return
|
return
|
||||||
|
|
|
@ -106,7 +106,7 @@ func Serialize(data any) (SerializedObject, E.NestedError) {
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func Deserialize(src map[string]any, target any) E.NestedError {
|
func Deserialize(src SerializedObject, target any) E.NestedError {
|
||||||
// convert data fields to lower no-snake
|
// convert data fields to lower no-snake
|
||||||
// convert target fields to lower
|
// convert target fields to lower
|
||||||
// then check if the field of data is in the target
|
// then check if the field of data is in the target
|
||||||
|
@ -117,6 +117,10 @@ func Deserialize(src map[string]any, target any) E.NestedError {
|
||||||
snakeCaseField := strings.ToLower(field.Name)
|
snakeCaseField := strings.ToLower(field.Name)
|
||||||
mapping[snakeCaseField] = field.Name
|
mapping[snakeCaseField] = field.Name
|
||||||
}
|
}
|
||||||
|
tValue := reflect.ValueOf(target)
|
||||||
|
if tValue.IsZero() {
|
||||||
|
return E.Invalid("value", "nil")
|
||||||
|
}
|
||||||
for k, v := range src {
|
for k, v := range src {
|
||||||
kCleaned := toLowerNoSnake(k)
|
kCleaned := toLowerNoSnake(k)
|
||||||
if fieldName, ok := mapping[kCleaned]; ok {
|
if fieldName, ok := mapping[kCleaned]; ok {
|
||||||
|
@ -150,13 +154,13 @@ func Deserialize(src map[string]any, target any) E.NestedError {
|
||||||
}
|
}
|
||||||
prop.Set(propNew)
|
prop.Set(propNew)
|
||||||
default:
|
default:
|
||||||
return E.Unsupported("field", k).Extraf("type=%s", propType)
|
return E.Invalid("conversion", k).Extraf("from %s to %s", vType, propType)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
return E.Unsupported("field", k).Extraf("type=%s", propType)
|
return E.Unsupported("field", k).Extraf("type %s is not settable", propType)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
return E.Failure("unknown field").With(k)
|
return E.Unexpected("field", k)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -38,6 +38,17 @@ func ExpectEqual[T comparable](t *testing.T, got T, want T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ExpectEqualAny[T comparable](t *testing.T, got T, wants []T) {
|
||||||
|
t.Helper()
|
||||||
|
for _, want := range wants {
|
||||||
|
if got == want {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
t.Errorf("expected any of:\n%v, got\n%v", wants, got)
|
||||||
|
t.FailNow()
|
||||||
|
}
|
||||||
|
|
||||||
func ExpectDeepEqual[T any](t *testing.T, got T, want T) {
|
func ExpectDeepEqual[T any](t *testing.T, got T, want T) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
if !reflect.DeepEqual(got, want) {
|
if !reflect.DeepEqual(got, want) {
|
||||||
|
|
Loading…
Add table
Reference in a new issue