mirror of
https://github.com/yusing/godoxy.git
synced 2025-05-19 20:32:35 +02:00
Cleaned up some validation code, stricter validation
This commit is contained in:
parent
254224c0e8
commit
1586610a44
23 changed files with 590 additions and 468 deletions
|
@ -75,6 +75,38 @@ func GetFileContent(w http.ResponseWriter, r *http.Request) {
|
|||
U.WriteBody(w, content)
|
||||
}
|
||||
|
||||
func validateFile(fileType FileType, content []byte) error {
|
||||
switch fileType {
|
||||
case FileTypeConfig:
|
||||
return config.Validate(content)
|
||||
case FileTypeMiddleware:
|
||||
errs := E.NewBuilder("middleware errors")
|
||||
middleware.BuildMiddlewaresFromYAML("", content, errs)
|
||||
return errs.Error()
|
||||
}
|
||||
return provider.Validate(content)
|
||||
}
|
||||
|
||||
func ValidateFile(w http.ResponseWriter, r *http.Request) {
|
||||
fileType := FileType(r.PathValue("type"))
|
||||
if !fileType.IsValid() {
|
||||
U.RespondError(w, U.ErrInvalidKey("type"), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
content, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
U.HandleErr(w, r, err)
|
||||
return
|
||||
}
|
||||
r.Body.Close()
|
||||
err = validateFile(fileType, content)
|
||||
if err != nil {
|
||||
U.RespondError(w, err, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
func SetFileContent(w http.ResponseWriter, r *http.Request) {
|
||||
fileType, filename, err := getArgs(r)
|
||||
if err != nil {
|
||||
|
@ -87,19 +119,7 @@ func SetFileContent(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
var valErr E.Error
|
||||
switch fileType {
|
||||
case FileTypeConfig:
|
||||
valErr = config.Validate(content)
|
||||
case FileTypeMiddleware:
|
||||
errs := E.NewBuilder("middleware errors")
|
||||
middleware.BuildMiddlewaresFromYAML(filename, content, errs)
|
||||
valErr = errs.Error()
|
||||
default:
|
||||
valErr = provider.Validate(content)
|
||||
}
|
||||
|
||||
if valErr != nil {
|
||||
if valErr := validateFile(fileType, content); valErr != nil {
|
||||
U.RespondError(w, valErr, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
"crypto/rand"
|
||||
"crypto/x509"
|
||||
"os"
|
||||
"regexp"
|
||||
|
||||
"github.com/go-acme/lego/v4/certcrypto"
|
||||
"github.com/go-acme/lego/v4/lego"
|
||||
|
@ -13,63 +14,89 @@ import (
|
|||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/yusing/go-proxy/internal/utils"
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/config/types"
|
||||
)
|
||||
|
||||
type Config types.AutoCertConfig
|
||||
type (
|
||||
AutocertConfig struct {
|
||||
Email string `json:"email,omitempty"`
|
||||
Domains []string `json:"domains,omitempty"`
|
||||
CertPath string `json:"cert_path,omitempty"`
|
||||
KeyPath string `json:"key_path,omitempty"`
|
||||
ACMEKeyPath string `json:"acme_key_path,omitempty"`
|
||||
Provider string `json:"provider,omitempty"`
|
||||
Options ProviderOpt `json:"options,omitempty"`
|
||||
}
|
||||
ProviderOpt map[string]any
|
||||
)
|
||||
|
||||
var (
|
||||
ErrMissingDomain = E.New("missing field 'domains'")
|
||||
ErrMissingEmail = E.New("missing field 'email'")
|
||||
ErrMissingProvider = E.New("missing field 'provider'")
|
||||
ErrInvalidDomain = E.New("invalid domain")
|
||||
ErrUnknownProvider = E.New("unknown provider")
|
||||
)
|
||||
|
||||
func NewConfig(cfg *types.AutoCertConfig) *Config {
|
||||
var domainOrWildcardRE = regexp.MustCompile(`^\*?([^.]+\.)+[^.]+$`)
|
||||
|
||||
// Validate implements the utils.CustomValidator interface.
|
||||
func (cfg *AutocertConfig) Validate() E.Error {
|
||||
if cfg == nil {
|
||||
cfg = new(types.AutoCertConfig)
|
||||
return nil
|
||||
}
|
||||
|
||||
if cfg.Provider == "" {
|
||||
cfg.Provider = ProviderLocal
|
||||
return nil
|
||||
}
|
||||
|
||||
b := E.NewBuilder("autocert errors")
|
||||
if cfg.Provider != ProviderLocal {
|
||||
if len(cfg.Domains) == 0 {
|
||||
b.Add(ErrMissingDomain)
|
||||
}
|
||||
if cfg.Email == "" {
|
||||
b.Add(ErrMissingEmail)
|
||||
}
|
||||
for i, d := range cfg.Domains {
|
||||
if !domainOrWildcardRE.MatchString(d) {
|
||||
b.Add(ErrInvalidDomain.Subjectf("domains[%d]", i))
|
||||
}
|
||||
}
|
||||
// check if provider is implemented
|
||||
providerConstructor, ok := providersGenMap[cfg.Provider]
|
||||
if !ok {
|
||||
b.Add(ErrUnknownProvider.
|
||||
Subject(cfg.Provider).
|
||||
Withf(strutils.DoYouMean(utils.NearestField(cfg.Provider, providersGenMap))))
|
||||
} else {
|
||||
_, err := providerConstructor(cfg.Options)
|
||||
if err != nil {
|
||||
b.Add(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return b.Error()
|
||||
}
|
||||
|
||||
func (cfg *AutocertConfig) GetProvider() (*Provider, E.Error) {
|
||||
if cfg == nil {
|
||||
cfg = new(AutocertConfig)
|
||||
}
|
||||
|
||||
if err := cfg.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if cfg.CertPath == "" {
|
||||
cfg.CertPath = CertFileDefault
|
||||
}
|
||||
if cfg.KeyPath == "" {
|
||||
cfg.KeyPath = KeyFileDefault
|
||||
}
|
||||
if cfg.Provider == "" {
|
||||
cfg.Provider = ProviderLocal
|
||||
}
|
||||
if cfg.ACMEKeyPath == "" {
|
||||
cfg.ACMEKeyPath = ACMEKeyFileDefault
|
||||
}
|
||||
return (*Config)(cfg)
|
||||
}
|
||||
|
||||
func (cfg *Config) GetProvider() (*Provider, E.Error) {
|
||||
b := E.NewBuilder("autocert errors")
|
||||
|
||||
if cfg.Provider != ProviderLocal {
|
||||
if len(cfg.Domains) == 0 {
|
||||
b.Add(ErrMissingDomain)
|
||||
}
|
||||
if cfg.Provider == "" {
|
||||
b.Add(ErrMissingProvider)
|
||||
}
|
||||
if cfg.Email == "" {
|
||||
b.Add(ErrMissingEmail)
|
||||
}
|
||||
// check if provider is implemented
|
||||
_, ok := providersGenMap[cfg.Provider]
|
||||
if !ok {
|
||||
b.Add(ErrUnknownProvider.
|
||||
Subject(cfg.Provider).
|
||||
Withf(strutils.DoYouMean(utils.NearestField(cfg.Provider, providersGenMap))))
|
||||
}
|
||||
}
|
||||
|
||||
if b.HasError() {
|
||||
return nil, b.Error()
|
||||
}
|
||||
|
||||
var privKey *ecdsa.PrivateKey
|
||||
var err error
|
||||
|
@ -103,7 +130,7 @@ func (cfg *Config) GetProvider() (*Provider, E.Error) {
|
|||
}, nil
|
||||
}
|
||||
|
||||
func (cfg *Config) loadACMEKey() (*ecdsa.PrivateKey, error) {
|
||||
func (cfg *AutocertConfig) loadACMEKey() (*ecdsa.PrivateKey, error) {
|
||||
data, err := os.ReadFile(cfg.ACMEKeyPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -111,7 +138,7 @@ func (cfg *Config) loadACMEKey() (*ecdsa.PrivateKey, error) {
|
|||
return x509.ParseECPrivateKey(data)
|
||||
}
|
||||
|
||||
func (cfg *Config) saveACMEKey(key *ecdsa.PrivateKey) error {
|
||||
func (cfg *AutocertConfig) saveACMEKey(key *ecdsa.PrivateKey) error {
|
||||
data, err := x509.MarshalECPrivateKey(key)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
|
@ -14,7 +14,6 @@ import (
|
|||
"github.com/go-acme/lego/v4/challenge"
|
||||
"github.com/go-acme/lego/v4/lego"
|
||||
"github.com/go-acme/lego/v4/registration"
|
||||
"github.com/yusing/go-proxy/internal/config/types"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
|
@ -24,7 +23,7 @@ import (
|
|||
|
||||
type (
|
||||
Provider struct {
|
||||
cfg *Config
|
||||
cfg *AutocertConfig
|
||||
user *User
|
||||
legoCfg *lego.Config
|
||||
client *lego.Client
|
||||
|
@ -33,7 +32,7 @@ type (
|
|||
tlsCert *tls.Certificate
|
||||
certExpiries CertExpiries
|
||||
}
|
||||
ProviderGenerator func(types.AutocertProviderOpt) (challenge.Provider, E.Error)
|
||||
ProviderGenerator func(ProviderOpt) (challenge.Provider, E.Error)
|
||||
|
||||
CertExpiries map[string]time.Time
|
||||
)
|
||||
|
@ -313,7 +312,7 @@ func providerGenerator[CT any, PT challenge.Provider](
|
|||
defaultCfg func() *CT,
|
||||
newProvider func(*CT) (PT, error),
|
||||
) ProviderGenerator {
|
||||
return func(opt types.AutocertProviderOpt) (challenge.Provider, E.Error) {
|
||||
return func(opt ProviderOpt) (challenge.Provider, E.Error) {
|
||||
cfg := defaultCfg()
|
||||
err := U.Deserialize(opt, cfg)
|
||||
if err != nil {
|
||||
|
|
|
@ -15,9 +15,11 @@ type User struct {
|
|||
func (u *User) GetEmail() string {
|
||||
return u.Email
|
||||
}
|
||||
|
||||
func (u *User) GetRegistration() *registration.Resource {
|
||||
return u.Registration
|
||||
}
|
||||
|
||||
func (u *User) GetPrivateKey() crypto.PrivateKey {
|
||||
return u.key
|
||||
}
|
||||
|
|
|
@ -234,7 +234,7 @@ func (cfg *Config) load() E.Error {
|
|||
errs := E.NewBuilder(errMsg)
|
||||
errs.Add(cfg.entrypoint.SetMiddlewares(model.Entrypoint.Middlewares))
|
||||
errs.Add(cfg.entrypoint.SetAccessLogger(cfg.task, model.Entrypoint.AccessLog))
|
||||
errs.Add(cfg.initNotification(model.Providers.Notification))
|
||||
cfg.initNotification(model.Providers.Notification)
|
||||
errs.Add(cfg.initAutoCert(model.AutoCert))
|
||||
errs.Add(cfg.loadRouteProviders(&model.Providers))
|
||||
|
||||
|
@ -249,28 +249,22 @@ func (cfg *Config) load() E.Error {
|
|||
return errs.Error()
|
||||
}
|
||||
|
||||
func (cfg *Config) initNotification(notifCfg []types.NotificationConfig) (err E.Error) {
|
||||
func (cfg *Config) initNotification(notifCfg []notif.NotificationConfig) {
|
||||
if len(notifCfg) == 0 {
|
||||
return
|
||||
}
|
||||
dispatcher := notif.StartNotifDispatcher(cfg.task)
|
||||
errs := E.NewBuilder("notification providers load errors")
|
||||
for i, notifier := range notifCfg {
|
||||
_, err := dispatcher.RegisterProvider(notifier)
|
||||
if err == nil {
|
||||
continue
|
||||
}
|
||||
errs.Add(err.Subjectf("[%d]", i))
|
||||
for _, notifier := range notifCfg {
|
||||
dispatcher.RegisterProvider(¬ifier)
|
||||
}
|
||||
return errs.Error()
|
||||
}
|
||||
|
||||
func (cfg *Config) initAutoCert(autocertCfg *types.AutoCertConfig) (err E.Error) {
|
||||
func (cfg *Config) initAutoCert(autocertCfg *autocert.AutocertConfig) (err E.Error) {
|
||||
if cfg.autocertProvider != nil {
|
||||
return
|
||||
}
|
||||
|
||||
cfg.autocertProvider, err = autocert.NewConfig(autocertCfg).GetProvider()
|
||||
cfg.autocertProvider, err = autocertCfg.GetProvider()
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -1,14 +0,0 @@
|
|||
package types
|
||||
|
||||
type (
|
||||
AutoCertConfig struct {
|
||||
Email string `json:"email,omitempty" validate:"email"`
|
||||
Domains []string `json:"domains,omitempty"`
|
||||
CertPath string `json:"cert_path,omitempty" validate:"omitempty,filepath"`
|
||||
KeyPath string `json:"key_path,omitempty" validate:"omitempty,filepath"`
|
||||
ACMEKeyPath string `json:"acme_key_path,omitempty" validate:"omitempty,filepath"`
|
||||
Provider string `json:"provider,omitempty"`
|
||||
Options AutocertProviderOpt `json:"options,omitempty"`
|
||||
}
|
||||
AutocertProviderOpt map[string]any
|
||||
)
|
|
@ -2,8 +2,12 @@ package types
|
|||
|
||||
import (
|
||||
"context"
|
||||
"regexp"
|
||||
|
||||
"github.com/go-playground/validator/v10"
|
||||
"github.com/yusing/go-proxy/internal/autocert"
|
||||
"github.com/yusing/go-proxy/internal/net/http/accesslog"
|
||||
"github.com/yusing/go-proxy/internal/notif"
|
||||
"github.com/yusing/go-proxy/internal/utils"
|
||||
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
|
@ -11,23 +15,22 @@ import (
|
|||
|
||||
type (
|
||||
Config struct {
|
||||
AutoCert *AutoCertConfig `json:"autocert" validate:"omitempty"`
|
||||
Entrypoint Entrypoint `json:"entrypoint"`
|
||||
Providers Providers `json:"providers"`
|
||||
MatchDomains []string `json:"match_domains" validate:"dive,fqdn"`
|
||||
Homepage HomepageConfig `json:"homepage"`
|
||||
TimeoutShutdown int `json:"timeout_shutdown" validate:"gte=0"`
|
||||
AutoCert *autocert.AutocertConfig `json:"autocert"`
|
||||
Entrypoint Entrypoint `json:"entrypoint"`
|
||||
Providers Providers `json:"providers"`
|
||||
MatchDomains []string `json:"match_domains" validate:"domain_name"`
|
||||
Homepage HomepageConfig `json:"homepage"`
|
||||
TimeoutShutdown int `json:"timeout_shutdown" validate:"gte=0"`
|
||||
}
|
||||
Providers struct {
|
||||
Files []string `json:"include" validate:"dive,filepath"`
|
||||
Docker map[string]string `json:"docker" validate:"dive,unix_addr|url"`
|
||||
Notification []NotificationConfig `json:"notification"`
|
||||
Files []string `json:"include" validate:"dive,filepath"`
|
||||
Docker map[string]string `json:"docker" validate:"dive,unix_addr|url"`
|
||||
Notification []notif.NotificationConfig `json:"notification"`
|
||||
}
|
||||
Entrypoint struct {
|
||||
Middlewares []map[string]any `json:"middlewares"`
|
||||
AccessLog *accesslog.Config `json:"access_log" validate:"omitempty"`
|
||||
}
|
||||
NotificationConfig map[string]any
|
||||
|
||||
ConfigInstance interface {
|
||||
Value() *Config
|
||||
|
@ -52,6 +55,17 @@ func Validate(data []byte) E.Error {
|
|||
return utils.DeserializeYAML(data, &model)
|
||||
}
|
||||
|
||||
var matchDomainsRegex = regexp.MustCompile(`^[^\.]?([\w\d\-_]\.?)+[^\.]?$`)
|
||||
|
||||
func init() {
|
||||
utils.RegisterDefaultValueFactory(DefaultConfig)
|
||||
utils.MustRegisterValidation("domain_name", func(fl validator.FieldLevel) bool {
|
||||
domains := fl.Field().Interface().([]string)
|
||||
for _, domain := range domains {
|
||||
if !matchDomainsRegex.MatchString(domain) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
|
|
@ -19,6 +19,13 @@ func Errorf(format string, args ...any) Error {
|
|||
return &baseError{fmt.Errorf(format, args...)}
|
||||
}
|
||||
|
||||
func Wrap(err error, message ...string) Error {
|
||||
if len(message) == 0 || message[0] == "" {
|
||||
return From(err)
|
||||
}
|
||||
return Errorf("%w: %s", err, message[0])
|
||||
}
|
||||
|
||||
func From(err error) Error {
|
||||
if err == nil {
|
||||
return nil
|
||||
|
|
|
@ -1,11 +1,13 @@
|
|||
package homepage
|
||||
|
||||
import "github.com/yusing/go-proxy/internal/utils"
|
||||
import (
|
||||
"github.com/yusing/go-proxy/internal/utils"
|
||||
)
|
||||
|
||||
type (
|
||||
//nolint:recvcheck
|
||||
Config map[string]Category
|
||||
Category []*Item
|
||||
Categories map[string]Category
|
||||
Category []*Item
|
||||
|
||||
ItemConfig struct {
|
||||
Show bool `json:"show"`
|
||||
|
@ -48,6 +50,10 @@ func NewItem(alias string) *Item {
|
|||
}
|
||||
}
|
||||
|
||||
func NewHomePageConfig() Categories {
|
||||
return Categories(make(map[string]Category))
|
||||
}
|
||||
|
||||
func (item *Item) IsEmpty() bool {
|
||||
return item == nil || item.IsUnset || item.ItemConfig == nil
|
||||
}
|
||||
|
@ -56,15 +62,11 @@ func (item *Item) GetOverride() *Item {
|
|||
return overrideConfigInstance.GetOverride(item)
|
||||
}
|
||||
|
||||
func NewHomePageConfig() Config {
|
||||
return Config(make(map[string]Category))
|
||||
func (c *Categories) Clear() {
|
||||
*c = make(Categories)
|
||||
}
|
||||
|
||||
func (c *Config) Clear() {
|
||||
*c = make(Config)
|
||||
}
|
||||
|
||||
func (c Config) Add(item *Item) {
|
||||
func (c Categories) Add(item *Item) {
|
||||
if c[item.Category] == nil {
|
||||
c[item.Category] = make(Category, 0)
|
||||
}
|
||||
|
|
|
@ -53,10 +53,6 @@ func GetOverrideConfig() *OverrideConfig {
|
|||
return overrideConfigInstance
|
||||
}
|
||||
|
||||
func (c *OverrideConfig) UnmarshalJSON(data []byte) error {
|
||||
return utils.DeserializeJSON(data, c)
|
||||
}
|
||||
|
||||
func (c *OverrideConfig) OverrideItem(alias string, override *ItemConfig) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
|
47
internal/notif/base.go
Normal file
47
internal/notif/base.go
Normal file
|
@ -0,0 +1,47 @@
|
|||
package notif
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
)
|
||||
|
||||
type ProviderBase struct {
|
||||
Name string `json:"name" validate:"required"`
|
||||
URL string `json:"url" validate:"url"`
|
||||
Token string `json:"token"`
|
||||
}
|
||||
|
||||
var (
|
||||
ErrMissingToken = E.New("token is required")
|
||||
ErrURLMissingScheme = E.New("url missing scheme, expect 'http://' or 'https://'")
|
||||
)
|
||||
|
||||
// Validate implements the utils.CustomValidator interface.
|
||||
func (base *ProviderBase) Validate() E.Error {
|
||||
if base.Token == "" {
|
||||
return ErrMissingToken
|
||||
}
|
||||
if !strings.HasPrefix(base.URL, "http://") && !strings.HasPrefix(base.URL, "https://") {
|
||||
return ErrURLMissingScheme
|
||||
}
|
||||
u, err := url.Parse(base.URL)
|
||||
if err != nil {
|
||||
return E.Wrap(err)
|
||||
}
|
||||
base.URL = u.String()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (base *ProviderBase) GetName() string {
|
||||
return base.Name
|
||||
}
|
||||
|
||||
func (base *ProviderBase) GetURL() string {
|
||||
return base.URL
|
||||
}
|
||||
|
||||
func (base *ProviderBase) GetToken() string {
|
||||
return base.Token
|
||||
}
|
54
internal/notif/config.go
Normal file
54
internal/notif/config.go
Normal file
|
@ -0,0 +1,54 @@
|
|||
package notif
|
||||
|
||||
import (
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/utils"
|
||||
)
|
||||
|
||||
type NotificationConfig struct {
|
||||
ProviderName string `json:"provider"`
|
||||
Provider Provider `json:"-"`
|
||||
}
|
||||
|
||||
var (
|
||||
ErrMissingNotifProvider = E.New("missing notification provider")
|
||||
ErrInvalidNotifProviderType = E.New("invalid notification provider type")
|
||||
ErrUnknownNotifProvider = E.New("unknown notification provider")
|
||||
)
|
||||
|
||||
// UnmarshalMap implements MapUnmarshaler.
|
||||
func (cfg *NotificationConfig) UnmarshalMap(m map[string]any) (err E.Error) {
|
||||
// extract provider name
|
||||
providerName := m["provider"]
|
||||
switch providerName := providerName.(type) {
|
||||
case string:
|
||||
cfg.ProviderName = providerName
|
||||
default:
|
||||
return ErrInvalidNotifProviderType
|
||||
}
|
||||
delete(m, "provider")
|
||||
|
||||
if cfg.ProviderName == "" {
|
||||
return ErrMissingNotifProvider
|
||||
}
|
||||
|
||||
// validate provider name and initialize provider
|
||||
switch cfg.ProviderName {
|
||||
case ProviderWebhook:
|
||||
cfg.Provider = &Webhook{}
|
||||
case ProviderGotify:
|
||||
cfg.Provider = &GotifyClient{}
|
||||
default:
|
||||
return ErrUnknownNotifProvider.
|
||||
Subject(cfg.ProviderName).
|
||||
Withf("expect %s or %s", ProviderWebhook, ProviderGotify)
|
||||
}
|
||||
|
||||
// unmarshal provider config
|
||||
if err := utils.Deserialize(m, cfg.Provider); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// validate provider
|
||||
return cfg.Provider.Validate()
|
||||
}
|
163
internal/notif/config_test.go
Normal file
163
internal/notif/config_test.go
Normal file
|
@ -0,0 +1,163 @@
|
|||
package notif
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/utils"
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
||||
func TestNotificationConfig(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg map[string]any
|
||||
expected Provider
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid_webhook",
|
||||
cfg: map[string]any{
|
||||
"name": "test",
|
||||
"provider": "webhook",
|
||||
"template": "discord",
|
||||
"url": "https://example.com",
|
||||
},
|
||||
expected: &Webhook{
|
||||
ProviderBase: ProviderBase{
|
||||
Name: "test",
|
||||
URL: "https://example.com",
|
||||
},
|
||||
Template: "discord",
|
||||
Method: http.MethodPost,
|
||||
MIMEType: "application/json",
|
||||
ColorMode: "dec",
|
||||
Payload: discordPayload,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid_gotify",
|
||||
cfg: map[string]any{
|
||||
"name": "test",
|
||||
"provider": "gotify",
|
||||
"url": "https://example.com",
|
||||
"token": "token",
|
||||
},
|
||||
expected: &GotifyClient{
|
||||
ProviderBase: ProviderBase{
|
||||
Name: "test",
|
||||
URL: "https://example.com",
|
||||
Token: "token",
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid_provider",
|
||||
cfg: map[string]any{
|
||||
"name": "test",
|
||||
"provider": "invalid",
|
||||
"url": "https://example.com",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "missing_url",
|
||||
cfg: map[string]any{
|
||||
"name": "test",
|
||||
"provider": "webhook",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "missing_provider",
|
||||
cfg: map[string]any{
|
||||
"name": "test",
|
||||
"provider": "webhook",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "gotify_missing_token",
|
||||
cfg: map[string]any{
|
||||
"name": "test",
|
||||
"provider": "gotify",
|
||||
"url": "https://example.com",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "webhook_missing_payload",
|
||||
cfg: map[string]any{
|
||||
"name": "test",
|
||||
"provider": "webhook",
|
||||
"url": "https://example.com",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "webhook_missing_url",
|
||||
cfg: map[string]any{
|
||||
"name": "test",
|
||||
"provider": "webhook",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "webhook_invalid_template",
|
||||
cfg: map[string]any{
|
||||
"name": "test",
|
||||
"provider": "webhook",
|
||||
"url": "https://example.com",
|
||||
"template": "invalid",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "webhook_invalid_json_payload",
|
||||
cfg: map[string]any{
|
||||
"name": "test",
|
||||
"provider": "webhook",
|
||||
"url": "https://example.com",
|
||||
"mime_type": "application/json",
|
||||
"payload": "invalid",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "webhook_empty_text_payload",
|
||||
cfg: map[string]any{
|
||||
"name": "test",
|
||||
"provider": "webhook",
|
||||
"url": "https://example.com",
|
||||
"mime_type": "text/plain",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "webhook_invalid_method",
|
||||
cfg: map[string]any{
|
||||
"name": "test",
|
||||
"provider": "webhook",
|
||||
"url": "https://example.com",
|
||||
"method": "invalid",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var cfg NotificationConfig
|
||||
provider := tt.cfg["provider"]
|
||||
err := utils.Deserialize(tt.cfg, &cfg)
|
||||
if tt.wantErr {
|
||||
ExpectHasError(t, err)
|
||||
} else {
|
||||
ExpectNoError(t, err)
|
||||
ExpectEqual(t, provider.(string), cfg.ProviderName)
|
||||
ExpectDeepEqual(t, cfg.Provider, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -2,13 +2,10 @@ package notif
|
|||
|
||||
import (
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/yusing/go-proxy/internal/config/types"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
"github.com/yusing/go-proxy/internal/utils"
|
||||
F "github.com/yusing/go-proxy/internal/utils/functional"
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
)
|
||||
|
||||
type (
|
||||
|
@ -27,12 +24,6 @@ type (
|
|||
|
||||
var dispatcher *Dispatcher
|
||||
|
||||
var (
|
||||
ErrMissingNotifProvider = E.New("missing notification provider")
|
||||
ErrInvalidNotifProviderType = E.New("invalid notification provider type")
|
||||
ErrUnknownNotifProvider = E.New("unknown notification provider")
|
||||
)
|
||||
|
||||
const dispatchErr = "notification dispatch error"
|
||||
|
||||
func StartNotifDispatcher(parent task.Parent) *Dispatcher {
|
||||
|
@ -57,29 +48,8 @@ func Notify(msg *LogMessage) {
|
|||
}
|
||||
}
|
||||
|
||||
func (disp *Dispatcher) RegisterProvider(cfg types.NotificationConfig) (Provider, E.Error) {
|
||||
providerName, ok := cfg["provider"]
|
||||
if !ok {
|
||||
return nil, ErrMissingNotifProvider
|
||||
}
|
||||
switch providerName := providerName.(type) {
|
||||
case string:
|
||||
delete(cfg, "provider")
|
||||
createFunc, ok := Providers[providerName]
|
||||
if !ok {
|
||||
return nil, ErrUnknownNotifProvider.
|
||||
Subject(providerName).
|
||||
Withf(strutils.DoYouMean(utils.NearestField(providerName, Providers)))
|
||||
}
|
||||
|
||||
provider, err := createFunc(cfg)
|
||||
if err == nil {
|
||||
disp.providers.Add(provider)
|
||||
}
|
||||
return provider, err
|
||||
default:
|
||||
return nil, ErrInvalidNotifProviderType.Subjectf("%T", providerName)
|
||||
}
|
||||
func (disp *Dispatcher) RegisterProvider(cfg *NotificationConfig) {
|
||||
disp.providers.Add(cfg.Provider)
|
||||
}
|
||||
|
||||
func (disp *Dispatcher) start() {
|
||||
|
@ -110,7 +80,7 @@ func (disp *Dispatcher) dispatch(msg *LogMessage) {
|
|||
errs := E.NewBuilder(dispatchErr)
|
||||
disp.providers.RangeAllParallel(func(p Provider) {
|
||||
if err := notifyProvider(task.Context(), p, msg); err != nil {
|
||||
errs.Add(E.PrependSubject(p.Name(), err))
|
||||
errs.Add(E.PrependSubject(p.GetName(), err))
|
||||
}
|
||||
})
|
||||
if errs.HasError() {
|
||||
|
|
|
@ -13,37 +13,24 @@ import (
|
|||
|
||||
type (
|
||||
GotifyClient struct {
|
||||
N string `json:"name" validate:"required"`
|
||||
U string `json:"url" validate:"url"`
|
||||
Tok string `json:"token" validate:"required"`
|
||||
ProviderBase
|
||||
}
|
||||
GotifyMessage model.MessageExternal
|
||||
)
|
||||
|
||||
const gotifyMsgEndpoint = "/message"
|
||||
|
||||
// Name implements Provider.
|
||||
func (client *GotifyClient) Name() string {
|
||||
return client.N
|
||||
func (client *GotifyClient) GetURL() string {
|
||||
return client.URL + gotifyMsgEndpoint
|
||||
}
|
||||
|
||||
// Method implements Provider.
|
||||
func (client *GotifyClient) Method() string {
|
||||
// GetMethod implements Provider.
|
||||
func (client *GotifyClient) GetMethod() string {
|
||||
return http.MethodPost
|
||||
}
|
||||
|
||||
// URL implements Provider.
|
||||
func (client *GotifyClient) URL() string {
|
||||
return client.U + gotifyMsgEndpoint
|
||||
}
|
||||
|
||||
// Token implements Provider.
|
||||
func (client *GotifyClient) Token() string {
|
||||
return client.Tok
|
||||
}
|
||||
|
||||
// MIMEType implements Provider.
|
||||
func (client *GotifyClient) MIMEType() string {
|
||||
// GetMIMEType implements Provider.
|
||||
func (client *GotifyClient) GetMIMEType() string {
|
||||
return "application/json"
|
||||
}
|
||||
|
||||
|
|
|
@ -1,52 +0,0 @@
|
|||
package notif
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/utils"
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
||||
func TestGotifyValidation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
newGotify := Providers[ProviderGotify]
|
||||
|
||||
t.Run("valid", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, err := newGotify(map[string]any{
|
||||
"name": "test",
|
||||
"url": "https://example.com",
|
||||
"token": "token",
|
||||
})
|
||||
ExpectNoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("missing url", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, err := newGotify(map[string]any{
|
||||
"name": "test",
|
||||
"token": "token",
|
||||
})
|
||||
ExpectError(t, utils.ErrValidationError, err)
|
||||
})
|
||||
|
||||
t.Run("missing token", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, err := newGotify(map[string]any{
|
||||
"name": "test",
|
||||
"url": "https://example.com",
|
||||
})
|
||||
ExpectError(t, utils.ErrValidationError, err)
|
||||
})
|
||||
|
||||
t.Run("invalid url", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, err := newGotify(map[string]any{
|
||||
"name": "test",
|
||||
"url": "example.com",
|
||||
"token": "token",
|
||||
})
|
||||
ExpectError(t, utils.ErrValidationError, err)
|
||||
})
|
||||
}
|
|
@ -2,22 +2,24 @@ package notif
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
gphttp "github.com/yusing/go-proxy/internal/net/http"
|
||||
U "github.com/yusing/go-proxy/internal/utils"
|
||||
"github.com/yusing/go-proxy/internal/utils"
|
||||
)
|
||||
|
||||
type (
|
||||
Provider interface {
|
||||
Name() string
|
||||
URL() string
|
||||
Method() string
|
||||
Token() string
|
||||
MIMEType() string
|
||||
utils.CustomValidator
|
||||
|
||||
GetName() string
|
||||
GetURL() string
|
||||
GetToken() string
|
||||
GetMethod() string
|
||||
GetMIMEType() string
|
||||
|
||||
MakeBody(logMsg *LogMessage) (io.Reader, error)
|
||||
|
||||
makeRespError(resp *http.Response) error
|
||||
|
@ -31,47 +33,29 @@ const (
|
|||
ProviderWebhook = "webhook"
|
||||
)
|
||||
|
||||
var Providers = map[string]ProviderCreateFunc{
|
||||
ProviderGotify: newNotifProvider[*GotifyClient],
|
||||
ProviderWebhook: newNotifProvider[*Webhook],
|
||||
}
|
||||
|
||||
func newNotifProvider[T Provider](cfg map[string]any) (Provider, E.Error) {
|
||||
var client T
|
||||
err := U.Deserialize(cfg, &client)
|
||||
if err != nil {
|
||||
return nil, err.Subject(client.Name())
|
||||
}
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func formatError(p Provider, err error) error {
|
||||
return fmt.Errorf("%s error: %w", p.Name(), err)
|
||||
}
|
||||
|
||||
func notifyProvider(ctx context.Context, provider Provider, msg *LogMessage) error {
|
||||
body, err := provider.MakeBody(msg)
|
||||
if err != nil {
|
||||
return formatError(provider, err)
|
||||
return E.PrependSubject(provider.GetName(), err)
|
||||
}
|
||||
req, err := http.NewRequestWithContext(
|
||||
ctx,
|
||||
http.MethodPost,
|
||||
provider.URL(),
|
||||
provider.GetURL(),
|
||||
body,
|
||||
)
|
||||
if err != nil {
|
||||
return formatError(provider, err)
|
||||
return E.PrependSubject(provider.GetName(), err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", provider.MIMEType())
|
||||
if provider.Token() != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+provider.Token())
|
||||
req.Header.Set("Content-Type", provider.GetMIMEType())
|
||||
if provider.GetToken() != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+provider.GetToken())
|
||||
}
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return formatError(provider, err)
|
||||
return E.PrependSubject(provider.GetName(), err)
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
|
|
|
@ -8,19 +8,16 @@ import (
|
|||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/go-playground/validator/v10"
|
||||
"github.com/yusing/go-proxy/internal/utils"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
)
|
||||
|
||||
type Webhook struct {
|
||||
N string `json:"name" validate:"required"`
|
||||
U string `json:"url" validate:"url"`
|
||||
Template string `json:"template" validate:"omitempty,oneof=discord"`
|
||||
Payload string `json:"payload" validate:"jsonIfTemplateNotUsed"`
|
||||
Tok string `json:"token"`
|
||||
Meth string `json:"method" validate:"oneof=GET POST PUT"`
|
||||
MIMETyp string `json:"mime_type"`
|
||||
ColorM string `json:"color_mode" validate:"oneof=hex dec"`
|
||||
ProviderBase
|
||||
Template string `json:"template"`
|
||||
Payload string `json:"payload"`
|
||||
Method string `json:"method"`
|
||||
MIMEType string `json:"mime_type"`
|
||||
ColorMode string `json:"color_mode"`
|
||||
}
|
||||
|
||||
//go:embed templates/discord.json
|
||||
|
@ -30,60 +27,65 @@ var webhookTemplates = map[string]string{
|
|||
"discord": discordPayload,
|
||||
}
|
||||
|
||||
func DefaultValue() *Webhook {
|
||||
return &Webhook{
|
||||
Meth: "POST",
|
||||
ColorM: "hex",
|
||||
MIMETyp: "application/json",
|
||||
func (webhook *Webhook) Validate() E.Error {
|
||||
if err := webhook.ProviderBase.Validate(); err != nil && !err.Is(ErrMissingToken) {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
func jsonIfTemplateNotUsed(fl validator.FieldLevel) bool {
|
||||
template := fl.Parent().FieldByName("Template").String()
|
||||
if template != "" {
|
||||
return true
|
||||
}
|
||||
payload := fl.Field().String()
|
||||
return json.Valid([]byte(payload))
|
||||
}
|
||||
|
||||
func init() {
|
||||
utils.RegisterDefaultValueFactory(DefaultValue)
|
||||
utils.MustRegisterValidation("jsonIfTemplateNotUsed", jsonIfTemplateNotUsed)
|
||||
}
|
||||
|
||||
// Name implements Provider.
|
||||
func (webhook *Webhook) Name() string {
|
||||
return webhook.N
|
||||
}
|
||||
|
||||
// Method implements Provider.
|
||||
func (webhook *Webhook) Method() string {
|
||||
return webhook.Meth
|
||||
}
|
||||
|
||||
// URL implements Provider.
|
||||
func (webhook *Webhook) URL() string {
|
||||
return webhook.U
|
||||
}
|
||||
|
||||
// Token implements Provider.
|
||||
func (webhook *Webhook) Token() string {
|
||||
return webhook.Tok
|
||||
}
|
||||
|
||||
// MIMEType implements Provider.
|
||||
func (webhook *Webhook) MIMEType() string {
|
||||
return webhook.MIMETyp
|
||||
}
|
||||
|
||||
func (webhook *Webhook) ColorMode() string {
|
||||
switch webhook.Template {
|
||||
case "discord":
|
||||
return "dec"
|
||||
switch webhook.MIMEType {
|
||||
case "":
|
||||
webhook.MIMEType = "application/json"
|
||||
case "application/json", "application/x-www-form-urlencoded", "text/plain":
|
||||
default:
|
||||
return webhook.ColorM
|
||||
return E.New("invalid mime_type, expect empty, 'application/json', 'application/x-www-form-urlencoded' or 'text/plain'")
|
||||
}
|
||||
|
||||
switch webhook.Template {
|
||||
case "":
|
||||
if webhook.MIMEType == "application/json" && !json.Valid([]byte(webhook.Payload)) {
|
||||
return E.New("invalid payload, expect valid JSON")
|
||||
}
|
||||
if webhook.Payload == "" {
|
||||
return E.New("invalid payload, expect non-empty")
|
||||
}
|
||||
case "discord":
|
||||
webhook.ColorMode = "dec"
|
||||
webhook.Method = http.MethodPost
|
||||
webhook.MIMEType = "application/json"
|
||||
if webhook.Payload == "" {
|
||||
webhook.Payload = discordPayload
|
||||
}
|
||||
default:
|
||||
return E.New("invalid template, expect empty or 'discord'")
|
||||
}
|
||||
|
||||
switch webhook.Method {
|
||||
case "":
|
||||
webhook.Method = http.MethodPost
|
||||
case http.MethodGet, http.MethodPost, http.MethodPut:
|
||||
default:
|
||||
return E.New("invalid method, expect empty, 'GET', 'POST' or 'PUT'")
|
||||
}
|
||||
|
||||
switch webhook.ColorMode {
|
||||
case "":
|
||||
webhook.ColorMode = "hex"
|
||||
case "hex", "dec":
|
||||
default:
|
||||
return E.New("invalid color_mode, expect empty, 'hex' or 'dec'")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetMethod implements Provider.
|
||||
func (webhook *Webhook) GetMethod() string {
|
||||
return webhook.Method
|
||||
}
|
||||
|
||||
// GetMIMEType implements Provider.
|
||||
func (webhook *Webhook) GetMIMEType() string {
|
||||
return webhook.MIMEType
|
||||
}
|
||||
|
||||
// makeRespError implements Provider.
|
||||
|
@ -108,7 +110,7 @@ func (webhook *Webhook) MakeBody(logMsg *LogMessage) (io.Reader, error) {
|
|||
return nil, err
|
||||
}
|
||||
var color string
|
||||
if webhook.ColorMode() == "hex" {
|
||||
if webhook.ColorMode == "hex" {
|
||||
color = logMsg.Color.HexString()
|
||||
} else {
|
||||
color = logMsg.Color.DecString()
|
||||
|
|
|
@ -1,121 +0,0 @@
|
|||
package notif
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/utils"
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
||||
func TestWebhookValidation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
newWebhook := Providers[ProviderWebhook]
|
||||
|
||||
t.Run("valid", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, err := newWebhook(map[string]any{
|
||||
"name": "test",
|
||||
"url": "https://example.com",
|
||||
"payload": "{}",
|
||||
})
|
||||
ExpectNoError(t, err)
|
||||
})
|
||||
t.Run("valid template", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, err := newWebhook(map[string]any{
|
||||
"name": "test",
|
||||
"url": "https://example.com",
|
||||
"template": "discord",
|
||||
})
|
||||
ExpectNoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("missing url", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, err := newWebhook(map[string]any{
|
||||
"name": "test",
|
||||
"payload": "{}",
|
||||
})
|
||||
ExpectError(t, utils.ErrValidationError, err)
|
||||
})
|
||||
|
||||
t.Run("missing payload", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, err := newWebhook(map[string]any{
|
||||
"name": "test",
|
||||
"url": "https://example.com",
|
||||
})
|
||||
ExpectError(t, utils.ErrValidationError, err)
|
||||
})
|
||||
t.Run("invalid url", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, err := newWebhook(map[string]any{
|
||||
"name": "test",
|
||||
"url": "example.com",
|
||||
"payload": "{}",
|
||||
})
|
||||
ExpectError(t, utils.ErrValidationError, err)
|
||||
})
|
||||
t.Run("invalid payload", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, err := newWebhook(map[string]any{
|
||||
"name": "test",
|
||||
"url": "https://example.com",
|
||||
"payload": "abcd",
|
||||
})
|
||||
ExpectError(t, utils.ErrValidationError, err)
|
||||
})
|
||||
t.Run("invalid method", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, err := newWebhook(map[string]any{
|
||||
"name": "test",
|
||||
"url": "https://example.com",
|
||||
"payload": "{}",
|
||||
"method": "abcd",
|
||||
})
|
||||
ExpectError(t, utils.ErrValidationError, err)
|
||||
})
|
||||
t.Run("invalid template", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, err := newWebhook(map[string]any{
|
||||
"name": "test",
|
||||
"url": "https://example.com",
|
||||
"template": "abcd",
|
||||
})
|
||||
ExpectError(t, utils.ErrValidationError, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestWebhookBody(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var webhook Webhook
|
||||
webhook.Payload = discordPayload
|
||||
bodyReader, err := webhook.MakeBody(&LogMessage{
|
||||
Title: "abc",
|
||||
Extras: map[string]any{
|
||||
"foo": "bar",
|
||||
},
|
||||
})
|
||||
ExpectNoError(t, err)
|
||||
|
||||
var body struct {
|
||||
Embeds []struct {
|
||||
Title string `json:"title"`
|
||||
Fields []struct {
|
||||
Name string `json:"name"`
|
||||
Value string `json:"value"`
|
||||
} `json:"fields"`
|
||||
} `json:"embeds"`
|
||||
}
|
||||
|
||||
err = json.NewDecoder(bodyReader).Decode(&body)
|
||||
ExpectNoError(t, err)
|
||||
|
||||
ExpectEqual(t, body.Embeds[0].Title, "abc")
|
||||
fields := body.Embeds[0].Fields
|
||||
ExpectEqual(t, fields[0].Name, "foo")
|
||||
ExpectEqual(t, fields[0].Value, "bar")
|
||||
}
|
|
@ -57,7 +57,7 @@ func HomepageCategories() []string {
|
|||
return categories
|
||||
}
|
||||
|
||||
func HomepageConfig(useDefaultCategories bool, categoryFilter, providerFilter string) homepage.Config {
|
||||
func HomepageConfig(useDefaultCategories bool, categoryFilter, providerFilter string) homepage.Categories {
|
||||
hpCfg := homepage.NewHomePageConfig()
|
||||
|
||||
routes.GetHTTPRoutes().RangeAll(func(alias string, r route.HTTPRoute) {
|
||||
|
|
|
@ -1,13 +1,10 @@
|
|||
package utils
|
||||
|
||||
// FIXME: some times [%d] is not in correct order
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"os"
|
||||
"reflect"
|
||||
"runtime/debug"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
@ -21,6 +18,10 @@ import (
|
|||
|
||||
type SerializedObject = map[string]any
|
||||
|
||||
type MapUnmarshaller interface {
|
||||
UnmarshalMap(m map[string]any) E.Error
|
||||
}
|
||||
|
||||
var (
|
||||
ErrInvalidType = E.New("invalid type")
|
||||
ErrNilValue = E.New("nil")
|
||||
|
@ -29,6 +30,8 @@ var (
|
|||
ErrUnknownField = E.New("unknown field")
|
||||
)
|
||||
|
||||
var mapUnmarshalerType = reflect.TypeFor[MapUnmarshaller]()
|
||||
|
||||
var defaultValues = functional.NewMapOf[reflect.Type, func() any]()
|
||||
|
||||
func RegisterDefaultValueFactory[T any](factory func() *T) {
|
||||
|
@ -56,8 +59,9 @@ func extractFields(t reflect.Type) (all, anonymous []reflect.StructField) {
|
|||
if t.Kind() != reflect.Struct {
|
||||
return nil, nil
|
||||
}
|
||||
var fields []reflect.StructField
|
||||
for i := range t.NumField() {
|
||||
n := t.NumField()
|
||||
fields := make([]reflect.StructField, 0, n)
|
||||
for i := range n {
|
||||
field := t.Field(i)
|
||||
if !field.IsExported() {
|
||||
continue
|
||||
|
@ -74,31 +78,74 @@ func extractFields(t reflect.Type) (all, anonymous []reflect.StructField) {
|
|||
return fields, anonymous
|
||||
}
|
||||
|
||||
func ValidateWithFieldTags(s any) E.Error {
|
||||
errs := E.NewBuilder("validate error")
|
||||
err := validate.Struct(s)
|
||||
var valErrs validator.ValidationErrors
|
||||
if errors.As(err, &valErrs) {
|
||||
for _, e := range valErrs {
|
||||
detail := e.ActualTag()
|
||||
if e.Param() != "" {
|
||||
detail += ":" + e.Param()
|
||||
}
|
||||
errs.Add(ErrValidationError.
|
||||
Subject(e.Namespace()).
|
||||
Withf("require %q", detail))
|
||||
}
|
||||
}
|
||||
return errs.Error()
|
||||
}
|
||||
|
||||
// Deserialize takes a SerializedObject and a target value, and assigns the values in the SerializedObject to the target value.
|
||||
// Deserialize ignores case differences between the field names in the SerializedObject and the target.
|
||||
//
|
||||
// The target value must be a struct or a map[string]any.
|
||||
// If the target value is a struct, the SerializedObject will be deserialized into the struct fields and validate if needed.
|
||||
// If the target value is a map[string]any, the SerializedObject will be deserialized into the map.
|
||||
// If the target value is a struct , and implements the MapUnmarshaller interface,
|
||||
// the UnmarshalMap method will be called.
|
||||
//
|
||||
// If the target value is a struct, but does not implements the MapUnmarshaller interface,
|
||||
// the SerializedObject will be deserialized into the struct fields and validate if needed.
|
||||
//
|
||||
// If the target value is a map[string]any the SerializedObject will be deserialized into the map.
|
||||
//
|
||||
// The function returns an error if the target value is not a struct or a map[string]any, or if there is an error during deserialization.
|
||||
func Deserialize(src SerializedObject, dst any) E.Error {
|
||||
if src == nil {
|
||||
return E.Errorf("deserialize: src is %w", ErrNilValue)
|
||||
}
|
||||
if dst == nil {
|
||||
return E.Errorf("deserialize: dst is %w\n%s", ErrNilValue, debug.Stack())
|
||||
}
|
||||
|
||||
dstV := reflect.ValueOf(dst)
|
||||
dstT := dstV.Type()
|
||||
|
||||
if src == nil {
|
||||
if dstV.CanSet() {
|
||||
dstV.Set(reflect.Zero(dstT))
|
||||
return nil
|
||||
}
|
||||
return E.Errorf("deserialize: src is %w and dst is not settable", ErrNilValue)
|
||||
}
|
||||
|
||||
if dstT.Implements(mapUnmarshalerType) {
|
||||
for dstV.IsNil() {
|
||||
switch dstT.Kind() {
|
||||
case reflect.Struct:
|
||||
dstV.Set(New(dstT))
|
||||
case reflect.Map:
|
||||
dstV.Set(reflect.MakeMap(dstT))
|
||||
case reflect.Slice:
|
||||
dstV.Set(reflect.MakeSlice(dstT, 0, 0))
|
||||
case reflect.Ptr:
|
||||
dstV.Set(reflect.New(dstT.Elem()))
|
||||
default:
|
||||
return E.Errorf("deserialize: %w for dst %s", ErrInvalidType, dstT.String())
|
||||
}
|
||||
dstV = dstV.Elem()
|
||||
}
|
||||
return dstV.Interface().(MapUnmarshaller).UnmarshalMap(src)
|
||||
}
|
||||
|
||||
for dstT.Kind() == reflect.Ptr {
|
||||
if dstV.IsNil() {
|
||||
if dstV.CanSet() {
|
||||
dstV.Set(New(dstT.Elem()))
|
||||
} else {
|
||||
return E.Errorf("deserialize: dst is %w\n%s", ErrNilValue, debug.Stack())
|
||||
return E.Errorf("deserialize: dst is %w and not settable", ErrNilValue)
|
||||
}
|
||||
}
|
||||
dstV = dstV.Elem()
|
||||
|
@ -113,9 +160,8 @@ func Deserialize(src SerializedObject, dst any) E.Error {
|
|||
|
||||
switch dstV.Kind() {
|
||||
case reflect.Struct:
|
||||
needValidate := false
|
||||
hasValidateTag := false
|
||||
mapping := make(map[string]reflect.Value)
|
||||
fieldName := make(map[string]string)
|
||||
fields, anonymous := extractFields(dstT)
|
||||
for _, anon := range anonymous {
|
||||
if field := dstV.FieldByName(anon.Name); field.Kind() == reflect.Ptr && field.IsNil() {
|
||||
|
@ -134,17 +180,15 @@ func Deserialize(src SerializedObject, dst any) E.Error {
|
|||
}
|
||||
key = strutils.ToLowerNoSnake(key)
|
||||
mapping[key] = dstV.FieldByName(field.Name)
|
||||
fieldName[field.Name] = key
|
||||
|
||||
if !needValidate {
|
||||
_, needValidate = field.Tag.Lookup("validate")
|
||||
if !hasValidateTag {
|
||||
_, hasValidateTag = field.Tag.Lookup("validate")
|
||||
}
|
||||
|
||||
aliases, ok := field.Tag.Lookup("aliases")
|
||||
if ok {
|
||||
for _, alias := range strutils.CommaSeperatedList(aliases) {
|
||||
mapping[alias] = dstV.FieldByName(field.Name)
|
||||
fieldName[field.Name] = alias
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -158,20 +202,10 @@ func Deserialize(src SerializedObject, dst any) E.Error {
|
|||
errs.Add(ErrUnknownField.Subject(k).Withf(strutils.DoYouMean(NearestField(k, mapping))))
|
||||
}
|
||||
}
|
||||
if needValidate {
|
||||
err := validate.Struct(dstV.Interface())
|
||||
var valErrs validator.ValidationErrors
|
||||
if errors.As(err, &valErrs) {
|
||||
for _, e := range valErrs {
|
||||
detail := e.ActualTag()
|
||||
if e.Param() != "" {
|
||||
detail += ":" + e.Param()
|
||||
}
|
||||
errs.Add(ErrValidationError.
|
||||
Subject(e.StructNamespace()).
|
||||
Withf("require %q", detail))
|
||||
}
|
||||
}
|
||||
if hasValidateTag {
|
||||
errs.Add(ValidateWithFieldTags(dstV.Interface()))
|
||||
} else if validator, ok := dstV.Addr().Interface().(CustomValidator); ok {
|
||||
errs.Add(validator.Validate())
|
||||
}
|
||||
return errs.Error()
|
||||
case reflect.Map:
|
||||
|
@ -188,6 +222,9 @@ func Deserialize(src SerializedObject, dst any) E.Error {
|
|||
errs.Add(err.Subject(k))
|
||||
}
|
||||
}
|
||||
if validator, ok := dstV.Addr().Interface().(CustomValidator); ok {
|
||||
errs.Add(validator.Validate())
|
||||
}
|
||||
return errs.Error()
|
||||
default:
|
||||
return ErrUnsupportedConversion.Subject("mapping to " + dstT.String())
|
||||
|
@ -421,14 +458,6 @@ func DeserializeYAMLMap[V any](data []byte) (_ functional.Map[string, V], err E.
|
|||
return functional.NewMapFrom(m2), nil
|
||||
}
|
||||
|
||||
func DeserializeJSON[T any](data []byte, target T) E.Error {
|
||||
m := make(map[string]any)
|
||||
if err := json.Unmarshal(data, &m); err != nil {
|
||||
return E.From(err)
|
||||
}
|
||||
return Deserialize(m, target)
|
||||
}
|
||||
|
||||
func loadSerialized[T any](path string, dst *T, deserialize func(data []byte, dst any) error) error {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
|
|
|
@ -35,6 +35,14 @@ func ExpectNoError(t *testing.T, err error) {
|
|||
}
|
||||
}
|
||||
|
||||
func ExpectHasError(t *testing.T, err error) {
|
||||
t.Helper()
|
||||
if errors.Is(err, nil) {
|
||||
t.Error("expected err not nil")
|
||||
t.FailNow()
|
||||
}
|
||||
}
|
||||
|
||||
func ExpectError(t *testing.T, expected error, err error) {
|
||||
t.Helper()
|
||||
if !errors.Is(err, expected) {
|
||||
|
|
|
@ -9,6 +9,10 @@ var validate = validator.New()
|
|||
|
||||
var ErrValidationError = E.New("validation error")
|
||||
|
||||
type CustomValidator interface {
|
||||
Validate() E.Error
|
||||
}
|
||||
|
||||
func Validator() *validator.Validate {
|
||||
return validate
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue