diff --git a/internal/autocert/user.go b/internal/autocert/user.go index 3117f2a..9ced682 100644 --- a/internal/autocert/user.go +++ b/internal/autocert/user.go @@ -1,8 +1,9 @@ package autocert import ( - "github.com/go-acme/lego/v4/registration" "crypto" + + "github.com/go-acme/lego/v4/registration" ) type User struct { @@ -19,4 +20,4 @@ func (u *User) GetRegistration() *registration.Resource { } func (u *User) GetPrivateKey() crypto.PrivateKey { return u.key -} \ No newline at end of file +} diff --git a/internal/docker/idlewatcher/waker.go b/internal/docker/idlewatcher/waker.go index 3fde40f..90d23dc 100644 --- a/internal/docker/idlewatcher/waker.go +++ b/internal/docker/idlewatcher/waker.go @@ -64,7 +64,7 @@ func (w *Waker) wake(next http.HandlerFunc, rw http.ResponseWriter, r *http.Requ defer cancel() accept := gphttp.GetAccept(r.Header) - acceptHTML := r.Method == http.MethodGet && accept.AcceptHTML() + acceptHTML := (r.Method == http.MethodGet && accept.AcceptHTML() || r.RequestURI == "/" && accept.IsEmpty()) isCheckRedirect := r.Header.Get(headerCheckRedirect) != "" if !isCheckRedirect && acceptHTML { diff --git a/internal/route/tcp.go b/internal/route/tcp.go index 4ce5901..d5e6621 100755 --- a/internal/route/tcp.go +++ b/internal/route/tcp.go @@ -36,7 +36,7 @@ func (route *TCPRoute) Setup() error { if err != nil { return err } - //! this read the allocated port from orginal ':0' + //! this read the allocated port from original ':0' route.Port.ListeningPort = T.Port(in.Addr().(*net.TCPAddr).Port) route.listener = in return nil diff --git a/internal/route/udp.go b/internal/route/udp.go index 071d660..83227b4 100755 --- a/internal/route/udp.go +++ b/internal/route/udp.go @@ -51,7 +51,7 @@ func (route *UDPRoute) Setup() error { return err } - //! this read the allocated listeningPort from orginal ':0' + //! this read the allocated listeningPort from original ':0' route.Port.ListeningPort = T.Port(source.LocalAddr().(*net.UDPAddr).Port) route.listeningConn = source diff --git a/internal/utils/io.go b/internal/utils/io.go index 2875f45..5987b7c 100644 --- a/internal/utils/io.go +++ b/internal/utils/io.go @@ -6,6 +6,7 @@ import ( "errors" "io" "os" + "sync" "syscall" E "github.com/yusing/go-proxy/internal/error" @@ -28,10 +29,8 @@ type ( } Pipe struct { - r ContextReader - w ContextWriter - ctx context.Context - cancel context.CancelFunc + r ContextReader + w ContextWriter } BidirectionalPipe struct { @@ -59,12 +58,9 @@ func (w *ContextWriter) Write(p []byte) (int, error) { } func NewPipe(ctx context.Context, r io.ReadCloser, w io.WriteCloser) *Pipe { - _, cancel := context.WithCancel(ctx) return &Pipe{ - r: ContextReader{ctx: ctx, Reader: r}, - w: ContextWriter{ctx: ctx, Writer: w}, - ctx: ctx, - cancel: cancel, + r: ContextReader{ctx: ctx, Reader: r}, + w: ContextWriter{ctx: ctx, Writer: w}, } } @@ -87,22 +83,20 @@ func NewBidirectionalPipe(ctx context.Context, rw1 io.ReadWriteCloser, rw2 io.Re } } -func NewBidirectionalPipeIntermediate(ctx context.Context, listener io.ReadCloser, client io.ReadWriteCloser, target io.ReadWriteCloser) *BidirectionalPipe { - return &BidirectionalPipe{ - pSrcDst: NewPipe(ctx, listener, client), - pDstSrc: NewPipe(ctx, client, target), - } -} - func (p BidirectionalPipe) Start() error { - errCh := make(chan error, 2) + var wg sync.WaitGroup + wg.Add(2) + b := E.NewBuilder("bidirectional pipe error") go func() { - errCh <- p.pSrcDst.Start() + b.AddE(p.pSrcDst.Start()) + wg.Done() }() go func() { - errCh <- p.pDstSrc.Start() + b.AddE(p.pDstSrc.Start()) + wg.Done() }() - return E.JoinE("bidirectional pipe error", <-errCh, <-errCh).Error() + wg.Wait() + return b.Build().Error() } func Copy(dst *ContextWriter, src *ContextReader) error { diff --git a/internal/utils/serialization.go b/internal/utils/serialization.go index 147ada0..12d564e 100644 --- a/internal/utils/serialization.go +++ b/internal/utils/serialization.go @@ -149,9 +149,8 @@ func Deserialize(src SerializedObject, dst any) E.NestedError { if dstV.Kind() == reflect.Struct { mapping := make(map[string]reflect.Value) - for i := 0; i < dstV.NumField(); i++ { - field := dstT.Field(i) - mapping[ToLowerNoSnake(field.Name)] = dstV.Field(i) + for _, field := range reflect.VisibleFields(dstT) { + mapping[ToLowerNoSnake(field.Name)] = dstV.FieldByName(field.Name) } for k, v := range src { if field, ok := mapping[ToLowerNoSnake(k)]; ok { @@ -322,7 +321,7 @@ func ConvertString(src string, dst reflect.Value) (convertible bool, convErr E.N var tmp any switch dst.Kind() { case reflect.Slice: - // one liner is comma seperated list + // one liner is comma separated list if len(lines) == 0 { dst.Set(reflect.ValueOf(CommaSeperatedList(src))) return