From ec8cca1245337957897c3abc76827dc7525351de Mon Sep 17 00:00:00 2001 From: yusing Date: Thu, 24 Apr 2025 05:56:03 +0800 Subject: [PATCH] feat: trie implementation --- go.mod | 6 ++ go.sum | 22 ++++++ internal/utils/trie/any.go | 49 ++++++++++++++ internal/utils/trie/any_debug.go | 13 ++++ internal/utils/trie/any_prod.go | 7 ++ internal/utils/trie/any_test.go | 16 +++++ internal/utils/trie/json.go | 26 +++++++ internal/utils/trie/json_test.go | 37 ++++++++++ internal/utils/trie/key.go | 80 ++++++++++++++++++++++ internal/utils/trie/key_test.go | 86 +++++++++++++++++++++++ internal/utils/trie/node.go | 54 +++++++++++++++ internal/utils/trie/trie.go | 44 ++++++++++++ internal/utils/trie/trie_test.go | 35 ++++++++++ internal/utils/trie/walk.go | 111 ++++++++++++++++++++++++++++++ internal/utils/trie/walk_test.go | 113 +++++++++++++++++++++++++++++++ 15 files changed, 699 insertions(+) create mode 100644 internal/utils/trie/any.go create mode 100644 internal/utils/trie/any_debug.go create mode 100644 internal/utils/trie/any_prod.go create mode 100644 internal/utils/trie/any_test.go create mode 100644 internal/utils/trie/json.go create mode 100644 internal/utils/trie/json_test.go create mode 100644 internal/utils/trie/key.go create mode 100644 internal/utils/trie/key_test.go create mode 100644 internal/utils/trie/node.go create mode 100644 internal/utils/trie/trie.go create mode 100644 internal/utils/trie/trie_test.go create mode 100644 internal/utils/trie/walk.go create mode 100644 internal/utils/trie/walk_test.go diff --git a/go.mod b/go.mod index 2d8870f..3c249a4 100644 --- a/go.mod +++ b/go.mod @@ -30,6 +30,7 @@ require ( replace github.com/coreos/go-oidc/v3 => github.com/godoxy-app/go-oidc/v3 v3.14.2 require ( + github.com/bytedance/sonic v1.13.2 github.com/docker/cli v28.1.1+incompatible github.com/docker/go-connections v0.5.0 github.com/stretchr/testify v1.10.0 @@ -39,9 +40,11 @@ require ( github.com/Microsoft/go-winio v0.6.2 // indirect github.com/andybalholm/cascadia v1.3.3 // indirect github.com/beorn7/perks v1.0.1 // indirect + github.com/bytedance/sonic/loader v0.2.4 // indirect github.com/cenkalti/backoff/v4 v4.3.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cloudflare/cloudflare-go v0.115.0 // indirect + github.com/cloudwego/base64x v0.1.5 // indirect github.com/containerd/log v0.1.0 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/distribution/reference v0.6.0 // indirect @@ -58,6 +61,7 @@ require ( github.com/goccy/go-json v0.10.5 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/google/go-querystring v1.1.0 // indirect + github.com/klauspost/cpuid/v2 v2.2.7 // indirect github.com/leodido/go-urn v1.4.0 // indirect github.com/lufia/plan9stats v0.0.0-20250317134145-8bc96cf8fc35 // indirect github.com/mattn/go-colorable v0.1.14 // indirect @@ -81,6 +85,7 @@ require ( github.com/sirupsen/logrus v1.9.3 // indirect github.com/tklauser/go-sysconf v0.3.15 // indirect github.com/tklauser/numcpus v0.10.0 // indirect + github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/yusufpapurcu/wmi v1.2.4 // indirect go.opentelemetry.io/auto/sdk v1.1.0 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.60.0 // indirect @@ -88,6 +93,7 @@ require ( go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.30.0 // indirect go.opentelemetry.io/otel/metric v1.35.0 // indirect go.opentelemetry.io/otel/trace v1.35.0 // indirect + golang.org/x/arch v0.8.0 // indirect golang.org/x/mod v0.24.0 // indirect golang.org/x/sync v0.13.0 // indirect golang.org/x/sys v0.32.0 // indirect diff --git a/go.sum b/go.sum index d530579..b46df07 100644 --- a/go.sum +++ b/go.sum @@ -8,12 +8,20 @@ github.com/andybalholm/cascadia v1.3.3 h1:AG2YHrzJIm4BZ19iwJ/DAua6Btl3IwJX+VI4kk github.com/andybalholm/cascadia v1.3.3/go.mod h1:xNd9bqTn98Ln4DwST8/nG+H0yuB8Hmgu1YHNnWw0GeA= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= +github.com/bytedance/sonic v1.13.2 h1:8/H1FempDZqC4VqjptGo14QQlJx8VdZJegxs6wwfqpQ= +github.com/bytedance/sonic v1.13.2/go.mod h1:o68xyaF9u2gvVBuGHPlUVCy+ZfmNNO5ETf1+KgkJhz4= +github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU= +github.com/bytedance/sonic/loader v0.2.4 h1:ZWCw4stuXUsn1/+zQDqeE7JKP+QO47tz7QCNan80NzY= +github.com/bytedance/sonic/loader v0.2.4/go.mod h1:N8A3vUdtUebEY2/VQC0MyhYeKUFosQU6FxH2JmUe6VI= github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cloudflare/cloudflare-go v0.115.0 h1:84/dxeeXweCc0PN5Cto44iTA8AkG1fyT11yPO5ZB7sM= github.com/cloudflare/cloudflare-go v0.115.0/go.mod h1:Ds6urDwn/TF2uIU24mu7H91xkKP8gSAHxQ44DSZgVmU= +github.com/cloudwego/base64x v0.1.5 h1:XPciSp1xaq2VCSt6lF0phncD4koWyULpl5bUxbfCyP4= +github.com/cloudwego/base64x v0.1.5/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w= +github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY= github.com/coder/websocket v1.8.13 h1:f3QZdXy7uGVz+4uCJy2nTZyM0yTBj8yANEHhqlXZ9FE= github.com/coder/websocket v1.8.13/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs= github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I= @@ -90,6 +98,10 @@ github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= +github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= +github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM= +github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= +github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= @@ -160,13 +172,20 @@ github.com/shirou/gopsutil/v4 v4.25.3/go.mod h1:xbuxyoZj+UsgnZrENu3lQivsngRR5Bdj github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/tklauser/go-sysconf v0.3.15 h1:VE89k0criAymJ/Os65CSn1IXaol+1wrsFHEB8Ol49K4= github.com/tklauser/go-sysconf v0.3.15/go.mod h1:Dmjwr6tYFIseJw7a3dRLJfsHAMXZ3nEnL/aZY+0IuI4= github.com/tklauser/numcpus v0.10.0 h1:18njr6LDBk1zuna922MgdjQuJFjrdppsZG60sHGfjso= github.com/tklauser/numcpus v0.10.0/go.mod h1:BiTKazU708GQTYF4mB+cmlpT2Is1gLk7XVuEeem8LsQ= +github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= +github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/vincent-petithory/dataurl v1.0.0 h1:cXw+kPto8NLuJtlMsI152irrVw9fRDX8AbShPRpg2CI= github.com/vincent-petithory/dataurl v1.0.0/go.mod h1:FHafX5vmDzyP+1CQATJn7WFKc9CvnvxyvZy6I1MrG/U= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= @@ -194,6 +213,8 @@ go.opentelemetry.io/otel/trace v1.35.0 h1:dPpEfJu1sDIqruz7BHFG3c7528f6ddfSWfFDVt go.opentelemetry.io/otel/trace v1.35.0/go.mod h1:WUk7DtFp1Aw2MkvqGdwiXYDZZNvA/1J8o6xRXLrIkyc= go.opentelemetry.io/proto/otlp v1.3.1 h1:TrMUixzpM0yuc/znrFTP9MMRh8trP93mkCiDVeXrui0= go.opentelemetry.io/proto/otlp v1.3.1/go.mod h1:0X1WI4de4ZsLrrJNLAQbFeLCm3T7yBkR0XqQ7niQU+8= +golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc= +golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= @@ -316,3 +337,4 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gotest.tools/v3 v3.5.1 h1:EENdUnS3pdur5nybKYIh2Vfgc8IUNBjxDPSjtiJcOzU= gotest.tools/v3 v3.5.1/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU= +nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50= diff --git a/internal/utils/trie/any.go b/internal/utils/trie/any.go new file mode 100644 index 0000000..e98ec5e --- /dev/null +++ b/internal/utils/trie/any.go @@ -0,0 +1,49 @@ +package trie + +import ( + "sync/atomic" +) + +// AnyValue is a wrapper of atomic.Value +// It is used to store values in trie nodes +// And allowed to assign to empty struct value when node +// is not an end node anymore +type AnyValue struct { + v atomic.Value +} + +type zeroValue struct{} + +var zero zeroValue + +func (av *AnyValue) Store(v any) { + if v == nil { + av.v.Store(zero) + return + } + defer panicInvalidAssignment() + av.v.Store(v) +} + +func (av *AnyValue) Swap(v any) any { + defer panicInvalidAssignment() + return av.v.Swap(v) +} + +func (av *AnyValue) Load() any { + switch v := av.v.Load().(type) { + case zeroValue: + return nil + default: + return v + } +} + +func (av *AnyValue) IsNil() bool { + switch v := av.v.Load().(type) { + case zeroValue: + return true // assigned nil manually + default: + return v == nil // uninitialized + } +} diff --git a/internal/utils/trie/any_debug.go b/internal/utils/trie/any_debug.go new file mode 100644 index 0000000..5010ace --- /dev/null +++ b/internal/utils/trie/any_debug.go @@ -0,0 +1,13 @@ +//go:build debug + +package trie + +import "fmt" + +func panicInvalidAssignment() { + // assigned anything after manually assigning nil + // will panic because of type mismatch (zeroValue and v.(type)) + if r := recover(); r != nil { + panic(fmt.Errorf("attempt to assign non-nil value on edge node or assigning mismatched type: %v", r)) + } +} diff --git a/internal/utils/trie/any_prod.go b/internal/utils/trie/any_prod.go new file mode 100644 index 0000000..382cb04 --- /dev/null +++ b/internal/utils/trie/any_prod.go @@ -0,0 +1,7 @@ +//go:build !debug + +package trie + +func panicInvalidAssignment() { + // no-op +} diff --git a/internal/utils/trie/any_test.go b/internal/utils/trie/any_test.go new file mode 100644 index 0000000..5d91d0d --- /dev/null +++ b/internal/utils/trie/any_test.go @@ -0,0 +1,16 @@ +package trie + +import ( + "testing" +) + +func TestStoreNil(t *testing.T) { + var v AnyValue + v.Store(nil) + if v.Load() != nil { + t.Fatal("expected nil") + } + if v.IsNil() { + t.Fatal("expected true") + } +} diff --git a/internal/utils/trie/json.go b/internal/utils/trie/json.go new file mode 100644 index 0000000..9cc705f --- /dev/null +++ b/internal/utils/trie/json.go @@ -0,0 +1,26 @@ +package trie + +import ( + "maps" + + "github.com/bytedance/sonic" +) + +var sonicConfig = sonic.Config{ + EncodeNullForInfOrNan: true, +}.Froze() + +func (r *Root) MarshalJSON() ([]byte, error) { + return sonicConfig.Marshal(maps.Collect(r.Walk)) +} + +func (r *Root) UnmarshalJSON(data []byte) error { + var m map[string]any + if err := sonicConfig.Unmarshal(data, &m); err != nil { + return err + } + for k, v := range m { + r.Store(NewKey(k), v) + } + return nil +} diff --git a/internal/utils/trie/json_test.go b/internal/utils/trie/json_test.go new file mode 100644 index 0000000..af48044 --- /dev/null +++ b/internal/utils/trie/json_test.go @@ -0,0 +1,37 @@ +package trie + +import ( + "testing" + + "github.com/bytedance/sonic" +) + +func TestMarshalUnmarshalJSON(t *testing.T) { + trie := NewTrie() + data := map[string]any{ + "foo.bar": 42.12, + "foo.baz": "hello", + "qwe.rt.yu.io": 123.45, + } + for k, v := range data { + trie.Store(NewKey(k), v) + } + + // MarshalJSON + bytesFromTrie, err := sonic.Marshal(trie) + if err != nil { + t.Fatalf("sonic.Marshal error: %v", err) + } + + // UnmarshalJSON + newTrie := NewTrie() + if err := sonic.Unmarshal(bytesFromTrie, newTrie); err != nil { + t.Fatalf("UnmarshalJSON error: %v", err) + } + for k, v := range data { + got, ok := newTrie.Get(NewKey(k)) + if !ok || got != v { + t.Errorf("UnmarshalJSON: key %q got %v, want %v", k, got, v) + } + } +} diff --git a/internal/utils/trie/key.go b/internal/utils/trie/key.go new file mode 100644 index 0000000..f57885f --- /dev/null +++ b/internal/utils/trie/key.go @@ -0,0 +1,80 @@ +package trie + +import ( + "slices" + "strings" + + "github.com/yusing/go-proxy/internal/utils/strutils" +) + +type Key struct { + segments []string // escaped segments + full string // unescaped original key + hasWildcard bool +} + +func Namespace(ns string) *Key { + return &Key{ + segments: []string{ns}, + full: ns, + hasWildcard: false, + } +} + +func NewKey(keyStr string) *Key { + key := &Key{ + segments: strutils.SplitRune(keyStr, '.'), + full: keyStr, + } + for _, seg := range key.segments { + if seg == "*" || seg == "**" { + key.hasWildcard = true + } + } + return key +} + +func EscapeSegment(seg string) string { + var sb strings.Builder + for _, r := range seg { + switch r { + case '.', '*': + sb.WriteString("__") + default: + sb.WriteRune(r) + } + } + return sb.String() +} + +func (ns Key) With(segment string) *Key { + ns.segments = append(ns.segments, segment) + ns.full = ns.full + "." + segment + ns.hasWildcard = ns.hasWildcard || segment == "*" || segment == "**" + return &ns +} + +func (ns Key) WithEscaped(segment string) *Key { + ns.segments = append(ns.segments, EscapeSegment(segment)) + ns.full = ns.full + "." + segment + return &ns +} + +func (ns *Key) NumSegments() int { + return len(ns.segments) +} + +func (ns *Key) HasWildcard() bool { + return ns.hasWildcard +} + +func (ns *Key) String() string { + return ns.full +} + +func (ns *Key) Clone() *Key { + clone := *ns + clone.segments = slices.Clone(ns.segments) + clone.full = strings.Clone(ns.full) + return &clone +} diff --git a/internal/utils/trie/key_test.go b/internal/utils/trie/key_test.go new file mode 100644 index 0000000..3dabcfc --- /dev/null +++ b/internal/utils/trie/key_test.go @@ -0,0 +1,86 @@ +package trie + +import ( + "reflect" + "testing" +) + +func TestNamespace(t *testing.T) { + k := Namespace("foo") + if k.String() != "foo" { + t.Errorf("Namespace.String() = %q, want %q", k.String(), "foo") + } + if k.NumSegments() != 1 { + t.Errorf("Namespace.NumSegments() = %d, want 1", k.NumSegments()) + } + if k.HasWildcard() { + t.Error("Namespace.HasWildcard() = true, want false") + } +} + +func TestNewKey(t *testing.T) { + k := NewKey("a.b.c") + if !reflect.DeepEqual(k.segments, []string{"a", "b", "c"}) { + t.Errorf("NewKey.segments = %v, want [a b c]", k.segments) + } + if k.String() != "a.b.c" { + t.Errorf("NewKey.String() = %q, want %q", k.String(), "a.b.c") + } + if k.NumSegments() != 3 { + t.Errorf("NewKey.NumSegments() = %d, want 3", k.NumSegments()) + } + if k.HasWildcard() { + t.Error("NewKey.HasWildcard() = true, want false") + } + + kw := NewKey("foo.*.bar") + if !kw.HasWildcard() { + t.Error("NewKey.HasWildcard() = false, want true for wildcard") + } +} + +func TestWithAndWithEscaped(t *testing.T) { + k := Namespace("foo") + k2 := k.Clone().With("bar") + if k2.String() != "foo.bar" { + t.Errorf("With.String() = %q, want %q", k2.String(), "foo.bar") + } + if k2.NumSegments() != 2 { + t.Errorf("With.NumSegments() = %d, want 2", k2.NumSegments()) + } + + k3 := Namespace("foo").WithEscaped("b.r*") + esc := EscapeSegment("b.r*") + if k3.segments[1] != esc { + t.Errorf("WithEscaped.segment = %q, want %q", k3.segments[1], esc) + } +} + +func TestEscapeSegment(t *testing.T) { + cases := map[string]string{ + "foo": "foo", + "f.o": "f__o", + "*": "__", + "a*b.c": "a__b__c", + } + for in, want := range cases { + if got := EscapeSegment(in); got != want { + t.Errorf("EscapeSegment(%q) = %q, want %q", in, got, want) + } + } +} + +func TestClone(t *testing.T) { + k := NewKey("x.y.z") + cl := k.Clone() + if !reflect.DeepEqual(k, cl) { + t.Errorf("Clone() = %v, want %v", cl, k) + } + cl.With("new") + if cl == k { + t.Error("Clone() returns same pointer") + } + if reflect.DeepEqual(k.segments, cl.segments) { + t.Error("Clone is not deep copy: segments slice is shared") + } +} diff --git a/internal/utils/trie/node.go b/internal/utils/trie/node.go new file mode 100644 index 0000000..527ba22 --- /dev/null +++ b/internal/utils/trie/node.go @@ -0,0 +1,54 @@ +package trie + +import ( + "github.com/puzpuzpuz/xsync/v3" +) + +type Node struct { + key string + children *xsync.MapOf[string, *Node] // lock-free map which allows concurrent access + value AnyValue // only end nodes have values +} + +func mayPrefix(key, part string) string { + if key == "" { + return part + } + return key + "." + part +} + +func (node *Node) newChild(part string) *Node { + return &Node{ + key: mayPrefix(node.key, part), + children: xsync.NewMapOf[string, *Node](), + } +} + +func (node *Node) Get(key *Key) (any, bool) { + for _, seg := range key.segments { + child, ok := node.children.Load(seg) + if !ok { + return nil, false + } + node = child + } + v := node.value.Load() + if v == nil { + return nil, false + } + return v, true +} + +func (node *Node) loadOrStore(key *Key, newFunc func() any) *Node { + for i, seg := range key.segments { + child, _ := node.children.LoadOrCompute(seg, func() *Node { + newNode := node.newChild(seg) + if i == len(key.segments)-1 { + newNode.value.Store(newFunc()) + } + return newNode + }) + node = child + } + return node +} diff --git a/internal/utils/trie/trie.go b/internal/utils/trie/trie.go new file mode 100644 index 0000000..a90ad05 --- /dev/null +++ b/internal/utils/trie/trie.go @@ -0,0 +1,44 @@ +package trie + +import "github.com/puzpuzpuz/xsync/v3" + +type Root struct { + *Node + cached *xsync.MapOf[string, *Node] +} + +func NewTrie() *Root { + return &Root{ + Node: &Node{ + children: xsync.NewMapOf[string, *Node](), + }, + cached: xsync.NewMapOf[string, *Node](), + } +} + +func (r *Root) getNode(key *Key, newFunc func() any) *Node { + if key.hasWildcard { + panic("should not call Load or Store on a key with any wildcard: " + key.full) + } + node, _ := r.cached.LoadOrCompute(key.full, func() *Node { + return r.Node.loadOrStore(key, newFunc) + }) + return node +} + +// LoadOrStore loads or stores the value for the key +// Returns the value loaded/stored +func (r *Root) LoadOrStore(key *Key, newFunc func() any) any { + return r.getNode(key, newFunc).value.Load() +} + +// LoadAndStore loads or stores the value for the key +// Returns the old value if exists, nil otherwise +func (r *Root) LoadAndStore(key *Key, val any) any { + return r.getNode(key, func() any { return val }).value.Swap(val) +} + +// Store stores the value for the key +func (r *Root) Store(key *Key, val any) { + r.getNode(key, func() any { return val }).value.Store(val) +} diff --git a/internal/utils/trie/trie_test.go b/internal/utils/trie/trie_test.go new file mode 100644 index 0000000..2b677f4 --- /dev/null +++ b/internal/utils/trie/trie_test.go @@ -0,0 +1,35 @@ +package trie + +import "testing" + +var nsCPU = Namespace("cpu") + +// Test functions +func TestLoadOrStore(t *testing.T) { + trie := NewTrie() + ptr := trie.LoadOrStore(nsCPU, func() any { + return new(int) + }) + if ptr == nil { + t.Fatal("expected pointer to be created") + } + if ptr != trie.LoadOrStore(nsCPU, func() any { + return new(int) + }) { + t.Fatal("expected same pointer to be returned") + } + got, ok := trie.Get(nsCPU) + if !ok || got != ptr { + t.Fatal("expected same pointer to be returned") + } +} + +func TestStore(t *testing.T) { + trie := NewTrie() + ptr := new(int) + trie.Store(nsCPU, ptr) + got, ok := trie.Get(nsCPU) + if !ok || got != ptr { + t.Fatal("expected same pointer to be returned") + } +} diff --git a/internal/utils/trie/walk.go b/internal/utils/trie/walk.go new file mode 100644 index 0000000..d7ac2c4 --- /dev/null +++ b/internal/utils/trie/walk.go @@ -0,0 +1,111 @@ +package trie + +import ( + "maps" + "slices" +) + +type YieldFunc = func(part string, value any) bool +type YieldKeyFunc = func(key string) bool +type Iterator = func(YieldFunc) +type KeyIterator = func(YieldKeyFunc) + +// WalkAll walks all nodes in the trie, yields full key and series +func (node *Node) Walk(yield YieldFunc) { + node.walkAll(yield) +} + +func (node *Node) walkAll(yield YieldFunc) bool { + if !node.value.IsNil() { + if !yield(node.key, node.value.Load()) { + return false + } + return true + } + for _, v := range node.children.Range { + if !v.walkAll(yield) { + return false + } + } + return true +} + +func (node *Node) WalkKeys(yield YieldKeyFunc) { + node.walkKeys(yield) +} + +func (node *Node) walkKeys(yield YieldKeyFunc) bool { + if !node.value.IsNil() { + return !yield(node.key) + } + for _, v := range node.children.Range { + if !v.walkKeys(yield) { + return false + } + } + return true +} + +func (node *Node) Keys() []string { + return slices.Collect(node.WalkKeys) +} + +func (node *Node) Map() map[string]any { + return maps.Collect(node.Walk) +} + +func (tree Root) Query(key *Key) Iterator { + if !key.hasWildcard { + return func(yield YieldFunc) { + if v, ok := tree.Node.Get(key); ok { + yield(key.full, v) + } + return + } + } + return func(yield YieldFunc) { + tree.walkQuery(key.segments, tree.Node, yield, false) + } +} + +func (tree Root) walkQuery(patternParts []string, node *Node, yield YieldFunc, recursive bool) bool { + if len(patternParts) == 0 { + if !node.value.IsNil() { // end + if !yield(node.key, node.value.Load()) { + return true + } + } else if recursive { + return tree.walkAll(yield) + } + return true + } + pat := patternParts[0] + + switch pat { + case "**": + // ** matches zero or more segments + // Option 1: ** matches zero segment, move to next pattern part + if !tree.walkQuery(patternParts[1:], node, yield, false) { + return false + } + // Option 2: ** matches one or more segments + for _, child := range node.children.Range { + if !tree.walkQuery(patternParts, child, yield, true) { + return false + } + } + case "*": + // * matches any single segment + for _, child := range node.children.Range { + if !tree.walkQuery(patternParts[1:], child, yield, false) { + return false + } + } + default: + // Exact match + if child, ok := node.children.Load(pat); ok { + return tree.walkQuery(patternParts[1:], child, yield, false) + } + } + return true +} diff --git a/internal/utils/trie/walk_test.go b/internal/utils/trie/walk_test.go new file mode 100644 index 0000000..7315d0a --- /dev/null +++ b/internal/utils/trie/walk_test.go @@ -0,0 +1,113 @@ +package trie_test + +import ( + "maps" + "slices" + "testing" + + . "github.com/yusing/go-proxy/internal/utils/trie" +) + +// Test data for trie tests +var ( + testData = map[string]any{ + "routes.route1": new(int), + "routes.route2": new(int), + "routes.route3": new(int), + "system.cpu_average": new(int), + "system.mem.used": new(int), + "system.mem.percentage_used": new(int), + "system.disks.disk0.used": new(int), + "system.disks.disk0.percentage_used": new(int), + "system.disks.disk1.used": new(int), + "system.disks.disk1.percentage_used": new(int), + } + + testWalkDisksWants = []string{ + "system.disks.disk0.used", + "system.disks.disk0.percentage_used", + "system.disks.disk1.used", + "system.disks.disk1.percentage_used", + } + testWalkDisksUsedWants = []string{ + "system.disks.disk0.used", + "system.disks.disk1.used", + } + testUsedWants = []string{ + "system.mem.used", + "system.disks.disk0.used", + "system.disks.disk1.used", + } +) + +// Helper functions +func keys(m map[string]any) []string { + return slices.Sorted(maps.Keys(m)) +} + +func keysEqual(m map[string]any, want []string) bool { + slices.Sort(want) + return slices.Equal(keys(m), want) +} + +func TestWalkAll(t *testing.T) { + trie := NewTrie() + for key, series := range testData { + trie.Store(NewKey(key), series) + } + + walked := maps.Collect(trie.Walk) + for k, v := range testData { + if _, ok := walked[k]; !ok { + t.Fatalf("expected key %s not found", k) + } + if v != walked[k] { + t.Fatalf("key %s expected %v, got %v", k, v, walked[k]) + } + } +} + +func TestWalk(t *testing.T) { + trie := NewTrie() + for key, series := range testData { + trie.Store(NewKey(key), series) + } + + tests := []struct { + query string + want []string + wantEmpty bool + }{ + {"system.disks.*.used", testWalkDisksUsedWants, false}, + {"system.*.*.used", testWalkDisksUsedWants, false}, + {"*.disks.*.used", testWalkDisksUsedWants, false}, + {"*.*.*.used", testWalkDisksUsedWants, false}, + {"system.disks.**", testWalkDisksWants, false}, // note: original code uses '*' not '**' + {"system.disks", nil, true}, + {"**.used", testUsedWants, false}, + } + + for _, tc := range tests { + t.Run(tc.query, func(t *testing.T) { + got := maps.Collect(trie.Query(NewKey(tc.query))) + if tc.wantEmpty { + if len(got) != 0 { + t.Fatalf("expected empty, got %v", keys(got)) + } + return + } + if !keysEqual(got, tc.want) { + t.Fatalf("expected %v, got %v", tc.want, keys(got)) + } + for _, k := range tc.want { + want, ok := testData[k] + if !ok { + t.Fatalf("expected key %s not found", k) + } + if got[k] != want { + t.Fatalf("key %s expected %v, got %v", k, want, got[k]) + } + } + }) + } +}