fix forward auth attempt#1

This commit is contained in:
yusing 2024-10-06 03:18:06 +08:00
parent 99746bad8e
commit 01ff63a007
4 changed files with 84 additions and 11 deletions

View file

@ -37,36 +37,38 @@ var ForwardAuth = &forwardAuth{
} }
func NewForwardAuthfunc(optsRaw OptionsRaw) (*Middleware, E.NestedError) { func NewForwardAuthfunc(optsRaw OptionsRaw) (*Middleware, E.NestedError) {
faWithOpts := new(forwardAuth) fa := new(forwardAuth)
faWithOpts.forwardAuthOpts = new(forwardAuthOpts) fa.forwardAuthOpts = new(forwardAuthOpts)
err := Deserialize(optsRaw, faWithOpts.forwardAuthOpts) err := Deserialize(optsRaw, fa.forwardAuthOpts)
if err != nil { if err != nil {
return nil, err return nil, err
} }
_, err = E.Check(url.Parse(faWithOpts.Address)) _, err = E.Check(url.Parse(fa.Address))
if err != nil { if err != nil {
return nil, E.Invalid("address", faWithOpts.Address) return nil, E.Invalid("address", fa.Address)
} }
faWithOpts.m = &Middleware{ fa.m = &Middleware{
impl: faWithOpts, impl: fa,
before: faWithOpts.forward, before: fa.forward,
} }
// TODO: use tr from reverse proxy // TODO: use tr from reverse proxy
tr, ok := faWithOpts.forwardAuthOpts.transport.(*http.Transport) tr, ok := fa.forwardAuthOpts.transport.(*http.Transport)
if ok { if ok {
tr = tr.Clone() tr = tr.Clone()
} else {
tr = gpHTTP.DefaultTransport.Clone()
} }
faWithOpts.client = http.Client{ fa.client = http.Client{
CheckRedirect: func(r *Request, via []*Request) error { CheckRedirect: func(r *Request, via []*Request) error {
return http.ErrUseLastResponse return http.ErrUseLastResponse
}, },
Timeout: 30 * time.Second, Timeout: 30 * time.Second,
Transport: tr, Transport: tr,
} }
return faWithOpts.m, nil return fa.m, nil
} }
func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Request) { func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Request) {
@ -106,6 +108,7 @@ func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Req
} }
if faResp.StatusCode < http.StatusOK || faResp.StatusCode >= http.StatusMultipleChoices { if faResp.StatusCode < http.StatusOK || faResp.StatusCode >= http.StatusMultipleChoices {
fa.m.AddTracef("status %d", faResp.StatusCode)
gpHTTP.CopyHeader(w.Header(), faResp.Header) gpHTTP.CopyHeader(w.Header(), faResp.Header)
gpHTTP.RemoveHop(w.Header()) gpHTTP.RemoveHop(w.Header())
@ -116,6 +119,7 @@ func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Req
return return
} else if redirectURL.String() != "" { } else if redirectURL.String() != "" {
w.Header().Set("Location", redirectURL.String()) w.Header().Set("Location", redirectURL.String())
fa.m.AddTracef("redirect to %q", redirectURL.String())
} }
w.WriteHeader(faResp.StatusCode) w.WriteHeader(faResp.StatusCode)

View file

@ -16,6 +16,7 @@ type Trace struct {
Message string `json:"msg"` Message string `json:"msg"`
ReqHeaders http.Header `json:"req_headers,omitempty"` ReqHeaders http.Header `json:"req_headers,omitempty"`
RespHeaders http.Header `json:"resp_headers,omitempty"` RespHeaders http.Header `json:"resp_headers,omitempty"`
RespStatus int `json:"resp_status,omitempty"`
Additional map[string]any `json:"additional,omitempty"` Additional map[string]any `json:"additional,omitempty"`
} }
@ -46,6 +47,7 @@ func (tr *Trace) WithResponse(resp *Response) *Trace {
tr.URL = resp.Request.RequestURI tr.URL = resp.Request.RequestURI
tr.ReqHeaders = resp.Request.Header.Clone() tr.ReqHeaders = resp.Request.Header.Clone()
tr.RespHeaders = resp.Header.Clone() tr.RespHeaders = resp.Header.Clone()
tr.RespStatus = resp.StatusCode
return tr return tr
} }

View file

@ -277,6 +277,16 @@ func Convert(src reflect.Value, dst reflect.Value) E.NestedError {
func ConvertString(src string, dst reflect.Value) (convertible bool, convErr E.NestedError) { func ConvertString(src string, dst reflect.Value) (convertible bool, convErr E.NestedError) {
convertible = true convertible = true
if dst.Kind() == reflect.Ptr {
if dst.IsNil() {
dst.Set(reflect.New(dst.Type().Elem()))
}
dst = dst.Elem()
}
if dst.Kind() == reflect.String {
dst.SetString(src)
return
}
// primitive types / simple types // primitive types / simple types
switch dst.Kind() { switch dst.Kind() {
case reflect.Bool: case reflect.Bool:

View file

@ -1,6 +1,7 @@
package utils package utils
import ( import (
"reflect"
"testing" "testing"
. "github.com/yusing/go-proxy/internal/utils/testing" . "github.com/yusing/go-proxy/internal/utils/testing"
@ -45,3 +46,59 @@ func TestDeserialize(t *testing.T) {
ExpectNoError(t, err.Error()) ExpectNoError(t, err.Error())
ExpectDeepEqual(t, s, testStruct) ExpectDeepEqual(t, s, testStruct)
} }
func TestStringIntConvert(t *testing.T) {
s := "127"
test := struct {
i8 int8
i16 int16
i32 int32
i64 int64
u8 uint8
u16 uint16
u32 uint32
u64 uint64
}{}
ok, err := ConvertString(s, reflect.ValueOf(&test.i8))
ExpectTrue(t, ok)
ExpectNoError(t, err.Error())
ExpectEqual(t, test.i8, int8(127))
ok, err = ConvertString(s, reflect.ValueOf(&test.i16))
ExpectTrue(t, ok)
ExpectNoError(t, err.Error())
ExpectEqual(t, test.i16, int16(127))
ok, err = ConvertString(s, reflect.ValueOf(&test.i32))
ExpectTrue(t, ok)
ExpectNoError(t, err.Error())
ExpectEqual(t, test.i32, int32(127))
ok, err = ConvertString(s, reflect.ValueOf(&test.i64))
ExpectTrue(t, ok)
ExpectNoError(t, err.Error())
ExpectEqual(t, test.i64, int64(127))
ok, err = ConvertString(s, reflect.ValueOf(&test.u8))
ExpectTrue(t, ok)
ExpectNoError(t, err.Error())
ExpectEqual(t, test.u8, uint8(127))
ok, err = ConvertString(s, reflect.ValueOf(&test.u16))
ExpectTrue(t, ok)
ExpectNoError(t, err.Error())
ExpectEqual(t, test.u16, uint16(127))
ok, err = ConvertString(s, reflect.ValueOf(&test.u32))
ExpectTrue(t, ok)
ExpectNoError(t, err.Error())
ExpectEqual(t, test.u32, uint32(127))
ok, err = ConvertString(s, reflect.ValueOf(&test.u64))
ExpectTrue(t, ok)
ExpectNoError(t, err.Error())
ExpectEqual(t, test.u64, uint64(127))
}