package rules

import (
	"encoding/base64"
	"fmt"
	"net/http"
	"net/url"
	"testing"

	"github.com/yusing/go-proxy/internal/gperr"
	. "github.com/yusing/go-proxy/internal/utils/testing"
	"golang.org/x/crypto/bcrypt"
)

func TestParseOn(t *testing.T) {
	tests := []struct {
		name    string
		input   string
		wantErr gperr.Error
	}{
		// header
		{
			name:    "header_valid_kv",
			input:   "header Connection Upgrade",
			wantErr: nil,
		},
		{
			name:    "header_valid_k",
			input:   "header Connection",
			wantErr: nil,
		},
		{
			name:    "header_missing_arg",
			input:   "header",
			wantErr: ErrExpectKVOptionalV,
		},
		// query
		{
			name:    "query_valid_kv",
			input:   "query key value",
			wantErr: nil,
		},
		{
			name:    "query_valid_k",
			input:   "query key",
			wantErr: nil,
		},
		{
			name:    "query_missing_arg",
			input:   "query",
			wantErr: ErrExpectKVOptionalV,
		},
		{
			name:    "cookie_valid_kv",
			input:   "cookie key value",
			wantErr: nil,
		},
		{
			name:    "cookie_valid_k",
			input:   "cookie key",
			wantErr: nil,
		},
		{
			name:    "cookie_missing_arg",
			input:   "cookie",
			wantErr: ErrExpectKVOptionalV,
		},
		// method
		{
			name:    "method_valid",
			input:   "method GET",
			wantErr: nil,
		},
		{
			name:    "method_invalid",
			input:   "method invalid",
			wantErr: ErrInvalidArguments,
		},
		{
			name:    "method_missing_arg",
			input:   "method",
			wantErr: ErrExpectOneArg,
		},
		// path
		{
			name:    "path_valid",
			input:   "path /home",
			wantErr: nil,
		},
		{
			name:    "path_missing_arg",
			input:   "path",
			wantErr: ErrExpectOneArg,
		},
		// remote
		{
			name:    "remote_valid",
			input:   "remote 127.0.0.1",
			wantErr: nil,
		},
		{
			name:    "remote_invalid",
			input:   "remote abcd",
			wantErr: ErrInvalidArguments,
		},
		{
			name:    "remote_missing_arg",
			input:   "remote",
			wantErr: ErrExpectOneArg,
		},
		{
			name:    "unknown_target",
			input:   "unknown",
			wantErr: ErrInvalidOnTarget,
		},
	}

	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			on := &RuleOn{}
			err := on.Parse(tt.input)
			if tt.wantErr != nil {
				ExpectError(t, tt.wantErr, err)
			} else {
				ExpectNoError(t, err)
			}
		})
	}
}

type testCorrectness struct {
	name    string
	checker string
	input   *http.Request
	want    bool
}

func genCorrectnessTestCases(field string, genRequest func(k, v string) *http.Request) []testCorrectness {
	return []testCorrectness{
		{
			name:    field + "_match",
			checker: field + " foo bar",
			input:   genRequest("foo", "bar"),
			want:    true,
		},
		{
			name:    field + "_no_match",
			checker: field + " foo baz",
			input:   genRequest("foo", "bar"),
			want:    false,
		},
		{
			name:    field + "_exists",
			checker: field + " foo",
			input:   genRequest("foo", "abcd"),
			want:    true,
		},
		{
			name:    field + "_not_exists",
			checker: field + " foo",
			input:   genRequest("bar", "abcd"),
			want:    false,
		},
	}
}

func TestOnCorrectness(t *testing.T) {
	tests := []testCorrectness{
		{
			name:    "method_match",
			checker: "method GET",
			input:   &http.Request{Method: http.MethodGet},
			want:    true,
		},
		{
			name:    "method_no_match",
			checker: "method GET",
			input:   &http.Request{Method: http.MethodPost},
			want:    false,
		},
		{
			name:    "path_exact_match",
			checker: "path /example",
			input: &http.Request{
				URL: &url.URL{Path: "/example"},
			},
			want: true,
		},
		{
			name:    "path_wildcard_match",
			checker: "path /example/*",
			input: &http.Request{
				URL: &url.URL{Path: "/example/123"},
			},
			want: true,
		},
		{
			name:    "remote_match",
			checker: "remote 192.168.1.0/24",
			input: &http.Request{
				RemoteAddr: "192.168.1.5",
			},
			want: true,
		},
		{
			name:    "remote_no_match",
			checker: "remote 192.168.1.0/24",
			input: &http.Request{
				RemoteAddr: "192.168.2.5",
			},
			want: false,
		},
		{
			name:    "basic_auth_correct",
			checker: "basic_auth user " + string(Must(bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost))),
			input: &http.Request{
				Header: http.Header{
					"Authorization": {"Basic " + base64.StdEncoding.EncodeToString([]byte("user:password"))}, // "user:password"
				},
			},
			want: true,
		},
		{
			name:    "basic_auth_incorrect",
			checker: "basic_auth user " + string(Must(bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost))),
			input: &http.Request{
				Header: http.Header{
					"Authorization": {"Basic " + base64.StdEncoding.EncodeToString([]byte("user:incorrect"))}, // "user:wrong"
				},
			},
			want: false,
		},
	}

	tests = append(tests, genCorrectnessTestCases("header", func(k, v string) *http.Request {
		return &http.Request{
			Header: http.Header{k: []string{v}},
		}
	})...)
	tests = append(tests, genCorrectnessTestCases("query", func(k, v string) *http.Request {
		return &http.Request{
			URL: &url.URL{
				RawQuery: fmt.Sprintf("%s=%s", k, v),
			},
		}
	})...)
	tests = append(tests, genCorrectnessTestCases("cookie", func(k, v string) *http.Request {
		return &http.Request{
			Header: http.Header{
				"Cookie": {fmt.Sprintf("%s=%s", k, v)},
			},
		}
	})...)
	tests = append(tests, genCorrectnessTestCases("form", func(k, v string) *http.Request {
		return &http.Request{
			Form: url.Values{
				k: []string{v},
			},
		}
	})...)
	tests = append(tests, genCorrectnessTestCases("postform", func(k, v string) *http.Request {
		return &http.Request{
			PostForm: url.Values{
				k: []string{v},
			},
		}
	})...)

	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			on, err := parseOn(tt.checker)
			ExpectNoError(t, err)
			got := on.Check(Cache{}, tt.input)
			if tt.want != got {
				t.Errorf("want %v, got %v", tt.want, got)
			}
		})
	}
}