//nolint:gofumpt package route import ( "io" "net/http" "net/http/httptest" "net/url" "os" "path/filepath" "testing" . "github.com/yusing/go-proxy/internal/utils/testing" ) func TestPathTraversalAttack(t *testing.T) { tmp := t.TempDir() root := filepath.Join(tmp, "static") if err := os.Mkdir(root, 0755); err != nil { t.Fatalf("Failed to create root directory: %v", err) } // Create a file inside the root validPath := "test.txt" validContent := "test content" if err := os.WriteFile(filepath.Join(root, validPath), []byte(validContent), 0644); err != nil { t.Fatalf("Failed to create test file: %v", err) } // create one at .. secretFile := "secret.txt" if err := os.WriteFile(filepath.Join(tmp, secretFile), []byte(validContent), 0644); err != nil { t.Fatalf("Failed to create test file: %v", err) } traversals := []string{ "../", "./../", "./.././", "..%2f", ".%2f..%2f", ".%2f%2e%2e", ".%2e", ".%2e/", ".%2e%2f", "%2e.", "%2e%2e", } for _, traversal := range traversals { traversals = append(traversals, "%2f"+traversal) traversals = append(traversals, traversal+"%2f") traversals = append(traversals, "%2f"+traversal+"%2f") traversals = append(traversals, "/"+traversal) traversals = append(traversals, traversal+"/") traversals = append(traversals, "/"+traversal+"/") } // Setup the FileServer fs, err := NewFileServer(&Route{Root: root}) if err != nil { t.Fatalf("Failed to create FileServer: %v", err) } // Create a test server with the handler ts := httptest.NewServer(fs.handler) defer ts.Close() // Test valid path t.Run("valid path", func(t *testing.T) { validURL := ts.URL + "/" + validPath resp, err := http.Get(validURL) if err != nil { t.Errorf("Error making request to %s: %v", validURL, err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { t.Errorf("Expected 200 OK, got %d", resp.StatusCode) } body, err := io.ReadAll(resp.Body) if err != nil { t.Errorf("Error reading response body: %v", err) } if string(body) != validContent { t.Errorf("Expected %q, got %q", validContent, string(body)) } }) // Test ../ path // tsURL := Must(url.Parse(ts.URL)) for _, traversal := range traversals { p := traversal + secretFile t.Run(p, func(t *testing.T) { u := &url.URL{Scheme: "http", Host: ts.Listener.Addr().String(), Path: p} resp, err := http.DefaultClient.Do(&http.Request{ Method: http.MethodGet, URL: u, }) if err != nil { t.Errorf("Error making request to %s: %v", p, err) } defer resp.Body.Close() if resp.StatusCode != http.StatusNotFound && resp.StatusCode != http.StatusBadRequest { t.Errorf("Expected status 404 or 400, got %d in url %s", resp.StatusCode, u.String()) } u = Must(url.Parse(ts.URL + "/" + p)) resp, err = http.DefaultClient.Do(&http.Request{ Method: http.MethodGet, URL: u, }) if err != nil { t.Errorf("Error making request to %s: %v", u.String(), err) } defer resp.Body.Close() }) } }