fix: bypass rules should not check first

This commit is contained in:
yusing 2025-01-09 18:17:05 +08:00
parent 8109c9ac4f
commit 4ebe0abba0

View file

@ -2,8 +2,6 @@ package rules
import ( import (
"net/http" "net/http"
"github.com/yusing/go-proxy/internal/logging"
) )
type ( type (
@ -48,76 +46,58 @@ type (
// BuildHandler returns a http.HandlerFunc that implements the rules. // BuildHandler returns a http.HandlerFunc that implements the rules.
// //
// Bypass rules are executed first
// if a bypass rule matches, // if a bypass rule matches,
// the request is passed to the upstream and no more rules are executed. // the request is passed to the upstream and no more rules are executed.
// //
// Other rules are executed later
// if no rule matches, the default rule is executed // if no rule matches, the default rule is executed
// if no rule matches and default rule is not set, // if no rule matches and default rule is not set,
// the request is passed to the upstream. // the request is passed to the upstream.
func (rules Rules) BuildHandler(up http.Handler) http.HandlerFunc { func (rules Rules) BuildHandler(up http.Handler) http.HandlerFunc {
// move bypass rules to the front. var (
bypassRules := make(Rules, 0, len(rules)) defaultRule Rule
otherRules := make(Rules, 0, len(rules)) defaultRuleIndex int
)
var defaultRule Rule for i, rule := range rules {
if rule.Name == "default" {
for _, rule := range rules {
switch {
case rule.Do.isBypass():
bypassRules = append(bypassRules, rule)
case rule.Name == "default":
defaultRule = rule defaultRule = rule
default: defaultRuleIndex = i
otherRules = append(otherRules, rule) break
} }
} }
rules = append(rules[:defaultRuleIndex], rules[defaultRuleIndex+1:]...)
// free allocated empty slices // free allocated empty slices
// before encapsulating them into the handlerFunc. // before encapsulating them into the handlerFunc.
if len(bypassRules) == 0 { if len(rules) == 0 {
bypassRules = []Rule{} if defaultRule.Do.isBypass() {
return up.ServeHTTP
} }
if len(otherRules) == 0 { rules = []Rule{}
otherRules = []Rule{}
} }
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
for _, rule := range bypassRules { hasMatch := false
for _, rule := range rules {
if rule.On.check(r) { if rule.On.check(r) {
logging.Debug(). if rule.Do.isBypass() {
Str("rule", rule.Name).
Msg("matched: bypass")
up.ServeHTTP(w, r) up.ServeHTTP(w, r)
return return
} }
}
hasMatch := false
for _, rule := range otherRules {
if rule.On.check(r) {
logging.Debug().
Str("rule", rule.Name).
Msgf("matched proceed=%t", rule.Do.exec.proceed)
hasMatch = true
rule.Do.exec.HandlerFunc(w, r) rule.Do.exec.HandlerFunc(w, r)
if !rule.Do.exec.proceed { if !rule.Do.exec.proceed {
return return
} }
hasMatch = true
} }
} }
if hasMatch || defaultRule.Do.isBypass() { if hasMatch || defaultRule.Do.isBypass() {
logging.Debug().
Str("rule", defaultRule.Name).
Msg("matched: bypass")
up.ServeHTTP(w, r) up.ServeHTTP(w, r)
return return
} }
logging.Debug().
Str("rule", defaultRule.Name).
Msg("matched: default")
defaultRule.Do.exec.HandlerFunc(w, r) defaultRule.Do.exec.HandlerFunc(w, r)
} }
} }