diff --git a/internal/net/http/middleware/forward_auth.go b/internal/net/http/middleware/forward_auth.go index dbb89f9..1e88eca 100644 --- a/internal/net/http/middleware/forward_auth.go +++ b/internal/net/http/middleware/forward_auth.go @@ -37,36 +37,38 @@ var ForwardAuth = &forwardAuth{ } func NewForwardAuthfunc(optsRaw OptionsRaw) (*Middleware, E.NestedError) { - faWithOpts := new(forwardAuth) - faWithOpts.forwardAuthOpts = new(forwardAuthOpts) - err := Deserialize(optsRaw, faWithOpts.forwardAuthOpts) + fa := new(forwardAuth) + fa.forwardAuthOpts = new(forwardAuthOpts) + err := Deserialize(optsRaw, fa.forwardAuthOpts) if err != nil { return nil, err } - _, err = E.Check(url.Parse(faWithOpts.Address)) + _, err = E.Check(url.Parse(fa.Address)) if err != nil { - return nil, E.Invalid("address", faWithOpts.Address) + return nil, E.Invalid("address", fa.Address) } - faWithOpts.m = &Middleware{ - impl: faWithOpts, - before: faWithOpts.forward, + fa.m = &Middleware{ + impl: fa, + before: fa.forward, } // TODO: use tr from reverse proxy - tr, ok := faWithOpts.forwardAuthOpts.transport.(*http.Transport) + tr, ok := fa.forwardAuthOpts.transport.(*http.Transport) if ok { tr = tr.Clone() + } else { + tr = gpHTTP.DefaultTransport.Clone() } - faWithOpts.client = http.Client{ + fa.client = http.Client{ CheckRedirect: func(r *Request, via []*Request) error { return http.ErrUseLastResponse }, Timeout: 30 * time.Second, Transport: tr, } - return faWithOpts.m, nil + return fa.m, nil } func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Request) { @@ -106,6 +108,7 @@ func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Req } if faResp.StatusCode < http.StatusOK || faResp.StatusCode >= http.StatusMultipleChoices { + fa.m.AddTracef("status %d", faResp.StatusCode) gpHTTP.CopyHeader(w.Header(), faResp.Header) gpHTTP.RemoveHop(w.Header()) @@ -116,6 +119,7 @@ func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Req return } else if redirectURL.String() != "" { w.Header().Set("Location", redirectURL.String()) + fa.m.AddTracef("redirect to %q", redirectURL.String()) } w.WriteHeader(faResp.StatusCode) diff --git a/internal/net/http/middleware/trace.go b/internal/net/http/middleware/trace.go index d654d81..a9520c4 100644 --- a/internal/net/http/middleware/trace.go +++ b/internal/net/http/middleware/trace.go @@ -16,6 +16,7 @@ type Trace struct { Message string `json:"msg"` ReqHeaders http.Header `json:"req_headers,omitempty"` RespHeaders http.Header `json:"resp_headers,omitempty"` + RespStatus int `json:"resp_status,omitempty"` Additional map[string]any `json:"additional,omitempty"` } @@ -46,6 +47,7 @@ func (tr *Trace) WithResponse(resp *Response) *Trace { tr.URL = resp.Request.RequestURI tr.ReqHeaders = resp.Request.Header.Clone() tr.RespHeaders = resp.Header.Clone() + tr.RespStatus = resp.StatusCode return tr } diff --git a/internal/utils/serialization.go b/internal/utils/serialization.go index e29b31c..147ada0 100644 --- a/internal/utils/serialization.go +++ b/internal/utils/serialization.go @@ -277,6 +277,16 @@ func Convert(src reflect.Value, dst reflect.Value) E.NestedError { func ConvertString(src string, dst reflect.Value) (convertible bool, convErr E.NestedError) { convertible = true + if dst.Kind() == reflect.Ptr { + if dst.IsNil() { + dst.Set(reflect.New(dst.Type().Elem())) + } + dst = dst.Elem() + } + if dst.Kind() == reflect.String { + dst.SetString(src) + return + } // primitive types / simple types switch dst.Kind() { case reflect.Bool: diff --git a/internal/utils/serialization_test.go b/internal/utils/serialization_test.go index 387e8b5..00213e5 100644 --- a/internal/utils/serialization_test.go +++ b/internal/utils/serialization_test.go @@ -1,6 +1,7 @@ package utils import ( + "reflect" "testing" . "github.com/yusing/go-proxy/internal/utils/testing" @@ -45,3 +46,59 @@ func TestDeserialize(t *testing.T) { ExpectNoError(t, err.Error()) ExpectDeepEqual(t, s, testStruct) } + +func TestStringIntConvert(t *testing.T) { + s := "127" + + test := struct { + i8 int8 + i16 int16 + i32 int32 + i64 int64 + u8 uint8 + u16 uint16 + u32 uint32 + u64 uint64 + }{} + + ok, err := ConvertString(s, reflect.ValueOf(&test.i8)) + + ExpectTrue(t, ok) + ExpectNoError(t, err.Error()) + ExpectEqual(t, test.i8, int8(127)) + + ok, err = ConvertString(s, reflect.ValueOf(&test.i16)) + ExpectTrue(t, ok) + ExpectNoError(t, err.Error()) + ExpectEqual(t, test.i16, int16(127)) + + ok, err = ConvertString(s, reflect.ValueOf(&test.i32)) + ExpectTrue(t, ok) + ExpectNoError(t, err.Error()) + ExpectEqual(t, test.i32, int32(127)) + + ok, err = ConvertString(s, reflect.ValueOf(&test.i64)) + ExpectTrue(t, ok) + ExpectNoError(t, err.Error()) + ExpectEqual(t, test.i64, int64(127)) + + ok, err = ConvertString(s, reflect.ValueOf(&test.u8)) + ExpectTrue(t, ok) + ExpectNoError(t, err.Error()) + ExpectEqual(t, test.u8, uint8(127)) + + ok, err = ConvertString(s, reflect.ValueOf(&test.u16)) + ExpectTrue(t, ok) + ExpectNoError(t, err.Error()) + ExpectEqual(t, test.u16, uint16(127)) + + ok, err = ConvertString(s, reflect.ValueOf(&test.u32)) + ExpectTrue(t, ok) + ExpectNoError(t, err.Error()) + ExpectEqual(t, test.u32, uint32(127)) + + ok, err = ConvertString(s, reflect.ValueOf(&test.u64)) + ExpectTrue(t, ok) + ExpectNoError(t, err.Error()) + ExpectEqual(t, test.u64, uint64(127)) +}