package middleware

import (
	"net"

	E "github.com/yusing/go-proxy/internal/error"
	gphttp "github.com/yusing/go-proxy/internal/net/http"
	"github.com/yusing/go-proxy/internal/net/types"
)

// https://nginx.org/en/docs/http/ngx_http_realip_module.html

type realIP struct {
	realIPOpts
	m *Middleware
}

type realIPOpts struct {
	// Header is the name of the header to use for the real client IP
	Header string `validate:"required"`
	// From is a list of Address / CIDRs to trust
	From []*types.CIDR `validate:"min=1"`
	/*
		If recursive search is disabled,
		the original client address that matches one of the trusted addresses is replaced by
		the last address sent in the request header field defined by the Header field.
		If recursive search is enabled,
		the original client address that matches one of the trusted addresses is replaced by
		the last non-trusted address sent in the request header field.
	*/
	Recursive bool
}

var (
	RealIP            = &Middleware{withOptions: NewRealIP}
	realIPOptsDefault = realIPOpts{
		Header: "X-Real-IP",
		From:   []*types.CIDR{},
	}
)

func NewRealIP(opts OptionsRaw) (*Middleware, E.Error) {
	riWithOpts := new(realIP)
	riWithOpts.m = &Middleware{
		impl:   riWithOpts,
		before: Rewrite(riWithOpts.setRealIP),
	}
	riWithOpts.realIPOpts = realIPOptsDefault
	err := Deserialize(opts, &riWithOpts.realIPOpts)
	if err != nil {
		return nil, err
	}
	if len(riWithOpts.From) == 0 {
		return nil, E.New("no allowed CIDRs").Subject("from")
	}
	return riWithOpts.m, nil
}

func (ri *realIP) isInCIDRList(ip net.IP) bool {
	for _, CIDR := range ri.From {
		if CIDR.Contains(ip) {
			return true
		}
	}
	// not in any CIDR
	return false
}

func (ri *realIP) setRealIP(req *Request) {
	clientIPStr, _, err := net.SplitHostPort(req.RemoteAddr)
	if err != nil {
		clientIPStr = req.RemoteAddr
	}

	clientIP := net.ParseIP(clientIPStr)
	isTrusted := false

	for _, CIDR := range ri.From {
		if CIDR.Contains(clientIP) {
			isTrusted = true
			break
		}
	}
	if !isTrusted {
		ri.m.AddTracef("client ip %s is not trusted", clientIP).With("allowed CIDRs", ri.From)
		return
	}

	var realIPs = req.Header.Values(ri.Header)
	var lastNonTrustedIP string

	if len(realIPs) == 0 {
		ri.m.AddTracef("no real ip found in header %s", ri.Header).WithRequest(req)
		return
	}

	if !ri.Recursive {
		lastNonTrustedIP = realIPs[len(realIPs)-1]
	} else {
		for _, r := range realIPs {
			if !ri.isInCIDRList(net.ParseIP(r)) {
				lastNonTrustedIP = r
			}
		}
	}

	if lastNonTrustedIP == "" {
		ri.m.AddTracef("no non-trusted ip found").With("allowed CIDRs", ri.From).With("ips", realIPs)
		return
	}

	req.RemoteAddr = lastNonTrustedIP
	req.Header.Set(ri.Header, lastNonTrustedIP)
	req.Header.Set(gphttp.HeaderXRealIP, lastNonTrustedIP)
	req.Header.Set(gphttp.HeaderXForwardedFor, lastNonTrustedIP)
	ri.m.AddTracef("set real ip %s", lastNonTrustedIP)
}