feat: proxmox idlewatcher (#88)

* feat: idle sleep for proxmox LXCs

* refactor: replace deprecated docker api types

* chore(api): remove debug task list endpoint

* refactor: move servemux to gphttp/servemux; favicon.go to v1/favicon

* refactor: introduce Pool interface, move agent_pool to agent module

* refactor: simplify api code

* feat: introduce debug api

* refactor: remove net.URL and net.CIDR types, improved unmarshal handling

* chore: update Makefile for debug build tag, update README

* chore: add gperr.Unwrap method

* feat: relative time and duration formatting

* chore: add ROOT_DIR environment variable, refactor

* migration: move homepage override and icon cache to $BASE_DIR/data, add migration code

* fix: nil dereference on marshalling service health

* fix: wait for route deletion

* chore: enhance tasks debuggability

* feat: stdout access logger and MultiWriter

* fix(agent): remove agent properly on verify error

* fix(metrics): disk exclusion logic and added corresponding tests

* chore: update schema and prettify, fix package.json and Makefile

* fix: I/O buffer not being shrunk before putting back to pool

* feat: enhanced error handling module

* chore: deps upgrade

* feat: better value formatting and handling

---------

Co-authored-by: yusing <yusing@6uo.me>
This commit is contained in:
Yuzerion 2025-04-16 14:52:33 +08:00 committed by GitHub
parent 88f3a95b61
commit 57292f0fe8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
173 changed files with 4131 additions and 2096 deletions

1
.gitignore vendored
View file

@ -11,6 +11,7 @@ error_pages/
!examples/error_pages/ !examples/error_pages/
profiles/ profiles/
data/ data/
debug/
logs/ logs/
log/ log/

View file

@ -27,16 +27,16 @@ endif
ifeq ($(debug), 1) ifeq ($(debug), 1)
CGO_ENABLED = 0 CGO_ENABLED = 0
GODOXY_DEBUG = 1 GODOXY_DEBUG = 1
BUILD_FLAGS += -gcflags=all='-N -l' BUILD_FLAGS += -gcflags=all='-N -l' -tags debug
else ifeq ($(pprof), 1) else ifeq ($(pprof), 1)
CGO_ENABLED = 1 CGO_ENABLED = 1
GORACE = log_path=logs/pprof strip_path_prefix=$(shell pwd)/ halt_on_error=1 GORACE = log_path=logs/pprof strip_path_prefix=$(shell pwd)/ halt_on_error=1
BUILD_FLAGS = -tags pprof BUILD_FLAGS += -tags pprof
VERSION := ${VERSION}-pprof VERSION := ${VERSION}-pprof
else else
CGO_ENABLED = 0 CGO_ENABLED = 0
LDFLAGS += -s -w LDFLAGS += -s -w
BUILD_FLAGS = -pgo=auto -tags production BUILD_FLAGS += -pgo=auto -tags production
endif endif
BUILD_FLAGS += -ldflags='$(LDFLAGS)' BUILD_FLAGS += -ldflags='$(LDFLAGS)'
@ -50,6 +50,8 @@ export GODEBUG
export GORACE export GORACE
export BUILD_FLAGS export BUILD_FLAGS
.PHONY: debug
test: test:
GODOXY_TEST=1 go test ./internal/... GODOXY_TEST=1 go test ./internal/...
@ -67,6 +69,10 @@ build:
run: run:
[ -f .env ] && godotenv -f .env go run ${BUILD_FLAGS} ${CMD_PATH} [ -f .env ] && godotenv -f .env go run ${BUILD_FLAGS} ${CMD_PATH}
debug:
make NAME="godoxy-test" debug=1 build
sh -c 'HTTP_ADDR=:81 HTTPS_ADDR=:8443 API_ADDR=:8899 DEBUG=1 bin/godoxy-test'
mtrace: mtrace:
bin/godoxy debug-ls-mtrace > mtrace.json bin/godoxy debug-ls-mtrace > mtrace.json
@ -88,43 +94,5 @@ cloc:
link-binary: link-binary:
ln -s /app/${NAME} bin/run ln -s /app/${NAME} bin/run
# To generate schema
# comment out this part from typescript-json-schema.js#L884
#
# if (indexType.flags !== ts.TypeFlags.Number && !isIndexedObject) {
# throw new Error("Not supported: IndexSignatureDeclaration with index symbol other than a number or a string");
# }
gen-schema-single:
bun --bun run typescript-json-schema --noExtraProps --required --skipLibCheck --tsNodeRegister=true -o schemas/${OUT} schemas/${IN} ${CLASS}
# minify
python3 -c "import json; f=open('schemas/${OUT}', 'r'); j=json.load(f); f.close(); f=open('schemas/${OUT}', 'w'); json.dump(j, f, separators=(',', ':'));"
gen-schema:
cd schemas && bun --bun tsc
make IN=config/config.ts \
CLASS=Config \
OUT=config.schema.json \
gen-schema-single
make IN=providers/routes.ts \
CLASS=Routes \
OUT=routes.schema.json \
gen-schema-single
make IN=middlewares/middleware_compose.ts \
CLASS=MiddlewareCompose \
OUT=middleware_compose.schema.json \
gen-schema-single
make IN=docker.ts \
CLASS=DockerRoutes \
OUT=docker_routes.schema.json \
gen-schema-single
cd ..
publish-schema:
cd schemas && bun publish && cd ..
update-schema-generator:
pnpm up -g typescript-json-schema
push-github: push-github:
git push origin $(shell git rev-parse --abbrev-ref HEAD) git push origin $(shell git rev-parse --abbrev-ref HEAD)

View file

@ -5,6 +5,7 @@
[![Quality Gate Status](https://sonarcloud.io/api/project_badges/measure?project=yusing_go-proxy&metric=alert_status)](https://sonarcloud.io/summary/new_code?id=yusing_godoxy) [![Quality Gate Status](https://sonarcloud.io/api/project_badges/measure?project=yusing_go-proxy&metric=alert_status)](https://sonarcloud.io/summary/new_code?id=yusing_godoxy)
![GitHub last commit](https://img.shields.io/github/last-commit/yusing/godoxy) ![GitHub last commit](https://img.shields.io/github/last-commit/yusing/godoxy)
[![Lines of Code](https://sonarcloud.io/api/project_badges/measure?project=yusing_go-proxy&metric=ncloc)](https://sonarcloud.io/summary/new_code?id=yusing_godoxy) [![Lines of Code](https://sonarcloud.io/api/project_badges/measure?project=yusing_go-proxy&metric=ncloc)](https://sonarcloud.io/summary/new_code?id=yusing_godoxy)
![Demo](https://img.shields.io/website?url=https%3A%2F%2Fgodoxy.demo.6uo.me&label=Demo&link=https%3A%2F%2Fgodoxy.demo.6uo.me)
[![Discord](https://dcbadge.limes.pink/api/server/umReR62nRd?style=flat)](https://discord.gg/umReR62nRd) [![Discord](https://dcbadge.limes.pink/api/server/umReR62nRd?style=flat)](https://discord.gg/umReR62nRd)
A lightweight, simple, and [performant](https://github.com/yusing/godoxy/wiki/Benchmarks) reverse proxy with WebUI. A lightweight, simple, and [performant](https://github.com/yusing/godoxy/wiki/Benchmarks) reverse proxy with WebUI.
@ -47,19 +48,17 @@ For full documentation, check out **[Wiki](https://github.com/yusing/godoxy/wiki
- Effortless configuration - Effortless configuration
- Simple multi-node setup with GoDoxy agents or Docker Socket Proxies - Simple multi-node setup with GoDoxy agents or Docker Socket Proxies
- Error messages is clear and detailed, easy troubleshooting - Error messages is clear and detailed, easy troubleshooting
- Auto SSL with Let's Encrypt (See [Supported DNS-01 Challenge Providers](https://github.com/yusing/go-proxy/wiki/Supported-DNS%E2%80%9001-Providers)) - **Auto SSL** with Let's Encrypt (See [Supported DNS-01 Challenge Providers](https://github.com/yusing/go-proxy/wiki/Supported-DNS%E2%80%9001-Providers))
- Auto hot-reload on container state / config file changes - **Auto hot-reload** on container state / config file changes
- Container aware: create routes dynamically from running docker containers - **Container aware**: create routes dynamically from running docker containers
- **idlesleeper**: stop and wake containers based on traffic _(optional, see [screenshots](#idlesleeper))_ - **idlesleeper**: stop and wake containers based on traffic _(optional, see [screenshots](#idlesleeper))_
- HTTP reserve proxy and TCP/UDP port forwarding - HTTP reserve proxy and TCP/UDP port forwarding
- OpenID Connect integration: SSO and secure your apps easily - **OpenID Connect integration**: SSO and secure your apps easily
- [HTTP middleware](https://github.com/yusing/go-proxy/wiki/Middlewares) and [Custom error pages support](https://github.com/yusing/go-proxy/wiki/Middlewares#custom-error-pages) - [HTTP middleware](https://github.com/yusing/go-proxy/wiki/Middlewares) and [Custom error pages support](https://github.com/yusing/go-proxy/wiki/Middlewares#custom-error-pages)
- **Web UI with App dashboard, config editor, _uptime and system metrics_, _docker logs viewer_** - **Web UI with App dashboard, config editor, _uptime and system metrics_, _docker logs viewer_**
- Supports linux/amd64 and linux/arm64 - Supports **linux/amd64** and **linux/arm64**
- Written in **[Go](https://go.dev)** - Written in **[Go](https://go.dev)**
[🔼Back to top](#table-of-content)
## Prerequisites ## Prerequisites
Setup Wildcard DNS Record(s) for machine running `GoDoxy`, e.g. Setup Wildcard DNS Record(s) for machine running `GoDoxy`, e.g.
@ -74,13 +73,17 @@ Setup Wildcard DNS Record(s) for machine running `GoDoxy`, e.g.
3. Create a route if applicable (a route is like a "Virtual Host" in NPM) 3. Create a route if applicable (a route is like a "Virtual Host" in NPM)
4. Watch for container / config changes and update automatically 4. Watch for container / config changes and update automatically
GoDoxy uses the label `proxy.aliases` as the subdomain(s), if unset it defaults to the `container_name` field in docker compose. > [!NOTE]
> GoDoxy uses the label `proxy.aliases` as the subdomain(s), if unset it defaults to the `container_name` field in docker compose.
For example, with the label `proxy.aliases: qbt` you can access your app via `qbt.domain.com`. >
> For example, with the label `proxy.aliases: qbt` you can access your app via `qbt.domain.com`.
## Setup ## Setup
**NOTE:** GoDoxy is designed to be (and only works when) running in `host` network mode, do not change it. To change listening ports, modify `.env`. > [!NOTE]
> GoDoxy is designed to be running in `host` network mode, do not change it.
>
> To change listening ports, modify `.env`.
1. Prepare a new directory for docker compose and config files. 1. Prepare a new directory for docker compose and config files.
@ -90,11 +93,7 @@ For example, with the label `proxy.aliases: qbt` you can access your app via `qb
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/yusing/godoxy/main/scripts/setup.sh)" /bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/yusing/godoxy/main/scripts/setup.sh)"
``` ```
3. Start the container `docker compose up -d` and wait for it to be ready 3. You may now do some extra configuration on WebUI `https://godoxy.yourdomain.com`
4. You may now do some extra configuration on WebUI `https://godoxy.yourdomain.com`
[🔼Back to top](#table-of-content)
## Screenshots ## Screenshots
@ -127,8 +126,6 @@ For example, with the label `proxy.aliases: qbt` you can access your app via `qb
</table> </table>
</div> </div>
[🔼Back to top](#table-of-content)
## Manual Setup ## Manual Setup
1. Make `config` directory then grab `config.example.yml` into `config/config.yml` 1. Make `config` directory then grab `config.example.yml` into `config/config.yml`

View file

@ -5,7 +5,8 @@
[![Quality Gate Status](https://sonarcloud.io/api/project_badges/measure?project=yusing_go-proxy&metric=alert_status)](https://sonarcloud.io/summary/new_code?id=yusing_go-proxy) [![Quality Gate Status](https://sonarcloud.io/api/project_badges/measure?project=yusing_go-proxy&metric=alert_status)](https://sonarcloud.io/summary/new_code?id=yusing_go-proxy)
![GitHub last commit](https://img.shields.io/github/last-commit/yusing/godoxy) ![GitHub last commit](https://img.shields.io/github/last-commit/yusing/godoxy)
[![Lines of Code](https://sonarcloud.io/api/project_badges/measure?project=yusing_go-proxy&metric=ncloc)](https://sonarcloud.io/summary/new_code?id=yusing_go-proxy) [![Lines of Code](https://sonarcloud.io/api/project_badges/measure?project=yusing_go-proxy&metric=ncloc)](https://sonarcloud.io/summary/new_code?id=yusing_go-proxy)
[![](https://dcbadge.limes.pink/api/server/umReR62nRd?style=flat)](https://discord.gg/umReR62nRd) ![Demo](https://img.shields.io/website?url=https%3A%2F%2Fgodoxy.demo.6uo.me&label=Demo&link=https%3A%2F%2Fgodoxy.demo.6uo.me)
[![Discord](https://dcbadge.limes.pink/api/server/umReR62nRd?style=flat)](https://discord.gg/umReR62nRd)
輕量、易用、 [高效能](https://github.com/yusing/godoxy/wiki/Benchmarks),且帶有主頁和配置面板的反向代理 輕量、易用、 [高效能](https://github.com/yusing/godoxy/wiki/Benchmarks),且帶有主頁和配置面板的反向代理
@ -68,7 +69,10 @@
## 安裝 ## 安裝
**注意:** GoDoxy 設計為(且僅在)`host` 網路模式下運作,請勿更改。如需更改監聽埠,請修改 `.env` > [!NOTE]
> GoDoxy 僅在 `host` 網路模式下運作,請勿更改。
>
> 如需更改監聽埠,請修改 `.env`
1. 準備一個新目錄用於 docker compose 和配置文件。 1. 準備一個新目錄用於 docker compose 和配置文件。
@ -78,9 +82,7 @@
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/yusing/godoxy/main/scripts/setup.sh)" /bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/yusing/godoxy/main/scripts/setup.sh)"
``` ```
3. 啟動容器 `docker compose up -d` 並等待就緒 3. 現在可以在 WebUI `https://godoxy.yourdomain.com` 進行額外配置
4. 現在可以在 WebUI `https://godoxy.yourdomain.com` 進行額外配置
[🔼回到頂部](#目錄) [🔼回到頂部](#目錄)

16
agent/pkg/agent/agents.go Normal file
View file

@ -0,0 +1,16 @@
package agent
import (
"github.com/yusing/go-proxy/internal/utils/pool"
)
type agents struct{ pool.Pool[*AgentConfig] }
var Agents = agents{pool.New[*AgentConfig]("agents")}
func (agents agents) Get(agentAddrOrDockerHost string) (*AgentConfig, bool) {
if !IsDockerHostAgent(agentAddrOrDockerHost) {
return agents.Base().Load(agentAddrOrDockerHost)
}
return agents.Base().Load(GetAgentAddrFromDockerHost(agentAddrOrDockerHost))
}

View file

@ -4,19 +4,16 @@ import (
"context" "context"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"encoding/json"
"net" "net"
"net/http" "net/http"
"net/url"
"os" "os"
"strings" "strings"
"time" "time"
"github.com/rs/zerolog"
"github.com/yusing/go-proxy/agent/pkg/certs" "github.com/yusing/go-proxy/agent/pkg/certs"
"github.com/yusing/go-proxy/internal/gperr" "github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/net/gphttp" "github.com/yusing/go-proxy/internal/net/gphttp"
"github.com/yusing/go-proxy/internal/net/types"
"github.com/yusing/go-proxy/pkg" "github.com/yusing/go-proxy/pkg"
) )
@ -26,7 +23,6 @@ type AgentConfig struct {
httpClient *http.Client httpClient *http.Client
tlsConfig *tls.Config tlsConfig *tls.Config
name string name string
l zerolog.Logger
} }
const ( const (
@ -49,8 +45,8 @@ const (
) )
var ( var (
AgentURL = types.MustParseURL(APIBaseURL) AgentURL, _ = url.Parse(APIBaseURL)
HTTPProxyURL = types.MustParseURL(APIBaseURL + EndpointProxyHTTP) HTTPProxyURL, _ = url.Parse(APIBaseURL + EndpointProxyHTTP)
HTTPProxyURLPrefixLen = len(APIEndpointBase + EndpointProxyHTTP) HTTPProxyURLPrefixLen = len(APIEndpointBase + EndpointProxyHTTP)
) )
@ -71,6 +67,11 @@ func GetAgentAddrFromDockerHost(dockerHost string) string {
return dockerHost[FakeDockerHostPrefixLen:] return dockerHost[FakeDockerHostPrefixLen:]
} }
// Key implements pool.Object
func (cfg *AgentConfig) Key() string {
return cfg.Addr
}
func (cfg *AgentConfig) FakeDockerHost() string { func (cfg *AgentConfig) FakeDockerHost() string {
return FakeDockerHostPrefix + cfg.Addr return FakeDockerHostPrefix + cfg.Addr
} }
@ -121,7 +122,7 @@ func (cfg *AgentConfig) InitWithCerts(ctx context.Context, ca, crt, key []byte)
versionStr := string(version) versionStr := string(version)
// skip version check for dev versions // skip version check for dev versions
if strings.HasPrefix(versionStr, "v") && !checkVersion(versionStr, pkg.GetVersion()) { if strings.HasPrefix(versionStr, "v") && !checkVersion(versionStr, pkg.GetVersion().String()) {
return gperr.Errorf("agent version mismatch: server: %s, agent: %s", pkg.GetVersion(), versionStr) return gperr.Errorf("agent version mismatch: server: %s, agent: %s", pkg.GetVersion(), versionStr)
} }
@ -132,8 +133,6 @@ func (cfg *AgentConfig) InitWithCerts(ctx context.Context, ca, crt, key []byte)
} }
cfg.name = string(name) cfg.name = string(name)
cfg.l = logging.With().Str("agent", cfg.name).Logger()
cfg.l.Info().Msg("agent initialized")
return nil return nil
} }
@ -193,9 +192,10 @@ func (cfg *AgentConfig) String() string {
return cfg.name + "@" + cfg.Addr return cfg.name + "@" + cfg.Addr
} }
func (cfg *AgentConfig) MarshalJSON() ([]byte, error) { // MarshalMap implements pool.Object
return json.Marshal(map[string]string{ func (cfg *AgentConfig) MarshalMap() map[string]any {
return map[string]any{
"name": cfg.Name(), "name": cfg.Name(),
"addr": cfg.Addr, "addr": cfg.Addr,
}) }
} }

View file

@ -59,7 +59,7 @@ func AgentCertsFilepath(host string) (filepathOut string, ok bool) {
if !isValidAgentHost(host) { if !isValidAgentHost(host) {
return "", false return "", false
} }
return filepath.Join(common.AgentCertsBasePath, host+".zip"), true return filepath.Join(common.CertsDir, host+".zip"), true
} }
func ExtractCert(data []byte) (ca, crt, key []byte, err error) { func ExtractCert(data []byte) (ca, crt, key []byte, err error) {

View file

@ -8,7 +8,6 @@ import (
"strings" "strings"
"github.com/yusing/go-proxy/internal/net/gphttp" "github.com/yusing/go-proxy/internal/net/gphttp"
"github.com/yusing/go-proxy/internal/net/types"
"github.com/yusing/go-proxy/internal/watcher/health" "github.com/yusing/go-proxy/internal/watcher/health"
"github.com/yusing/go-proxy/internal/watcher/health/monitor" "github.com/yusing/go-proxy/internal/watcher/health/monitor"
) )
@ -44,11 +43,11 @@ func CheckHealth(w http.ResponseWriter, r *http.Request) {
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
return return
} }
result, err = monitor.NewHTTPHealthMonitor(types.NewURL(&url.URL{ result, err = monitor.NewHTTPHealthMonitor(&url.URL{
Scheme: scheme, Scheme: scheme,
Host: host, Host: host,
Path: path, Path: path,
}), defaultHealthConfig).CheckHealth() }, defaultHealthConfig).CheckHealth()
case "tcp", "udp": case "tcp", "udp":
host := query.Get("host") host := query.Get("host")
if host == "" { if host == "" {
@ -63,10 +62,10 @@ func CheckHealth(w http.ResponseWriter, r *http.Request) {
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
return return
} }
result, err = monitor.NewRawHealthMonitor(types.NewURL(&url.URL{ result, err = monitor.NewRawHealthMonitor(&url.URL{
Scheme: scheme, Scheme: scheme,
Host: host, Host: host,
}), defaultHealthConfig).CheckHealth() }, defaultHealthConfig).CheckHealth()
} }
if err != nil { if err != nil {

View file

@ -9,7 +9,6 @@ import (
"github.com/yusing/go-proxy/internal/docker" "github.com/yusing/go-proxy/internal/docker"
"github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy" "github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy"
"github.com/yusing/go-proxy/internal/net/types"
) )
func serviceUnavailable(w http.ResponseWriter, r *http.Request) { func serviceUnavailable(w http.ResponseWriter, r *http.Request) {
@ -22,10 +21,10 @@ func DockerSocketHandler() http.HandlerFunc {
logging.Warn().Err(err).Msg("failed to connect to docker client") logging.Warn().Err(err).Msg("failed to connect to docker client")
return serviceUnavailable return serviceUnavailable
} }
rp := reverseproxy.NewReverseProxy("docker", types.NewURL(&url.URL{ rp := reverseproxy.NewReverseProxy("docker", &url.URL{
Scheme: "http", Scheme: "http",
Host: client.DummyHost, Host: client.DummyHost,
}), dockerClient.HTTPClient().Transport) }, dockerClient.HTTPClient().Transport)
return rp.ServeHTTP return rp.ServeHTTP
} }

View file

@ -12,7 +12,6 @@ import (
"github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/net/gphttp" "github.com/yusing/go-proxy/internal/net/gphttp"
"github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy" "github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy"
"github.com/yusing/go-proxy/internal/net/types"
"github.com/yusing/go-proxy/internal/utils/strutils" "github.com/yusing/go-proxy/internal/utils/strutils"
) )
@ -55,9 +54,9 @@ func ProxyHTTP(w http.ResponseWriter, r *http.Request) {
logging.Debug().Msgf("proxy http request: %s %s", r.Method, r.URL.String()) logging.Debug().Msgf("proxy http request: %s %s", r.Method, r.URL.String())
rp := reverseproxy.NewReverseProxy("agent", types.NewURL(&url.URL{ rp := reverseproxy.NewReverseProxy("agent", &url.URL{
Scheme: scheme, Scheme: scheme,
Host: host, Host: host,
}), transport) }, transport)
rp.ServeHTTP(w, r) rp.ServeHTTP(w, r)
} }

View file

@ -7,6 +7,7 @@ import (
"sync" "sync"
"github.com/yusing/go-proxy/internal/api/v1/auth" "github.com/yusing/go-proxy/internal/api/v1/auth"
debugapi "github.com/yusing/go-proxy/internal/api/v1/debug"
"github.com/yusing/go-proxy/internal/api/v1/query" "github.com/yusing/go-proxy/internal/api/v1/query"
"github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/config" "github.com/yusing/go-proxy/internal/config"
@ -19,6 +20,7 @@ import (
"github.com/yusing/go-proxy/internal/net/gphttp/middleware" "github.com/yusing/go-proxy/internal/net/gphttp/middleware"
"github.com/yusing/go-proxy/internal/route/routes/routequery" "github.com/yusing/go-proxy/internal/route/routes/routequery"
"github.com/yusing/go-proxy/internal/task" "github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/migrations"
"github.com/yusing/go-proxy/pkg" "github.com/yusing/go-proxy/pkg"
) )
@ -38,6 +40,9 @@ func parallel(fns ...func()) {
func main() { func main() {
initProfiling() initProfiling()
if err := migrations.RunMigrations(); err != nil {
gperr.LogFatal("migration error", err)
}
args := pkg.GetArgs(common.MainServerCommandValidator{}) args := pkg.GetArgs(common.MainServerCommandValidator{})
switch args.Command { switch args.Command {
@ -146,6 +151,8 @@ func main() {
uptime.Poller.Start() uptime.Poller.Start()
config.WatchChanges() config.WatchChanges()
debugapi.StartServer(cfg)
task.WaitExit(cfg.Value().TimeoutShutdown) task.WaitExit(cfg.Value().TimeoutShutdown)
} }

View file

@ -1,4 +1,4 @@
//go:build production //go:build !pprof
package main package main

8
go.mod
View file

@ -36,7 +36,7 @@ require (
// favicon extraction // favicon extraction
require ( require (
github.com/PuerkitoBio/goquery v1.10.2 // parsing HTML for extract fav icon github.com/PuerkitoBio/goquery v1.10.3 // parsing HTML for extract fav icon
github.com/vincent-petithory/dataurl v1.0.0 // data url for fav icon github.com/vincent-petithory/dataurl v1.0.0 // data url for fav icon
) )
@ -63,6 +63,8 @@ require (
github.com/stretchr/testify v1.10.0 // testing utilities github.com/stretchr/testify v1.10.0 // testing utilities
) )
require github.com/luthermonson/go-proxmox v0.2.2
require ( require (
github.com/Microsoft/go-winio v0.6.2 // indirect github.com/Microsoft/go-winio v0.6.2 // indirect
github.com/andybalholm/cascadia v1.3.3 // indirect github.com/andybalholm/cascadia v1.3.3 // indirect
@ -73,7 +75,7 @@ require (
github.com/cloudflare/cloudflare-go v0.115.0 // indirect github.com/cloudflare/cloudflare-go v0.115.0 // indirect
github.com/containerd/log v0.1.0 // indirect github.com/containerd/log v0.1.0 // indirect
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/diskfs/go-diskfs v1.5.2 // indirect github.com/diskfs/go-diskfs v1.6.0 // indirect
github.com/distribution/reference v0.6.0 // indirect github.com/distribution/reference v0.6.0 // indirect
github.com/djherbis/times v1.6.0 // indirect github.com/djherbis/times v1.6.0 // indirect
github.com/docker/go-units v0.5.0 // indirect github.com/docker/go-units v0.5.0 // indirect
@ -111,7 +113,7 @@ require (
github.com/pkg/errors v0.9.1 // indirect github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect
github.com/prometheus/client_model v0.6.1 // indirect github.com/prometheus/client_model v0.6.2 // indirect
github.com/prometheus/common v0.63.0 // indirect github.com/prometheus/common v0.63.0 // indirect
github.com/prometheus/procfs v0.16.0 // indirect github.com/prometheus/procfs v0.16.0 // indirect
github.com/quic-go/qpack v0.5.1 // indirect github.com/quic-go/qpack v0.5.1 // indirect

12
go.sum
View file

@ -2,8 +2,8 @@ github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 h1:UQHMgLO+TxOEl
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E= github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E=
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
github.com/PuerkitoBio/goquery v1.10.2 h1:7fh2BdHcG6VFZsK7toXBT/Bh1z5Wmy8Q9MV9HqT2AM8= github.com/PuerkitoBio/goquery v1.10.3 h1:pFYcNSqHxBD06Fpj/KsbStFRsgRATgnf3LeXiUkhzPo=
github.com/PuerkitoBio/goquery v1.10.2/go.mod h1:0guWGjcLu9AYC7C1GHnpysHy056u9aEkUHwhdnePMCU= github.com/PuerkitoBio/goquery v1.10.3/go.mod h1:tMUX0zDMHXYlAQk6p35XxQMqMweEKB7iK7iLNd4RH4Y=
github.com/andybalholm/cascadia v1.3.3 h1:AG2YHrzJIm4BZ19iwJ/DAua6Btl3IwJX+VI4kktS1LM= github.com/andybalholm/cascadia v1.3.3 h1:AG2YHrzJIm4BZ19iwJ/DAua6Btl3IwJX+VI4kktS1LM=
github.com/andybalholm/cascadia v1.3.3/go.mod h1:xNd9bqTn98Ln4DwST8/nG+H0yuB8Hmgu1YHNnWw0GeA= github.com/andybalholm/cascadia v1.3.3/go.mod h1:xNd9bqTn98Ln4DwST8/nG+H0yuB8Hmgu1YHNnWw0GeA=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
@ -27,8 +27,8 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/diskfs/go-diskfs v1.5.2 h1:Aj+f4sYlu3seXJe5KwyOWlol0eRBG9EKGYYYm37DO9s= github.com/diskfs/go-diskfs v1.6.0 h1:YmK5+vLSfkwC6kKKRTRPGaDGNF+Xh8FXeiNHwryDfu4=
github.com/diskfs/go-diskfs v1.5.2/go.mod h1:bRFumZeGFCO8C2KNswrQeuj2m1WCVr4Ms5IjWMczMDk= github.com/diskfs/go-diskfs v1.6.0/go.mod h1:bRFumZeGFCO8C2KNswrQeuj2m1WCVr4Ms5IjWMczMDk=
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk= github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E= github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E=
github.com/djherbis/times v1.6.0 h1:w2ctJ92J8fBvWPxugmXIv7Nz7Q3iDMKNx9v5ocVH20c= github.com/djherbis/times v1.6.0 h1:w2ctJ92J8fBvWPxugmXIv7Nz7Q3iDMKNx9v5ocVH20c=
@ -174,8 +174,8 @@ github.com/prashantv/gostub v1.1.0 h1:BTyx3RfQjRHnUWaGF9oQos79AlQ5k8WNktv7VGvVH4
github.com/prashantv/gostub v1.1.0/go.mod h1:A5zLQHz7ieHGG7is6LLXLz7I8+3LZzsrV0P1IAHhP5U= github.com/prashantv/gostub v1.1.0/go.mod h1:A5zLQHz7ieHGG7is6LLXLz7I8+3LZzsrV0P1IAHhP5U=
github.com/prometheus/client_golang v1.22.0 h1:rb93p9lokFEsctTys46VnV1kLCDpVZ0a/Y92Vm0Zc6Q= github.com/prometheus/client_golang v1.22.0 h1:rb93p9lokFEsctTys46VnV1kLCDpVZ0a/Y92Vm0Zc6Q=
github.com/prometheus/client_golang v1.22.0/go.mod h1:R7ljNsLXhuQXYZYtw6GAE9AZg8Y7vEW5scdCXrWRXC0= github.com/prometheus/client_golang v1.22.0/go.mod h1:R7ljNsLXhuQXYZYtw6GAE9AZg8Y7vEW5scdCXrWRXC0=
github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E= github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk=
github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY= github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE=
github.com/prometheus/common v0.63.0 h1:YR/EIY1o3mEFP/kZCD7iDMnLPlGyuU2Gb3HIcXnA98k= github.com/prometheus/common v0.63.0 h1:YR/EIY1o3mEFP/kZCD7iDMnLPlGyuU2Gb3HIcXnA98k=
github.com/prometheus/common v0.63.0/go.mod h1:VVFF/fBIoToEnWRVkYoXEkq3R3paCoxG9PXP74SnV18= github.com/prometheus/common v0.63.0/go.mod h1:VVFF/fBIoToEnWRVkYoXEkq3R3paCoxG9PXP74SnV18=
github.com/prometheus/procfs v0.16.0 h1:xh6oHhKwnOJKMYiYBDWmkHqQPyiY40sny36Cmx2bbsM= github.com/prometheus/procfs v0.16.0 h1:xh6oHhKwnOJKMYiYBDWmkHqQPyiY40sny36Cmx2bbsM=

View file

@ -1,7 +1,6 @@
package api package api
import ( import (
"fmt"
"net/http" "net/http"
"github.com/prometheus/client_golang/prometheus/promhttp" "github.com/prometheus/client_golang/prometheus/promhttp"
@ -9,63 +8,17 @@ import (
"github.com/yusing/go-proxy/internal/api/v1/auth" "github.com/yusing/go-proxy/internal/api/v1/auth"
"github.com/yusing/go-proxy/internal/api/v1/certapi" "github.com/yusing/go-proxy/internal/api/v1/certapi"
"github.com/yusing/go-proxy/internal/api/v1/dockerapi" "github.com/yusing/go-proxy/internal/api/v1/dockerapi"
"github.com/yusing/go-proxy/internal/api/v1/favicon"
"github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/common"
config "github.com/yusing/go-proxy/internal/config/types" config "github.com/yusing/go-proxy/internal/config/types"
"github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/logging/memlogger" "github.com/yusing/go-proxy/internal/logging/memlogger"
"github.com/yusing/go-proxy/internal/metrics/uptime" "github.com/yusing/go-proxy/internal/metrics/uptime"
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders" "github.com/yusing/go-proxy/internal/net/gphttp/servemux"
"github.com/yusing/go-proxy/internal/utils/strutils"
) )
type (
ServeMux struct {
*http.ServeMux
cfg config.ConfigInstance
}
WithCfgHandler = func(config.ConfigInstance, http.ResponseWriter, *http.Request)
)
func (mux ServeMux) HandleFunc(methods, endpoint string, h any, requireAuth ...bool) {
var handler http.HandlerFunc
switch h := h.(type) {
case func(http.ResponseWriter, *http.Request):
handler = h
case http.Handler:
handler = h.ServeHTTP
case WithCfgHandler:
handler = func(w http.ResponseWriter, r *http.Request) {
h(mux.cfg, w, r)
}
default:
panic(fmt.Errorf("unsupported handler type: %T", h))
}
matchDomains := mux.cfg.Value().MatchDomains
if len(matchDomains) > 0 {
origHandler := handler
handler = func(w http.ResponseWriter, r *http.Request) {
if httpheaders.IsWebsocket(r.Header) {
httpheaders.SetWebsocketAllowedDomains(r.Header, matchDomains)
}
origHandler(w, r)
}
}
if len(requireAuth) > 0 && requireAuth[0] {
handler = auth.RequireAuth(handler)
}
if methods == "" {
mux.ServeMux.HandleFunc(endpoint, handler)
} else {
for _, m := range strutils.CommaSeperatedList(methods) {
mux.ServeMux.HandleFunc(m+" "+endpoint, handler)
}
}
}
func NewHandler(cfg config.ConfigInstance) http.Handler { func NewHandler(cfg config.ConfigInstance) http.Handler {
mux := ServeMux{http.NewServeMux(), cfg} mux := servemux.NewServeMux(cfg)
mux.HandleFunc("GET", "/v1", v1.Index) mux.HandleFunc("GET", "/v1", v1.Index)
mux.HandleFunc("GET", "/v1/version", v1.GetVersion) mux.HandleFunc("GET", "/v1/version", v1.GetVersion)
@ -79,7 +32,7 @@ func NewHandler(cfg config.ConfigInstance) http.Handler {
mux.HandleFunc("POST", "/v1/file/validate/{type}", v1.ValidateFile, true) mux.HandleFunc("POST", "/v1/file/validate/{type}", v1.ValidateFile, true)
mux.HandleFunc("GET", "/v1/health", v1.Health, true) mux.HandleFunc("GET", "/v1/health", v1.Health, true)
mux.HandleFunc("GET", "/v1/logs", memlogger.Handler(), true) mux.HandleFunc("GET", "/v1/logs", memlogger.Handler(), true)
mux.HandleFunc("GET", "/v1/favicon", v1.GetFavIcon, true) mux.HandleFunc("GET", "/v1/favicon", favicon.GetFavIcon, true)
mux.HandleFunc("POST", "/v1/homepage/set", v1.SetHomePageOverrides, true) mux.HandleFunc("POST", "/v1/homepage/set", v1.SetHomePageOverrides, true)
mux.HandleFunc("GET", "/v1/agents", v1.ListAgents, true) mux.HandleFunc("GET", "/v1/agents", v1.ListAgents, true)
mux.HandleFunc("GET", "/v1/agents/new", v1.NewAgent, true) mux.HandleFunc("GET", "/v1/agents/new", v1.NewAgent, true)

View file

@ -4,21 +4,11 @@ import (
"net/http" "net/http"
"time" "time"
"github.com/coder/websocket" "github.com/yusing/go-proxy/agent/pkg/agent"
"github.com/coder/websocket/wsjson"
config "github.com/yusing/go-proxy/internal/config/types" config "github.com/yusing/go-proxy/internal/config/types"
"github.com/yusing/go-proxy/internal/net/gphttp"
"github.com/yusing/go-proxy/internal/net/gphttp/gpwebsocket" "github.com/yusing/go-proxy/internal/net/gphttp/gpwebsocket"
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
) )
func ListAgents(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Request) { func ListAgents(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Request) {
if httpheaders.IsWebsocket(r.Header) { gpwebsocket.DynamicJSONHandler(w, r, agent.Agents.Slice, 10*time.Second)
gpwebsocket.Periodic(w, r, 10*time.Second, func(conn *websocket.Conn) error {
wsjson.Write(r.Context(), conn, cfg.ListAgents())
return nil
})
} else {
gphttp.RespondJSON(w, r, cfg.ListAgents())
}
} }

View file

@ -27,7 +27,7 @@ func fileType(file string) FileType {
switch { switch {
case strings.HasPrefix(path.Base(file), "config."): case strings.HasPrefix(path.Base(file), "config."):
return FileTypeConfig return FileTypeConfig
case strings.HasPrefix(file, common.MiddlewareComposeBasePath): case strings.HasPrefix(file, common.MiddlewareComposeDir):
return FileTypeMiddleware return FileTypeMiddleware
} }
return FileTypeProvider return FileTypeProvider
@ -43,9 +43,9 @@ func (t FileType) IsValid() bool {
func (t FileType) GetPath(filename string) string { func (t FileType) GetPath(filename string) string {
if t == FileTypeMiddleware { if t == FileTypeMiddleware {
return path.Join(common.MiddlewareComposeBasePath, filename) return path.Join(common.MiddlewareComposeDir, filename)
} }
return path.Join(common.ConfigBasePath, filename) return path.Join(common.ConfigDir, filename)
} }
func getArgs(r *http.Request) (fileType FileType, filename string, err error) { func getArgs(r *http.Request) (fileType FileType, filename string, err error) {

View file

@ -0,0 +1,75 @@
//go:build debug
package debugapi
import (
"iter"
"net/http"
"sort"
"time"
"github.com/yusing/go-proxy/agent/pkg/agent"
config "github.com/yusing/go-proxy/internal/config/types"
"github.com/yusing/go-proxy/internal/docker"
"github.com/yusing/go-proxy/internal/idlewatcher"
"github.com/yusing/go-proxy/internal/net/gphttp/gpwebsocket"
"github.com/yusing/go-proxy/internal/net/gphttp/servemux"
"github.com/yusing/go-proxy/internal/net/gphttp/server"
"github.com/yusing/go-proxy/internal/proxmox"
"github.com/yusing/go-proxy/internal/task"
)
func StartServer(cfg config.ConfigInstance) {
srv := server.NewServer(server.Options{
Name: "debug",
HTTPAddr: "127.0.0.1:7777",
Handler: newHandler(cfg),
})
srv.Start(task.RootTask("debug_server", false))
}
type debuggable interface {
MarshalMap() map[string]any
Key() string
}
func toSortedSlice[T debuggable](data iter.Seq2[string, T]) []map[string]any {
s := make([]map[string]any, 0)
for _, v := range data {
m := v.MarshalMap()
m["key"] = v.Key()
s = append(s, m)
}
sort.Slice(s, func(i, j int) bool {
return s[i]["key"].(string) < s[j]["key"].(string)
})
return s
}
func jsonHandler[T debuggable](getData iter.Seq2[string, T]) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
gpwebsocket.DynamicJSONHandler(w, r, func() []map[string]any {
return toSortedSlice(getData)
}, 1*time.Second)
}
}
func iterMap[K comparable, V debuggable](m func() map[K]V) iter.Seq2[K, V] {
return func(yield func(K, V) bool) {
for k, v := range m() {
if !yield(k, v) {
break
}
}
}
}
func newHandler(cfg config.ConfigInstance) http.Handler {
mux := servemux.NewServeMux(cfg)
mux.HandleFunc("GET", "/tasks", jsonHandler(task.AllTasks()))
mux.HandleFunc("GET", "/idlewatcher", jsonHandler(idlewatcher.Watchers()))
mux.HandleFunc("GET", "/agents", jsonHandler(agent.Agents.Iter))
mux.HandleFunc("GET", "/proxmox", jsonHandler(proxmox.Clients.Iter))
mux.HandleFunc("GET", "/docker", jsonHandler(iterMap(docker.Clients)))
return mux
}

View file

@ -0,0 +1,11 @@
//go:build !debug
package debugapi
import (
config "github.com/yusing/go-proxy/internal/config/types"
)
func StartServer(cfg config.ConfigInstance) {
// do nothing
}

View file

@ -18,7 +18,7 @@ type Container struct {
} }
func Containers(w http.ResponseWriter, r *http.Request) { func Containers(w http.ResponseWriter, r *http.Request) {
serveHTTP[Container, []Container](w, r, GetContainers) serveHTTP[Container](w, r, GetContainers)
} }
func GetContainers(ctx context.Context, dockerClients DockerClients) ([]Container, gperr.Error) { func GetContainers(ctx context.Context, dockerClients DockerClients) ([]Container, gperr.Error) {

View file

@ -25,7 +25,7 @@ func (d *dockerInfo) MarshalJSON() ([]byte, error) {
}, },
"images": d.Images, "images": d.Images,
"n_cpu": d.NCPU, "n_cpu": d.NCPU,
"memory": strutils.FormatByteSizeWithUnit(d.MemTotal), "memory": strutils.FormatByteSize(d.MemTotal),
}) })
} }

View file

@ -22,7 +22,7 @@ func Logs(w http.ResponseWriter, r *http.Request) {
until := query.Get("to") until := query.Get("to")
levels := query.Get("levels") // TODO: implement levels levels := query.Get("levels") // TODO: implement levels
dockerClient, found, err := getDockerClient(w, server) dockerClient, found, err := getDockerClient(server)
if err != nil { if err != nil {
gphttp.BadRequest(w, err.Error()) gphttp.BadRequest(w, err.Error())
return return

View file

@ -8,6 +8,7 @@ import (
"github.com/coder/websocket" "github.com/coder/websocket"
"github.com/coder/websocket/wsjson" "github.com/coder/websocket/wsjson"
"github.com/yusing/go-proxy/agent/pkg/agent"
config "github.com/yusing/go-proxy/internal/config/types" config "github.com/yusing/go-proxy/internal/config/types"
"github.com/yusing/go-proxy/internal/docker" "github.com/yusing/go-proxy/internal/docker"
"github.com/yusing/go-proxy/internal/gperr" "github.com/yusing/go-proxy/internal/gperr"
@ -44,7 +45,7 @@ func getDockerClients() (DockerClients, gperr.Error) {
dockerClients[name] = dockerClient dockerClients[name] = dockerClient
} }
for _, agent := range cfg.ListAgents() { for _, agent := range agent.Agents.Iter {
dockerClient, err := docker.NewClient(agent.FakeDockerHost()) dockerClient, err := docker.NewClient(agent.FakeDockerHost())
if err != nil { if err != nil {
connErrs.Add(err) connErrs.Add(err)
@ -56,7 +57,7 @@ func getDockerClients() (DockerClients, gperr.Error) {
return dockerClients, connErrs.Error() return dockerClients, connErrs.Error()
} }
func getDockerClient(w http.ResponseWriter, server string) (*docker.SharedClient, bool, error) { func getDockerClient(server string) (*docker.SharedClient, bool, error) {
cfg := config.GetInstance() cfg := config.GetInstance()
var host string var host string
for name, h := range cfg.Value().Providers.Docker { for name, h := range cfg.Value().Providers.Docker {
@ -65,7 +66,7 @@ func getDockerClient(w http.ResponseWriter, server string) (*docker.SharedClient
break break
} }
} }
for _, agent := range cfg.ListAgents() { for _, agent := range agent.Agents.Iter {
if agent.Name() == server { if agent.Name() == server {
host = agent.FakeDockerHost() host = agent.FakeDockerHost()
break break
@ -119,6 +120,6 @@ func serveHTTP[V any, T ResultType[V]](w http.ResponseWriter, r *http.Request, g
}) })
} else { } else {
result, err := getResult(r.Context(), dockerClients) result, err := getResult(r.Context(), dockerClients)
handleResult[V, T](w, err, result) handleResult[V](w, err, result)
} }
} }

View file

@ -1,4 +1,4 @@
package v1 package favicon
import ( import (
"errors" "errors"

View file

@ -4,20 +4,10 @@ import (
"net/http" "net/http"
"time" "time"
"github.com/coder/websocket"
"github.com/coder/websocket/wsjson"
"github.com/yusing/go-proxy/internal/net/gphttp"
"github.com/yusing/go-proxy/internal/net/gphttp/gpwebsocket" "github.com/yusing/go-proxy/internal/net/gphttp/gpwebsocket"
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
"github.com/yusing/go-proxy/internal/route/routes/routequery" "github.com/yusing/go-proxy/internal/route/routes/routequery"
) )
func Health(w http.ResponseWriter, r *http.Request) { func Health(w http.ResponseWriter, r *http.Request) {
if httpheaders.IsWebsocket(r.Header) { gpwebsocket.DynamicJSONHandler(w, r, routequery.HealthMap, 1*time.Second)
gpwebsocket.Periodic(w, r, 1*time.Second, func(conn *websocket.Conn) error {
return wsjson.Write(r.Context(), conn, routequery.HealthMap())
})
} else {
gphttp.RespondJSON(w, r, routequery.HealthMap())
}
} }

View file

@ -13,7 +13,6 @@ import (
"github.com/yusing/go-proxy/internal/net/gphttp/middleware" "github.com/yusing/go-proxy/internal/net/gphttp/middleware"
"github.com/yusing/go-proxy/internal/route/routes/routequery" "github.com/yusing/go-proxy/internal/route/routes/routequery"
route "github.com/yusing/go-proxy/internal/route/types" route "github.com/yusing/go-proxy/internal/route/types"
"github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/internal/utils" "github.com/yusing/go-proxy/internal/utils"
) )
@ -28,7 +27,6 @@ const (
ListRouteProviders = "route_providers" ListRouteProviders = "route_providers"
ListHomepageCategories = "homepage_categories" ListHomepageCategories = "homepage_categories"
ListIcons = "icons" ListIcons = "icons"
ListTasks = "tasks"
) )
func List(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Request) { func List(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Request) {
@ -76,8 +74,6 @@ func List(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Request) {
icons = []string{} icons = []string{}
} }
gphttp.RespondJSON(w, r, icons) gphttp.RespondJSON(w, r, icons)
case ListTasks:
gphttp.RespondJSON(w, r, task.DebugTaskList())
default: default:
gphttp.BadRequest(w, fmt.Sprintf("invalid what: %s", what)) gphttp.BadRequest(w, fmt.Sprintf("invalid what: %s", what))
} }
@ -98,7 +94,7 @@ func listRoute(which string) any {
} }
func listFiles(w http.ResponseWriter, r *http.Request) { func listFiles(w http.ResponseWriter, r *http.Request) {
files, err := utils.ListFiles(common.ConfigBasePath, 0, true) files, err := utils.ListFiles(common.ConfigDir, 0, true)
if err != nil { if err != nil {
gphttp.ServerError(w, r, err) gphttp.ServerError(w, r, err)
return return
@ -111,17 +107,17 @@ func listFiles(w http.ResponseWriter, r *http.Request) {
for _, file := range files { for _, file := range files {
t := fileType(file) t := fileType(file)
file = strings.TrimPrefix(file, common.ConfigBasePath+"/") file = strings.TrimPrefix(file, common.ConfigDir+"/")
resp[t] = append(resp[t], file) resp[t] = append(resp[t], file)
} }
mids, err := utils.ListFiles(common.MiddlewareComposeBasePath, 0, true) mids, err := utils.ListFiles(common.MiddlewareComposeDir, 0, true)
if err != nil { if err != nil {
gphttp.ServerError(w, r, err) gphttp.ServerError(w, r, err)
return return
} }
for _, mid := range mids { for _, mid := range mids {
mid = strings.TrimPrefix(mid, common.MiddlewareComposeBasePath+"/") mid = strings.TrimPrefix(mid, common.MiddlewareComposeDir+"/")
resp[FileTypeMiddleware] = append(resp[FileTypeMiddleware], mid) resp[FileTypeMiddleware] = append(resp[FileTypeMiddleware], mid)
} }
gphttp.RespondJSON(w, r, resp) gphttp.RespondJSON(w, r, resp)

View file

@ -40,7 +40,7 @@ func NewAgent(w http.ResponseWriter, r *http.Request) {
return return
} }
hostport := fmt.Sprintf("%s:%d", host, port) hostport := fmt.Sprintf("%s:%d", host, port)
if _, ok := config.GetInstance().GetAgent(hostport); ok { if _, ok := agent.Agents.Get(hostport); ok {
gphttp.ClientError(w, gphttp.ErrAlreadyExists("agent", hostport), http.StatusConflict) gphttp.ClientError(w, gphttp.ErrAlreadyExists("agent", hostport), http.StatusConflict)
return return
} }

View file

@ -58,7 +58,3 @@ func ListRoutes() (map[string]map[string]any, gperr.Error) {
func ListMiddlewareTraces() (middleware.Traces, gperr.Error) { func ListMiddlewareTraces() (middleware.Traces, gperr.Error) {
return List[middleware.Traces](v1.ListMiddlewareTraces) return List[middleware.Traces](v1.ListMiddlewareTraces)
} }
func DebugListTasks() (map[string]any, gperr.Error) {
return List[map[string]any](v1.ListTasks)
}

View file

@ -4,30 +4,18 @@ import (
"net/http" "net/http"
"time" "time"
"github.com/coder/websocket"
"github.com/coder/websocket/wsjson"
config "github.com/yusing/go-proxy/internal/config/types" config "github.com/yusing/go-proxy/internal/config/types"
"github.com/yusing/go-proxy/internal/net/gphttp"
"github.com/yusing/go-proxy/internal/net/gphttp/gpwebsocket" "github.com/yusing/go-proxy/internal/net/gphttp/gpwebsocket"
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
"github.com/yusing/go-proxy/internal/utils/strutils" "github.com/yusing/go-proxy/internal/utils/strutils"
) )
func Stats(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Request) { func Stats(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Request) {
if httpheaders.IsWebsocket(r.Header) { gpwebsocket.DynamicJSONHandler(w, r, func() map[string]any {
gpwebsocket.Periodic(w, r, 1*time.Second, func(conn *websocket.Conn) error { return map[string]any{
return wsjson.Write(r.Context(), conn, getStats(cfg)) "proxies": cfg.Statistics(),
}) "uptime": strutils.FormatDuration(time.Since(startTime)),
} else { }
gphttp.RespondJSON(w, r, getStats(cfg)) }, 1*time.Second)
}
} }
var startTime = time.Now() var startTime = time.Now()
func getStats(cfg config.ConfigInstance) map[string]any {
return map[string]any{
"proxies": cfg.Statistics(),
"uptime": strutils.FormatDuration(time.Since(startTime)),
}
}

View file

@ -3,8 +3,8 @@ package v1
import ( import (
"net/http" "net/http"
"github.com/yusing/go-proxy/agent/pkg/agent"
agentPkg "github.com/yusing/go-proxy/agent/pkg/agent" agentPkg "github.com/yusing/go-proxy/agent/pkg/agent"
config "github.com/yusing/go-proxy/internal/config/types"
"github.com/yusing/go-proxy/internal/gperr" "github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/metrics/systeminfo" "github.com/yusing/go-proxy/internal/metrics/systeminfo"
"github.com/yusing/go-proxy/internal/net/gphttp" "github.com/yusing/go-proxy/internal/net/gphttp"
@ -12,7 +12,7 @@ import (
"github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy" "github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy"
) )
func SystemInfo(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Request) { func SystemInfo(w http.ResponseWriter, r *http.Request) {
query := r.URL.Query() query := r.URL.Query()
agentAddr := query.Get("agent_addr") agentAddr := query.Get("agent_addr")
query.Del("agent_addr") query.Del("agent_addr")
@ -21,7 +21,7 @@ func SystemInfo(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Reques
return return
} }
agent, ok := cfg.GetAgent(agentAddr) agent, ok := agent.Agents.Get(agentAddr)
if !ok { if !ok {
gphttp.NotFound(w, "agent_addr") gphttp.NotFound(w, "agent_addr")
return return

View file

@ -8,5 +8,5 @@ import (
) )
func GetVersion(w http.ResponseWriter, r *http.Request) { func GetVersion(w http.ResponseWriter, r *http.Request) {
gphttp.WriteBody(w, []byte(pkg.GetVersion())) gphttp.WriteBody(w, []byte(pkg.GetVersion().String()))
} }

View file

@ -1,18 +1,20 @@
package autocert package autocert
import ( import (
"path/filepath"
"github.com/go-acme/lego/v4/providers/dns/clouddns" "github.com/go-acme/lego/v4/providers/dns/clouddns"
"github.com/go-acme/lego/v4/providers/dns/cloudflare" "github.com/go-acme/lego/v4/providers/dns/cloudflare"
"github.com/go-acme/lego/v4/providers/dns/duckdns" "github.com/go-acme/lego/v4/providers/dns/duckdns"
"github.com/go-acme/lego/v4/providers/dns/ovh" "github.com/go-acme/lego/v4/providers/dns/ovh"
"github.com/go-acme/lego/v4/providers/dns/porkbun" "github.com/go-acme/lego/v4/providers/dns/porkbun"
"github.com/yusing/go-proxy/internal/common"
) )
const ( var (
certBasePath = "certs/" CertFileDefault = filepath.Join(common.CertsDir, "cert.crt")
CertFileDefault = certBasePath + "cert.crt" KeyFileDefault = filepath.Join(common.CertsDir, "priv.key")
KeyFileDefault = certBasePath + "priv.key" ACMEKeyFileDefault = filepath.Join(common.CertsDir, "acme.key")
ACMEKeyFileDefault = certBasePath + "acme.key"
) )
const ( const (

View file

@ -1,46 +1,10 @@
package common package common
import ( import "time"
"time"
)
// file, folder structure
const (
DotEnvPath = ".env"
DotEnvExamplePath = ".env.example"
ConfigBasePath = "config"
ConfigFileName = "config.yml"
ConfigExampleFileName = "config.example.yml"
ConfigPath = ConfigBasePath + "/" + ConfigFileName
HomepageJSONConfigPath = ConfigBasePath + "/.homepage.json"
IconListCachePath = ConfigBasePath + "/.icon_list_cache.json"
IconCachePath = ConfigBasePath + "/.icon_cache.json"
MiddlewareComposeBasePath = ConfigBasePath + "/middlewares"
ComposeFileName = "compose.yml"
ComposeExampleFileName = "compose.example.yml"
ErrorPagesBasePath = "error_pages"
AgentCertsBasePath = "certs"
)
var RequiredDirectories = []string{
ConfigBasePath,
ErrorPagesBasePath,
MiddlewareComposeBasePath,
}
const DockerHostFromEnv = "$DOCKER_HOST" const DockerHostFromEnv = "$DOCKER_HOST"
const ( const (
HealthCheckIntervalDefault = 5 * time.Second HealthCheckIntervalDefault = 5 * time.Second
HealthCheckTimeoutDefault = 5 * time.Second HealthCheckTimeoutDefault = 5 * time.Second
WakeTimeoutDefault = "30s"
StopTimeoutDefault = "30s"
StopMethodDefault = "stop"
) )

View file

@ -19,6 +19,8 @@ var (
IsDebug = GetEnvBool("DEBUG", IsTest) IsDebug = GetEnvBool("DEBUG", IsTest)
IsTrace = GetEnvBool("TRACE", false) && IsDebug IsTrace = GetEnvBool("TRACE", false) && IsDebug
RootDir = GetEnvString("ROOT_DIR", "./")
HTTP3Enabled = GetEnvBool("HTTP3_ENABLED", true) HTTP3Enabled = GetEnvBool("HTTP3_ENABLED", true)
ProxyHTTPAddr, ProxyHTTPAddr,

33
internal/common/paths.go Normal file
View file

@ -0,0 +1,33 @@
package common
import (
"path/filepath"
)
// file, folder structure
var (
ConfigDir = filepath.Join(RootDir, "config")
ConfigFileName = "config.yml"
ConfigExampleFileName = "config.example.yml"
ConfigPath = filepath.Join(ConfigDir, ConfigFileName)
MiddlewareComposeDir = filepath.Join(ConfigDir, "middlewares")
ErrorPagesDir = filepath.Join(RootDir, "error_pages")
CertsDir = filepath.Join(RootDir, "certs")
DataDir = filepath.Join(RootDir, "data")
MetricsDataDir = filepath.Join(DataDir, "metrics")
HomepageJSONConfigPath = filepath.Join(DataDir, "homepage.json")
IconListCachePath = filepath.Join(DataDir, "icon_list_cache.json")
IconCachePath = filepath.Join(DataDir, "icon_cache.json")
)
var RequiredDirectories = []string{
ConfigDir,
ErrorPagesDir,
MiddlewareComposeDir,
DataDir,
MetricsDataDir,
}

View file

@ -1,66 +0,0 @@
package config
import (
"slices"
"github.com/yusing/go-proxy/agent/pkg/agent"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/route/provider"
"github.com/yusing/go-proxy/internal/utils/functional"
)
var agentPool = functional.NewMapOf[string, *agent.AgentConfig]()
func addAgent(agent *agent.AgentConfig) {
agentPool.Store(agent.Addr, agent)
}
func removeAllAgents() {
agentPool.Clear()
}
func GetAgent(addr string) (agent *agent.AgentConfig, ok bool) {
agent, ok = agentPool.Load(addr)
return
}
func (cfg *Config) GetAgent(agentAddrOrDockerHost string) (*agent.AgentConfig, bool) {
if !agent.IsDockerHostAgent(agentAddrOrDockerHost) {
return GetAgent(agentAddrOrDockerHost)
}
return GetAgent(agent.GetAgentAddrFromDockerHost(agentAddrOrDockerHost))
}
func (cfg *Config) VerifyNewAgent(host string, ca agent.PEMPair, client agent.PEMPair) (int, gperr.Error) {
if slices.ContainsFunc(cfg.value.Providers.Agents, func(a *agent.AgentConfig) bool {
return a.Addr == host
}) {
return 0, gperr.New("agent already exists")
}
var agentCfg agent.AgentConfig
agentCfg.Addr = host
err := agentCfg.InitWithCerts(cfg.task.Context(), ca.Cert, client.Cert, client.Key)
if err != nil {
return 0, gperr.Wrap(err, "failed to start agent")
}
addAgent(&agentCfg)
provider := provider.NewAgentProvider(&agentCfg)
if err := cfg.errIfExists(provider); err != nil {
return 0, err
}
err = provider.LoadRoutes()
if err != nil {
return 0, gperr.Wrap(err, "failed to load routes")
}
return provider.NumRoutes(), nil
}
func (cfg *Config) ListAgents() []*agent.AgentConfig {
agents := make([]*agent.AgentConfig, 0, agentPool.Size())
agentPool.RangeAll(func(key string, value *agent.AgentConfig) {
agents = append(agents, value)
})
return agents
}

View file

@ -19,6 +19,7 @@ import (
"github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/net/gphttp/server" "github.com/yusing/go-proxy/internal/net/gphttp/server"
"github.com/yusing/go-proxy/internal/notif" "github.com/yusing/go-proxy/internal/notif"
"github.com/yusing/go-proxy/internal/proxmox"
proxy "github.com/yusing/go-proxy/internal/route/provider" proxy "github.com/yusing/go-proxy/internal/route/provider"
"github.com/yusing/go-proxy/internal/task" "github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/internal/utils" "github.com/yusing/go-proxy/internal/utils"
@ -215,23 +216,22 @@ func (cfg *Config) StartServers(opts ...*StartServersOptions) {
} }
func (cfg *Config) load() gperr.Error { func (cfg *Config) load() gperr.Error {
const errMsg = "config load error"
data, err := os.ReadFile(common.ConfigPath) data, err := os.ReadFile(common.ConfigPath)
if err != nil { if err != nil {
gperr.LogFatal(errMsg, err) gperr.LogFatal("error reading config", err)
} }
model := config.DefaultConfig() model := config.DefaultConfig()
if err := utils.UnmarshalValidateYAML(data, model); err != nil { if err := utils.UnmarshalValidateYAML(data, model); err != nil {
gperr.LogFatal(errMsg, err) gperr.LogFatal("error unmarshalling config", err)
} }
// errors are non fatal below // errors are non fatal below
errs := gperr.NewBuilder(errMsg) errs := gperr.NewBuilder()
errs.Add(cfg.entrypoint.SetMiddlewares(model.Entrypoint.Middlewares)) errs.Add(cfg.entrypoint.SetMiddlewares(model.Entrypoint.Middlewares))
errs.Add(cfg.entrypoint.SetAccessLogger(cfg.task, model.Entrypoint.AccessLog)) errs.Add(cfg.entrypoint.SetAccessLogger(cfg.task, model.Entrypoint.AccessLog))
cfg.initNotification(model.Providers.Notification) cfg.initNotification(model.Providers.Notification)
errs.Add(cfg.initProxmox(model.Providers.Proxmox))
errs.Add(cfg.initAutoCert(model.AutoCert)) errs.Add(cfg.initAutoCert(model.AutoCert))
errs.Add(cfg.loadRouteProviders(&model.Providers)) errs.Add(cfg.loadRouteProviders(&model.Providers))
@ -256,6 +256,18 @@ func (cfg *Config) initNotification(notifCfg []notif.NotificationConfig) {
} }
} }
func (cfg *Config) initProxmox(proxmoxCfgs []proxmox.Config) (err gperr.Error) {
errs := gperr.NewBuilder("proxmox config errors")
for _, proxmoxCfg := range proxmoxCfgs {
if err := proxmoxCfg.Init(); err != nil {
errs.Add(err.Subject(proxmoxCfg.URL))
} else {
proxmox.Clients.Add(proxmoxCfg.Client())
}
}
return errs.Error()
}
func (cfg *Config) initAutoCert(autocertCfg *autocert.AutocertConfig) (err gperr.Error) { func (cfg *Config) initAutoCert(autocertCfg *autocert.AutocertConfig) (err gperr.Error) {
if cfg.autocertProvider != nil { if cfg.autocertProvider != nil {
return return
@ -277,8 +289,8 @@ func (cfg *Config) errIfExists(p *proxy.Provider) gperr.Error {
func (cfg *Config) initAgents(agentCfgs []*agent.AgentConfig) gperr.Error { func (cfg *Config) initAgents(agentCfgs []*agent.AgentConfig) gperr.Error {
var wg sync.WaitGroup var wg sync.WaitGroup
var errs gperr.Builder
errs := gperr.NewBuilderWithConcurrency()
wg.Add(len(agentCfgs)) wg.Add(len(agentCfgs))
for _, agentCfg := range agentCfgs { for _, agentCfg := range agentCfgs {
go func(agentCfg *agent.AgentConfig) { go func(agentCfg *agent.AgentConfig) {
@ -286,7 +298,7 @@ func (cfg *Config) initAgents(agentCfgs []*agent.AgentConfig) gperr.Error {
if err := agentCfg.Init(cfg.task.Context()); err != nil { if err := agentCfg.Init(cfg.task.Context()); err != nil {
errs.Add(err.Subject(agentCfg.String())) errs.Add(err.Subject(agentCfg.String()))
} else { } else {
addAgent(agentCfg) agent.Agents.Add(agentCfg)
} }
}(agentCfg) }(agentCfg)
} }
@ -298,7 +310,7 @@ func (cfg *Config) loadRouteProviders(providers *config.Providers) gperr.Error {
errs := gperr.NewBuilder("route provider errors") errs := gperr.NewBuilder("route provider errors")
results := gperr.NewBuilder("loaded route providers") results := gperr.NewBuilder("loaded route providers")
removeAllAgents() agent.Agents.Clear()
n := len(providers.Agents) + len(providers.Docker) + len(providers.Files) n := len(providers.Agents) + len(providers.Docker) + len(providers.Files)
if n == 0 { if n == 0 {
@ -309,12 +321,12 @@ func (cfg *Config) loadRouteProviders(providers *config.Providers) gperr.Error {
errs.Add(cfg.initAgents(providers.Agents)) errs.Add(cfg.initAgents(providers.Agents))
for _, agent := range providers.Agents { for _, a := range providers.Agents {
if !agent.IsInitialized() { // failed to initialize if !a.IsInitialized() { // failed to initialize
continue continue
} }
addAgent(agent) agent.Agents.Add(a)
routeProviders = append(routeProviders, proxy.NewAgentProvider(agent)) routeProviders = append(routeProviders, proxy.NewAgentProvider(a))
} }
for _, filename := range providers.Files { for _, filename := range providers.Files {
routeProviders = append(routeProviders, proxy.NewFileProvider(filename)) routeProviders = append(routeProviders, proxy.NewFileProvider(filename))
@ -338,6 +350,8 @@ func (cfg *Config) loadRouteProviders(providers *config.Providers) gperr.Error {
lenLongestName = len(k) lenLongestName = len(k)
} }
}) })
errs.EnableConcurrency()
results.EnableConcurrency()
cfg.providers.RangeAllParallel(func(_ string, p *proxy.Provider) { cfg.providers.RangeAllParallel(func(_ string, p *proxy.Provider) {
if err := p.LoadRoutes(); err != nil { if err := p.LoadRoutes(); err != nil {
errs.Add(err.Subject(p.String())) errs.Add(err.Subject(p.String()))

View file

@ -56,7 +56,7 @@ func TestFileProviderValidate(t *testing.T) {
cfg := config.DefaultConfig() cfg := config.DefaultConfig()
if tt.init != nil { if tt.init != nil {
for _, filename := range tt.filenames { for _, filename := range tt.filenames {
filepath := path.Join(common.ConfigBasePath, filename) filepath := path.Join(common.ConfigDir, filename)
assert.NoError(t, tt.init(filepath)) assert.NoError(t, tt.init(filepath))
} }
} }
@ -67,7 +67,7 @@ func TestFileProviderValidate(t *testing.T) {
})), cfg) })), cfg)
if tt.cleanup != nil { if tt.cleanup != nil {
for _, filename := range tt.filenames { for _, filename := range tt.filenames {
filepath := path.Join(common.ConfigBasePath, filename) filepath := path.Join(common.ConfigDir, filename)
assert.NoError(t, tt.cleanup(filepath)) assert.NoError(t, tt.cleanup(filepath))
} }
} }

View file

@ -1,6 +1,10 @@
package config package config
import ( import (
"slices"
"github.com/yusing/go-proxy/agent/pkg/agent"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/route" "github.com/yusing/go-proxy/internal/route"
"github.com/yusing/go-proxy/internal/route/provider" "github.com/yusing/go-proxy/internal/route/provider"
) )
@ -51,3 +55,32 @@ func (cfg *Config) Statistics() map[string]any {
"providers": providerStats, "providers": providerStats,
} }
} }
func (cfg *Config) VerifyNewAgent(host string, ca agent.PEMPair, client agent.PEMPair) (int, gperr.Error) {
if slices.ContainsFunc(cfg.value.Providers.Agents, func(a *agent.AgentConfig) bool {
return a.Addr == host
}) {
return 0, gperr.New("agent already exists")
}
agentCfg := new(agent.AgentConfig)
agentCfg.Addr = host
err := agentCfg.InitWithCerts(cfg.task.Context(), ca.Cert, client.Cert, client.Key)
if err != nil {
return 0, gperr.Wrap(err, "failed to start agent")
}
// must add it first to let LoadRoutes() reference from it
agent.Agents.Add(agentCfg)
provider := provider.NewAgentProvider(agentCfg)
if err := cfg.errIfExists(provider); err != nil {
agent.Agents.Del(agentCfg)
return 0, err
}
err = provider.LoadRoutes()
if err != nil {
agent.Agents.Del(agentCfg)
return 0, gperr.Wrap(err, "failed to load routes")
}
return provider.NumRoutes(), nil
}

View file

@ -14,7 +14,7 @@ import (
"github.com/yusing/go-proxy/internal/gperr" "github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/net/gphttp/accesslog" "github.com/yusing/go-proxy/internal/net/gphttp/accesslog"
"github.com/yusing/go-proxy/internal/notif" "github.com/yusing/go-proxy/internal/notif"
proxmox "github.com/yusing/go-proxy/internal/proxmox/types" "github.com/yusing/go-proxy/internal/proxmox"
"github.com/yusing/go-proxy/internal/utils" "github.com/yusing/go-proxy/internal/utils"
) )
@ -28,11 +28,11 @@ type (
TimeoutShutdown int `json:"timeout_shutdown" validate:"gte=0"` TimeoutShutdown int `json:"timeout_shutdown" validate:"gte=0"`
} }
Providers struct { Providers struct {
Files []string `json:"include" yaml:"include,omitempty" validate:"unique,dive,config_file_exists"` Files []string `json:"include" validate:"unique,dive,config_file_exists"`
Docker map[string]string `json:"docker" yaml:"docker,omitempty" validate:"unique,dive,unix_addr|url"` Docker map[string]string `json:"docker" validate:"unique,dive,unix_addr|url"`
Proxmox map[string]proxmox.Config `json:"proxmox" yaml:"proxmox,omitempty"` Proxmox []proxmox.Config `json:"proxmox"`
Agents []*agent.AgentConfig `json:"agents" yaml:"agents,omitempty" validate:"unique=Addr"` Agents []*agent.AgentConfig `json:"agents" validate:"unique=Addr"`
Notification []notif.NotificationConfig `json:"notification" yaml:"notification,omitempty" validate:"unique=ProviderName"` Notification []notif.NotificationConfig `json:"notification" validate:"unique=ProviderName"`
} }
Entrypoint struct { Entrypoint struct {
Middlewares []map[string]any `json:"middlewares"` Middlewares []map[string]any `json:"middlewares"`
@ -45,9 +45,7 @@ type (
Statistics() map[string]any Statistics() map[string]any
RouteProviderList() []string RouteProviderList() []string
Context() context.Context Context() context.Context
GetAgent(agentAddrOrDockerHost string) (*agent.AgentConfig, bool)
VerifyNewAgent(host string, ca agent.PEMPair, client agent.PEMPair) (int, gperr.Error) VerifyNewAgent(host string, ca agent.PEMPair, client agent.PEMPair) (int, gperr.Error)
ListAgents() []*agent.AgentConfig
AutoCertProvider() *autocert.Provider AutoCertProvider() *autocert.Provider
} }
) )
@ -104,7 +102,7 @@ func init() {
}) })
utils.MustRegisterValidation("config_file_exists", func(fl validator.FieldLevel) bool { utils.MustRegisterValidation("config_file_exists", func(fl validator.FieldLevel) bool {
filename := fl.Field().Interface().(string) filename := fl.Field().Interface().(string)
info, err := os.Stat(path.Join(common.ConfigBasePath, filename)) info, err := os.Stat(path.Join(common.ConfigDir, filename))
return err == nil && !info.IsDir() return err == nil && !info.IsDir()
}) })
} }

View file

@ -4,6 +4,7 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"maps"
"net" "net"
"net/http" "net/http"
"sync" "sync"
@ -14,16 +15,15 @@ import (
"github.com/docker/docker/client" "github.com/docker/docker/client"
"github.com/yusing/go-proxy/agent/pkg/agent" "github.com/yusing/go-proxy/agent/pkg/agent"
"github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/common"
config "github.com/yusing/go-proxy/internal/config/types"
"github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/task" "github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/internal/utils/strutils"
) )
type ( type (
SharedClient struct { SharedClient struct {
*client.Client *client.Client
key string
refCount uint32 refCount uint32
closedOn int64 closedOn int64
@ -66,7 +66,7 @@ func initClientCleaner() {
defer clientMapMu.Unlock() defer clientMapMu.Unlock()
for _, c := range clientMap { for _, c := range clientMap {
delete(clientMap, c.key) delete(clientMap, c.Key())
c.Client.Close() c.Client.Close()
} }
}) })
@ -80,30 +80,20 @@ func closeTimedOutClients() {
for _, c := range clientMap { for _, c := range clientMap {
if atomic.LoadUint32(&c.refCount) == 0 && now-atomic.LoadInt64(&c.closedOn) > clientTTLSecs { if atomic.LoadUint32(&c.refCount) == 0 && now-atomic.LoadInt64(&c.closedOn) > clientTTLSecs {
delete(clientMap, c.key) delete(clientMap, c.Key())
c.Client.Close() c.Client.Close()
logging.Debug().Str("host", c.key).Msg("docker client closed") logging.Debug().Str("host", c.DaemonHost()).Msg("docker client closed")
} }
} }
} }
func (c *SharedClient) Address() string { func Clients() map[string]*SharedClient {
return c.addr clientMapMu.RLock()
} defer clientMapMu.RUnlock()
func (c *SharedClient) CheckConnection(ctx context.Context) error { clients := make(map[string]*SharedClient, len(clientMap))
conn, err := c.dial(ctx) maps.Copy(clients, clientMap)
if err != nil { return clients
return err
}
conn.Close()
return nil
}
// if the client is still referenced, this is no-op.
func (c *SharedClient) Close() {
atomic.StoreInt64(&c.closedOn, time.Now().Unix())
atomic.AddUint32(&c.refCount, ^uint32(0))
} }
// NewClient creates a new Docker client connection to the specified host. // NewClient creates a new Docker client connection to the specified host.
@ -134,7 +124,7 @@ func NewClient(host string) (*SharedClient, error) {
var dial func(ctx context.Context) (net.Conn, error) var dial func(ctx context.Context) (net.Conn, error)
if agent.IsDockerHostAgent(host) { if agent.IsDockerHostAgent(host) {
cfg, ok := config.GetInstance().GetAgent(host) cfg, ok := agent.Agents.Get(host)
if !ok { if !ok {
panic(fmt.Errorf("agent %q not found", host)) panic(fmt.Errorf("agent %q not found", host))
} }
@ -187,7 +177,6 @@ func NewClient(host string) (*SharedClient, error) {
c := &SharedClient{ c := &SharedClient{
Client: client, Client: client,
key: host,
refCount: 1, refCount: 1,
addr: addr, addr: addr,
dial: dial, dial: dial,
@ -197,9 +186,44 @@ func NewClient(host string) (*SharedClient, error) {
if c.dial == nil { if c.dial == nil {
c.dial = client.Dialer() c.dial = client.Dialer()
} }
if c.addr == "" {
c.addr = c.Client.DaemonHost()
}
defer logging.Debug().Str("host", host).Msg("docker client initialized") defer logging.Debug().Str("host", host).Msg("docker client initialized")
clientMap[c.key] = c clientMap[c.Key()] = c
return c, nil return c, nil
} }
func (c *SharedClient) Key() string {
return c.DaemonHost()
}
func (c *SharedClient) Address() string {
return c.addr
}
func (c *SharedClient) CheckConnection(ctx context.Context) error {
conn, err := c.dial(ctx)
if err != nil {
return err
}
conn.Close()
return nil
}
// if the client is still referenced, this is no-op.
func (c *SharedClient) Close() {
atomic.StoreInt64(&c.closedOn, time.Now().Unix())
atomic.AddUint32(&c.refCount, ^uint32(0))
}
func (c *SharedClient) MarshalMap() map[string]any {
return map[string]any{
"host": c.DaemonHost(),
"addr": c.addr,
"ref_count": c.refCount,
"closed_on": strutils.FormatUnixTime(c.closedOn),
}
}

View file

@ -8,16 +8,17 @@ import (
"github.com/docker/docker/api/types/container" "github.com/docker/docker/api/types/container"
"github.com/docker/go-connections/nat" "github.com/docker/go-connections/nat"
"github.com/yusing/go-proxy/agent/pkg/agent" "github.com/yusing/go-proxy/agent/pkg/agent"
config "github.com/yusing/go-proxy/internal/config/types" "github.com/yusing/go-proxy/internal/gperr"
idlewatcher "github.com/yusing/go-proxy/internal/idlewatcher/types"
"github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/logging"
U "github.com/yusing/go-proxy/internal/utils" "github.com/yusing/go-proxy/internal/utils"
"github.com/yusing/go-proxy/internal/utils/strutils" "github.com/yusing/go-proxy/internal/utils/strutils"
) )
type ( type (
PortMapping = map[int]*container.Port PortMapping = map[int]*container.Port
Container struct { Container struct {
_ U.NoCopy _ utils.NoCopy
DockerHost string `json:"docker_host"` DockerHost string `json:"docker_host"`
Image *ContainerImage `json:"image"` Image *ContainerImage `json:"image"`
@ -26,7 +27,8 @@ type (
Agent *agent.AgentConfig `json:"agent"` Agent *agent.AgentConfig `json:"agent"`
Labels map[string]string `json:"-"` RouteConfig map[string]string `json:"route_config"`
IdlewatcherConfig *idlewatcher.Config `json:"idlewatcher_config"`
Mounts []string `json:"mounts"` Mounts []string `json:"mounts"`
@ -35,16 +37,10 @@ type (
PublicHostname string `json:"public_hostname"` PublicHostname string `json:"public_hostname"`
PrivateHostname string `json:"private_hostname"` PrivateHostname string `json:"private_hostname"`
Aliases []string `json:"aliases"` Aliases []string `json:"aliases"`
IsExcluded bool `json:"is_excluded"` IsExcluded bool `json:"is_excluded"`
IsExplicit bool `json:"is_explicit"` IsExplicit bool `json:"is_explicit"`
IdleTimeout string `json:"idle_timeout,omitempty"` Running bool `json:"running"`
WakeTimeout string `json:"wake_timeout,omitempty"`
StopMethod string `json:"stop_method,omitempty"`
StopTimeout string `json:"stop_timeout,omitempty"` // stop_method = "stop" only
StopSignal string `json:"stop_signal,omitempty"` // stop_method = "stop" | "kill" only
StartEndpoint string `json:"start_endpoint,omitempty"`
Running bool `json:"running"`
} }
ContainerImage struct { ContainerImage struct {
Author string `json:"author,omitempty"` Author string `json:"author,omitempty"`
@ -69,21 +65,15 @@ func FromDocker(c *container.Summary, dockerHost string) (res *Container) {
PublicPortMapping: helper.getPublicPortMapping(), PublicPortMapping: helper.getPublicPortMapping(),
PrivatePortMapping: helper.getPrivatePortMapping(), PrivatePortMapping: helper.getPrivatePortMapping(),
Aliases: helper.getAliases(), Aliases: helper.getAliases(),
IsExcluded: strutils.ParseBool(helper.getDeleteLabel(LabelExclude)), IsExcluded: strutils.ParseBool(helper.getDeleteLabel(LabelExclude)),
IsExplicit: isExplicit, IsExplicit: isExplicit,
IdleTimeout: helper.getDeleteLabel(LabelIdleTimeout), Running: c.Status == "running" || c.State == "running",
WakeTimeout: helper.getDeleteLabel(LabelWakeTimeout),
StopMethod: helper.getDeleteLabel(LabelStopMethod),
StopTimeout: helper.getDeleteLabel(LabelStopTimeout),
StopSignal: helper.getDeleteLabel(LabelStopSignal),
StartEndpoint: helper.getDeleteLabel(LabelStartEndpoint),
Running: c.Status == "running" || c.State == "running",
} }
if agent.IsDockerHostAgent(dockerHost) { if agent.IsDockerHostAgent(dockerHost) {
var ok bool var ok bool
res.Agent, ok = config.GetInstance().GetAgent(dockerHost) res.Agent, ok = agent.Agents.Get(dockerHost)
if !ok { if !ok {
logging.Error().Msgf("agent %q not found", dockerHost) logging.Error().Msgf("agent %q not found", dockerHost)
} }
@ -91,6 +81,7 @@ func FromDocker(c *container.Summary, dockerHost string) (res *Container) {
res.setPrivateHostname(helper) res.setPrivateHostname(helper)
res.setPublicHostname() res.setPublicHostname()
res.loadDeleteIdlewatcherLabels(helper)
for lbl := range c.Labels { for lbl := range c.Labels {
if strings.HasPrefix(lbl, NSProxy+".") { if strings.HasPrefix(lbl, NSProxy+".") {
@ -200,3 +191,31 @@ func (c *Container) setPrivateHostname(helper containerHelper) {
return return
} }
} }
func (c *Container) loadDeleteIdlewatcherLabels(helper containerHelper) {
cfg := map[string]any{
"idle_timeout": helper.getDeleteLabel(LabelIdleTimeout),
"wake_timeout": helper.getDeleteLabel(LabelWakeTimeout),
"stop_method": helper.getDeleteLabel(LabelStopMethod),
"stop_timeout": helper.getDeleteLabel(LabelStopTimeout),
"stop_signal": helper.getDeleteLabel(LabelStopSignal),
"start_endpoint": helper.getDeleteLabel(LabelStartEndpoint),
}
// set only if idlewatcher is enabled
idleTimeout := cfg["idle_timeout"]
if idleTimeout != "" {
idwCfg := &idlewatcher.Config{
Docker: &idlewatcher.DockerConfig{
DockerHost: c.DockerHost,
ContainerID: c.ContainerID,
ContainerName: c.ContainerName,
},
}
err := utils.MapUnmarshalValidate(cfg, idwCfg)
if err != nil {
gperr.LogWarn("invalid idlewatcher config", gperr.PrependSubject(c.ContainerName, err))
} else {
c.IdlewatcherConfig = idwCfg
}
}
}

View file

@ -24,5 +24,5 @@ func (c *SharedClient) Inspect(containerID string) (*Container, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
return FromInspectResponse(json, c.key), nil return FromInspectResponse(json, c.DaemonHost()), nil
} }

View file

@ -61,7 +61,7 @@ func (ep *Entrypoint) SetAccessLogger(parent task.Parent, cfg *accesslog.Config)
return return
} }
ep.accessLogger, err = accesslog.NewFileAccessLogger(parent, cfg) ep.accessLogger, err = accesslog.NewAccessLogger(parent, cfg)
if err != nil { if err != nil {
return return
} }

View file

@ -1,9 +1,10 @@
package gperr package gperr
import ( import (
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"encoding/json"
) )
// baseError is an immutable wrapper around an error. // baseError is an immutable wrapper around an error.
@ -48,17 +49,6 @@ func (err *baseError) Error() string {
return err.Err.Error() return err.Err.Error()
} }
// MarshalJSON implements the json.Marshaler interface.
func (err *baseError) MarshalJSON() ([]byte, error) { func (err *baseError) MarshalJSON() ([]byte, error) {
//nolint:errorlint return json.Marshal(err.Err)
switch err := err.Err.(type) {
case Error, *withSubject:
return json.Marshal(err)
case json.Marshaler:
return err.MarshalJSON()
case interface{ MarshalText() ([]byte, error) }:
return err.MarshalText()
default:
return json.Marshal(err.Error())
}
} }

View file

@ -24,6 +24,10 @@ type Builder struct {
rwLock rwLock
} }
type multiline struct {
*Builder
}
// NewBuilder creates a new Builder. // NewBuilder creates a new Builder.
// //
// If about is not provided, the Builder will not have a subject // If about is not provided, the Builder will not have a subject
@ -78,12 +82,15 @@ func (b *Builder) Add(err error) *Builder {
return b return b
} }
wrapped := wrap(err)
b.Lock() b.Lock()
defer b.Unlock() defer b.Unlock()
switch err := wrapped.(type) { b.add(err)
return b
}
func (b *Builder) add(err error) {
switch err := err.(type) {
case *baseError: case *baseError:
b.errs = append(b.errs, err.Err) b.errs = append(b.errs, err.Err)
case *nestedError: case *nestedError:
@ -92,11 +99,11 @@ func (b *Builder) Add(err error) *Builder {
} else { } else {
b.errs = append(b.errs, err) b.errs = append(b.errs, err)
} }
case *MultilineError:
b.add(&err.nestedError)
default: default:
panic("bug: should not reach here") b.errs = append(b.errs, err)
} }
return b
} }
func (b *Builder) Adds(err string) *Builder { func (b *Builder) Adds(err string) *Builder {
@ -144,8 +151,9 @@ func (b *Builder) AddRange(errs ...error) *Builder {
b.Lock() b.Lock()
defer b.Unlock() defer b.Unlock()
b.errs = append(b.errs, nonNilErrs...) for _, err := range nonNilErrs {
b.add(err)
}
return b return b
} }

View file

@ -6,14 +6,14 @@ import (
"github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/logging"
) )
func log(msg string, err error, level zerolog.Level, logger ...*zerolog.Logger) { func log(_ string, err error, level zerolog.Level, logger ...*zerolog.Logger) {
var l *zerolog.Logger var l *zerolog.Logger
if len(logger) > 0 { if len(logger) > 0 {
l = logger[0] l = logger[0]
} else { } else {
l = logging.GetLogger() l = logging.GetLogger()
} }
l.WithLevel(level).Msg(msg + ": " + err.Error()) l.WithLevel(level).Msg(err.Error())
} }
func LogFatal(msg string, err error, logger ...*zerolog.Logger) { func LogFatal(msg string, err error, logger ...*zerolog.Logger) {

View file

@ -0,0 +1,45 @@
package gperr
import (
"fmt"
"reflect"
)
type MultilineError struct {
nestedError
}
func Multiline() *MultilineError {
return &MultilineError{}
}
func (m *MultilineError) add(err error) {
m.Extras = append(m.Extras, err)
}
func (m *MultilineError) Addf(format string, args ...any) *MultilineError {
m.add(fmt.Errorf(format, args...))
return m
}
func (m *MultilineError) Adds(s string) *MultilineError {
m.add(newError(s))
return m
}
func (m *MultilineError) AddLines(lines any) *MultilineError {
v := reflect.ValueOf(lines)
if v.Kind() == reflect.Slice {
for i := range v.Len() {
switch v := v.Index(i).Interface().(type) {
case string:
m.add(newError(v))
case error:
m.add(v)
default:
m.add(fmt.Errorf("%v", v))
}
}
}
return m
}

View file

@ -0,0 +1,38 @@
package gperr
import (
"net"
"testing"
"github.com/stretchr/testify/require"
)
func TestMultiline(t *testing.T) {
multiline := Multiline()
multiline.Addf("line 1 %s", "test")
multiline.Adds("line 2")
multiline.AddLines([]any{1, "2", 3.0, net.IPv4(127, 0, 0, 1)})
t.Error(New("result").With(multiline))
t.Error(multiline.Subject("subject").Withf("inner"))
}
func TestWrapMultiline(t *testing.T) {
multiline := Multiline()
var wrapper error = wrap(multiline)
_, ok := wrapper.(*MultilineError)
if !ok {
t.Errorf("wrapper is not a MultilineError")
}
}
func TestPrependSubjectMultiline(t *testing.T) {
multiline := Multiline()
multiline.Addf("line 1 %s", "test")
multiline.Adds("line 2")
multiline.AddLines([]any{1, "2", 3.0, net.IPv4(127, 0, 0, 1)})
multiline.Subject("subject")
builder := NewBuilder()
builder.Add(multiline)
require.Equal(t, len(builder.errs), len(multiline.Extras), builder.errs)
}

View file

@ -15,7 +15,7 @@ type nestedError struct {
func (err nestedError) Subject(subject string) Error { func (err nestedError) Subject(subject string) Error {
if err.Err == nil { if err.Err == nil {
err.Err = newError(subject) err.Err = PrependSubject(subject, errStr(""))
} else { } else {
err.Err = PrependSubject(subject, err.Err) err.Err = PrependSubject(subject, err.Err)
} }

View file

@ -1,10 +1,12 @@
package gperr package gperr
import ( import (
"encoding/json" "errors"
"slices" "slices"
"strings" "strings"
"encoding/json"
"github.com/yusing/go-proxy/internal/utils/strutils/ansi" "github.com/yusing/go-proxy/internal/utils/strutils/ansi"
) )
@ -64,7 +66,7 @@ func (err *withSubject) Prepend(subject string) *withSubject {
} }
func (err *withSubject) Is(other error) bool { func (err *withSubject) Is(other error) bool {
return err.Err == other return errors.Is(other, err.Err)
} }
func (err *withSubject) Unwrap() error { func (err *withSubject) Unwrap() error {
@ -92,7 +94,6 @@ func (err *withSubject) Error() string {
return sb.String() return sb.String()
} }
// MarshalJSON implements the json.Marshaler interface.
func (err *withSubject) MarshalJSON() ([]byte, error) { func (err *withSubject) MarshalJSON() ([]byte, error) {
subjects := slices.Clone(err.Subjects) subjects := slices.Clone(err.Subjects)
slices.Reverse(subjects) slices.Reverse(subjects)

View file

@ -1,9 +1,10 @@
package gperr package gperr
import ( import (
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"encoding/json"
) )
func newError(message string) error { func newError(message string) error {
@ -41,6 +42,18 @@ func Wrap(err error, message ...string) Error {
return &baseError{fmt.Errorf("%s: %w", message[0], err)} return &baseError{fmt.Errorf("%s: %w", message[0], err)}
} }
func Unwrap(err error) Error {
//nolint:errorlint
switch err := err.(type) {
case interface{ Unwrap() []error }:
return &nestedError{Extras: err.Unwrap()}
case interface{ Unwrap() error }:
return &baseError{err.Unwrap()}
default:
return &baseError{err}
}
}
func wrap(err error) Error { func wrap(err error) Error {
if err == nil { if err == nil {
return nil return nil

View file

@ -2,15 +2,14 @@ package homepage
import ( import (
"net/http" "net/http"
"net/url"
net "github.com/yusing/go-proxy/internal/net/types"
) )
type route interface { type route interface {
TargetName() string TargetName() string
ProviderName() string ProviderName() string
Reference() string Reference() string
TargetURL() *net.URL TargetURL() *url.URL
} }
type httpRoute interface { type httpRoute interface {

View file

@ -0,0 +1,13 @@
package idlewatcher
import "context"
func (w *Watcher) cancelled(reqCtx context.Context) bool {
select {
case <-reqCtx.Done():
w.l.Debug().AnErr("cause", context.Cause(reqCtx)).Msg("wake canceled")
return true
default:
return false
}
}

View file

@ -1,60 +0,0 @@
package idlewatcher
import (
"context"
"errors"
"github.com/docker/docker/api/types/container"
)
type (
containerMeta struct {
ContainerID, ContainerName string
}
containerState struct {
running bool
ready bool
err error
}
)
func (w *Watcher) ContainerID() string {
return w.route.ContainerInfo().ContainerID
}
func (w *Watcher) ContainerName() string {
return w.route.ContainerInfo().ContainerName
}
func (w *Watcher) containerStop(ctx context.Context) error {
return w.client.ContainerStop(ctx, w.ContainerID(), container.StopOptions{
Signal: string(w.Config().StopSignal),
Timeout: &w.Config().StopTimeout,
})
}
func (w *Watcher) containerPause(ctx context.Context) error {
return w.client.ContainerPause(ctx, w.ContainerID())
}
func (w *Watcher) containerKill(ctx context.Context) error {
return w.client.ContainerKill(ctx, w.ContainerID(), string(w.Config().StopSignal))
}
func (w *Watcher) containerUnpause(ctx context.Context) error {
return w.client.ContainerUnpause(ctx, w.ContainerID())
}
func (w *Watcher) containerStart(ctx context.Context) error {
return w.client.ContainerStart(ctx, w.ContainerID(), container.StartOptions{})
}
func (w *Watcher) containerStatus() (string, error) {
ctx, cancel := context.WithTimeoutCause(w.task.Context(), dockerReqTimeout, errors.New("docker request timeout"))
defer cancel()
json, err := w.client.ContainerInspect(ctx, w.ContainerID())
if err != nil {
return "", err
}
return json.State.Status, nil
}

View file

@ -0,0 +1,40 @@
package idlewatcher
import (
"iter"
"strconv"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
type watcherDebug struct {
*Watcher
}
func (w watcherDebug) MarshalMap() map[string]any {
state := w.state.Load()
return map[string]any{
"name": w.Name(),
"state": map[string]string{
"status": string(state.status),
"ready": strconv.FormatBool(state.ready),
"err": fmtErr(state.err),
},
"expires": strutils.FormatTime(w.expires()),
"last_reset": strutils.FormatTime(w.lastReset.Load()),
"config": w.cfg,
}
}
func Watchers() iter.Seq2[string, watcherDebug] {
return func(yield func(string, watcherDebug) bool) {
watcherMapMu.RLock()
defer watcherMapMu.RUnlock()
for k, w := range watcherMap {
if !yield(k, watcherDebug{w}) {
return
}
}
}
}

View file

@ -42,20 +42,6 @@ func (w *Watcher) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
} }
} }
func (w *Watcher) cancelled(reqCtx context.Context, rw http.ResponseWriter) bool {
select {
case <-reqCtx.Done():
w.WakeDebug().Str("cause", context.Cause(reqCtx).Error()).Msg("canceled")
return true
case <-w.task.Context().Done():
w.WakeDebug().Str("cause", w.task.FinishCause().Error()).Msg("canceled")
http.Error(rw, "Service unavailable", http.StatusServiceUnavailable)
return true
default:
return false
}
}
func isFaviconPath(path string) bool { func isFaviconPath(path string) bool {
return path == "/favicon.ico" return path == "/favicon.ico"
} }
@ -70,13 +56,13 @@ func (w *Watcher) wakeFromHTTP(rw http.ResponseWriter, r *http.Request) (shouldN
// handle favicon request // handle favicon request
if isFaviconPath(r.URL.Path) { if isFaviconPath(r.URL.Path) {
r.URL.RawQuery = "alias=" + w.route.TargetName() r.URL.RawQuery = "alias=" + w.rp.TargetName
favicon.GetFavIcon(rw, r) favicon.GetFavIcon(rw, r)
return false return false
} }
// Check if start endpoint is configured and request path matches // Check if start endpoint is configured and request path matches
if w.Config().StartEndpoint != "" && r.URL.Path != w.Config().StartEndpoint { if w.cfg.StartEndpoint != "" && r.URL.Path != w.cfg.StartEndpoint {
http.Error(rw, "Forbidden: Container can only be started via configured start endpoint", http.StatusForbidden) http.Error(rw, "Forbidden: Container can only be started via configured start endpoint", http.StatusForbidden)
return false return false
} }
@ -95,44 +81,48 @@ func (w *Watcher) wakeFromHTTP(rw http.ResponseWriter, r *http.Request) (shouldN
rw.Header().Add("Cache-Control", "must-revalidate") rw.Header().Add("Cache-Control", "must-revalidate")
rw.Header().Add("Connection", "close") rw.Header().Add("Connection", "close")
if _, err := rw.Write(body); err != nil { if _, err := rw.Write(body); err != nil {
w.Err(err).Msg("error writing http response") return false
} }
return false return false
} }
ctx, cancel := context.WithTimeoutCause(r.Context(), w.Config().WakeTimeout, errors.New("wake timeout")) ctx, cancel := context.WithTimeoutCause(r.Context(), w.cfg.WakeTimeout, errors.New("wake timeout"))
defer cancel() defer cancel()
if w.cancelled(ctx, rw) { if w.cancelled(ctx) {
gphttp.ServerError(rw, r, context.Cause(ctx), http.StatusServiceUnavailable)
return false return false
} }
w.WakeTrace().Msg("signal received") w.l.Trace().Msg("signal received")
err := w.wakeIfStopped() err := w.wakeIfStopped()
if err != nil { if err != nil {
w.WakeError(err) gphttp.ServerError(rw, r, err)
http.Error(rw, "Error waking container", http.StatusInternalServerError)
return false return false
} }
var ready bool
for { for {
if w.cancelled(ctx, rw) { w.resetIdleTimer()
if w.cancelled(ctx) {
gphttp.ServerError(rw, r, context.Cause(ctx), http.StatusServiceUnavailable)
return false return false
} }
ready, err := w.checkUpdateState() w, ready, err = checkUpdateState(w.Key())
if err != nil { if err != nil {
http.Error(rw, "Error waking container", http.StatusInternalServerError) gphttp.ServerError(rw, r, err)
return false return false
} }
if ready { if ready {
w.resetIdleTimer()
if isCheckRedirect { if isCheckRedirect {
w.Debug().Msgf("redirecting to %s ...", w.hc.URL()) w.l.Debug().Stringer("url", w.hc.URL()).Msg("container is ready, redirecting")
rw.WriteHeader(http.StatusOK) rw.WriteHeader(http.StatusOK)
return false return false
} }
w.Debug().Msgf("passing through to %s ...", w.hc.URL()) w.l.Debug().Stringer("url", w.hc.URL()).Msg("container is ready, passing through")
return true return true
} }

View file

@ -3,11 +3,10 @@ package idlewatcher
import ( import (
"context" "context"
"errors" "errors"
"fmt"
"net" "net"
"time" "time"
"github.com/yusing/go-proxy/internal/net/types" gpnet "github.com/yusing/go-proxy/internal/net/types"
) )
// Setup implements types.Stream. // Setup implements types.Stream.
@ -21,19 +20,19 @@ func (w *Watcher) Setup() error {
} }
// Accept implements types.Stream. // Accept implements types.Stream.
func (w *Watcher) Accept() (conn types.StreamConn, err error) { func (w *Watcher) Accept() (conn gpnet.StreamConn, err error) {
conn, err = w.stream.Accept() conn, err = w.stream.Accept()
if err != nil { if err != nil {
return return
} }
if wakeErr := w.wakeFromStream(); wakeErr != nil { if wakeErr := w.wakeFromStream(); wakeErr != nil {
w.WakeError(wakeErr) w.l.Err(wakeErr).Msg("error waking container")
} }
return return
} }
// Handle implements types.Stream. // Handle implements types.Stream.
func (w *Watcher) Handle(conn types.StreamConn) error { func (w *Watcher) Handle(conn gpnet.StreamConn) error {
if err := w.wakeFromStream(); err != nil { if err := w.wakeFromStream(); err != nil {
return err return err
} }
@ -53,35 +52,29 @@ func (w *Watcher) wakeFromStream() error {
return nil return nil
} }
w.WakeDebug().Msg("wake signal received") w.l.Debug().Msg("wake signal received")
wakeErr := w.wakeIfStopped() err := w.wakeIfStopped()
if wakeErr != nil { if err != nil {
wakeErr = fmt.Errorf("%s failed: %w", w.String(), wakeErr) return err
w.WakeError(wakeErr)
return wakeErr
} }
ctx, cancel := context.WithTimeoutCause(w.task.Context(), w.Config().WakeTimeout, errors.New("wake timeout")) ctx, cancel := context.WithTimeoutCause(w.task.Context(), w.cfg.WakeTimeout, errors.New("wake timeout"))
defer cancel() defer cancel()
var ready bool
for { for {
select { if w.cancelled(ctx) {
case <-w.task.Context().Done(): return context.Cause(ctx)
cause := w.task.FinishCause()
w.WakeDebug().Str("cause", cause.Error()).Msg("canceled")
return cause
case <-ctx.Done():
cause := context.Cause(ctx)
w.WakeDebug().Str("cause", cause.Error()).Msg("timeout")
return cause
default:
} }
if ready, err := w.checkUpdateState(); err != nil { w, ready, err = checkUpdateState(w.Key())
if err != nil {
return err return err
} else if ready { }
if ready {
w.resetIdleTimer() w.resetIdleTimer()
w.Debug().Msg("container is ready, passing through to " + w.hc.URL().String()) w.l.Debug().Stringer("url", w.hc.URL()).Msg("container is ready, passing through")
return nil return nil
} }

View file

@ -0,0 +1,122 @@
package idlewatcher
import (
"errors"
"time"
"github.com/yusing/go-proxy/internal/gperr"
idlewatcher "github.com/yusing/go-proxy/internal/idlewatcher/types"
"github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/internal/watcher/health"
)
// Start implements health.HealthMonitor.
func (w *Watcher) Start(parent task.Parent) gperr.Error {
w.task.OnCancel("route_cleanup", func() {
parent.Finish(w.task.FinishCause())
})
return nil
}
// Task implements health.HealthMonitor.
func (w *Watcher) Task() *task.Task {
return w.task
}
// Finish implements health.HealthMonitor.
func (w *Watcher) Finish(reason any) {
if w.stream != nil {
w.stream.Close()
}
}
// Name implements health.HealthMonitor.
func (w *Watcher) Name() string {
return w.cfg.ContainerName()
}
// String implements health.HealthMonitor.
func (w *Watcher) String() string {
return w.Name()
}
// Uptime implements health.HealthMonitor.
func (w *Watcher) Uptime() time.Duration {
return 0
}
// Latency implements health.HealthMonitor.
func (w *Watcher) Latency() time.Duration {
return 0
}
// Status implements health.HealthMonitor.
func (w *Watcher) Status() health.Status {
state := w.state.Load()
if state.err != nil {
return health.StatusError
}
if state.ready {
return health.StatusHealthy
}
if state.status == idlewatcher.ContainerStatusRunning {
return health.StatusStarting
}
return health.StatusNapping
}
func checkUpdateState(key string) (w *Watcher, ready bool, err error) {
watcherMapMu.RLock()
w, ok := watcherMap[key]
if !ok {
watcherMapMu.RUnlock()
return nil, false, errors.New("watcher not found")
}
watcherMapMu.RUnlock()
// already ready
if w.ready() {
return w, true, nil
}
if !w.running() {
return w, false, nil
}
// the new container info not yet updated
if w.hc.URL().Host == "" {
return w, false, nil
}
res, err := w.hc.CheckHealth()
if err != nil {
w.setError(err)
return w, false, err
}
if res.Healthy {
w.setReady()
return w, true, nil
}
w.setStarting()
return w, false, nil
}
// MarshalMap implements health.HealthMonitor.
func (w *Watcher) MarshalMap() map[string]any {
url := w.hc.URL()
if url.Port() == "0" {
url = nil
}
var detail string
if err := w.error(); err != nil {
detail = err.Error()
}
return (&health.JSONRepresentation{
Name: w.Name(),
Status: w.Status(),
Config: dummyHealthCheckConfig,
URL: url,
Detail: detail,
}).MarshalMap()
}

View file

@ -19,11 +19,11 @@ var loadingPage []byte
var loadingPageTmpl = template.Must(template.New("loading_page").Parse(string(loadingPage))) var loadingPageTmpl = template.Must(template.New("loading_page").Parse(string(loadingPage)))
func (w *Watcher) makeLoadingPageBody() []byte { func (w *Watcher) makeLoadingPageBody() []byte {
msg := w.ContainerName() + " is starting..." msg := w.cfg.ContainerName() + " is starting..."
data := new(templateData) data := new(templateData)
data.CheckRedirectHeader = httpheaders.HeaderGoDoxyCheckRedirect data.CheckRedirectHeader = httpheaders.HeaderGoDoxyCheckRedirect
data.Title = w.route.HomepageItem().Name data.Title = w.cfg.ContainerName()
data.Message = msg data.Message = msg
buf := bytes.NewBuffer(make([]byte, len(loadingPage)+len(data.Title)+len(data.Message)+len(httpheaders.HeaderGoDoxyCheckRedirect))) buf := bytes.NewBuffer(make([]byte, len(loadingPage)+len(data.Title)+len(data.Message)+len(httpheaders.HeaderGoDoxyCheckRedirect)))

View file

@ -0,0 +1,90 @@
package provider
import (
"context"
"github.com/docker/docker/api/types/container"
"github.com/yusing/go-proxy/internal/docker"
"github.com/yusing/go-proxy/internal/gperr"
idlewatcher "github.com/yusing/go-proxy/internal/idlewatcher/types"
"github.com/yusing/go-proxy/internal/watcher"
)
type DockerProvider struct {
client *docker.SharedClient
watcher *watcher.DockerWatcher
containerID string
}
var startOptions = container.StartOptions{}
func NewDockerProvider(dockerHost, containerID string) (idlewatcher.Provider, error) {
client, err := docker.NewClient(dockerHost)
if err != nil {
return nil, err
}
return &DockerProvider{
client: client,
watcher: watcher.NewDockerWatcher(dockerHost),
containerID: containerID,
}, nil
}
func (p *DockerProvider) ContainerPause(ctx context.Context) error {
return p.client.ContainerPause(ctx, p.containerID)
}
func (p *DockerProvider) ContainerUnpause(ctx context.Context) error {
return p.client.ContainerUnpause(ctx, p.containerID)
}
func (p *DockerProvider) ContainerStart(ctx context.Context) error {
return p.client.ContainerStart(ctx, p.containerID, startOptions)
}
func (p *DockerProvider) ContainerStop(ctx context.Context, signal idlewatcher.Signal, timeout int) error {
return p.client.ContainerStop(ctx, p.containerID, container.StopOptions{
Signal: string(signal),
Timeout: &timeout,
})
}
func (p *DockerProvider) ContainerKill(ctx context.Context, signal idlewatcher.Signal) error {
return p.client.ContainerKill(ctx, p.containerID, string(signal))
}
func (p *DockerProvider) ContainerStatus(ctx context.Context) (idlewatcher.ContainerStatus, error) {
status, err := p.client.ContainerInspect(ctx, p.containerID)
if err != nil {
return idlewatcher.ContainerStatusError, err
}
switch status.State.Status {
case "running":
return idlewatcher.ContainerStatusRunning, nil
case "exited", "dead", "restarting":
return idlewatcher.ContainerStatusStopped, nil
case "paused":
return idlewatcher.ContainerStatusPaused, nil
}
return idlewatcher.ContainerStatusError, idlewatcher.ErrUnexpectedContainerStatus.Subject(status.State.Status)
}
func (p *DockerProvider) Watch(ctx context.Context) (eventCh <-chan watcher.Event, errCh <-chan gperr.Error) {
return p.watcher.EventsWithOptions(ctx, watcher.DockerListOptions{
Filters: watcher.NewDockerFilter(
watcher.DockerFilterContainer,
watcher.DockerFilterContainerNameID(p.containerID),
watcher.DockerFilterStart,
watcher.DockerFilterStop,
watcher.DockerFilterDie,
watcher.DockerFilterKill,
watcher.DockerFilterDestroy,
watcher.DockerFilterPause,
watcher.DockerFilterUnpause,
),
})
}
func (p *DockerProvider) Close() {
p.client.Close()
}

View file

@ -0,0 +1,129 @@
package provider
import (
"context"
"strconv"
"time"
"github.com/yusing/go-proxy/internal/gperr"
idlewatcher "github.com/yusing/go-proxy/internal/idlewatcher/types"
"github.com/yusing/go-proxy/internal/proxmox"
"github.com/yusing/go-proxy/internal/watcher"
"github.com/yusing/go-proxy/internal/watcher/events"
)
type ProxmoxProvider struct {
*proxmox.Node
vmid int
lxcName string
running bool
}
const proxmoxStateCheckInterval = 1 * time.Second
var ErrNodeNotFound = gperr.New("node not found in pool")
func NewProxmoxProvider(nodeName string, vmid int) (idlewatcher.Provider, error) {
node, ok := proxmox.Nodes.Get(nodeName)
if !ok {
return nil, ErrNodeNotFound.Subject(nodeName).
Withf("available nodes: %s", proxmox.AvailableNodeNames())
}
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
lxcName, err := node.LXCName(ctx, vmid)
if err != nil {
return nil, err
}
return &ProxmoxProvider{Node: node, vmid: vmid, lxcName: lxcName}, nil
}
func (p *ProxmoxProvider) ContainerPause(ctx context.Context) error {
return p.LXCAction(ctx, p.vmid, proxmox.LXCSuspend)
}
func (p *ProxmoxProvider) ContainerUnpause(ctx context.Context) error {
return p.LXCAction(ctx, p.vmid, proxmox.LXCResume)
}
func (p *ProxmoxProvider) ContainerStart(ctx context.Context) error {
return p.LXCAction(ctx, p.vmid, proxmox.LXCStart)
}
func (p *ProxmoxProvider) ContainerStop(ctx context.Context, _ idlewatcher.Signal, _ int) error {
return p.LXCAction(ctx, p.vmid, proxmox.LXCShutdown)
}
func (p *ProxmoxProvider) ContainerKill(ctx context.Context, _ idlewatcher.Signal) error {
return p.LXCAction(ctx, p.vmid, proxmox.LXCShutdown)
}
func (p *ProxmoxProvider) ContainerStatus(ctx context.Context) (idlewatcher.ContainerStatus, error) {
status, err := p.LXCStatus(ctx, p.vmid)
if err != nil {
return idlewatcher.ContainerStatusError, err
}
switch status {
case proxmox.LXCStatusRunning:
return idlewatcher.ContainerStatusRunning, nil
case proxmox.LXCStatusStopped:
return idlewatcher.ContainerStatusStopped, nil
}
return idlewatcher.ContainerStatusError, idlewatcher.ErrUnexpectedContainerStatus.Subject(string(status))
}
func (p *ProxmoxProvider) Watch(ctx context.Context) (<-chan watcher.Event, <-chan gperr.Error) {
eventCh := make(chan watcher.Event)
errCh := make(chan gperr.Error)
go func() {
defer close(eventCh)
defer close(errCh)
var err error
p.running, err = p.LXCIsRunning(ctx, p.vmid)
if err != nil {
errCh <- gperr.Wrap(err)
return
}
ticker := time.NewTicker(proxmoxStateCheckInterval)
defer ticker.Stop()
event := watcher.Event{
Type: events.EventTypeDocker,
ActorID: strconv.Itoa(p.vmid),
ActorName: p.lxcName,
}
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
status, err := p.ContainerStatus(ctx)
if err != nil {
errCh <- gperr.Wrap(err)
return
}
running := status == idlewatcher.ContainerStatusRunning
if p.running != running {
p.running = running
if running {
event.Action = events.ActionContainerStart
} else {
event.Action = events.ActionContainerStop
}
eventCh <- event
}
}
}
}()
return eventCh, errCh
}
func (p *ProxmoxProvider) Close() {
// noop
}

View file

@ -1,7 +1,9 @@
package idlewatcher package idlewatcher
import idlewatcher "github.com/yusing/go-proxy/internal/idlewatcher/types"
func (w *Watcher) running() bool { func (w *Watcher) running() bool {
return w.state.Load().running return w.state.Load().status == idlewatcher.ContainerStatusRunning
} }
func (w *Watcher) ready() bool { func (w *Watcher) ready() bool {
@ -14,26 +16,29 @@ func (w *Watcher) error() error {
func (w *Watcher) setReady() { func (w *Watcher) setReady() {
w.state.Store(&containerState{ w.state.Store(&containerState{
running: true, status: idlewatcher.ContainerStatusRunning,
ready: true, ready: true,
}) })
} }
func (w *Watcher) setStarting() { func (w *Watcher) setStarting() {
w.state.Store(&containerState{ w.state.Store(&containerState{
running: true, status: idlewatcher.ContainerStatusRunning,
ready: false, ready: false,
}) })
} }
func (w *Watcher) setNapping() { func (w *Watcher) setNapping(status idlewatcher.ContainerStatus) {
w.setError(nil) w.state.Store(&containerState{
status: status,
ready: false,
})
} }
func (w *Watcher) setError(err error) { func (w *Watcher) setError(err error) {
w.state.Store(&containerState{ w.state.Store(&containerState{
running: false, status: idlewatcher.ContainerStatusError,
ready: false, ready: false,
err: err, err: err,
}) })
} }

View file

@ -1,110 +1,128 @@
package idlewatcher package idlewatcher
import ( import (
"errors"
"net/url" "net/url"
"strconv"
"strings" "strings"
"time" "time"
"github.com/yusing/go-proxy/internal/docker"
"github.com/yusing/go-proxy/internal/gperr" "github.com/yusing/go-proxy/internal/gperr"
) )
type ( type (
Config struct { Config struct {
IdleTimeout time.Duration `json:"idle_timeout,omitempty"` Proxmox *ProxmoxConfig `json:"proxmox,omitempty"`
WakeTimeout time.Duration `json:"wake_timeout,omitempty"` Docker *DockerConfig `json:"docker,omitempty"`
StopTimeout int `json:"stop_timeout,omitempty"` // docker api takes integer seconds for timeout argument
StopMethod StopMethod `json:"stop_method,omitempty"` IdleTimeout time.Duration `json:"idle_timeout"`
WakeTimeout time.Duration `json:"wake_timeout"`
StopTimeout time.Duration `json:"stop_timeout"`
StopMethod StopMethod `json:"stop_method"`
StopSignal Signal `json:"stop_signal,omitempty"` StopSignal Signal `json:"stop_signal,omitempty"`
StartEndpoint string `json:"start_endpoint,omitempty"` // Optional path that must be hit to start container StartEndpoint string `json:"start_endpoint,omitempty"` // Optional path that must be hit to start container
} }
StopMethod string StopMethod string
Signal string Signal string
DockerConfig struct {
DockerHost string `json:"docker_host" validate:"required"`
ContainerID string `json:"container_id" validate:"required"`
ContainerName string `json:"container_name" validate:"required"`
}
ProxmoxConfig struct {
Node string `json:"node" validate:"required"`
VMID int `json:"vmid" validate:"required"`
}
) )
const ( const (
WakeTimeoutDefault = 30 * time.Second
StopTimeoutDefault = 1 * time.Minute
StopMethodPause StopMethod = "pause" StopMethodPause StopMethod = "pause"
StopMethodStop StopMethod = "stop" StopMethodStop StopMethod = "stop"
StopMethodKill StopMethod = "kill" StopMethodKill StopMethod = "kill"
) )
var validSignals = map[string]struct{}{ func (c *Config) Key() string {
"": {}, if c.Docker != nil {
"SIGINT": {}, "SIGTERM": {}, "SIGHUP": {}, "SIGQUIT": {}, return c.Docker.ContainerID
"INT": {}, "TERM": {}, "HUP": {}, "QUIT": {}, }
return c.Proxmox.Node + ":" + strconv.Itoa(c.Proxmox.VMID)
} }
func ValidateConfig(cont *docker.Container) (*Config, gperr.Error) { func (c *Config) ContainerName() string {
if cont == nil || cont.IdleTimeout == "" { if c.Docker != nil {
return nil, nil return c.Docker.ContainerName
} }
return "lxc " + strconv.Itoa(c.Proxmox.VMID)
errs := gperr.NewBuilder("invalid idlewatcher config")
idleTimeout := gperr.Collect(errs, validateDurationPostitive, cont.IdleTimeout)
wakeTimeout := gperr.Collect(errs, validateDurationPostitive, cont.WakeTimeout)
stopTimeout := gperr.Collect(errs, validateDurationPostitive, cont.StopTimeout)
stopMethod := gperr.Collect(errs, validateStopMethod, cont.StopMethod)
signal := gperr.Collect(errs, validateSignal, cont.StopSignal)
startEndpoint := gperr.Collect(errs, validateStartEndpoint, cont.StartEndpoint)
if errs.HasError() {
return nil, errs.Error()
}
return &Config{
IdleTimeout: idleTimeout,
WakeTimeout: wakeTimeout,
StopTimeout: int(stopTimeout.Seconds()),
StopMethod: stopMethod,
StopSignal: signal,
StartEndpoint: startEndpoint,
}, nil
} }
func validateDurationPostitive(value string) (time.Duration, error) { func (c *Config) Validate() gperr.Error {
d, err := time.ParseDuration(value) if c.IdleTimeout == 0 { // no idle timeout means no idle watcher
if err != nil { return nil
return 0, err
} }
if d < 0 { errs := gperr.NewBuilder("idlewatcher config validation error")
return 0, errors.New("duration must be positive") errs.AddRange(
} c.validateProvider(),
return d, nil c.validateTimeouts(),
c.validateStopMethod(),
c.validateStopSignal(),
c.validateStartEndpoint(),
)
return errs.Error()
} }
func validateSignal(s string) (Signal, error) { func (c *Config) validateProvider() error {
if _, ok := validSignals[s]; ok { if c.Docker == nil && c.Proxmox == nil {
return Signal(s), nil return gperr.New("missing idlewatcher provider config")
} }
return "", errors.New("invalid signal " + s) return nil
} }
func validateStopMethod(s string) (StopMethod, error) { func (c *Config) validateTimeouts() error {
sm := StopMethod(s) if c.WakeTimeout == 0 {
switch sm { c.WakeTimeout = WakeTimeoutDefault
}
if c.StopTimeout == 0 {
c.StopTimeout = StopTimeoutDefault
}
return nil
}
func (c *Config) validateStopMethod() error {
switch c.StopMethod {
case "":
c.StopMethod = StopMethodStop
return nil
case StopMethodPause, StopMethodStop, StopMethodKill: case StopMethodPause, StopMethodStop, StopMethodKill:
return sm, nil return nil
default: default:
return "", errors.New("invalid stop method " + s) return gperr.New("invalid stop method").Subject(string(c.StopMethod))
} }
} }
func validateStartEndpoint(s string) (string, error) { func (c *Config) validateStopSignal() error {
if s == "" { switch c.StopSignal {
return "", nil case "", "SIGINT", "SIGTERM", "SIGQUIT", "SIGHUP", "INT", "TERM", "QUIT", "HUP":
return nil
default:
return gperr.New("invalid stop signal").Subject(string(c.StopSignal))
}
}
func (c *Config) validateStartEndpoint() error {
if c.StartEndpoint == "" {
return nil
} }
// checks needed as of Go 1.6 because of change https://github.com/golang/go/commit/617c93ce740c3c3cc28cdd1a0d712be183d0b328#diff-6c2d018290e298803c0c9419d8739885L195 // checks needed as of Go 1.6 because of change https://github.com/golang/go/commit/617c93ce740c3c3cc28cdd1a0d712be183d0b328#diff-6c2d018290e298803c0c9419d8739885L195
// emulate browser and strip the '#' suffix prior to validation. see issue-#237 // emulate browser and strip the '#' suffix prior to validation. see issue-#237
if i := strings.Index(s, "#"); i > -1 { if i := strings.Index(c.StartEndpoint, "#"); i > -1 {
s = s[:i] c.StartEndpoint = c.StartEndpoint[:i]
} }
if len(s) == 0 { if len(c.StartEndpoint) == 0 {
return "", errors.New("start endpoint must not be empty if defined") return gperr.New("start endpoint must not be empty if defined")
} }
if _, err := url.ParseRequestURI(s); err != nil { _, err := url.ParseRequestURI(c.StartEndpoint)
return "", err return err
}
return s, nil
} }

View file

@ -35,9 +35,10 @@ func TestValidateStartEndpoint(t *testing.T) {
} }
for _, tc := range tests { for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
s, err := validateStartEndpoint(tc.input) cfg := Config{StartEndpoint: tc.input}
err := cfg.validateStartEndpoint()
if err == nil { if err == nil {
ExpectEqual(t, s, tc.input) ExpectEqual(t, cfg.StartEndpoint, tc.input)
} }
if (err != nil) != tc.wantErr { if (err != nil) != tc.wantErr {
t.Errorf("validateStartEndpoint() error = %v, wantErr %t", err, tc.wantErr) t.Errorf("validateStartEndpoint() error = %v, wantErr %t", err, tc.wantErr)

View file

@ -0,0 +1,14 @@
package idlewatcher
import "github.com/yusing/go-proxy/internal/gperr"
type ContainerStatus string
const (
ContainerStatusError ContainerStatus = "error"
ContainerStatusRunning ContainerStatus = "running"
ContainerStatusPaused ContainerStatus = "paused"
ContainerStatusStopped ContainerStatus = "stopped"
)
var ErrUnexpectedContainerStatus = gperr.New("unexpected container status")

View file

@ -0,0 +1,19 @@
package idlewatcher
import (
"context"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/watcher/events"
)
type Provider interface {
ContainerPause(ctx context.Context) error
ContainerUnpause(ctx context.Context) error
ContainerStart(ctx context.Context) error
ContainerStop(ctx context.Context, signal Signal, timeout int) error
ContainerKill(ctx context.Context, signal Signal) error
ContainerStatus(ctx context.Context) (ContainerStatus, error)
Watch(ctx context.Context) (eventCh <-chan events.Event, errCh <-chan gperr.Error)
Close()
}

View file

@ -1,172 +0,0 @@
package idlewatcher
import (
"time"
"github.com/yusing/go-proxy/internal/gperr"
idlewatcher "github.com/yusing/go-proxy/internal/idlewatcher/types"
"github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy"
net "github.com/yusing/go-proxy/internal/net/types"
route "github.com/yusing/go-proxy/internal/route/types"
"github.com/yusing/go-proxy/internal/task"
U "github.com/yusing/go-proxy/internal/utils"
"github.com/yusing/go-proxy/internal/watcher/health"
"github.com/yusing/go-proxy/internal/watcher/health/monitor"
)
type (
Waker = idlewatcher.Waker
waker struct {
_ U.NoCopy
rp *reverseproxy.ReverseProxy
stream net.Stream
hc health.HealthChecker
}
)
const (
idleWakerCheckInterval = 100 * time.Millisecond
idleWakerCheckTimeout = time.Second
)
// TODO: support stream
func newWaker(parent task.Parent, route route.Route, rp *reverseproxy.ReverseProxy, stream net.Stream) (Waker, gperr.Error) {
hcCfg := route.HealthCheckConfig()
hcCfg.Timeout = idleWakerCheckTimeout
waker := &waker{
rp: rp,
stream: stream,
}
watcher, err := registerWatcher(parent, route, waker)
if err != nil {
return nil, gperr.Errorf("register watcher: %w", err)
}
switch {
case route.IsAgent():
waker.hc = monitor.NewAgentProxiedMonitor(route.Agent(), hcCfg, monitor.AgentTargetFromURL(route.TargetURL()))
case rp != nil:
waker.hc = monitor.NewHTTPHealthChecker(route.TargetURL(), hcCfg)
case stream != nil:
waker.hc = monitor.NewRawHealthChecker(route.TargetURL(), hcCfg)
default:
panic("both nil")
}
return watcher, nil
}
// lifetime should follow route provider.
func NewHTTPWaker(parent task.Parent, route route.Route, rp *reverseproxy.ReverseProxy) (Waker, gperr.Error) {
return newWaker(parent, route, rp, nil)
}
func NewStreamWaker(parent task.Parent, route route.Route, stream net.Stream) (Waker, gperr.Error) {
return newWaker(parent, route, nil, stream)
}
// Start implements health.HealthMonitor.
func (w *Watcher) Start(parent task.Parent) gperr.Error {
w.task.OnCancel("route_cleanup", func() {
parent.Finish(w.task.FinishCause())
})
return nil
}
// Task implements health.HealthMonitor.
func (w *Watcher) Task() *task.Task {
return w.task
}
// Finish implements health.HealthMonitor.
func (w *Watcher) Finish(reason any) {
if w.stream != nil {
w.stream.Close()
}
}
// Name implements health.HealthMonitor.
func (w *Watcher) Name() string {
return w.String()
}
// String implements health.HealthMonitor.
func (w *Watcher) String() string {
return w.ContainerName()
}
// Uptime implements health.HealthMonitor.
func (w *Watcher) Uptime() time.Duration {
return 0
}
// Latency implements health.HealthMonitor.
func (w *Watcher) Latency() time.Duration {
return 0
}
// Status implements health.HealthMonitor.
func (w *Watcher) Status() health.Status {
state := w.state.Load()
if state.err != nil {
return health.StatusError
}
if state.ready {
return health.StatusHealthy
}
if state.running {
return health.StatusStarting
}
return health.StatusNapping
}
func (w *Watcher) checkUpdateState() (ready bool, err error) {
// already ready
if w.ready() {
return true, nil
}
if !w.running() {
return false, nil
}
// the new container info not yet updated
if w.hc.URL().Host == "" {
return false, nil
}
res, err := w.hc.CheckHealth()
if err != nil {
w.setError(err)
return false, err
}
if res.Healthy {
w.setReady()
return true, nil
}
w.setStarting()
return false, nil
}
// MarshalJSON implements health.HealthMonitor.
func (w *Watcher) MarshalJSON() ([]byte, error) {
var url *net.URL
if w.hc.URL().Port() != "0" {
url = w.hc.URL()
}
var detail string
if err := w.error(); err != nil {
detail = err.Error()
}
return (&monitor.JSONRepresentation{
Name: w.Name(),
Status: w.Status(),
Config: w.hc.Config(),
URL: url,
Detail: detail,
}).MarshalJSON()
}

View file

@ -7,195 +7,236 @@ import (
"time" "time"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"github.com/yusing/go-proxy/internal/docker"
"github.com/yusing/go-proxy/internal/gperr" "github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/idlewatcher/provider"
idlewatcher "github.com/yusing/go-proxy/internal/idlewatcher/types" idlewatcher "github.com/yusing/go-proxy/internal/idlewatcher/types"
"github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy"
net "github.com/yusing/go-proxy/internal/net/types"
route "github.com/yusing/go-proxy/internal/route/types" route "github.com/yusing/go-proxy/internal/route/types"
"github.com/yusing/go-proxy/internal/task" "github.com/yusing/go-proxy/internal/task"
U "github.com/yusing/go-proxy/internal/utils" U "github.com/yusing/go-proxy/internal/utils"
"github.com/yusing/go-proxy/internal/utils/atomic" "github.com/yusing/go-proxy/internal/utils/atomic"
"github.com/yusing/go-proxy/internal/watcher"
"github.com/yusing/go-proxy/internal/watcher/events" "github.com/yusing/go-proxy/internal/watcher/events"
"github.com/yusing/go-proxy/internal/watcher/health"
"github.com/yusing/go-proxy/internal/watcher/health/monitor"
) )
type ( type (
routeHelper struct {
rp *reverseproxy.ReverseProxy
stream net.Stream
hc health.HealthChecker
}
containerState struct {
status idlewatcher.ContainerStatus
ready bool
err error
}
Watcher struct { Watcher struct {
_ U.NoCopy _ U.NoCopy
routeHelper
zerolog.Logger l zerolog.Logger
*waker cfg *idlewatcher.Config
route route.Route provider idlewatcher.Provider
client *docker.SharedClient state atomic.Value[*containerState]
state atomic.Value[*containerState] lastReset atomic.Value[time.Time]
stopByMethod StopCallback // send a docker command w.r.t. `stop_method` ticker *time.Ticker
ticker *time.Ticker task *task.Task
lastReset time.Time
task *task.Task
} }
StopCallback func() error StopCallback func() error
) )
const ContextKey = "idlewatcher.watcher"
var ( var (
watcherMap = make(map[string]*Watcher) watcherMap = make(map[string]*Watcher)
watcherMapMu sync.RWMutex watcherMapMu sync.RWMutex
errShouldNotReachHere = errors.New("should not reach here")
) )
const dockerReqTimeout = 3 * time.Second const (
idleWakerCheckInterval = 100 * time.Millisecond
idleWakerCheckTimeout = time.Second
)
func registerWatcher(parent task.Parent, route route.Route, waker *waker) (*Watcher, error) { var dummyHealthCheckConfig = &health.HealthCheckConfig{
cfg := route.IdlewatcherConfig() Interval: idleWakerCheckInterval,
cont := route.ContainerInfo() Timeout: idleWakerCheckTimeout,
key := cont.ContainerID }
var (
causeReload = gperr.New("reloaded")
causeContainerDestroy = gperr.New("container destroyed")
)
const reqTimeout = 3 * time.Second
// TODO: fix stream type
func NewWatcher(parent task.Parent, r route.Route) (*Watcher, error) {
cfg := r.IdlewatcherConfig()
key := cfg.Key()
watcherMapMu.RLock()
// if the watcher already exists, finish it
w, exists := watcherMap[key]
if exists {
if w.cfg == cfg {
// same address, likely two routes from the same container
return w, nil
}
w.task.Finish(causeReload)
}
watcherMapMu.RUnlock()
w = &Watcher{
ticker: time.NewTicker(cfg.IdleTimeout),
cfg: cfg,
routeHelper: routeHelper{
hc: monitor.NewMonitor(r),
},
}
var p idlewatcher.Provider
var providerType string
var err error
switch {
case cfg.Docker != nil:
p, err = provider.NewDockerProvider(cfg.Docker.DockerHost, cfg.Docker.ContainerID)
providerType = "docker"
default:
p, err = provider.NewProxmoxProvider(cfg.Proxmox.Node, cfg.Proxmox.VMID)
providerType = "proxmox"
}
if err != nil {
return nil, err
}
w.provider = p
w.l = logging.With().
Str("provider", providerType).
Str("container", cfg.ContainerName()).
Logger()
switch r := r.(type) {
case route.ReverseProxyRoute:
w.rp = r.ReverseProxy()
case route.StreamRoute:
w.stream = r
default:
return nil, gperr.New("unexpected route type")
}
ctx, cancel := context.WithTimeout(parent.Context(), reqTimeout)
defer cancel()
status, err := w.provider.ContainerStatus(ctx)
if err != nil {
w.provider.Close()
return nil, gperr.Wrap(err, "failed to get container status")
}
switch p := w.provider.(type) {
case *provider.ProxmoxProvider:
shutdownTimeout := max(time.Second, cfg.StopTimeout-idleWakerCheckTimeout)
err = p.LXCSetShutdownTimeout(ctx, cfg.Proxmox.VMID, shutdownTimeout)
if err != nil {
w.l.Warn().Err(err).Msg("failed to set shutdown timeout")
}
}
w.state.Store(&containerState{status: status})
w.task = parent.Subtask("idlewatcher."+r.TargetName(), true)
watcherMapMu.Lock() watcherMapMu.Lock()
defer watcherMapMu.Unlock() defer watcherMapMu.Unlock()
w, ok := watcherMap[key] watcherMap[key] = w
if !ok { go func() {
client, err := docker.NewClient(cont.DockerHost) cause := w.watchUntilDestroy()
if err != nil { if cause.Is(causeContainerDestroy) {
return nil, err
}
w = &Watcher{
Logger: logging.With().Str("name", cont.ContainerName).Logger(),
client: client,
task: parent.Subtask("idlewatcher." + cont.ContainerName),
ticker: time.NewTicker(cfg.IdleTimeout),
}
}
// FIXME: possible race condition here
w.waker = waker
w.route = route
w.ticker.Reset(cfg.IdleTimeout)
if cont.Running {
w.setStarting()
} else {
w.setNapping()
}
if !ok {
w.stopByMethod = w.getStopCallback()
watcherMap[key] = w
go func() {
cause := w.watchUntilDestroy()
watcherMapMu.Lock() watcherMapMu.Lock()
defer watcherMapMu.Unlock() defer watcherMapMu.Unlock()
delete(watcherMap, key) delete(watcherMap, key)
w.l.Info().Msg("idlewatcher stopped")
} else if !cause.Is(causeReload) {
gperr.LogError("idlewatcher stopped unexpectedly", cause, &w.l)
}
w.ticker.Stop() w.ticker.Stop()
w.client.Close() w.provider.Close()
w.task.Finish(cause) w.task.Finish(cause)
}() }()
} w.l.Info().Msg("idlewatcher started")
return w, nil return w, nil
} }
func (w *Watcher) Config() *idlewatcher.Config { func (w *Watcher) Key() string {
return w.route.IdlewatcherConfig() return w.cfg.Key()
} }
func (w *Watcher) Wake() error { func (w *Watcher) Wake() error {
return w.wakeIfStopped() return w.wakeIfStopped()
} }
// WakeDebug logs a debug message related to waking the container.
func (w *Watcher) WakeDebug() *zerolog.Event {
//nolint:zerologlint
return w.Debug().Str("action", "wake")
}
func (w *Watcher) WakeTrace() *zerolog.Event {
//nolint:zerologlint
return w.Trace().Str("action", "wake")
}
func (w *Watcher) WakeError(err error) {
w.Err(err).Str("action", "wake").Msg("error")
}
func (w *Watcher) wakeIfStopped() error { func (w *Watcher) wakeIfStopped() error {
if w.running() { state := w.state.Load()
if state.status == idlewatcher.ContainerStatusRunning {
w.l.Debug().Msg("container is already running")
return nil return nil
} }
status, err := w.containerStatus() ctx, cancel := context.WithTimeout(w.task.Context(), w.cfg.WakeTimeout)
if err != nil { defer cancel()
return err switch state.status {
case idlewatcher.ContainerStatusStopped:
w.l.Info().Msg("starting container")
return w.provider.ContainerStart(ctx)
case idlewatcher.ContainerStatusPaused:
w.l.Info().Msg("unpausing container")
return w.provider.ContainerUnpause(ctx)
default:
return gperr.Errorf("unexpected container status: %s", state.status)
}
}
func (w *Watcher) stopByMethod() error {
if !w.running() {
return nil
} }
ctx, cancel := context.WithTimeout(w.task.Context(), w.Config().WakeTimeout) cfg := w.cfg
ctx, cancel := context.WithTimeout(w.task.Context(), cfg.StopTimeout)
defer cancel() defer cancel()
// !Hard coded here since theres no constants from Docker API switch cfg.StopMethod {
switch status {
case "exited", "dead":
return w.containerStart(ctx)
case "paused":
return w.containerUnpause(ctx)
case "running":
return nil
default:
return gperr.Errorf("unexpected container status: %s", status)
}
}
func (w *Watcher) getStopCallback() StopCallback {
var cb func(context.Context) error
switch w.Config().StopMethod {
case idlewatcher.StopMethodPause: case idlewatcher.StopMethodPause:
cb = w.containerPause return w.provider.ContainerPause(ctx)
case idlewatcher.StopMethodStop: case idlewatcher.StopMethodStop:
cb = w.containerStop return w.provider.ContainerStop(ctx, cfg.StopSignal, int(cfg.StopTimeout.Seconds()))
case idlewatcher.StopMethodKill: case idlewatcher.StopMethodKill:
cb = w.containerKill return w.provider.ContainerKill(ctx, cfg.StopSignal)
default: default:
panic(errShouldNotReachHere) return gperr.Errorf("unexpected stop method: %q", cfg.StopMethod)
}
return func() error {
ctx, cancel := context.WithTimeout(w.task.Context(), time.Duration(w.Config().StopTimeout)*time.Second)
defer cancel()
return cb(ctx)
} }
} }
func (w *Watcher) resetIdleTimer() { func (w *Watcher) resetIdleTimer() {
w.Trace().Msg("reset idle timer") w.ticker.Reset(w.cfg.IdleTimeout)
w.ticker.Reset(w.Config().IdleTimeout) w.lastReset.Store(time.Now())
w.lastReset = time.Now()
} }
func (w *Watcher) expires() time.Time { func (w *Watcher) expires() time.Time {
return w.lastReset.Add(w.Config().IdleTimeout) if !w.running() {
} return time.Time{}
}
func (w *Watcher) getEventCh(ctx context.Context, dockerWatcher *watcher.DockerWatcher) (eventCh <-chan events.Event, errCh <-chan gperr.Error) { return w.lastReset.Load().Add(w.cfg.IdleTimeout)
eventCh, errCh = dockerWatcher.EventsWithOptions(ctx, watcher.DockerListOptions{
Filters: watcher.NewDockerFilter(
watcher.DockerFilterContainer,
watcher.DockerFilterContainerNameID(w.route.ContainerInfo().ContainerID),
watcher.DockerFilterStart,
watcher.DockerFilterStop,
watcher.DockerFilterDie,
watcher.DockerFilterKill,
watcher.DockerFilterDestroy,
watcher.DockerFilterPause,
watcher.DockerFilterUnpause,
),
})
return
} }
// watchUntilDestroy waits for the container to be created, started, or unpaused, // watchUntilDestroy waits for the container to be created, started, or unpaused,
@ -209,55 +250,34 @@ func (w *Watcher) getEventCh(ctx context.Context, dockerWatcher *watcher.DockerW
// //
// it exits only if the context is canceled, the container is destroyed, // it exits only if the context is canceled, the container is destroyed,
// errors occurred on docker client, or route provider died (mainly caused by config reload). // errors occurred on docker client, or route provider died (mainly caused by config reload).
func (w *Watcher) watchUntilDestroy() (returnCause error) { func (w *Watcher) watchUntilDestroy() (returnCause gperr.Error) {
eventCtx, eventCancel := context.WithCancel(w.task.Context()) eventCh, errCh := w.provider.Watch(w.Task().Context())
defer eventCancel()
dockerWatcher := watcher.NewDockerWatcher(w.client.DaemonHost())
dockerEventCh, dockerEventErrCh := w.getEventCh(eventCtx, dockerWatcher)
for { for {
select { select {
case <-w.task.Context().Done(): case <-w.task.Context().Done():
return w.task.FinishCause() return gperr.Wrap(w.task.FinishCause())
case err := <-dockerEventErrCh: case err := <-errCh:
if !err.Is(context.Canceled) {
gperr.LogError("idlewatcher error", err, &w.Logger)
}
return err return err
case e := <-dockerEventCh: case e := <-eventCh:
w.l.Debug().Stringer("action", e.Action).Msg("state changed")
if e.Action == events.ActionContainerDestroy {
return causeContainerDestroy
}
w.resetIdleTimer()
switch { switch {
case e.Action == events.ActionContainerDestroy: case e.Action.IsContainerStart(): // create / start / unpause
w.setError(errors.New("container destroyed"))
w.Info().Str("reason", "container destroyed").Msg("watcher stopped")
return errors.New("container destroyed")
// create / start / unpause
case e.Action.IsContainerWake():
w.setStarting() w.setStarting()
w.resetIdleTimer() w.l.Info().Msg("awaken")
w.Info().Msg("awaken") case e.Action.IsContainerStop(): // stop / kill / die
case e.Action.IsContainerSleep(): // stop / pause / kil w.setNapping(idlewatcher.ContainerStatusStopped)
w.setNapping() w.ticker.Stop()
w.resetIdleTimer() case e.Action.IsContainerPause(): // pause
w.setNapping(idlewatcher.ContainerStatusPaused)
w.ticker.Stop() w.ticker.Stop()
default: default:
w.Error().Msg("unexpected docker event: " + e.String()) w.l.Error().Stringer("action", e.Action).Msg("unexpected container action")
} }
// container name changed should also change the container id
// if w.ContainerName != e.ActorName {
// w.Debug().Msgf("renamed %s -> %s", w.ContainerName, e.ActorName)
// w.ContainerName = e.ActorName
// }
// if w.ContainerID != e.ActorID {
// w.Debug().Msgf("id changed %s -> %s", w.ContainerID, e.ActorID)
// w.ContainerID = e.ActorID
// // recreate event stream
// eventCancel()
// eventCtx, eventCancel = context.WithCancel(w.task.Context())
// defer eventCancel()
// dockerEventCh, dockerEventErrCh = w.getEventCh(eventCtx, dockerWatcher)
// }
case <-w.ticker.C: case <-w.ticker.C:
w.ticker.Stop() w.ticker.Stop()
if w.running() { if w.running() {
@ -269,11 +289,18 @@ func (w *Watcher) watchUntilDestroy() (returnCause error) {
if errors.Is(err, context.DeadlineExceeded) { if errors.Is(err, context.DeadlineExceeded) {
err = errors.New("timeout waiting for container to stop, please set a higher value for `stop_timeout`") err = errors.New("timeout waiting for container to stop, please set a higher value for `stop_timeout`")
} }
w.Err(err).Msgf("container stop with method %q failed", w.Config().StopMethod) w.l.Err(err).Msgf("container stop with method %q failed", w.cfg.StopMethod)
default: default:
w.Info().Str("reason", "idle timeout").Msg("container stopped") w.l.Info().Str("reason", "idle timeout").Msg("container stopped")
} }
} }
} }
} }
} }
func fmtErr(err error) string {
if err == nil {
return ""
}
return err.Error()
}

View file

@ -10,6 +10,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/gperr" "github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/task" "github.com/yusing/go-proxy/internal/task"
@ -39,14 +40,12 @@ const (
pollInterval = 1 * time.Second pollInterval = 1 * time.Second
gatherErrsInterval = 30 * time.Second gatherErrsInterval = 30 * time.Second
saveInterval = 5 * time.Minute saveInterval = 5 * time.Minute
saveBaseDir = "data/metrics"
) )
var initDataDirOnce sync.Once var initDataDirOnce sync.Once
func initDataDir() { func initDataDir() {
if err := os.MkdirAll(saveBaseDir, 0o755); err != nil { if err := os.MkdirAll(common.MetricsDataDir, 0o755); err != nil {
logging.Error().Err(err).Msg("failed to create metrics data directory") logging.Error().Err(err).Msg("failed to create metrics data directory")
} }
} }
@ -65,7 +64,7 @@ func NewPoller[T any, AggregateT json.Marshaler](
} }
func (p *Poller[T, AggregateT]) savePath() string { func (p *Poller[T, AggregateT]) savePath() string {
return filepath.Join(saveBaseDir, fmt.Sprintf("%s.json", p.name)) return filepath.Join(common.MetricsDataDir, fmt.Sprintf("%s.json", p.name))
} }
func (p *Poller[T, AggregateT]) load() error { func (p *Poller[T, AggregateT]) load() error {

View file

@ -172,6 +172,45 @@ func (s *SystemInfo) collectMemoryInfo(ctx context.Context) error {
return nil return nil
} }
func shouldExcludeDisk(name string) bool {
// include only sd* and nvme* disk devices
// but not partitions like nvme0p1
if len(name) < 3 {
return true
}
switch {
case strings.HasPrefix(name, "nvme"),
strings.HasPrefix(name, "mmcblk"): // NVMe/SD/MMC
s := name[len(name)-2]
// skip namespaces/partitions
switch s {
case 'p', 'n':
return true
default:
return false
}
}
switch name[0] {
case 's', 'h', 'v': // SCSI/SATA/virtio disks
if name[1] != 'd' {
return true
}
case 'x': // Xen virtual disks
if name[1:3] != "vd" {
return true
}
default:
return true
}
last := name[len(name)-1]
if last >= '0' && last <= '9' {
// skip partitions
return true
}
return false
}
func (s *SystemInfo) collectDisksInfo(ctx context.Context, lastResult *SystemInfo) error { func (s *SystemInfo) collectDisksInfo(ctx context.Context, lastResult *SystemInfo) error {
ioCounters, err := disk.IOCountersWithContext(ctx) ioCounters, err := disk.IOCountersWithContext(ctx)
if err != nil { if err != nil {
@ -179,34 +218,9 @@ func (s *SystemInfo) collectDisksInfo(ctx context.Context, lastResult *SystemInf
} }
s.DisksIO = make(map[string]*DiskIO, len(ioCounters)) s.DisksIO = make(map[string]*DiskIO, len(ioCounters))
for name, io := range ioCounters { for name, io := range ioCounters {
// include only /dev/sd* and /dev/nvme* disk devices if shouldExcludeDisk(name) {
if len(name) < 3 {
continue continue
} }
switch {
case strings.HasPrefix(name, "nvme"),
strings.HasPrefix(name, "mmcblk"): // NVMe/SD/MMC
if name[len(name)-2] == 'p' {
continue // skip partitions
}
default:
switch name[0] {
case 's', 'h', 'v': // SCSI/SATA/virtio disks
if name[1] != 'd' {
continue
}
case 'x': // Xen virtual disks
if name[1:3] != "vd" {
continue
}
default:
continue
}
last := name[len(name)-1]
if last >= '0' && last <= '9' {
continue // skip partitions
}
}
s.DisksIO[name] = &DiskIO{ s.DisksIO[name] = &DiskIO{
ReadBytes: io.ReadBytes, ReadBytes: io.ReadBytes,
WriteBytes: io.WriteBytes, WriteBytes: io.WriteBytes,

View file

@ -10,6 +10,73 @@ import (
. "github.com/yusing/go-proxy/internal/utils/testing" . "github.com/yusing/go-proxy/internal/utils/testing"
) )
func TestExcludeDisks(t *testing.T) {
tests := []struct {
name string
shouldExclude bool
}{
{
name: "nvme0",
shouldExclude: false,
},
{
name: "nvme0n1",
shouldExclude: true,
},
{
name: "nvme0n1p1",
shouldExclude: true,
},
{
name: "sda",
shouldExclude: false,
},
{
name: "sda1",
shouldExclude: true,
},
{
name: "hda",
shouldExclude: false,
},
{
name: "vda",
shouldExclude: false,
},
{
name: "xvda",
shouldExclude: false,
},
{
name: "xva",
shouldExclude: true,
},
{
name: "loop0",
shouldExclude: true,
},
{
name: "mmcblk0",
shouldExclude: false,
},
{
name: "mmcblk0p1",
shouldExclude: true,
},
{
name: "ab",
shouldExclude: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := shouldExcludeDisk(tt.name)
ExpectEqual(t, result, tt.shouldExclude)
})
}
}
// Create test data // Create test data
var cpuAvg = 45.67 var cpuAvg = 45.67
var testInfo = &SystemInfo{ var testInfo = &SystemInfo{
@ -118,7 +185,7 @@ func TestSystemInfo(t *testing.T) {
func TestSerialize(t *testing.T) { func TestSerialize(t *testing.T) {
entries := make([]*SystemInfo, 5) entries := make([]*SystemInfo, 5)
for i := 0; i < 5; i++ { for i := range 5 {
entries[i] = testInfo entries[i] = testInfo
} }
for _, query := range allQueries { for _, query := range allQueries {
@ -140,9 +207,9 @@ func TestSerialize(t *testing.T) {
} }
} }
func BenchmarkSerialize(b *testing.B) { func BenchmarkJSONMarshal(b *testing.B) {
entries := make([]*SystemInfo, b.N) entries := make([]*SystemInfo, b.N)
for i := 0; i < b.N; i++ { for i := range b.N {
entries[i] = testInfo entries[i] = testInfo
} }
queries := map[string]Aggregated{} queries := map[string]Aggregated{}
@ -153,14 +220,14 @@ func BenchmarkSerialize(b *testing.B) {
b.ReportAllocs() b.ReportAllocs()
b.ResetTimer() b.ResetTimer()
b.Run("optimized", func(b *testing.B) { b.Run("optimized", func(b *testing.B) {
for i := 0; i < b.N; i++ { for b.Loop() {
for _, query := range allQueries { for _, query := range allQueries {
_, _ = queries[query].MarshalJSON() _, _ = queries[query].MarshalJSON()
} }
} }
}) })
b.Run("json", func(b *testing.B) { b.Run("json", func(b *testing.B) {
for i := 0; i < b.N; i++ { for b.Loop() {
for _, query := range allQueries { for _, query := range allQueries {
_, _ = json.Marshal([]map[string]any(queries[query])) _, _ = json.Marshal([]map[string]any(queries[query]))
} }

View file

@ -25,11 +25,15 @@ type (
} }
AccessLogIO interface { AccessLogIO interface {
io.Writer
sync.Locker
Name() string // file name or path
}
supportRotate interface {
io.ReadWriteCloser io.ReadWriteCloser
io.ReadWriteSeeker io.ReadWriteSeeker
io.ReaderAt io.ReaderAt
sync.Locker
Name() string // file name or path
Truncate(size int64) error Truncate(size int64) error
} }
@ -40,7 +44,33 @@ type (
} }
) )
func NewAccessLogger(parent task.Parent, io AccessLogIO, cfg *Config) *AccessLogger { func NewAccessLogger(parent task.Parent, cfg *Config) (*AccessLogger, error) {
var ios []AccessLogIO
if cfg.Stdout {
ios = append(ios, stdoutIO)
}
if cfg.Path != "" {
io, err := newFileIO(cfg.Path)
if err != nil {
return nil, err
}
ios = append(ios, io)
}
if len(ios) == 0 {
return nil, nil
}
return NewAccessLoggerWithIO(parent, NewMultiWriter(ios...), cfg), nil
}
func NewMockAccessLogger(parent task.Parent, cfg *Config) *AccessLogger {
return NewAccessLoggerWithIO(parent, &MockFile{}, cfg)
}
func NewAccessLoggerWithIO(parent task.Parent, io AccessLogIO, cfg *Config) *AccessLogger {
if cfg.BufferSize == 0 { if cfg.BufferSize == 0 {
cfg.BufferSize = DefaultBufferSize cfg.BufferSize = DefaultBufferSize
} }
@ -152,7 +182,9 @@ func (l *AccessLogger) Flush() error {
func (l *AccessLogger) close() { func (l *AccessLogger) close() {
l.io.Lock() l.io.Lock()
defer l.io.Unlock() defer l.io.Unlock()
l.io.Close() if r, ok := l.io.(io.Closer); ok {
r.Close()
}
} }
func (l *AccessLogger) write(data []byte) { func (l *AccessLogger) write(data []byte) {

View file

@ -56,7 +56,7 @@ func fmtLog(cfg *Config) (ts string, line string) {
var buf bytes.Buffer var buf bytes.Buffer
t := time.Now() t := time.Now()
logger := NewAccessLogger(testTask, nil, cfg) logger := NewMockAccessLogger(testTask, cfg)
logger.Formatter.SetGetTimeNow(func() time.Time { logger.Formatter.SetGetTimeNow(func() time.Time {
return t return t
}) })

View file

@ -7,7 +7,7 @@ import (
// BackScanner provides an interface to read a file backward line by line. // BackScanner provides an interface to read a file backward line by line.
type BackScanner struct { type BackScanner struct {
file AccessLogIO file supportRotate
chunkSize int chunkSize int
offset int64 offset int64
buffer []byte buffer []byte
@ -18,7 +18,7 @@ type BackScanner struct {
// NewBackScanner creates a new Scanner to read the file backward. // NewBackScanner creates a new Scanner to read the file backward.
// chunkSize determines the size of each read chunk from the end of the file. // chunkSize determines the size of each read chunk from the end of the file.
func NewBackScanner(file AccessLogIO, chunkSize int) *BackScanner { func NewBackScanner(file supportRotate, chunkSize int) *BackScanner {
size, err := file.Seek(0, io.SeekEnd) size, err := file.Seek(0, io.SeekEnd)
if err != nil { if err != nil {
return &BackScanner{err: err} return &BackScanner{err: err}

View file

@ -1,6 +1,10 @@
package accesslog package accesslog
import "github.com/yusing/go-proxy/internal/utils" import (
"errors"
"github.com/yusing/go-proxy/internal/utils"
)
type ( type (
Format string Format string
@ -19,7 +23,8 @@ type (
Config struct { Config struct {
BufferSize int `json:"buffer_size"` BufferSize int `json:"buffer_size"`
Format Format `json:"format" validate:"oneof=common combined json"` Format Format `json:"format" validate:"oneof=common combined json"`
Path string `json:"path" validate:"required"` Path string `json:"path"`
Stdout bool `json:"stdout"`
Filters Filters `json:"filters"` Filters Filters `json:"filters"`
Fields Fields `json:"fields"` Fields Fields `json:"fields"`
Retention *Retention `json:"retention"` Retention *Retention `json:"retention"`
@ -34,6 +39,13 @@ var (
const DefaultBufferSize = 64 * 1024 // 64KB const DefaultBufferSize = 64 * 1024 // 64KB
func (cfg *Config) Validate() error {
if cfg.Path == "" && !cfg.Stdout {
return errors.New("path or stdout is required")
}
return nil
}
func DefaultConfig() *Config { func DefaultConfig() *Config {
return &Config{ return &Config{
BufferSize: DefaultBufferSize, BufferSize: DefaultBufferSize,

View file

@ -3,11 +3,10 @@ package accesslog
import ( import (
"fmt" "fmt"
"os" "os"
"path" pathPkg "path"
"sync" "sync"
"github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/internal/utils" "github.com/yusing/go-proxy/internal/utils"
) )
@ -27,16 +26,16 @@ var (
openedFilesMu sync.Mutex openedFilesMu sync.Mutex
) )
func NewFileAccessLogger(parent task.Parent, cfg *Config) (*AccessLogger, error) { func newFileIO(path string) (AccessLogIO, error) {
openedFilesMu.Lock() openedFilesMu.Lock()
var file *File var file *File
path := path.Clean(cfg.Path) path = pathPkg.Clean(path)
if opened, ok := openedFiles[path]; ok { if opened, ok := openedFiles[path]; ok {
opened.refCount.Add() opened.refCount.Add()
file = opened file = opened
} else { } else {
f, err := os.OpenFile(cfg.Path, os.O_APPEND|os.O_CREATE|os.O_RDWR, 0o644) f, err := os.OpenFile(path, os.O_APPEND|os.O_CREATE|os.O_RDWR, 0o644)
if err != nil { if err != nil {
openedFilesMu.Unlock() openedFilesMu.Unlock()
return nil, fmt.Errorf("access log open error: %w", err) return nil, fmt.Errorf("access log open error: %w", err)
@ -47,7 +46,7 @@ func NewFileAccessLogger(parent task.Parent, cfg *Config) (*AccessLogger, error)
} }
openedFilesMu.Unlock() openedFilesMu.Unlock()
return NewAccessLogger(parent, file, cfg), nil return file, nil
} }
func (f *File) Close() error { func (f *File) Close() error {

View file

@ -16,7 +16,6 @@ func TestConcurrentFileLoggersShareSameAccessLogIO(t *testing.T) {
cfg := DefaultConfig() cfg := DefaultConfig()
cfg.Path = "test.log" cfg.Path = "test.log"
parent := task.RootTask("test", false)
loggerCount := 10 loggerCount := 10
accessLogIOs := make([]AccessLogIO, loggerCount) accessLogIOs := make([]AccessLogIO, loggerCount)
@ -33,9 +32,9 @@ func TestConcurrentFileLoggersShareSameAccessLogIO(t *testing.T) {
wg.Add(1) wg.Add(1)
go func(index int) { go func(index int) {
defer wg.Done() defer wg.Done()
logger, err := NewFileAccessLogger(parent, cfg) file, err := newFileIO(cfg.Path)
ExpectNoError(t, err) ExpectNoError(t, err)
accessLogIOs[index] = logger.io accessLogIOs[index] = file
}(i) }(i)
} }
@ -59,7 +58,7 @@ func TestConcurrentAccessLoggerLogAndFlush(t *testing.T) {
loggers := make([]*AccessLogger, loggerCount) loggers := make([]*AccessLogger, loggerCount)
for i := range loggerCount { for i := range loggerCount {
loggers[i] = NewAccessLogger(parent, &file, cfg) loggers[i] = NewAccessLoggerWithIO(parent, &file, cfg)
} }
var wg sync.WaitGroup var wg sync.WaitGroup

View file

@ -6,7 +6,6 @@ import (
"strings" "strings"
"github.com/yusing/go-proxy/internal/gperr" "github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/net/types"
"github.com/yusing/go-proxy/internal/utils/strutils" "github.com/yusing/go-proxy/internal/utils/strutils"
) )
@ -24,7 +23,7 @@ type (
Key, Value string Key, Value string
} }
Host string Host string
CIDR struct{ types.CIDR } CIDR net.IPNet
) )
var ErrInvalidHTTPHeaderFilter = gperr.New("invalid http header filter") var ErrInvalidHTTPHeaderFilter = gperr.New("invalid http header filter")
@ -86,7 +85,7 @@ func (h Host) Fulfill(req *http.Request, res *http.Response) bool {
return req.Host == string(h) return req.Host == string(h)
} }
func (cidr CIDR) Fulfill(req *http.Request, res *http.Response) bool { func (cidr *CIDR) Fulfill(req *http.Request, res *http.Response) bool {
ip, _, err := net.SplitHostPort(req.RemoteAddr) ip, _, err := net.SplitHostPort(req.RemoteAddr)
if err != nil { if err != nil {
ip = req.RemoteAddr ip = req.RemoteAddr
@ -95,5 +94,9 @@ func (cidr CIDR) Fulfill(req *http.Request, res *http.Response) bool {
if netIP == nil { if netIP == nil {
return false return false
} }
return cidr.Contains(netIP) return (*net.IPNet)(cidr).Contains(netIP)
}
func (cidr *CIDR) String() string {
return (*net.IPNet)(cidr).String()
} }

View file

@ -1,6 +1,7 @@
package accesslog_test package accesslog_test
import ( import (
"net"
"net/http" "net/http"
"testing" "testing"
@ -155,9 +156,10 @@ func TestHeaderFilter(t *testing.T) {
} }
func TestCIDRFilter(t *testing.T) { func TestCIDRFilter(t *testing.T) {
cidr := []*CIDR{ cidr := []*CIDR{{
strutils.MustParse[*CIDR]("192.168.10.0/24"), IP: net.ParseIP("192.168.10.0"),
} Mask: net.CIDRMask(24, 32),
}}
ExpectEqual(t, cidr[0].String(), "192.168.10.0/24") ExpectEqual(t, cidr[0].String(), "192.168.10.0/24")
inCIDR := &http.Request{ inCIDR := &http.Request{
RemoteAddr: "192.168.10.1", RemoteAddr: "192.168.10.1",

View file

@ -0,0 +1,46 @@
package accesslog
import "strings"
type MultiWriter struct {
writers []AccessLogIO
}
func NewMultiWriter(writers ...AccessLogIO) AccessLogIO {
if len(writers) == 0 {
return nil
}
if len(writers) == 1 {
return writers[0]
}
return &MultiWriter{
writers: writers,
}
}
func (w *MultiWriter) Write(p []byte) (n int, err error) {
for _, writer := range w.writers {
writer.Write(p)
}
return len(p), nil
}
func (w *MultiWriter) Lock() {
for _, writer := range w.writers {
writer.Lock()
}
}
func (w *MultiWriter) Unlock() {
for _, writer := range w.writers {
writer.Unlock()
}
}
func (w *MultiWriter) Name() string {
names := make([]string, len(w.writers))
for i, writer := range w.writers {
names[i] = writer.Name()
}
return strings.Join(names, ", ")
}

View file

@ -2,11 +2,15 @@ package accesslog
import ( import (
"bytes" "bytes"
"io" ioPkg "io"
"time" "time"
) )
func (l *AccessLogger) rotate() (err error) { func (l *AccessLogger) rotate() (err error) {
io, ok := l.io.(supportRotate)
if !ok {
return nil
}
// Get retention configuration // Get retention configuration
config := l.Config().Retention config := l.Config().Retention
var shouldKeep func(t time.Time, lineCount int) bool var shouldKeep func(t time.Time, lineCount int) bool
@ -24,7 +28,7 @@ func (l *AccessLogger) rotate() (err error) {
return nil // No retention policy set return nil // No retention policy set
} }
s := NewBackScanner(l.io, defaultChunkSize) s := NewBackScanner(io, defaultChunkSize)
nRead := 0 nRead := 0
nLines := 0 nLines := 0
for s.Scan() { for s.Scan() {
@ -40,11 +44,11 @@ func (l *AccessLogger) rotate() (err error) {
} }
beg := int64(nRead) beg := int64(nRead)
if _, err := l.io.Seek(-beg, io.SeekEnd); err != nil { if _, err := io.Seek(-beg, ioPkg.SeekEnd); err != nil {
return err return err
} }
buf := make([]byte, nRead) buf := make([]byte, nRead)
if _, err := l.io.Read(buf); err != nil { if _, err := io.Read(buf); err != nil {
return err return err
} }
@ -55,8 +59,13 @@ func (l *AccessLogger) rotate() (err error) {
} }
func (l *AccessLogger) writeTruncate(buf []byte) (err error) { func (l *AccessLogger) writeTruncate(buf []byte) (err error) {
io, ok := l.io.(supportRotate)
if !ok {
return nil
}
// Seek to beginning and truncate // Seek to beginning and truncate
if _, err := l.io.Seek(0, 0); err != nil { if _, err := io.Seek(0, 0); err != nil {
return err return err
} }
@ -70,13 +79,13 @@ func (l *AccessLogger) writeTruncate(buf []byte) (err error) {
} }
// Truncate file // Truncate file
if err = l.io.Truncate(int64(nWritten)); err != nil { if err = io.Truncate(int64(nWritten)); err != nil {
return err return err
} }
// check bytes written == buffer size // check bytes written == buffer size
if nWritten != len(buf) { if nWritten != len(buf) {
return io.ErrShortWrite return ioPkg.ErrShortWrite
} }
return return
} }

View file

@ -33,7 +33,7 @@ func TestParseLogTime(t *testing.T) {
func TestRetentionCommonFormat(t *testing.T) { func TestRetentionCommonFormat(t *testing.T) {
var file MockFile var file MockFile
logger := NewAccessLogger(task.RootTask("test", false), &file, &Config{ logger := NewAccessLoggerWithIO(task.RootTask("test", false), &file, &Config{
Format: FormatCommon, Format: FormatCommon,
BufferSize: 1024, BufferSize: 1024,
}) })

View file

@ -0,0 +1,18 @@
package accesslog
import (
"io"
"os"
)
type StdoutLogger struct {
io.Writer
}
var stdoutIO = &StdoutLogger{os.Stdout}
func (l *StdoutLogger) Lock() {}
func (l *StdoutLogger) Unlock() {}
func (l *StdoutLogger) Name() string {
return "stdout"
}

View file

@ -6,6 +6,7 @@ import (
"time" "time"
"github.com/coder/websocket" "github.com/coder/websocket"
"github.com/coder/websocket/wsjson"
"github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/gperr" "github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/logging"
@ -84,3 +85,18 @@ func WriteText(r *http.Request, conn *websocket.Conn, msg string) bool {
} }
return true return true
} }
// DynamicJSONHandler serves a JSON response depending on the request type.
//
// If the request is a websocket, it serves the data for the given interval.
//
// Otherwise, it serves the data once.
func DynamicJSONHandler[ResultType any](w http.ResponseWriter, r *http.Request, getter func() ResultType, interval time.Duration) {
if httpheaders.IsWebsocket(r.Header) {
Periodic(w, r, interval, func(conn *websocket.Conn) error {
return wsjson.Write(r.Context(), conn, getter())
})
} else {
gphttp.RespondJSON(w, r, getter())
}
}

View file

@ -13,7 +13,6 @@ import (
"github.com/yusing/go-proxy/internal/route/routes" "github.com/yusing/go-proxy/internal/route/routes"
"github.com/yusing/go-proxy/internal/task" "github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/internal/watcher/health" "github.com/yusing/go-proxy/internal/watcher/health"
"github.com/yusing/go-proxy/internal/watcher/health/monitor"
) )
// TODO: stats of each server. // TODO: stats of each server.
@ -240,14 +239,14 @@ func (lb *LoadBalancer) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
lb.impl.ServeHTTP(srvs, rw, r) lb.impl.ServeHTTP(srvs, rw, r)
} }
// MarshalJSON implements health.HealthMonitor. // MarshalMap implements health.HealthMonitor.
func (lb *LoadBalancer) MarshalJSON() ([]byte, error) { func (lb *LoadBalancer) MarshalMap() map[string]any {
extra := make(map[string]any) extra := make(map[string]any)
lb.pool.RangeAll(func(k string, v Server) { lb.pool.RangeAll(func(k string, v Server) {
extra[v.Key()] = v extra[v.Key()] = v
}) })
return (&monitor.JSONRepresentation{ return (&health.JSONRepresentation{
Name: lb.Name(), Name: lb.Name(),
Status: lb.Status(), Status: lb.Status(),
Started: lb.startTime, Started: lb.startTime,
@ -256,7 +255,7 @@ func (lb *LoadBalancer) MarshalJSON() ([]byte, error) {
"config": lb.Config, "config": lb.Config,
"pool": extra, "pool": extra,
}, },
}).MarshalJSON() }).MarshalMap()
} }
// Name implements health.HealthMonitor. // Name implements health.HealthMonitor.

View file

@ -2,9 +2,9 @@ package types
import ( import (
"net/http" "net/http"
"net/url"
idlewatcher "github.com/yusing/go-proxy/internal/idlewatcher/types" idlewatcher "github.com/yusing/go-proxy/internal/idlewatcher/types"
net "github.com/yusing/go-proxy/internal/net/types"
U "github.com/yusing/go-proxy/internal/utils" U "github.com/yusing/go-proxy/internal/utils"
F "github.com/yusing/go-proxy/internal/utils/functional" F "github.com/yusing/go-proxy/internal/utils/functional"
"github.com/yusing/go-proxy/internal/watcher/health" "github.com/yusing/go-proxy/internal/watcher/health"
@ -15,7 +15,7 @@ type (
_ U.NoCopy _ U.NoCopy
name string name string
url *net.URL url *url.URL
weight Weight weight Weight
http.Handler `json:"-"` http.Handler `json:"-"`
@ -27,7 +27,7 @@ type (
health.HealthMonitor health.HealthMonitor
Name() string Name() string
Key() string Key() string
URL() *net.URL URL() *url.URL
Weight() Weight Weight() Weight
SetWeight(weight Weight) SetWeight(weight Weight)
TryWake() error TryWake() error
@ -38,7 +38,7 @@ type (
var NewServerPool = F.NewMap[Pool] var NewServerPool = F.NewMap[Pool]
func NewServer(name string, url *net.URL, weight Weight, handler http.Handler, healthMon health.HealthMonitor) Server { func NewServer(name string, url *url.URL, weight Weight, handler http.Handler, healthMon health.HealthMonitor) Server {
srv := &server{ srv := &server{
name: name, name: name,
url: url, url: url,
@ -52,7 +52,7 @@ func NewServer(name string, url *net.URL, weight Weight, handler http.Handler, h
func TestNewServer[T ~int | ~float32 | ~float64](weight T) Server { func TestNewServer[T ~int | ~float32 | ~float64](weight T) Server {
srv := &server{ srv := &server{
weight: Weight(weight), weight: Weight(weight),
url: net.MustParseURL("http://localhost"), url: &url.URL{Scheme: "http", Host: "localhost"},
} }
return srv return srv
} }
@ -61,7 +61,7 @@ func (srv *server) Name() string {
return srv.name return srv.name
} }
func (srv *server) URL() *net.URL { func (srv *server) URL() *url.URL {
return srv.url return srv.url
} }

View file

@ -6,7 +6,6 @@ import (
"github.com/go-playground/validator/v10" "github.com/go-playground/validator/v10"
gphttp "github.com/yusing/go-proxy/internal/net/gphttp" gphttp "github.com/yusing/go-proxy/internal/net/gphttp"
"github.com/yusing/go-proxy/internal/net/types"
"github.com/yusing/go-proxy/internal/utils" "github.com/yusing/go-proxy/internal/utils"
F "github.com/yusing/go-proxy/internal/utils/functional" F "github.com/yusing/go-proxy/internal/utils/functional"
) )
@ -18,8 +17,8 @@ type (
cachedAddr F.Map[string, bool] // cache for trusted IPs cachedAddr F.Map[string, bool] // cache for trusted IPs
} }
CIDRWhitelistOpts struct { CIDRWhitelistOpts struct {
Allow []*types.CIDR `validate:"min=1"` Allow []*net.IPNet `validate:"min=1"`
StatusCode int `json:"status_code" aliases:"status" validate:"omitempty,status_code"` StatusCode int `json:"status_code" aliases:"status" validate:"omitempty,status_code"`
Message string Message string
} }
) )
@ -27,7 +26,7 @@ type (
var ( var (
CIDRWhiteList = NewMiddleware[cidrWhitelist]() CIDRWhiteList = NewMiddleware[cidrWhitelist]()
cidrWhitelistDefaults = CIDRWhitelistOpts{ cidrWhitelistDefaults = CIDRWhitelistOpts{
Allow: []*types.CIDR{}, Allow: []*net.IPNet{},
StatusCode: http.StatusForbidden, StatusCode: http.StatusForbidden,
Message: "IP not allowed", Message: "IP not allowed",
} }

View file

@ -11,7 +11,6 @@ import (
"github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/net/types"
"github.com/yusing/go-proxy/internal/utils/atomic" "github.com/yusing/go-proxy/internal/utils/atomic"
"github.com/yusing/go-proxy/internal/utils/strutils" "github.com/yusing/go-proxy/internal/utils/strutils"
) )
@ -33,7 +32,7 @@ var (
cfCIDRsMu sync.Mutex cfCIDRsMu sync.Mutex
// RFC 1918. // RFC 1918.
localCIDRs = []*types.CIDR{ localCIDRs = []*net.IPNet{
{IP: net.IPv4(127, 0, 0, 1), Mask: net.IPv4Mask(255, 255, 255, 255)}, // 127.0.0.1/32 {IP: net.IPv4(127, 0, 0, 1), Mask: net.IPv4Mask(255, 255, 255, 255)}, // 127.0.0.1/32
{IP: net.IPv4(10, 0, 0, 0), Mask: net.IPv4Mask(255, 0, 0, 0)}, // 10.0.0.0/8 {IP: net.IPv4(10, 0, 0, 0), Mask: net.IPv4Mask(255, 0, 0, 0)}, // 10.0.0.0/8
{IP: net.IPv4(172, 16, 0, 0), Mask: net.IPv4Mask(255, 240, 0, 0)}, // 172.16.0.0/12 {IP: net.IPv4(172, 16, 0, 0), Mask: net.IPv4Mask(255, 240, 0, 0)}, // 172.16.0.0/12
@ -68,7 +67,7 @@ func (cri *cloudflareRealIP) getTracer() *Tracer {
return cri.realIP.getTracer() return cri.realIP.getTracer()
} }
func tryFetchCFCIDR() (cfCIDRs []*types.CIDR) { func tryFetchCFCIDR() (cfCIDRs []*net.IPNet) {
if time.Since(cfCIDRsLastUpdate.Load()) < cfCIDRsUpdateInterval { if time.Since(cfCIDRsLastUpdate.Load()) < cfCIDRsUpdateInterval {
return return
} }
@ -83,7 +82,7 @@ func tryFetchCFCIDR() (cfCIDRs []*types.CIDR) {
if common.IsTest { if common.IsTest {
cfCIDRs = localCIDRs cfCIDRs = localCIDRs
} else { } else {
cfCIDRs = make([]*types.CIDR, 0, 30) cfCIDRs = make([]*net.IPNet, 0, 30)
err := errors.Join( err := errors.Join(
fetchUpdateCFIPRange(cfIPv4CIDRsEndpoint, &cfCIDRs), fetchUpdateCFIPRange(cfIPv4CIDRsEndpoint, &cfCIDRs),
fetchUpdateCFIPRange(cfIPv6CIDRsEndpoint, &cfCIDRs), fetchUpdateCFIPRange(cfIPv6CIDRsEndpoint, &cfCIDRs),
@ -103,7 +102,7 @@ func tryFetchCFCIDR() (cfCIDRs []*types.CIDR) {
return return
} }
func fetchUpdateCFIPRange(endpoint string, cfCIDRs *[]*types.CIDR) error { func fetchUpdateCFIPRange(endpoint string, cfCIDRs *[]*net.IPNet) error {
resp, err := http.Get(endpoint) resp, err := http.Get(endpoint)
if err != nil { if err != nil {
return err return err
@ -124,7 +123,7 @@ func fetchUpdateCFIPRange(endpoint string, cfCIDRs *[]*types.CIDR) error {
return fmt.Errorf("cloudflare responeded an invalid CIDR: %s", line) return fmt.Errorf("cloudflare responeded an invalid CIDR: %s", line)
} }
*cfCIDRs = append(*cfCIDRs, (*types.CIDR)(cidr)) *cfCIDRs = append(*cfCIDRs, (*net.IPNet)(cidr))
} }
*cfCIDRs = append(*cfCIDRs, localCIDRs...) *cfCIDRs = append(*cfCIDRs, localCIDRs...)
return nil return nil

View file

@ -16,8 +16,6 @@ import (
"github.com/yusing/go-proxy/internal/watcher/events" "github.com/yusing/go-proxy/internal/watcher/events"
) )
const errPagesBasePath = common.ErrorPagesBasePath
var ( var (
setupOnce sync.Once setupOnce sync.Once
dirWatcher W.Watcher dirWatcher W.Watcher
@ -26,7 +24,7 @@ var (
func setup() { func setup() {
t := task.RootTask("error_page", false) t := task.RootTask("error_page", false)
dirWatcher = W.NewDirectoryWatcher(t, errPagesBasePath) dirWatcher = W.NewDirectoryWatcher(t, common.ErrorPagesDir)
loadContent() loadContent()
go watchDir() go watchDir()
} }
@ -46,7 +44,7 @@ func GetErrorPageByStatus(statusCode int) (content []byte, ok bool) {
} }
func loadContent() { func loadContent() {
files, err := U.ListFiles(errPagesBasePath, 0) files, err := U.ListFiles(common.ErrorPagesDir, 0)
if err != nil { if err != nil {
logging.Err(err).Msg("failed to list error page resources") logging.Err(err).Msg("failed to list error page resources")
return return

View file

@ -55,7 +55,7 @@ func All() map[string]*Middleware {
func LoadComposeFiles() { func LoadComposeFiles() {
errs := gperr.NewBuilder("middleware compile errors") errs := gperr.NewBuilder("middleware compile errors")
middlewareDefs, err := utils.ListFiles(common.MiddlewareComposeBasePath, 0) middlewareDefs, err := utils.ListFiles(common.MiddlewareComposeDir, 0)
if err != nil { if err != nil {
logging.Err(err).Msg("failed to list middleware definitions") logging.Err(err).Msg("failed to list middleware definitions")
return return

View file

@ -4,10 +4,10 @@ import (
"bytes" "bytes"
"net" "net"
"net/http" "net/http"
"net/url"
"slices" "slices"
"testing" "testing"
"github.com/yusing/go-proxy/internal/net/types"
. "github.com/yusing/go-proxy/internal/utils/testing" . "github.com/yusing/go-proxy/internal/utils/testing"
) )
@ -51,8 +51,8 @@ func TestModifyRequest(t *testing.T) {
}) })
t.Run("request_headers", func(t *testing.T) { t.Run("request_headers", func(t *testing.T) {
reqURL := types.MustParseURL("https://my.app/?arg_1=b") reqURL := Must(url.Parse("https://my.app/?arg_1=b"))
upstreamURL := types.MustParseURL("http://test.example.com") upstreamURL := Must(url.Parse("http://test.example.com"))
result, err := newMiddlewareTest(ModifyRequest, &testArgs{ result, err := newMiddlewareTest(ModifyRequest, &testArgs{
middlewareOpt: opts, middlewareOpt: opts,
reqURL: reqURL, reqURL: reqURL,
@ -128,8 +128,8 @@ func TestModifyRequest(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
reqURL := types.MustParseURL("https://my.app" + tt.path) reqURL := Must(url.Parse("https://my.app" + tt.path))
upstreamURL := types.MustParseURL(tt.upstreamURL) upstreamURL := Must(url.Parse(tt.upstreamURL))
opts["add_prefix"] = tt.addPrefix opts["add_prefix"] = tt.addPrefix
result, err := newMiddlewareTest(ModifyRequest, &testArgs{ result, err := newMiddlewareTest(ModifyRequest, &testArgs{

View file

@ -4,10 +4,10 @@ import (
"bytes" "bytes"
"net" "net"
"net/http" "net/http"
"net/url"
"slices" "slices"
"testing" "testing"
"github.com/yusing/go-proxy/internal/net/types"
. "github.com/yusing/go-proxy/internal/utils/testing" . "github.com/yusing/go-proxy/internal/utils/testing"
) )
@ -54,8 +54,8 @@ func TestModifyResponse(t *testing.T) {
}) })
t.Run("response_headers", func(t *testing.T) { t.Run("response_headers", func(t *testing.T) {
reqURL := types.MustParseURL("https://my.app/?arg_1=b") reqURL := Must(url.Parse("https://my.app/?arg_1=b"))
upstreamURL := types.MustParseURL("http://test.example.com") upstreamURL := Must(url.Parse("http://test.example.com"))
result, err := newMiddlewareTest(ModifyResponse, &testArgs{ result, err := newMiddlewareTest(ModifyResponse, &testArgs{
middlewareOpt: opts, middlewareOpt: opts,
reqURL: reqURL, reqURL: reqURL,

View file

@ -5,7 +5,6 @@ import (
"net/http" "net/http"
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders" "github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
"github.com/yusing/go-proxy/internal/net/types"
) )
// https://nginx.org/en/docs/http/ngx_http_realip_module.html // https://nginx.org/en/docs/http/ngx_http_realip_module.html
@ -19,7 +18,7 @@ type (
// Header is the name of the header to use for the real client IP // Header is the name of the header to use for the real client IP
Header string `validate:"required"` Header string `validate:"required"`
// From is a list of Address / CIDRs to trust // From is a list of Address / CIDRs to trust
From []*types.CIDR `validate:"required,min=1"` From []*net.IPNet `validate:"required,min=1"`
/* /*
If recursive search is disabled, If recursive search is disabled,
the original client address that matches one of the trusted addresses is replaced by the original client address that matches one of the trusted addresses is replaced by
@ -36,7 +35,7 @@ var (
RealIP = NewMiddleware[realIP]() RealIP = NewMiddleware[realIP]()
realIPOptsDefault = RealIPOpts{ realIPOptsDefault = RealIPOpts{
Header: "X-Real-IP", Header: "X-Real-IP",
From: []*types.CIDR{}, From: []*net.IPNet{},
} }
) )

View file

@ -7,7 +7,6 @@ import (
"testing" "testing"
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders" "github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
"github.com/yusing/go-proxy/internal/net/types"
. "github.com/yusing/go-proxy/internal/utils/testing" . "github.com/yusing/go-proxy/internal/utils/testing"
) )
@ -23,7 +22,7 @@ func TestSetRealIPOpts(t *testing.T) {
} }
optExpected := &RealIPOpts{ optExpected := &RealIPOpts{
Header: httpheaders.HeaderXRealIP, Header: httpheaders.HeaderXRealIP,
From: []*types.CIDR{ From: []*net.IPNet{
{ {
IP: net.ParseIP("127.0.0.0"), IP: net.ParseIP("127.0.0.0"),
Mask: net.IPv4Mask(255, 0, 0, 0), Mask: net.IPv4Mask(255, 0, 0, 0),

View file

@ -2,15 +2,15 @@ package middleware
import ( import (
"net/http" "net/http"
"net/url"
"testing" "testing"
"github.com/yusing/go-proxy/internal/net/types"
. "github.com/yusing/go-proxy/internal/utils/testing" . "github.com/yusing/go-proxy/internal/utils/testing"
) )
func TestRedirectToHTTPs(t *testing.T) { func TestRedirectToHTTPs(t *testing.T) {
result, err := newMiddlewareTest(RedirectHTTP, &testArgs{ result, err := newMiddlewareTest(RedirectHTTP, &testArgs{
reqURL: types.MustParseURL("http://example.com"), reqURL: Must(url.Parse("http://example.com")),
}) })
ExpectNoError(t, err) ExpectNoError(t, err)
ExpectEqual(t, result.ResponseStatus, http.StatusPermanentRedirect) ExpectEqual(t, result.ResponseStatus, http.StatusPermanentRedirect)
@ -19,7 +19,7 @@ func TestRedirectToHTTPs(t *testing.T) {
func TestNoRedirect(t *testing.T) { func TestNoRedirect(t *testing.T) {
result, err := newMiddlewareTest(RedirectHTTP, &testArgs{ result, err := newMiddlewareTest(RedirectHTTP, &testArgs{
reqURL: types.MustParseURL("https://example.com"), reqURL: Must(url.Parse("https://example.com")),
}) })
ExpectNoError(t, err) ExpectNoError(t, err)
ExpectEqual(t, result.ResponseStatus, http.StatusOK) ExpectEqual(t, result.ResponseStatus, http.StatusOK)

View file

@ -7,11 +7,11 @@ import (
"io" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url"
"github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/gperr" "github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy" "github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy"
"github.com/yusing/go-proxy/internal/net/types"
. "github.com/yusing/go-proxy/internal/utils/testing" . "github.com/yusing/go-proxy/internal/utils/testing"
) )
@ -80,11 +80,11 @@ type TestResult struct {
type testArgs struct { type testArgs struct {
middlewareOpt OptionsRaw middlewareOpt OptionsRaw
upstreamURL *types.URL upstreamURL *url.URL
realRoundTrip bool realRoundTrip bool
reqURL *types.URL reqURL *url.URL
reqMethod string reqMethod string
headers http.Header headers http.Header
body []byte body []byte
@ -96,13 +96,13 @@ type testArgs struct {
func (args *testArgs) setDefaults() { func (args *testArgs) setDefaults() {
if args.reqURL == nil { if args.reqURL == nil {
args.reqURL = Must(types.ParseURL("https://example.com")) args.reqURL = Must(url.Parse("https://example.com"))
} }
if args.reqMethod == "" { if args.reqMethod == "" {
args.reqMethod = http.MethodGet args.reqMethod = http.MethodGet
} }
if args.upstreamURL == nil { if args.upstreamURL == nil {
args.upstreamURL = Must(types.ParseURL("https://10.0.0.1:8443")) // dummy url, no actual effect args.upstreamURL = Must(url.Parse("https://10.0.0.1:8443")) // dummy url, no actual effect
} }
if args.respHeaders == nil { if args.respHeaders == nil {
args.respHeaders = http.Header{} args.respHeaders = http.Header{}

View file

@ -28,7 +28,6 @@ import (
"github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/net/gphttp/accesslog" "github.com/yusing/go-proxy/internal/net/gphttp/accesslog"
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders" "github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
"github.com/yusing/go-proxy/internal/net/types"
U "github.com/yusing/go-proxy/internal/utils" U "github.com/yusing/go-proxy/internal/utils"
"golang.org/x/net/http/httpguts" "golang.org/x/net/http/httpguts"
) )
@ -93,7 +92,7 @@ type ReverseProxy struct {
HandlerFunc http.HandlerFunc HandlerFunc http.HandlerFunc
TargetName string TargetName string
TargetURL *types.URL TargetURL *url.URL
} }
func singleJoiningSlash(a, b string) string { func singleJoiningSlash(a, b string) string {
@ -133,7 +132,7 @@ func joinURLPath(a, b *url.URL) (path, rawpath string) {
// URLs to the scheme, host, and base path provided in target. If the // URLs to the scheme, host, and base path provided in target. If the
// target's path is "/base" and the incoming request was for "/dir", // target's path is "/base" and the incoming request was for "/dir",
// the target request will be for /base/dir. // the target request will be for /base/dir.
func NewReverseProxy(name string, target *types.URL, transport http.RoundTripper) *ReverseProxy { func NewReverseProxy(name string, target *url.URL, transport http.RoundTripper) *ReverseProxy {
if transport == nil { if transport == nil {
panic("nil transport") panic("nil transport")
} }
@ -151,7 +150,7 @@ func (p *ReverseProxy) rewriteRequestURL(req *http.Request) {
targetQuery := p.TargetURL.RawQuery targetQuery := p.TargetURL.RawQuery
req.URL.Scheme = p.TargetURL.Scheme req.URL.Scheme = p.TargetURL.Scheme
req.URL.Host = p.TargetURL.Host req.URL.Host = p.TargetURL.Host
req.URL.Path, req.URL.RawPath = joinURLPath(&p.TargetURL.URL, req.URL) req.URL.Path, req.URL.RawPath = joinURLPath(p.TargetURL, req.URL)
if targetQuery == "" || req.URL.RawQuery == "" { if targetQuery == "" || req.URL.RawQuery == "" {
req.URL.RawQuery = targetQuery + req.URL.RawQuery req.URL.RawQuery = targetQuery + req.URL.RawQuery
} else { } else {

View file

@ -0,0 +1,61 @@
package servemux
import (
"fmt"
"net/http"
"github.com/yusing/go-proxy/internal/api/v1/auth"
config "github.com/yusing/go-proxy/internal/config/types"
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
type (
ServeMux struct {
*http.ServeMux
cfg config.ConfigInstance
}
WithCfgHandler = func(config.ConfigInstance, http.ResponseWriter, *http.Request)
)
func NewServeMux(cfg config.ConfigInstance) ServeMux {
return ServeMux{http.NewServeMux(), cfg}
}
func (mux ServeMux) HandleFunc(methods, endpoint string, h any, requireAuth ...bool) {
var handler http.HandlerFunc
switch h := h.(type) {
case func(http.ResponseWriter, *http.Request):
handler = h
case http.Handler:
handler = h.ServeHTTP
case WithCfgHandler:
handler = func(w http.ResponseWriter, r *http.Request) {
h(mux.cfg, w, r)
}
default:
panic(fmt.Errorf("unsupported handler type: %T", h))
}
matchDomains := mux.cfg.Value().MatchDomains
if len(matchDomains) > 0 {
origHandler := handler
handler = func(w http.ResponseWriter, r *http.Request) {
if httpheaders.IsWebsocket(r.Header) {
httpheaders.SetWebsocketAllowedDomains(r.Header, matchDomains)
}
origHandler(w, r)
}
}
if len(requireAuth) > 0 && requireAuth[0] {
handler = auth.RequireAuth(handler)
}
if methods == "" {
mux.ServeMux.HandleFunc(endpoint, handler)
} else {
for _, m := range strutils.CommaSeperatedList(methods) {
mux.ServeMux.HandleFunc(m+" "+endpoint, handler)
}
}
}

120
internal/net/ping.go Normal file
View file

@ -0,0 +1,120 @@
package netutils
import (
"context"
"errors"
"fmt"
"net"
"os"
"time"
"golang.org/x/net/icmp"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
)
var (
ipv4EchoBytes []byte
ipv6EchoBytes []byte
)
func init() {
echoBody := &icmp.Echo{
ID: os.Getpid() & 0xffff,
Seq: 1,
Data: []byte("Hello"),
}
ipv4Echo := &icmp.Message{
Type: ipv4.ICMPTypeEcho,
Body: echoBody,
}
ipv6Echo := &icmp.Message{
Type: ipv6.ICMPTypeEchoRequest,
Body: echoBody,
}
var err error
ipv4EchoBytes, err = ipv4Echo.Marshal(nil)
if err != nil {
panic(err)
}
ipv6EchoBytes, err = ipv6Echo.Marshal(nil)
if err != nil {
panic(err)
}
}
// Ping pings the IP address using ICMP.
func Ping(ctx context.Context, ip net.IP) (bool, error) {
var msgBytes []byte
if ip.To4() != nil {
msgBytes = ipv4EchoBytes
} else {
msgBytes = ipv6EchoBytes
}
conn, err := icmp.ListenPacket("ip:icmp", ip.String())
if err != nil {
return false, err
}
defer conn.Close()
err = conn.SetWriteDeadline(time.Now().Add(5 * time.Second))
if err != nil {
return false, err
}
_, err = conn.WriteTo(msgBytes, &net.IPAddr{IP: ip})
if err != nil {
return false, err
}
err = conn.SetReadDeadline(time.Now().Add(5 * time.Second))
if err != nil {
return false, err
}
buf := make([]byte, 1500)
for {
select {
case <-ctx.Done():
return false, ctx.Err()
default:
}
n, _, err := conn.ReadFrom(buf)
if err != nil {
return false, err
}
m, err := icmp.ParseMessage(ipv4.ICMPTypeEchoReply.Protocol(), buf[:n])
if err != nil {
continue
}
if m.Type == ipv4.ICMPTypeEchoReply {
return true, nil
}
}
}
var pingDialer = &net.Dialer{
Timeout: 2 * time.Second,
}
// PingWithTCPFallback pings the IP address using ICMP and TCP fallback.
//
// If the ICMP ping fails due to permission error, it will try to connect to the specified port.
func PingWithTCPFallback(ctx context.Context, ip net.IP, port int) (bool, error) {
ok, err := Ping(ctx, ip)
if err != nil {
if !errors.Is(err, os.ErrPermission) {
return false, err
}
} else {
return ok, nil
}
conn, err := pingDialer.DialContext(ctx, "tcp", fmt.Sprintf("%s:%d", ip, port))
if err != nil {
return false, err
}
defer conn.Close()
return true, nil
}

Some files were not shown because too many files have changed in this diff Show more