diff --git a/internal/auth/oauth_refresh.go b/internal/auth/oauth_refresh.go index df46f8e..800680b 100644 --- a/internal/auth/oauth_refresh.go +++ b/internal/auth/oauth_refresh.go @@ -34,7 +34,7 @@ type sessionClaims struct { type sessionID string -var oauthRefreshTokens jsonstore.Typed[oauthRefreshToken] +var oauthRefreshTokens jsonstore.MapStore[oauthRefreshToken] var ( defaultRefreshTokenExpiry = 30 * 24 * time.Hour // 1 month diff --git a/internal/homepage/override_config.go b/internal/homepage/override_config.go index fe0ca1f..8f7b746 100644 --- a/internal/homepage/override_config.go +++ b/internal/homepage/override_config.go @@ -1,6 +1,7 @@ package homepage import ( + "maps" "sync" "github.com/yusing/go-proxy/internal/common" @@ -15,12 +16,19 @@ type OverrideConfig struct { mu sync.RWMutex } -var overrideConfigInstance = jsonstore.Object[OverrideConfig](common.NamespaceHomepageOverrides) +var overrideConfigInstance = jsonstore.Object[*OverrideConfig](common.NamespaceHomepageOverrides) func GetOverrideConfig() *OverrideConfig { return overrideConfigInstance } +func (c *OverrideConfig) Initialize() { + c.ItemOverrides = make(map[string]*ItemConfig) + c.DisplayOrder = make(map[string]int) + c.CategoryOrder = make(map[string]int) + c.ItemVisibility = make(map[string]bool) +} + func (c *OverrideConfig) OverrideItem(alias string, override *ItemConfig) { c.mu.Lock() defer c.mu.Unlock() @@ -30,9 +38,7 @@ func (c *OverrideConfig) OverrideItem(alias string, override *ItemConfig) { func (c *OverrideConfig) OverrideItems(items map[string]*ItemConfig) { c.mu.Lock() defer c.mu.Unlock() - for key, value := range items { - c.ItemOverrides[key] = value - } + maps.Copy(c.ItemOverrides, items) } func (c *OverrideConfig) GetOverride(alias string, item *ItemConfig) *ItemConfig { diff --git a/internal/jsonstore/jsonstore.go b/internal/jsonstore/jsonstore.go index e04fbc3..ac5c39e 100644 --- a/internal/jsonstore/jsonstore.go +++ b/internal/jsonstore/jsonstore.go @@ -4,8 +4,11 @@ import ( "encoding/json" "fmt" "path/filepath" + "reflect" "sync" + "maps" + "github.com/puzpuzpuz/xsync/v3" "github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/gperr" @@ -16,16 +19,29 @@ import ( type namespace string -type Typed[VT any] struct { +type MapStore[VT any] struct { *xsync.MapOf[string, VT] } -type storesMap struct { - sync.RWMutex - m map[namespace]any +type ObjectStore[Pointer Initializer] struct { + ptr Pointer } -var stores = storesMap{m: make(map[namespace]any)} +type Initializer interface { + Initialize() +} + +type storeByNamespace struct { + sync.RWMutex + m map[namespace]store +} + +type store interface { + json.Marshaler + json.Unmarshaler +} + +var stores = storeByNamespace{m: make(map[namespace]store)} var storesPath = common.DataDir func init() { @@ -44,12 +60,16 @@ func load() error { stores.Lock() defer stores.Unlock() errs := gperr.NewBuilder("failed to load data stores") - for ns, store := range stores.m { - if err := utils.LoadJSONIfExist(filepath.Join(storesPath, string(ns)+".json"), &store); err != nil { + for ns, obj := range stores.m { + if init, ok := obj.(Initializer); ok { + init.Initialize() + } + if err := utils.LoadJSONIfExist(filepath.Join(storesPath, string(ns)+".json"), &obj); err != nil { errs.Add(err) } else { logging.Info().Str("name", string(ns)).Msg("store loaded") } + stores.m[ns] = obj } return errs.Error() } @@ -66,41 +86,42 @@ func save() error { return errs.Error() } -func Store[VT any](namespace namespace) Typed[VT] { +func Store[VT any](namespace namespace) MapStore[VT] { stores.Lock() defer stores.Unlock() if s, ok := stores.m[namespace]; ok { - return s.(Typed[VT]) - } - m := Typed[VT]{MapOf: xsync.NewMapOf[string, VT]()} - stores.m[namespace] = m - return m -} - -func Object[VT any](namespace namespace) *VT { - stores.Lock() - defer stores.Unlock() - if s, ok := stores.m[namespace]; ok { - v, ok := s.(*VT) - if ok { - return v + v, ok := s.(*MapStore[VT]) + if !ok { + panic(fmt.Errorf("type mismatch: %T != %T", s, v)) } - panic(fmt.Errorf("type mismatch: %T != %T", s, v)) + return *v } - v := new(VT) - stores.m[namespace] = v - return v + m := &MapStore[VT]{MapOf: xsync.NewMapOf[string, VT]()} + stores.m[namespace] = m + return *m } -func (s Typed[VT]) MarshalJSON() ([]byte, error) { - tmp := make(map[string]VT, s.Size()) - for k, v := range s.Range { - tmp[k] = v +func Object[Ptr Initializer](namespace namespace) Ptr { + stores.Lock() + defer stores.Unlock() + if s, ok := stores.m[namespace]; ok { + v, ok := s.(*ObjectStore[Ptr]) + if !ok { + panic(fmt.Errorf("type mismatch: %T != %T", s, v)) + } + return v.ptr } - return json.Marshal(tmp) + obj := &ObjectStore[Ptr]{} + obj.init() + stores.m[namespace] = obj + return obj.ptr } -func (s Typed[VT]) UnmarshalJSON(data []byte) error { +func (s MapStore[VT]) MarshalJSON() ([]byte, error) { + return json.Marshal(maps.Collect(s.Range)) +} + +func (s *MapStore[VT]) UnmarshalJSON(data []byte) error { tmp := make(map[string]VT) if err := json.Unmarshal(data, &tmp); err != nil { return err @@ -111,3 +132,17 @@ func (s Typed[VT]) UnmarshalJSON(data []byte) error { } return nil } + +func (obj *ObjectStore[Ptr]) init() { + obj.ptr = reflect.New(reflect.TypeFor[Ptr]().Elem()).Interface().(Ptr) + obj.ptr.Initialize() +} + +func (obj ObjectStore[Ptr]) MarshalJSON() ([]byte, error) { + return json.Marshal(obj.ptr) +} + +func (obj *ObjectStore[Ptr]) UnmarshalJSON(data []byte) error { + obj.init() + return json.Unmarshal(data, obj.ptr) +} diff --git a/internal/jsonstore/jsonstore_test.go b/internal/jsonstore/jsonstore_test.go index c2ffbb1..b3a82c7 100644 --- a/internal/jsonstore/jsonstore_test.go +++ b/internal/jsonstore/jsonstore_test.go @@ -12,19 +12,52 @@ func TestNewJSON(t *testing.T) { } } -func TestSaveLoad(t *testing.T) { +func TestSaveLoadStore(t *testing.T) { storesPath = t.TempDir() store := Store[string]("test") store.Store("a", "1") if err := save(); err != nil { t.Fatal(err) } - stores.m = nil if err := load(); err != nil { t.Fatal(err) } - store = Store[string]("test") - if v, _ := store.Load("a"); v != "1" { - t.Fatal("expected 1, got", v) + loaded := Store[string]("test") + v, ok := loaded.Load("a") + if !ok { + t.Fatal("expected key exists") + } + if v != "1" { + t.Fatalf("expected 1, got %q", v) + } + if loaded.MapOf == store.MapOf { + t.Fatal("expected different objects") + } +} + +type testObject struct { + I int `json:"i"` + S string `json:"s"` +} + +func (*testObject) Initialize() {} + +func TestSaveLoadObject(t *testing.T) { + storesPath = t.TempDir() + obj := Object[*testObject]("test") + obj.I = 1 + obj.S = "1" + if err := save(); err != nil { + t.Fatal(err) + } + if err := load(); err != nil { + t.Fatal(err) + } + loaded := Object[*testObject]("test") + if loaded.I != 1 || loaded.S != "1" { + t.Fatalf("expected 1, got %d, %s", loaded.I, loaded.S) + } + if loaded == obj { + t.Fatal("expected different objects") } } diff --git a/internal/route/routes/query.go b/internal/route/routes/query.go index 9500219..8f3195e 100644 --- a/internal/route/routes/query.go +++ b/internal/route/routes/query.go @@ -25,7 +25,7 @@ func getHealthInfo(r Route) map[string]string { } type HealthInfoRaw struct { - Status health.Status `json:"status,string"` + Status health.Status `json:"status"` Latency time.Duration `json:"latency"` }