package httpheaders

import (
	"net/http"
	"net/http/httptest"
	"strings"
	"testing"
)

func TestAppendCSP(t *testing.T) {
	tests := []struct {
		name           string
		initialHeaders map[string][]string
		sources        []string
		directives     []string
		expectedCSP    map[string]string
	}{
		{
			name:           "No CSP header",
			initialHeaders: map[string][]string{},
			sources:        []string{},
			directives:     []string{"default-src", "script-src", "frame-src", "style-src", "connect-src"},
			expectedCSP:    map[string]string{"default-src": "'self'", "script-src": "'self'", "frame-src": "'self'", "style-src": "'self'", "connect-src": "'self'"},
		},
		{
			name:           "No CSP header with sources",
			initialHeaders: map[string][]string{},
			sources:        []string{"https://example.com"},
			directives:     []string{"default-src", "script-src", "frame-src", "style-src", "connect-src"},
			expectedCSP:    map[string]string{"default-src": "'self' https://example.com", "script-src": "'self' https://example.com", "frame-src": "'self' https://example.com", "style-src": "'self' https://example.com", "connect-src": "'self' https://example.com"},
		},
		{
			name: "replace 'none' with sources",
			initialHeaders: map[string][]string{
				"Content-Security-Policy": {"default-src 'none'"},
			},
			sources:     []string{"https://example.com"},
			directives:  []string{"default-src"},
			expectedCSP: map[string]string{"default-src": "https://example.com"},
		},
		{
			name: "CSP header with some directives",
			initialHeaders: map[string][]string{
				"Content-Security-Policy": {"default-src 'none'", "script-src 'unsafe-inline'"},
			},
			sources:    []string{"https://example.com"},
			directives: []string{"script-src"},
			expectedCSP: map[string]string{
				"default-src": "'none",
				"script-src":  "'unsafe-inline' https://example.com",
			},
		},
		{
			name: "CSP header with some directives with self",
			initialHeaders: map[string][]string{
				"Content-Security-Policy": {"default-src 'self'", "connect-src 'self'"},
			},
			sources:    []string{"https://api.example.com"},
			directives: []string{"default-src", "connect-src"},
			expectedCSP: map[string]string{
				"default-src": "'self' https://api.example.com",
				"connect-src": "'self' https://api.example.com",
			},
		},
		{
			name: "AppendCSP sources conflict with existing CSP header",
			initialHeaders: map[string][]string{
				"Content-Security-Policy": {"default-src 'self' https://cdn.example.com", "script-src 'unsafe-inline'"},
			},
			sources:    []string{"https://cdn.example.com", "https://api.example.com"},
			directives: []string{"default-src", "script-src"},
			expectedCSP: map[string]string{
				"default-src": "'self' https://cdn.example.com https://api.example.com",
				"script-src":  "'unsafe-inline' https://cdn.example.com https://api.example.com",
			},
		},
		{
			name: "Non-standard CSP directive",
			initialHeaders: map[string][]string{
				"Content-Security-Policy": {
					"default-src 'self'",
					"script-src 'unsafe-inline'",
					"img-src 'self'", // img-src is not in cspDirectives list
				},
			},
			sources:    []string{"https://example.com"},
			directives: []string{"default-src", "script-src"},
			expectedCSP: map[string]string{
				"default-src": "'self' https://example.com",
				"script-src":  "'unsafe-inline' https://example.com",
				// img-src should not be present in response as it's not in cspDirectives
			},
		},
	}

	for _, tc := range tests {
		t.Run(tc.name, func(t *testing.T) {
			// Create a test request with initial headers
			req := httptest.NewRequest(http.MethodGet, "/", nil)
			for header, values := range tc.initialHeaders {
				req.Header[header] = values
			}

			// Create a test response recorder
			w := httptest.NewRecorder()

			// Call the function under test
			AppendCSP(w, req, tc.directives, tc.sources)

			// Check the resulting CSP headers
			respHeaders := w.Header()
			cspValues, exists := respHeaders["Content-Security-Policy"]

			// If we expect no CSP headers, verify none exist
			if len(tc.expectedCSP) == 0 {
				if exists && len(cspValues) > 0 {
					t.Errorf("Expected no CSP header, but got %v", cspValues)
				}
				return
			}

			// Verify CSP headers exist when expected
			if !exists || len(cspValues) == 0 {
				t.Errorf("Expected CSP header to be set, but it was not")
				return
			}

			// Parse the CSP response and verify each directive
			foundDirectives := make(map[string]string)
			for _, cspValue := range cspValues {
				parts := strings.Split(cspValue, ";")
				for _, part := range parts {
					part = strings.TrimSpace(part)
					if part == "" {
						continue
					}

					directiveParts := strings.SplitN(part, " ", 2)
					if len(directiveParts) != 2 {
						t.Errorf("Invalid CSP directive format: %s", part)
						continue
					}

					directive := directiveParts[0]
					value := directiveParts[1]
					foundDirectives[directive] = value
				}
			}

			// Verify expected directives
			for directive, expectedValue := range tc.expectedCSP {
				actualValue, ok := foundDirectives[directive]
				if !ok {
					t.Errorf("Expected directive %s not found in response", directive)
					continue
				}

				// Check if all expected sources are in the actual value
				expectedSources := strings.SplitSeq(expectedValue, " ")
				for source := range expectedSources {
					if !strings.Contains(actualValue, source) {
						t.Errorf("Directive %s missing expected source %s. Got: %s", directive, source, actualValue)
					}
				}
			}
		})
	}
}