package middleware_test import ( "io" "net/http" "net/http/httptest" "strings" "testing" "github.com/yusing/go-proxy/internal/entrypoint" . "github.com/yusing/go-proxy/internal/net/gphttp/middleware" "github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy" "github.com/yusing/go-proxy/internal/net/types" "github.com/yusing/go-proxy/internal/route" routeTypes "github.com/yusing/go-proxy/internal/route/types" "github.com/yusing/go-proxy/internal/task" expect "github.com/yusing/go-proxy/internal/utils/testing" ) func noOpHandler(w http.ResponseWriter, r *http.Request) {} func TestBypassCIDR(t *testing.T) { mr, err := ModifyRequest.New(map[string]any{ "set_headers": map[string]string{ "Test-Header": "test-value", }, "bypass": []string{"remote 127.0.0.1/32"}, }) expect.NoError(t, err) tests := []struct { name string remoteAddr string expectBypass bool }{ {"bypass", "127.0.0.1:8080", true}, {"no_bypass", "192.168.1.1:8080", false}, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { req := httptest.NewRequest("GET", "http://example.com", nil) req.RemoteAddr = test.remoteAddr recorder := httptest.NewRecorder() mr.ModifyRequest(noOpHandler, recorder, req) expect.NoError(t, err) if test.expectBypass { expect.Equal(t, req.Header.Get("Test-Header"), "") } else { expect.Equal(t, req.Header.Get("Test-Header"), "test-value") } }) } } func TestBypassPath(t *testing.T) { mr, err := ModifyRequest.New(map[string]any{ "bypass": []string{"path /test/*", "path /api"}, "set_headers": map[string]string{ "Test-Header": "test-value", }, }) expect.NoError(t, err) tests := []struct { name string path string expectBypass bool }{ {"bypass", "/test/123", true}, {"bypass2", "/test/123/456", true}, {"bypass3", "/api", true}, {"no_bypass", "/test1/123/456", false}, {"no_bypass2", "/api/123", false}, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { req := httptest.NewRequest("GET", "http://example.com"+test.path, nil) recorder := httptest.NewRecorder() mr.ModifyRequest(noOpHandler, recorder, req) expect.NoError(t, err) if test.expectBypass { expect.Equal(t, req.Header.Get("Test-Header"), "") } else { expect.Equal(t, req.Header.Get("Test-Header"), "test-value") } }) } } type fakeRoundTripper struct{} func (f fakeRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { return &http.Response{ StatusCode: http.StatusOK, Body: io.NopCloser(strings.NewReader("")), Request: req, Header: make(http.Header), }, nil } func TestReverseProxyBypass(t *testing.T) { rp := reverseproxy.NewReverseProxy("test", types.MustParseURL("http://example.com"), fakeRoundTripper{}) err := PatchReverseProxy(rp, map[string]OptionsRaw{ "response": { "bypass": "path /test/* | path /api", "set_headers": map[string]string{ "Test-Header": "test-value", }, }, }) expect.NoError(t, err) tests := []struct { name string path string expectBypass bool }{ {"bypass", "/test/123", true}, {"bypass2", "/test/123/456", true}, {"bypass3", "/api", true}, {"no_bypass", "/test1/123/456", false}, {"no_bypass2", "/api/123", false}, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { req := httptest.NewRequest("GET", "http://example.com"+test.path, nil) recorder := httptest.NewRecorder() rp.ServeHTTP(recorder, req) if test.expectBypass { expect.Equal(t, recorder.Header().Get("Test-Header"), "") } else { expect.Equal(t, recorder.Header().Get("Test-Header"), "test-value") } }) } } func TestEntrypointBypassRoute(t *testing.T) { go http.ListenAndServe(":8080", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("test")) })) entry := entrypoint.NewEntrypoint() r := &route.Route{ Alias: "test-route", Port: routeTypes.Port{ Proxy: 8080, }, } err := entry.SetMiddlewares([]map[string]any{ { "use": "redirectHTTP", "bypass": []string{"route test-route"}, }, { "use": "response", "set_headers": map[string]string{ "Test-Header": "test-value", }, }, }) expect.NoError(t, err) err = r.Validate() expect.NoError(t, err) r.Start(task.RootTask("test", false)) recorder := httptest.NewRecorder() req := httptest.NewRequest("GET", "http://test-route.example.com", nil) entry.ServeHTTP(recorder, req) expect.Equal(t, recorder.Code, http.StatusOK, "should bypass http redirect") expect.Equal(t, recorder.Body.String(), "test") expect.Equal(t, recorder.Header().Get("Test-Header"), "test-value") }