package rules

import (
	"net/http"
	"path"
	"strconv"
	"strings"

	"github.com/yusing/go-proxy/internal/gperr"
	gphttp "github.com/yusing/go-proxy/internal/net/gphttp"
	"github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy"
	"github.com/yusing/go-proxy/internal/net/types"
	"github.com/yusing/go-proxy/internal/utils/strutils"
)

type (
	Command struct {
		raw  string
		exec CommandHandler
	}
)

const (
	CommandRewrite          = "rewrite"
	CommandServe            = "serve"
	CommandProxy            = "proxy"
	CommandRedirect         = "redirect"
	CommandError            = "error"
	CommandRequireBasicAuth = "require_basic_auth"
	CommandSet              = "set"
	CommandAdd              = "add"
	CommandRemove           = "remove"
	CommandPass             = "pass"
	CommandPassAlt          = "bypass"
)

var commands = map[string]struct {
	help     Help
	validate ValidateFunc
	build    func(args any) CommandHandler
}{
	CommandRewrite: {
		help: Help{
			command: CommandRewrite,
			args: map[string]string{
				"from": "the path to rewrite, must start with /",
				"to":   "the path to rewrite to, must start with /",
			},
		},
		validate: func(args []string) (any, gperr.Error) {
			if len(args) != 2 {
				return nil, ErrExpectTwoArgs
			}
			return validateURLPaths(args)
		},
		build: func(args any) CommandHandler {
			a := args.([]string)
			orig, repl := a[0], a[1]
			return StaticCommand(func(w http.ResponseWriter, r *http.Request) {
				path := r.URL.Path
				if len(path) > 0 && path[0] != '/' {
					path = "/" + path
				}
				if !strings.HasPrefix(path, orig) {
					return
				}
				path = repl + path[len(orig):]
				r.URL.Path = path
				r.URL.RawPath = r.URL.EscapedPath()
				r.RequestURI = r.URL.RequestURI()
			})
		},
	},
	CommandServe: {
		help: Help{
			command: CommandServe,
			args: map[string]string{
				"root": "the file system path to serve, must be an existing directory",
			},
		},
		validate: validateFSPath,
		build: func(args any) CommandHandler {
			root := args.(string)
			return ReturningCommand(func(w http.ResponseWriter, r *http.Request) {
				http.ServeFile(w, r, path.Join(root, path.Clean(r.URL.Path)))
			})
		},
	},
	CommandRedirect: {
		help: Help{
			command: CommandRedirect,
			args: map[string]string{
				"to": "the url to redirect to, can be relative or absolute URL",
			},
		},
		validate: validateURL,
		build: func(args any) CommandHandler {
			target := args.(*types.URL).String()
			return ReturningCommand(func(w http.ResponseWriter, r *http.Request) {
				http.Redirect(w, r, target, http.StatusTemporaryRedirect)
			})
		},
	},
	CommandError: {
		help: Help{
			command: CommandError,
			args: map[string]string{
				"code": "the http status code to return",
				"text": "the error message to return",
			},
		},
		validate: func(args []string) (any, gperr.Error) {
			if len(args) != 2 {
				return nil, ErrExpectTwoArgs
			}
			codeStr, text := args[0], args[1]
			code, err := strconv.Atoi(codeStr)
			if err != nil {
				return nil, ErrInvalidArguments.With(err)
			}
			if !gphttp.IsStatusCodeValid(code) {
				return nil, ErrInvalidArguments.Subject(codeStr)
			}
			return &Tuple[int, string]{code, text}, nil
		},
		build: func(args any) CommandHandler {
			code, text := args.(*Tuple[int, string]).Unpack()
			return ReturningCommand(func(w http.ResponseWriter, r *http.Request) {
				http.Error(w, text, code)
			})
		},
	},
	CommandRequireBasicAuth: {
		help: Help{
			command: CommandRequireBasicAuth,
			args: map[string]string{
				"realm": "the authentication realm",
			},
		},
		validate: func(args []string) (any, gperr.Error) {
			if len(args) == 1 {
				return args[0], nil
			}
			return nil, ErrExpectOneArg
		},
		build: func(args any) CommandHandler {
			realm := args.(string)
			return ReturningCommand(func(w http.ResponseWriter, r *http.Request) {
				w.Header().Set("WWW-Authenticate", `Basic realm="`+realm+`"`)
				http.Error(w, "Unauthorized", http.StatusUnauthorized)
			})
		},
	},
	CommandProxy: {
		help: Help{
			command: CommandProxy,
			args: map[string]string{
				"to": "the url to proxy to, must be an absolute URL",
			},
		},
		validate: validateAbsoluteURL,
		build: func(args any) CommandHandler {
			target := args.(*types.URL)
			if target.Scheme == "" {
				target.Scheme = "http"
			}
			rp := reverseproxy.NewReverseProxy("", target, gphttp.NewTransport())
			return ReturningCommand(rp.ServeHTTP)
		},
	},
	CommandSet: {
		help: Help{
			command: CommandSet,
			args: map[string]string{
				"field": "the field to set",
				"value": "the value to set",
			},
		},
		validate: func(args []string) (any, gperr.Error) {
			return validateModField(ModFieldSet, args)
		},
		build: func(args any) CommandHandler {
			return args.(CommandHandler)
		},
	},
	CommandAdd: {
		help: Help{
			command: CommandAdd,
			args: map[string]string{
				"field": "the field to add",
				"value": "the value to add",
			},
		},
		validate: func(args []string) (any, gperr.Error) {
			return validateModField(ModFieldAdd, args)
		},
		build: func(args any) CommandHandler {
			return args.(CommandHandler)
		},
	},
	CommandRemove: {
		help: Help{
			command: CommandRemove,
			args: map[string]string{
				"field": "the field to remove",
			},
		},
		validate: func(args []string) (any, gperr.Error) {
			return validateModField(ModFieldRemove, args)
		},
		build: func(args any) CommandHandler {
			return args.(CommandHandler)
		},
	},
}

// Parse implements strutils.Parser.
func (cmd *Command) Parse(v string) error {
	lines := strutils.SplitLine(v)
	if len(lines) == 0 {
		return nil
	}

	executors := make([]CommandHandler, 0, len(lines))
	for _, line := range lines {
		if line == "" {
			continue
		}

		directive, args, err := parse(line)
		if err != nil {
			return err
		}

		if directive == CommandPass || directive == CommandPassAlt {
			if len(args) != 0 {
				return ErrInvalidArguments.Subject(directive)
			}
			executors = append(executors, BypassCommand{})
			continue
		}

		builder, ok := commands[directive]
		if !ok {
			return ErrUnknownDirective.Subject(directive)
		}
		validArgs, err := builder.validate(args)
		if err != nil {
			return err.Subject(directive).Withf("%s", builder.help.String())
		}

		executors = append(executors, builder.build(validArgs))
	}

	if len(executors) == 0 {
		return nil
	}

	exec, err := buildCmd(executors)
	if err != nil {
		return err
	}

	cmd.raw = v
	cmd.exec = exec
	return nil
}

func buildCmd(executors []CommandHandler) (CommandHandler, error) {
	for i, exec := range executors {
		switch exec.(type) {
		case ReturningCommand, BypassCommand:
			if i != len(executors)-1 {
				return nil, ErrInvalidCommandSequence.
					Withf("a returning / bypass command must be the last command")
			}
		}
	}

	return Commands(executors), nil
}

// Command is purely "bypass" or empty.
func (cmd *Command) isBypass() bool {
	if cmd == nil {
		return true
	}
	switch cmd := cmd.exec.(type) {
	case BypassCommand:
		return true
	case Commands:
		// bypass command is always the last one
		_, ok := cmd[len(cmd)-1].(BypassCommand)
		return ok
	default:
		return false
	}
}

func (cmd *Command) String() string {
	return cmd.raw
}

func (cmd *Command) MarshalText() ([]byte, error) {
	return []byte(cmd.String()), nil
}