api: fix validation and http response

This commit is contained in:
yusing 2025-01-04 09:01:52 +08:00
parent 112859caa5
commit c30d3f585f
6 changed files with 33 additions and 16 deletions

View file

@ -48,7 +48,7 @@ func SetFileContent(w http.ResponseWriter, r *http.Request) {
// no validation for include files // no validation for include files
if valErr != nil { if valErr != nil {
U.RespondJSON(w, r, valErr, http.StatusBadRequest) U.RespondError(w, valErr, http.StatusBadRequest)
return return
} }

View file

@ -4,6 +4,7 @@ import (
"net/http" "net/http"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/utils/strutils/ansi"
) )
// HandleErr logs the error and returns an HTTP error response to the client. // HandleErr logs the error and returns an HTTP error response to the client.
@ -23,6 +24,13 @@ func HandleErr(w http.ResponseWriter, r *http.Request, origErr error, code ...in
http.Error(w, http.StatusText(statusCode), statusCode) http.Error(w, http.StatusText(statusCode), statusCode)
} }
func RespondError(w http.ResponseWriter, err error, code ...int) {
if len(code) > 0 {
w.WriteHeader(code[0])
}
WriteBody(w, []byte(ansi.StripANSI(err.Error())))
}
func ErrMissingKey(k string) error { func ErrMissingKey(k string) error {
return E.New("missing key '" + k + "' in query or request body") return E.New("missing key '" + k + "' in query or request body")
} }

View file

@ -6,6 +6,7 @@ import (
"net/http" "net/http"
"github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/utils/strutils/ansi"
) )
func WriteBody(w http.ResponseWriter, body []byte) { func WriteBody(w http.ResponseWriter, body []byte) {
@ -27,13 +28,17 @@ func RespondJSON(w http.ResponseWriter, r *http.Request, data any, code ...int)
j = []byte(fmt.Sprintf("%q", data)) j = []byte(fmt.Sprintf("%q", data))
case []byte: case []byte:
j = data j = data
case error:
j, err = json.Marshal(ansi.StripANSI(data.Error()))
default: default:
j, err = json.MarshalIndent(data, "", " ") j, err = json.MarshalIndent(data, "", " ")
}
if err != nil { if err != nil {
logging.Panic().Err(err).Msg("failed to marshal json") logging.Panic().Err(err).Msg("failed to marshal json")
return false return false
} }
}
_, err = w.Write(j) _, err = w.Write(j)
if err != nil { if err != nil {
HandleErr(w, r, err) HandleErr(w, r, err)

View file

@ -67,8 +67,8 @@ func Load() (*Config, E.Error) {
} }
func Validate(data []byte) E.Error { func Validate(data []byte) E.Error {
var model *types.Config var model types.Config
return utils.DeserializeYAML(data, model) return utils.DeserializeYAML(data, &model)
} }
func MatchDomains() []string { func MatchDomains() []string {

View file

@ -31,8 +31,12 @@ func FileProviderImpl(filename string) (ProviderImpl, error) {
return impl, nil return impl, nil
} }
func validate(data []byte) (route.RawEntries, E.Error) { func validate(data []byte) (route.Routes, E.Error) {
return utils.DeserializeYAMLMap[*route.RawEntry](data) entries, err := utils.DeserializeYAMLMap[*route.RawEntry](data)
if err != nil {
return route.NewRoutes(), err
}
return route.FromEntries(entries)
} }
func Validate(data []byte) (err E.Error) { func Validate(data []byte) (err E.Error) {
@ -56,11 +60,7 @@ func (p *FileProvider) loadRoutesImpl() (route.Routes, E.Error) {
return routes, E.From(err) return routes, E.From(err)
} }
entries, err := validate(data) return validate(data)
if err == nil {
return route.FromEntries(entries)
}
return routes, E.From(err)
} }
func (p *FileProvider) NewWatcher() W.Watcher { func (p *FileProvider) NewWatcher() W.Watcher {

View file

@ -5,6 +5,7 @@ import (
"errors" "errors"
"os" "os"
"reflect" "reflect"
"runtime/debug"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@ -158,7 +159,7 @@ func Deserialize(src SerializedObject, dst any) E.Error {
return E.Errorf("deserialize: src is %w", ErrNilValue) return E.Errorf("deserialize: src is %w", ErrNilValue)
} }
if dst == nil { if dst == nil {
return E.Errorf("deserialize: dst is %w", ErrNilValue) return E.Errorf("deserialize: dst is %w\n%s", ErrNilValue, debug.Stack())
} }
dstV := reflect.ValueOf(dst) dstV := reflect.ValueOf(dst)
@ -169,7 +170,7 @@ func Deserialize(src SerializedObject, dst any) E.Error {
if dstV.CanSet() { if dstV.CanSet() {
dstV.Set(New(dstT.Elem())) dstV.Set(New(dstT.Elem()))
} else { } else {
return E.Errorf("deserialize: dst is %w", ErrNilValue) return E.Errorf("deserialize: dst is %w\n%s", ErrNilValue, debug.Stack())
} }
} }
dstV = dstV.Elem() dstV = dstV.Elem()
@ -191,6 +192,9 @@ func Deserialize(src SerializedObject, dst any) E.Error {
for _, field := range fields { for _, field := range fields {
var key string var key string
if jsonTag, ok := field.Tag.Lookup("json"); ok { if jsonTag, ok := field.Tag.Lookup("json"); ok {
if jsonTag == "-" {
continue
}
key = strutils.CommaSeperatedList(jsonTag)[0] key = strutils.CommaSeperatedList(jsonTag)[0]
} else { } else {
key = field.Name key = field.Name