package middleware

import (
	"bytes"
	"net"
	"net/http"
	"slices"
	"testing"

	"github.com/yusing/go-proxy/internal/net/types"
	. "github.com/yusing/go-proxy/internal/utils/testing"
)

func TestModifyResponse(t *testing.T) {
	opts := OptionsRaw{
		"set_headers": map[string]string{
			"X-Test-Resp-Status":              VarRespStatusCode,
			"X-Test-Resp-Content-Type":        VarRespContentType,
			"X-Test-Resp-Content-Length":      VarRespContentLen,
			"X-Test-Resp-Header-Content-Type": "$resp_header(Content-Type)",

			"X-Test-Req-Method":          VarRequestMethod,
			"X-Test-Req-Scheme":          VarRequestScheme,
			"X-Test-Req-Host":            VarRequestHost,
			"X-Test-Req-Port":            VarRequestPort,
			"X-Test-Req-Addr":            VarRequestAddr,
			"X-Test-Req-Path":            VarRequestPath,
			"X-Test-Req-Query":           VarRequestQuery,
			"X-Test-Req-Url":             VarRequestURL,
			"X-Test-Req-Uri":             VarRequestURI,
			"X-Test-Req-Content-Type":    VarRequestContentType,
			"X-Test-Req-Content-Length":  VarRequestContentLen,
			"X-Test-Remote-Host":         VarRemoteHost,
			"X-Test-Remote-Port":         VarRemotePort,
			"X-Test-Remote-Addr":         VarRemoteAddr,
			"X-Test-Upstream-Scheme":     VarUpstreamScheme,
			"X-Test-Upstream-Host":       VarUpstreamHost,
			"X-Test-Upstream-Port":       VarUpstreamPort,
			"X-Test-Upstream-Addr":       VarUpstreamAddr,
			"X-Test-Upstream-Url":        VarUpstreamURL,
			"X-Test-Header-Content-Type": "$header(Content-Type)",
			"X-Test-Arg-Arg_1":           "$arg(arg_1)",
		},
		"add_headers":  map[string]string{"Accept-Encoding": "test-value"},
		"hide_headers": []string{"Accept"},
	}

	t.Run("set_options", func(t *testing.T) {
		mr, err := ModifyResponse.New(opts)
		ExpectNoError(t, err)
		ExpectDeepEqual(t, mr.impl.(*modifyResponse).SetHeaders, opts["set_headers"].(map[string]string))
		ExpectDeepEqual(t, mr.impl.(*modifyResponse).AddHeaders, opts["add_headers"].(map[string]string))
		ExpectDeepEqual(t, mr.impl.(*modifyResponse).HideHeaders, opts["hide_headers"].([]string))
	})

	t.Run("response_headers", func(t *testing.T) {
		reqURL := types.MustParseURL("https://my.app/?arg_1=b")
		upstreamURL := types.MustParseURL("http://test.example.com")
		result, err := newMiddlewareTest(ModifyResponse, &testArgs{
			middlewareOpt: opts,
			reqURL:        reqURL,
			upstreamURL:   upstreamURL,
			body:          bytes.Repeat([]byte("a"), 100),
			headers: http.Header{
				"Content-Type": []string{"application/json"},
			},
			respHeaders: http.Header{
				"Content-Type": []string{"application/json"},
			},
			respBody:   bytes.Repeat([]byte("a"), 50),
			respStatus: http.StatusOK,
		})
		ExpectNoError(t, err)
		ExpectTrue(t, slices.Contains(result.ResponseHeaders.Values("Accept-Encoding"), "test-value"))
		ExpectEqual(t, result.ResponseHeaders.Get("Accept"), "")

		ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Resp-Status"), "200")
		ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Resp-Content-Type"), "application/json")
		ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Resp-Content-Length"), "50")
		ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Resp-Header-Content-Type"), "application/json")

		ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Method"), http.MethodGet)
		ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Scheme"), reqURL.Scheme)
		ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Host"), reqURL.Hostname())
		ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Port"), reqURL.Port())
		ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Addr"), reqURL.Host)
		ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Path"), reqURL.Path)
		ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Query"), reqURL.RawQuery)
		ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Url"), reqURL.String())
		ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Uri"), reqURL.RequestURI())
		ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Content-Type"), "application/json")
		ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Content-Length"), "100")

		remoteHost, remotePort, _ := net.SplitHostPort(result.RemoteAddr)
		ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Remote-Host"), remoteHost)
		ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Remote-Port"), remotePort)
		ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Remote-Addr"), result.RemoteAddr)

		ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Upstream-Scheme"), upstreamURL.Scheme)
		ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Upstream-Host"), upstreamURL.Hostname())
		ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Upstream-Port"), upstreamURL.Port())
		ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Upstream-Addr"), upstreamURL.Host)
		ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Upstream-Url"), upstreamURL.String())

		ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Header-Content-Type"), "application/json")
		ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Arg-Arg_1"), "b")
	})
}