mirror of
https://github.com/yusing/godoxy.git
synced 2025-05-31 00:52:35 +02:00
add test for path traversal attack, small fix on FileServer.Start method
This commit is contained in:
parent
4059e373e6
commit
1eb3cb3ddb
3 changed files with 139 additions and 3 deletions
|
@ -2,6 +2,8 @@ package route
|
|||
|
||||
import (
|
||||
"net/http"
|
||||
"path"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
gphttp "github.com/yusing/go-proxy/internal/net/http"
|
||||
|
@ -34,7 +36,14 @@ func handler(root string) http.Handler {
|
|||
}
|
||||
|
||||
func NewFileServer(base *Route) (*FileServer, E.Error) {
|
||||
s := &FileServer{Route: base, handler: handler(base.Root)}
|
||||
s := &FileServer{Route: base}
|
||||
|
||||
s.Root = filepath.Clean(s.Root)
|
||||
if !path.IsAbs(s.Root) {
|
||||
return nil, E.Errorf("root must be absolute")
|
||||
}
|
||||
|
||||
s.handler = handler(s.Root)
|
||||
|
||||
if len(s.Middlewares) > 0 {
|
||||
mid, err := middleware.BuildMiddlewareFromMap(s.Alias, s.Middlewares)
|
||||
|
@ -91,7 +100,9 @@ func (s *FileServer) Start(parent task.Parent) E.Error {
|
|||
|
||||
if s.UseHealthCheck() {
|
||||
s.Health = monitor.NewFileServerHealthMonitor(s.TargetName(), s.HealthCheck, s.Root)
|
||||
s.Health.Start(s.task)
|
||||
if err := s.Health.Start(s.task); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
routes.SetHTTPRoute(s.TargetName(), s)
|
||||
|
|
122
internal/route/fileserver_test.go
Normal file
122
internal/route/fileserver_test.go
Normal file
|
@ -0,0 +1,122 @@
|
|||
//nolint:gofumpt
|
||||
package route
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
||||
func TestPathTraversalAttack(t *testing.T) {
|
||||
tmp := t.TempDir()
|
||||
root := filepath.Join(tmp, "static")
|
||||
if err := os.Mkdir(root, 0755); err != nil {
|
||||
t.Fatalf("Failed to create root directory: %v", err)
|
||||
}
|
||||
|
||||
// Create a file inside the root
|
||||
validPath := "test.txt"
|
||||
validContent := "test content"
|
||||
if err := os.WriteFile(filepath.Join(root, validPath), []byte(validContent), 0644); err != nil {
|
||||
t.Fatalf("Failed to create test file: %v", err)
|
||||
}
|
||||
|
||||
// create one at ..
|
||||
secretFile := "secret.txt"
|
||||
if err := os.WriteFile(filepath.Join(tmp, secretFile), []byte(validContent), 0644); err != nil {
|
||||
t.Fatalf("Failed to create test file: %v", err)
|
||||
}
|
||||
|
||||
traversals := []string{
|
||||
"../",
|
||||
"./../",
|
||||
"./.././",
|
||||
"..%2f",
|
||||
".%2f..%2f",
|
||||
".%2f%2e%2e",
|
||||
".%2e",
|
||||
".%2e/",
|
||||
".%2e%2f",
|
||||
"%2e.",
|
||||
"%2e%2e",
|
||||
}
|
||||
|
||||
for _, traversal := range traversals {
|
||||
traversals = append(traversals, "%2f"+traversal)
|
||||
traversals = append(traversals, traversal+"%2f")
|
||||
traversals = append(traversals, "%2f"+traversal+"%2f")
|
||||
traversals = append(traversals, "/"+traversal)
|
||||
traversals = append(traversals, traversal+"/")
|
||||
traversals = append(traversals, "/"+traversal+"/")
|
||||
}
|
||||
|
||||
// Setup the FileServer
|
||||
fs, err := NewFileServer(&Route{Root: root})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create FileServer: %v", err)
|
||||
}
|
||||
|
||||
// Create a test server with the handler
|
||||
ts := httptest.NewServer(fs.handler)
|
||||
defer ts.Close()
|
||||
|
||||
// Test valid path
|
||||
t.Run("valid path", func(t *testing.T) {
|
||||
validURL := ts.URL + "/" + validPath
|
||||
resp, err := http.Get(validURL)
|
||||
if err != nil {
|
||||
t.Errorf("Error making request to %s: %v", validURL, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("Expected 200 OK, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Errorf("Error reading response body: %v", err)
|
||||
}
|
||||
|
||||
if string(body) != validContent {
|
||||
t.Errorf("Expected %q, got %q", validContent, string(body))
|
||||
}
|
||||
})
|
||||
|
||||
// Test ../ path
|
||||
// tsURL := Must(url.Parse(ts.URL))
|
||||
for _, traversal := range traversals {
|
||||
p := traversal + secretFile
|
||||
t.Run(p, func(t *testing.T) {
|
||||
u := &url.URL{Scheme: "http", Host: ts.Listener.Addr().String(), Path: p}
|
||||
resp, err := http.DefaultClient.Do(&http.Request{
|
||||
Method: http.MethodGet,
|
||||
URL: u,
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("Error making request to %s: %v", p, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusNotFound && resp.StatusCode != http.StatusBadRequest {
|
||||
t.Errorf("Expected status 404 or 400, got %d in url %s", resp.StatusCode, u.String())
|
||||
}
|
||||
|
||||
u = Must(url.Parse(ts.URL + "/" + p))
|
||||
resp, err = http.DefaultClient.Do(&http.Request{
|
||||
Method: http.MethodGet,
|
||||
URL: u,
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("Error making request to %s: %v", u.String(), err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
})
|
||||
}
|
||||
}
|
|
@ -16,7 +16,10 @@ func init() {
|
|||
}
|
||||
}
|
||||
|
||||
func IgnoreError[Result any](r Result, _ error) Result {
|
||||
func Must[Result any](r Result, err error) Result {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue