mirror of
https://github.com/yusing/godoxy.git
synced 2025-07-06 14:34:04 +02:00
refactor: move api/v1/utils to net/gphttp
This commit is contained in:
parent
d315710310
commit
dfd2f3962c
8 changed files with 227 additions and 147 deletions
|
@ -1,55 +0,0 @@
|
||||||
package utils
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"net/http"
|
|
||||||
|
|
||||||
E "github.com/yusing/go-proxy/internal/error"
|
|
||||||
"github.com/yusing/go-proxy/internal/utils/strutils/ansi"
|
|
||||||
)
|
|
||||||
|
|
||||||
// HandleErr logs the error and returns an error code to the client.
|
|
||||||
// If code is specified, it will be used as the HTTP status code; otherwise,
|
|
||||||
// http.StatusInternalServerError is used.
|
|
||||||
//
|
|
||||||
// The error is only logged but not returned to the client.
|
|
||||||
func HandleErr(w http.ResponseWriter, r *http.Request, err error, code ...int) {
|
|
||||||
if err == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
LogError(r).Msg(err.Error())
|
|
||||||
if len(code) == 0 {
|
|
||||||
code = []int{http.StatusInternalServerError}
|
|
||||||
}
|
|
||||||
http.Error(w, http.StatusText(code[0]), code[0])
|
|
||||||
}
|
|
||||||
|
|
||||||
// RespondError returns error details to the client.
|
|
||||||
// If code is specified, it will be used as the HTTP status code; otherwise,
|
|
||||||
// http.StatusBadRequest is used.
|
|
||||||
func RespondError(w http.ResponseWriter, err error, code ...int) {
|
|
||||||
if len(code) == 0 {
|
|
||||||
code = []int{http.StatusBadRequest}
|
|
||||||
}
|
|
||||||
buf, err := json.Marshal(err)
|
|
||||||
if err != nil { // just in case
|
|
||||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
|
||||||
http.Error(w, ansi.StripANSI(err.Error()), code[0])
|
|
||||||
return
|
|
||||||
}
|
|
||||||
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
|
||||||
w.WriteHeader(code[0])
|
|
||||||
_, _ = w.Write(buf)
|
|
||||||
}
|
|
||||||
|
|
||||||
func ErrMissingKey(k string) error {
|
|
||||||
return E.New("missing key '" + k + "' in query or request body")
|
|
||||||
}
|
|
||||||
|
|
||||||
func ErrInvalidKey(k string) error {
|
|
||||||
return E.New("invalid key '" + k + "' in query or request body")
|
|
||||||
}
|
|
||||||
|
|
||||||
func ErrNotFound(k, v string) error {
|
|
||||||
return E.Errorf("key %q with value %q not found", k, v)
|
|
||||||
}
|
|
|
@ -1,68 +0,0 @@
|
||||||
package utils
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/coder/websocket"
|
|
||||||
"github.com/yusing/go-proxy/internal/common"
|
|
||||||
config "github.com/yusing/go-proxy/internal/config/types"
|
|
||||||
"github.com/yusing/go-proxy/internal/logging"
|
|
||||||
)
|
|
||||||
|
|
||||||
func warnNoMatchDomains() {
|
|
||||||
logging.Warn().Msg("no match domains configured, accepting websocket API request from all origins")
|
|
||||||
}
|
|
||||||
|
|
||||||
var warnNoMatchDomainOnce sync.Once
|
|
||||||
|
|
||||||
func InitiateWS(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Request) (*websocket.Conn, error) {
|
|
||||||
var originPats []string
|
|
||||||
|
|
||||||
localAddresses := []string{"127.0.0.1", "10.0.*.*", "172.16.*.*", "192.168.*.*"}
|
|
||||||
|
|
||||||
if len(cfg.Value().MatchDomains) == 0 {
|
|
||||||
warnNoMatchDomainOnce.Do(warnNoMatchDomains)
|
|
||||||
originPats = []string{"*"}
|
|
||||||
} else {
|
|
||||||
originPats = make([]string, len(cfg.Value().MatchDomains))
|
|
||||||
for i, domain := range cfg.Value().MatchDomains {
|
|
||||||
originPats[i] = "*" + domain
|
|
||||||
}
|
|
||||||
originPats = append(originPats, localAddresses...)
|
|
||||||
}
|
|
||||||
if common.IsDebug {
|
|
||||||
originPats = []string{"*"}
|
|
||||||
}
|
|
||||||
return websocket.Accept(w, r, &websocket.AcceptOptions{
|
|
||||||
OriginPatterns: originPats,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func PeriodicWS(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Request, interval time.Duration, do func(conn *websocket.Conn) error) {
|
|
||||||
conn, err := InitiateWS(cfg, w, r)
|
|
||||||
if err != nil {
|
|
||||||
HandleErr(w, r, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
/* trunk-ignore(golangci-lint/errcheck) */
|
|
||||||
defer conn.CloseNow()
|
|
||||||
|
|
||||||
ticker := time.NewTicker(interval)
|
|
||||||
defer ticker.Stop()
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-cfg.Context().Done():
|
|
||||||
return
|
|
||||||
case <-r.Context().Done():
|
|
||||||
return
|
|
||||||
case <-ticker.C:
|
|
||||||
if err := do(conn); err != nil {
|
|
||||||
LogError(r).Msg(err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,17 +1,24 @@
|
||||||
package utils
|
package gphttp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/yusing/go-proxy/internal/logging"
|
"github.com/yusing/go-proxy/internal/logging"
|
||||||
"github.com/yusing/go-proxy/internal/utils/strutils/ansi"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func WriteBody(w http.ResponseWriter, body []byte) {
|
func WriteBody(w http.ResponseWriter, body []byte) {
|
||||||
if _, err := w.Write(body); err != nil {
|
if _, err := w.Write(body); err != nil {
|
||||||
logging.Err(err).Msg("failed to write body")
|
switch {
|
||||||
|
case errors.Is(err, http.ErrHandlerTimeout),
|
||||||
|
errors.Is(err, context.DeadlineExceeded):
|
||||||
|
logging.Err(err).Msg("timeout writing body")
|
||||||
|
default:
|
||||||
|
logging.Err(err).Msg("failed to write body")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -20,28 +27,19 @@ func RespondJSON(w http.ResponseWriter, r *http.Request, data any, code ...int)
|
||||||
w.WriteHeader(code[0])
|
w.WriteHeader(code[0])
|
||||||
}
|
}
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
var j []byte
|
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
switch data := data.(type) {
|
switch data := data.(type) {
|
||||||
case string:
|
case string:
|
||||||
j = []byte(fmt.Sprintf("%q", data))
|
_, err = w.Write([]byte(fmt.Sprintf("%q", data)))
|
||||||
case []byte:
|
case []byte:
|
||||||
j = data
|
panic("use WriteBody instead")
|
||||||
case error:
|
|
||||||
j, err = json.Marshal(ansi.StripANSI(data.Error()))
|
|
||||||
default:
|
default:
|
||||||
j, err = json.MarshalIndent(data, "", " ")
|
err = json.NewEncoder(w).Encode(data)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logging.Panic().Err(err).Msg("failed to marshal json")
|
LogError(r).Err(err).Msg("failed to encode json")
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = w.Write(j)
|
|
||||||
if err != nil {
|
|
||||||
HandleErr(w, r, err)
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
return true
|
return true
|
|
@ -1,22 +1,21 @@
|
||||||
package utils
|
package gphttp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"time"
|
||||||
"github.com/yusing/go-proxy/internal/common"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
httpClient = &http.Client{
|
httpClient = &http.Client{
|
||||||
Timeout: common.ConnectionTimeout,
|
Timeout: 5 * time.Second,
|
||||||
Transport: &http.Transport{
|
Transport: &http.Transport{
|
||||||
DisableKeepAlives: true,
|
DisableKeepAlives: true,
|
||||||
ForceAttemptHTTP2: false,
|
ForceAttemptHTTP2: false,
|
||||||
DialContext: (&net.Dialer{
|
DialContext: (&net.Dialer{
|
||||||
Timeout: common.DialTimeout,
|
Timeout: 3 * time.Second,
|
||||||
KeepAlive: common.KeepAlive, // this is different from DisableKeepAlives
|
KeepAlive: 60 * time.Second, // this is different from DisableKeepAlives
|
||||||
}).DialContext,
|
}).DialContext,
|
||||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||||
},
|
},
|
100
internal/net/gphttp/error.go
Normal file
100
internal/net/gphttp/error.go
Normal file
|
@ -0,0 +1,100 @@
|
||||||
|
package gphttp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"syscall"
|
||||||
|
|
||||||
|
"github.com/yusing/go-proxy/internal/gperr"
|
||||||
|
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ServerError is for handling server errors.
|
||||||
|
//
|
||||||
|
// It logs the error and returns http.StatusInternalServerError to the client.
|
||||||
|
// Status code can be specified as an argument.
|
||||||
|
func ServerError(w http.ResponseWriter, r *http.Request, err error, code ...int) {
|
||||||
|
switch {
|
||||||
|
case err == nil,
|
||||||
|
errors.Is(err, context.Canceled),
|
||||||
|
errors.Is(err, syscall.EPIPE),
|
||||||
|
errors.Is(err, syscall.ECONNRESET):
|
||||||
|
return
|
||||||
|
}
|
||||||
|
LogError(r).Msg(err.Error())
|
||||||
|
if httpheaders.IsWebsocket(r.Header) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(code) == 0 {
|
||||||
|
code = []int{http.StatusInternalServerError}
|
||||||
|
}
|
||||||
|
http.Error(w, http.StatusText(code[0]), code[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClientError is for responding to client errors.
|
||||||
|
//
|
||||||
|
// It returns http.StatusBadRequest with reason to the client.
|
||||||
|
// Status code can be specified as an argument.
|
||||||
|
//
|
||||||
|
// For JSON marshallable errors (e.g. gperr.Error), it returns the error details as JSON.
|
||||||
|
// Otherwise, it returns the error details as plain text.
|
||||||
|
func ClientError(w http.ResponseWriter, err error, code ...int) {
|
||||||
|
if len(code) == 0 {
|
||||||
|
code = []int{http.StatusBadRequest}
|
||||||
|
}
|
||||||
|
if gperr.IsJSONMarshallable(err) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(err)
|
||||||
|
} else {
|
||||||
|
http.Error(w, err.Error(), code[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// JSONError returns a JSON response of gperr.Error with the given status code.
|
||||||
|
func JSONError(w http.ResponseWriter, err gperr.Error, code int) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(code)
|
||||||
|
json.NewEncoder(w).Encode(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// BadRequest returns a Bad Request response with the given error message.
|
||||||
|
func BadRequest(w http.ResponseWriter, err string, code ...int) {
|
||||||
|
if len(code) == 0 {
|
||||||
|
code = []int{http.StatusBadRequest}
|
||||||
|
}
|
||||||
|
w.WriteHeader(code[0])
|
||||||
|
w.Write([]byte(err))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unauthorized returns an Unauthorized response with the given error message.
|
||||||
|
func Unauthorized(w http.ResponseWriter, err string) {
|
||||||
|
BadRequest(w, err, http.StatusUnauthorized)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Forbidden returns a Forbidden response with the given error message.
|
||||||
|
func Forbidden(w http.ResponseWriter, err string) {
|
||||||
|
BadRequest(w, err, http.StatusForbidden)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NotFound returns a Not Found response with the given error message.
|
||||||
|
func NotFound(w http.ResponseWriter, err string) {
|
||||||
|
BadRequest(w, err, http.StatusNotFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ErrMissingKey(k string) error {
|
||||||
|
return gperr.New(k + " is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
func ErrInvalidKey(k string) error {
|
||||||
|
return gperr.New(k + " is invalid")
|
||||||
|
}
|
||||||
|
|
||||||
|
func ErrAlreadyExists(k, v string) error {
|
||||||
|
return gperr.Errorf("%s %q already exists", k, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ErrNotFound(k, v string) error {
|
||||||
|
return gperr.Errorf("%s %q not found", k, v)
|
||||||
|
}
|
86
internal/net/gphttp/gpwebsocket/utils.go
Normal file
86
internal/net/gphttp/gpwebsocket/utils.go
Normal file
|
@ -0,0 +1,86 @@
|
||||||
|
package gpwebsocket
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/coder/websocket"
|
||||||
|
"github.com/yusing/go-proxy/internal/common"
|
||||||
|
"github.com/yusing/go-proxy/internal/gperr"
|
||||||
|
"github.com/yusing/go-proxy/internal/logging"
|
||||||
|
"github.com/yusing/go-proxy/internal/net/gphttp"
|
||||||
|
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
|
||||||
|
)
|
||||||
|
|
||||||
|
func warnNoMatchDomains() {
|
||||||
|
logging.Warn().Msg("no match domains configured, accepting websocket API request from all origins")
|
||||||
|
}
|
||||||
|
|
||||||
|
var warnNoMatchDomainOnce sync.Once
|
||||||
|
|
||||||
|
func Initiate(w http.ResponseWriter, r *http.Request) (*websocket.Conn, error) {
|
||||||
|
var originPats []string
|
||||||
|
|
||||||
|
localAddresses := []string{"127.0.0.1", "10.0.*.*", "172.16.*.*", "192.168.*.*"}
|
||||||
|
|
||||||
|
allowedDomains := httpheaders.WebsocketAllowedDomains(r.Header)
|
||||||
|
if len(allowedDomains) == 0 || common.IsDebug {
|
||||||
|
warnNoMatchDomainOnce.Do(warnNoMatchDomains)
|
||||||
|
originPats = []string{"*"}
|
||||||
|
} else {
|
||||||
|
originPats = make([]string, len(allowedDomains))
|
||||||
|
for i, domain := range allowedDomains {
|
||||||
|
if domain[0] != '.' {
|
||||||
|
originPats[i] = "*." + domain
|
||||||
|
} else {
|
||||||
|
originPats[i] = "*" + domain
|
||||||
|
}
|
||||||
|
}
|
||||||
|
originPats = append(originPats, localAddresses...)
|
||||||
|
}
|
||||||
|
return websocket.Accept(w, r, &websocket.AcceptOptions{
|
||||||
|
OriginPatterns: originPats,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func Periodic(w http.ResponseWriter, r *http.Request, interval time.Duration, do func(conn *websocket.Conn) error) {
|
||||||
|
conn, err := Initiate(w, r)
|
||||||
|
if err != nil {
|
||||||
|
gphttp.ServerError(w, r, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
//nolint:errcheck
|
||||||
|
defer conn.CloseNow()
|
||||||
|
|
||||||
|
if err := do(conn); err != nil {
|
||||||
|
gphttp.ServerError(w, r, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ticker := time.NewTicker(interval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-r.Context().Done():
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
if err := do(conn); err != nil {
|
||||||
|
gphttp.ServerError(w, r, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteText writes a text message to the websocket connection.
|
||||||
|
// It returns true if the message was written successfully, false otherwise.
|
||||||
|
// It logs an error if the message is not written successfully.
|
||||||
|
func WriteText(r *http.Request, conn *websocket.Conn, msg string) bool {
|
||||||
|
if err := conn.Write(r.Context(), websocket.MessageText, []byte(msg)); err != nil {
|
||||||
|
gperr.LogError("failed to write text message", err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
21
internal/net/gphttp/httpheaders/websocket.go
Normal file
21
internal/net/gphttp/httpheaders/websocket.go
Normal file
|
@ -0,0 +1,21 @@
|
||||||
|
package httpheaders
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
HeaderXGoDoxyWebsocketAllowedDomains = "X-GoDoxy-Websocket-Allowed-Domains"
|
||||||
|
)
|
||||||
|
|
||||||
|
func WebsocketAllowedDomains(h http.Header) []string {
|
||||||
|
return h[HeaderXGoDoxyWebsocketAllowedDomains]
|
||||||
|
}
|
||||||
|
|
||||||
|
func SetWebsocketAllowedDomains(h http.Header, domains []string) {
|
||||||
|
h[HeaderXGoDoxyWebsocketAllowedDomains] = domains
|
||||||
|
}
|
||||||
|
|
||||||
|
func IsWebsocket(h http.Header) bool {
|
||||||
|
return UpgradeType(h) == "websocket"
|
||||||
|
}
|
|
@ -1,4 +1,4 @@
|
||||||
package utils
|
package gphttp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
@ -9,7 +9,6 @@ import (
|
||||||
|
|
||||||
func reqLogger(r *http.Request, level zerolog.Level) *zerolog.Event {
|
func reqLogger(r *http.Request, level zerolog.Level) *zerolog.Event {
|
||||||
return logging.WithLevel(level).
|
return logging.WithLevel(level).
|
||||||
Str("module", "api").
|
|
||||||
Str("remote", r.RemoteAddr).
|
Str("remote", r.RemoteAddr).
|
||||||
Str("host", r.Host).
|
Str("host", r.Host).
|
||||||
Str("uri", r.Method+" "+r.RequestURI)
|
Str("uri", r.Method+" "+r.RequestURI)
|
Loading…
Add table
Reference in a new issue