diff --git a/internal/route/fileserver.go b/internal/route/fileserver.go index 35b8d6b..3bec3f6 100644 --- a/internal/route/fileserver.go +++ b/internal/route/fileserver.go @@ -2,6 +2,8 @@ package route import ( "net/http" + "path" + "path/filepath" "github.com/yusing/go-proxy/internal/common" gphttp "github.com/yusing/go-proxy/internal/net/http" @@ -34,7 +36,14 @@ func handler(root string) http.Handler { } func NewFileServer(base *Route) (*FileServer, E.Error) { - s := &FileServer{Route: base, handler: handler(base.Root)} + s := &FileServer{Route: base} + + s.Root = filepath.Clean(s.Root) + if !path.IsAbs(s.Root) { + return nil, E.Errorf("root must be absolute") + } + + s.handler = handler(s.Root) if len(s.Middlewares) > 0 { mid, err := middleware.BuildMiddlewareFromMap(s.Alias, s.Middlewares) @@ -91,7 +100,9 @@ func (s *FileServer) Start(parent task.Parent) E.Error { if s.UseHealthCheck() { s.Health = monitor.NewFileServerHealthMonitor(s.TargetName(), s.HealthCheck, s.Root) - s.Health.Start(s.task) + if err := s.Health.Start(s.task); err != nil { + return err + } } routes.SetHTTPRoute(s.TargetName(), s) diff --git a/internal/route/fileserver_test.go b/internal/route/fileserver_test.go new file mode 100644 index 0000000..f93e523 --- /dev/null +++ b/internal/route/fileserver_test.go @@ -0,0 +1,122 @@ +//nolint:gofumpt +package route + +import ( + "io" + "net/http" + "net/http/httptest" + "net/url" + "os" + "path/filepath" + "testing" + + . "github.com/yusing/go-proxy/internal/utils/testing" +) + +func TestPathTraversalAttack(t *testing.T) { + tmp := t.TempDir() + root := filepath.Join(tmp, "static") + if err := os.Mkdir(root, 0755); err != nil { + t.Fatalf("Failed to create root directory: %v", err) + } + + // Create a file inside the root + validPath := "test.txt" + validContent := "test content" + if err := os.WriteFile(filepath.Join(root, validPath), []byte(validContent), 0644); err != nil { + t.Fatalf("Failed to create test file: %v", err) + } + + // create one at .. + secretFile := "secret.txt" + if err := os.WriteFile(filepath.Join(tmp, secretFile), []byte(validContent), 0644); err != nil { + t.Fatalf("Failed to create test file: %v", err) + } + + traversals := []string{ + "../", + "./../", + "./.././", + "..%2f", + ".%2f..%2f", + ".%2f%2e%2e", + ".%2e", + ".%2e/", + ".%2e%2f", + "%2e.", + "%2e%2e", + } + + for _, traversal := range traversals { + traversals = append(traversals, "%2f"+traversal) + traversals = append(traversals, traversal+"%2f") + traversals = append(traversals, "%2f"+traversal+"%2f") + traversals = append(traversals, "/"+traversal) + traversals = append(traversals, traversal+"/") + traversals = append(traversals, "/"+traversal+"/") + } + + // Setup the FileServer + fs, err := NewFileServer(&Route{Root: root}) + if err != nil { + t.Fatalf("Failed to create FileServer: %v", err) + } + + // Create a test server with the handler + ts := httptest.NewServer(fs.handler) + defer ts.Close() + + // Test valid path + t.Run("valid path", func(t *testing.T) { + validURL := ts.URL + "/" + validPath + resp, err := http.Get(validURL) + if err != nil { + t.Errorf("Error making request to %s: %v", validURL, err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("Expected 200 OK, got %d", resp.StatusCode) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Errorf("Error reading response body: %v", err) + } + + if string(body) != validContent { + t.Errorf("Expected %q, got %q", validContent, string(body)) + } + }) + + // Test ../ path + // tsURL := Must(url.Parse(ts.URL)) + for _, traversal := range traversals { + p := traversal + secretFile + t.Run(p, func(t *testing.T) { + u := &url.URL{Scheme: "http", Host: ts.Listener.Addr().String(), Path: p} + resp, err := http.DefaultClient.Do(&http.Request{ + Method: http.MethodGet, + URL: u, + }) + if err != nil { + t.Errorf("Error making request to %s: %v", p, err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusNotFound && resp.StatusCode != http.StatusBadRequest { + t.Errorf("Expected status 404 or 400, got %d in url %s", resp.StatusCode, u.String()) + } + + u = Must(url.Parse(ts.URL + "/" + p)) + resp, err = http.DefaultClient.Do(&http.Request{ + Method: http.MethodGet, + URL: u, + }) + if err != nil { + t.Errorf("Error making request to %s: %v", u.String(), err) + } + defer resp.Body.Close() + }) + } +} diff --git a/internal/utils/testing/testing.go b/internal/utils/testing/testing.go index b95ba7d..fcd9e2e 100644 --- a/internal/utils/testing/testing.go +++ b/internal/utils/testing/testing.go @@ -16,7 +16,10 @@ func init() { } } -func IgnoreError[Result any](r Result, _ error) Result { +func Must[Result any](r Result, err error) Result { + if err != nil { + panic(err) + } return r }