Cleaned up some validation code, stricter validation

This commit is contained in:
yusing 2025-01-26 14:43:48 +08:00
parent 254224c0e8
commit 1586610a44
23 changed files with 590 additions and 468 deletions

View file

@ -75,6 +75,38 @@ func GetFileContent(w http.ResponseWriter, r *http.Request) {
U.WriteBody(w, content) U.WriteBody(w, content)
} }
func validateFile(fileType FileType, content []byte) error {
switch fileType {
case FileTypeConfig:
return config.Validate(content)
case FileTypeMiddleware:
errs := E.NewBuilder("middleware errors")
middleware.BuildMiddlewaresFromYAML("", content, errs)
return errs.Error()
}
return provider.Validate(content)
}
func ValidateFile(w http.ResponseWriter, r *http.Request) {
fileType := FileType(r.PathValue("type"))
if !fileType.IsValid() {
U.RespondError(w, U.ErrInvalidKey("type"), http.StatusBadRequest)
return
}
content, err := io.ReadAll(r.Body)
if err != nil {
U.HandleErr(w, r, err)
return
}
r.Body.Close()
err = validateFile(fileType, content)
if err != nil {
U.RespondError(w, err, http.StatusBadRequest)
return
}
w.WriteHeader(http.StatusOK)
}
func SetFileContent(w http.ResponseWriter, r *http.Request) { func SetFileContent(w http.ResponseWriter, r *http.Request) {
fileType, filename, err := getArgs(r) fileType, filename, err := getArgs(r)
if err != nil { if err != nil {
@ -87,19 +119,7 @@ func SetFileContent(w http.ResponseWriter, r *http.Request) {
return return
} }
var valErr E.Error if valErr := validateFile(fileType, content); valErr != nil {
switch fileType {
case FileTypeConfig:
valErr = config.Validate(content)
case FileTypeMiddleware:
errs := E.NewBuilder("middleware errors")
middleware.BuildMiddlewaresFromYAML(filename, content, errs)
valErr = errs.Error()
default:
valErr = provider.Validate(content)
}
if valErr != nil {
U.RespondError(w, valErr, http.StatusBadRequest) U.RespondError(w, valErr, http.StatusBadRequest)
return return
} }

View file

@ -6,6 +6,7 @@ import (
"crypto/rand" "crypto/rand"
"crypto/x509" "crypto/x509"
"os" "os"
"regexp"
"github.com/go-acme/lego/v4/certcrypto" "github.com/go-acme/lego/v4/certcrypto"
"github.com/go-acme/lego/v4/lego" "github.com/go-acme/lego/v4/lego"
@ -13,63 +14,89 @@ import (
"github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/utils" "github.com/yusing/go-proxy/internal/utils"
"github.com/yusing/go-proxy/internal/utils/strutils" "github.com/yusing/go-proxy/internal/utils/strutils"
"github.com/yusing/go-proxy/internal/config/types"
) )
type Config types.AutoCertConfig type (
AutocertConfig struct {
Email string `json:"email,omitempty"`
Domains []string `json:"domains,omitempty"`
CertPath string `json:"cert_path,omitempty"`
KeyPath string `json:"key_path,omitempty"`
ACMEKeyPath string `json:"acme_key_path,omitempty"`
Provider string `json:"provider,omitempty"`
Options ProviderOpt `json:"options,omitempty"`
}
ProviderOpt map[string]any
)
var ( var (
ErrMissingDomain = E.New("missing field 'domains'") ErrMissingDomain = E.New("missing field 'domains'")
ErrMissingEmail = E.New("missing field 'email'") ErrMissingEmail = E.New("missing field 'email'")
ErrMissingProvider = E.New("missing field 'provider'") ErrMissingProvider = E.New("missing field 'provider'")
ErrInvalidDomain = E.New("invalid domain")
ErrUnknownProvider = E.New("unknown provider") ErrUnknownProvider = E.New("unknown provider")
) )
func NewConfig(cfg *types.AutoCertConfig) *Config { var domainOrWildcardRE = regexp.MustCompile(`^\*?([^.]+\.)+[^.]+$`)
// Validate implements the utils.CustomValidator interface.
func (cfg *AutocertConfig) Validate() E.Error {
if cfg == nil { if cfg == nil {
cfg = new(types.AutoCertConfig) return nil
} }
if cfg.Provider == "" {
cfg.Provider = ProviderLocal
return nil
}
b := E.NewBuilder("autocert errors")
if cfg.Provider != ProviderLocal {
if len(cfg.Domains) == 0 {
b.Add(ErrMissingDomain)
}
if cfg.Email == "" {
b.Add(ErrMissingEmail)
}
for i, d := range cfg.Domains {
if !domainOrWildcardRE.MatchString(d) {
b.Add(ErrInvalidDomain.Subjectf("domains[%d]", i))
}
}
// check if provider is implemented
providerConstructor, ok := providersGenMap[cfg.Provider]
if !ok {
b.Add(ErrUnknownProvider.
Subject(cfg.Provider).
Withf(strutils.DoYouMean(utils.NearestField(cfg.Provider, providersGenMap))))
} else {
_, err := providerConstructor(cfg.Options)
if err != nil {
b.Add(err)
}
}
}
return b.Error()
}
func (cfg *AutocertConfig) GetProvider() (*Provider, E.Error) {
if cfg == nil {
cfg = new(AutocertConfig)
}
if err := cfg.Validate(); err != nil {
return nil, err
}
if cfg.CertPath == "" { if cfg.CertPath == "" {
cfg.CertPath = CertFileDefault cfg.CertPath = CertFileDefault
} }
if cfg.KeyPath == "" { if cfg.KeyPath == "" {
cfg.KeyPath = KeyFileDefault cfg.KeyPath = KeyFileDefault
} }
if cfg.Provider == "" {
cfg.Provider = ProviderLocal
}
if cfg.ACMEKeyPath == "" { if cfg.ACMEKeyPath == "" {
cfg.ACMEKeyPath = ACMEKeyFileDefault cfg.ACMEKeyPath = ACMEKeyFileDefault
} }
return (*Config)(cfg)
}
func (cfg *Config) GetProvider() (*Provider, E.Error) {
b := E.NewBuilder("autocert errors")
if cfg.Provider != ProviderLocal {
if len(cfg.Domains) == 0 {
b.Add(ErrMissingDomain)
}
if cfg.Provider == "" {
b.Add(ErrMissingProvider)
}
if cfg.Email == "" {
b.Add(ErrMissingEmail)
}
// check if provider is implemented
_, ok := providersGenMap[cfg.Provider]
if !ok {
b.Add(ErrUnknownProvider.
Subject(cfg.Provider).
Withf(strutils.DoYouMean(utils.NearestField(cfg.Provider, providersGenMap))))
}
}
if b.HasError() {
return nil, b.Error()
}
var privKey *ecdsa.PrivateKey var privKey *ecdsa.PrivateKey
var err error var err error
@ -103,7 +130,7 @@ func (cfg *Config) GetProvider() (*Provider, E.Error) {
}, nil }, nil
} }
func (cfg *Config) loadACMEKey() (*ecdsa.PrivateKey, error) { func (cfg *AutocertConfig) loadACMEKey() (*ecdsa.PrivateKey, error) {
data, err := os.ReadFile(cfg.ACMEKeyPath) data, err := os.ReadFile(cfg.ACMEKeyPath)
if err != nil { if err != nil {
return nil, err return nil, err
@ -111,7 +138,7 @@ func (cfg *Config) loadACMEKey() (*ecdsa.PrivateKey, error) {
return x509.ParseECPrivateKey(data) return x509.ParseECPrivateKey(data)
} }
func (cfg *Config) saveACMEKey(key *ecdsa.PrivateKey) error { func (cfg *AutocertConfig) saveACMEKey(key *ecdsa.PrivateKey) error {
data, err := x509.MarshalECPrivateKey(key) data, err := x509.MarshalECPrivateKey(key)
if err != nil { if err != nil {
return err return err

View file

@ -14,7 +14,6 @@ import (
"github.com/go-acme/lego/v4/challenge" "github.com/go-acme/lego/v4/challenge"
"github.com/go-acme/lego/v4/lego" "github.com/go-acme/lego/v4/lego"
"github.com/go-acme/lego/v4/registration" "github.com/go-acme/lego/v4/registration"
"github.com/yusing/go-proxy/internal/config/types"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/task" "github.com/yusing/go-proxy/internal/task"
@ -24,7 +23,7 @@ import (
type ( type (
Provider struct { Provider struct {
cfg *Config cfg *AutocertConfig
user *User user *User
legoCfg *lego.Config legoCfg *lego.Config
client *lego.Client client *lego.Client
@ -33,7 +32,7 @@ type (
tlsCert *tls.Certificate tlsCert *tls.Certificate
certExpiries CertExpiries certExpiries CertExpiries
} }
ProviderGenerator func(types.AutocertProviderOpt) (challenge.Provider, E.Error) ProviderGenerator func(ProviderOpt) (challenge.Provider, E.Error)
CertExpiries map[string]time.Time CertExpiries map[string]time.Time
) )
@ -313,7 +312,7 @@ func providerGenerator[CT any, PT challenge.Provider](
defaultCfg func() *CT, defaultCfg func() *CT,
newProvider func(*CT) (PT, error), newProvider func(*CT) (PT, error),
) ProviderGenerator { ) ProviderGenerator {
return func(opt types.AutocertProviderOpt) (challenge.Provider, E.Error) { return func(opt ProviderOpt) (challenge.Provider, E.Error) {
cfg := defaultCfg() cfg := defaultCfg()
err := U.Deserialize(opt, cfg) err := U.Deserialize(opt, cfg)
if err != nil { if err != nil {

View file

@ -15,9 +15,11 @@ type User struct {
func (u *User) GetEmail() string { func (u *User) GetEmail() string {
return u.Email return u.Email
} }
func (u *User) GetRegistration() *registration.Resource { func (u *User) GetRegistration() *registration.Resource {
return u.Registration return u.Registration
} }
func (u *User) GetPrivateKey() crypto.PrivateKey { func (u *User) GetPrivateKey() crypto.PrivateKey {
return u.key return u.key
} }

View file

@ -234,7 +234,7 @@ func (cfg *Config) load() E.Error {
errs := E.NewBuilder(errMsg) errs := E.NewBuilder(errMsg)
errs.Add(cfg.entrypoint.SetMiddlewares(model.Entrypoint.Middlewares)) errs.Add(cfg.entrypoint.SetMiddlewares(model.Entrypoint.Middlewares))
errs.Add(cfg.entrypoint.SetAccessLogger(cfg.task, model.Entrypoint.AccessLog)) errs.Add(cfg.entrypoint.SetAccessLogger(cfg.task, model.Entrypoint.AccessLog))
errs.Add(cfg.initNotification(model.Providers.Notification)) cfg.initNotification(model.Providers.Notification)
errs.Add(cfg.initAutoCert(model.AutoCert)) errs.Add(cfg.initAutoCert(model.AutoCert))
errs.Add(cfg.loadRouteProviders(&model.Providers)) errs.Add(cfg.loadRouteProviders(&model.Providers))
@ -249,28 +249,22 @@ func (cfg *Config) load() E.Error {
return errs.Error() return errs.Error()
} }
func (cfg *Config) initNotification(notifCfg []types.NotificationConfig) (err E.Error) { func (cfg *Config) initNotification(notifCfg []notif.NotificationConfig) {
if len(notifCfg) == 0 { if len(notifCfg) == 0 {
return return
} }
dispatcher := notif.StartNotifDispatcher(cfg.task) dispatcher := notif.StartNotifDispatcher(cfg.task)
errs := E.NewBuilder("notification providers load errors") for _, notifier := range notifCfg {
for i, notifier := range notifCfg { dispatcher.RegisterProvider(&notifier)
_, err := dispatcher.RegisterProvider(notifier)
if err == nil {
continue
}
errs.Add(err.Subjectf("[%d]", i))
} }
return errs.Error()
} }
func (cfg *Config) initAutoCert(autocertCfg *types.AutoCertConfig) (err E.Error) { func (cfg *Config) initAutoCert(autocertCfg *autocert.AutocertConfig) (err E.Error) {
if cfg.autocertProvider != nil { if cfg.autocertProvider != nil {
return return
} }
cfg.autocertProvider, err = autocert.NewConfig(autocertCfg).GetProvider() cfg.autocertProvider, err = autocertCfg.GetProvider()
return return
} }

View file

@ -1,14 +0,0 @@
package types
type (
AutoCertConfig struct {
Email string `json:"email,omitempty" validate:"email"`
Domains []string `json:"domains,omitempty"`
CertPath string `json:"cert_path,omitempty" validate:"omitempty,filepath"`
KeyPath string `json:"key_path,omitempty" validate:"omitempty,filepath"`
ACMEKeyPath string `json:"acme_key_path,omitempty" validate:"omitempty,filepath"`
Provider string `json:"provider,omitempty"`
Options AutocertProviderOpt `json:"options,omitempty"`
}
AutocertProviderOpt map[string]any
)

View file

@ -2,8 +2,12 @@ package types
import ( import (
"context" "context"
"regexp"
"github.com/go-playground/validator/v10"
"github.com/yusing/go-proxy/internal/autocert"
"github.com/yusing/go-proxy/internal/net/http/accesslog" "github.com/yusing/go-proxy/internal/net/http/accesslog"
"github.com/yusing/go-proxy/internal/notif"
"github.com/yusing/go-proxy/internal/utils" "github.com/yusing/go-proxy/internal/utils"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
@ -11,23 +15,22 @@ import (
type ( type (
Config struct { Config struct {
AutoCert *AutoCertConfig `json:"autocert" validate:"omitempty"` AutoCert *autocert.AutocertConfig `json:"autocert"`
Entrypoint Entrypoint `json:"entrypoint"` Entrypoint Entrypoint `json:"entrypoint"`
Providers Providers `json:"providers"` Providers Providers `json:"providers"`
MatchDomains []string `json:"match_domains" validate:"dive,fqdn"` MatchDomains []string `json:"match_domains" validate:"domain_name"`
Homepage HomepageConfig `json:"homepage"` Homepage HomepageConfig `json:"homepage"`
TimeoutShutdown int `json:"timeout_shutdown" validate:"gte=0"` TimeoutShutdown int `json:"timeout_shutdown" validate:"gte=0"`
} }
Providers struct { Providers struct {
Files []string `json:"include" validate:"dive,filepath"` Files []string `json:"include" validate:"dive,filepath"`
Docker map[string]string `json:"docker" validate:"dive,unix_addr|url"` Docker map[string]string `json:"docker" validate:"dive,unix_addr|url"`
Notification []NotificationConfig `json:"notification"` Notification []notif.NotificationConfig `json:"notification"`
} }
Entrypoint struct { Entrypoint struct {
Middlewares []map[string]any `json:"middlewares"` Middlewares []map[string]any `json:"middlewares"`
AccessLog *accesslog.Config `json:"access_log" validate:"omitempty"` AccessLog *accesslog.Config `json:"access_log" validate:"omitempty"`
} }
NotificationConfig map[string]any
ConfigInstance interface { ConfigInstance interface {
Value() *Config Value() *Config
@ -52,6 +55,17 @@ func Validate(data []byte) E.Error {
return utils.DeserializeYAML(data, &model) return utils.DeserializeYAML(data, &model)
} }
var matchDomainsRegex = regexp.MustCompile(`^[^\.]?([\w\d\-_]\.?)+[^\.]?$`)
func init() { func init() {
utils.RegisterDefaultValueFactory(DefaultConfig) utils.RegisterDefaultValueFactory(DefaultConfig)
utils.MustRegisterValidation("domain_name", func(fl validator.FieldLevel) bool {
domains := fl.Field().Interface().([]string)
for _, domain := range domains {
if !matchDomainsRegex.MatchString(domain) {
return false
}
}
return true
})
} }

View file

@ -19,6 +19,13 @@ func Errorf(format string, args ...any) Error {
return &baseError{fmt.Errorf(format, args...)} return &baseError{fmt.Errorf(format, args...)}
} }
func Wrap(err error, message ...string) Error {
if len(message) == 0 || message[0] == "" {
return From(err)
}
return Errorf("%w: %s", err, message[0])
}
func From(err error) Error { func From(err error) Error {
if err == nil { if err == nil {
return nil return nil

View file

@ -1,11 +1,13 @@
package homepage package homepage
import "github.com/yusing/go-proxy/internal/utils" import (
"github.com/yusing/go-proxy/internal/utils"
)
type ( type (
//nolint:recvcheck //nolint:recvcheck
Config map[string]Category Categories map[string]Category
Category []*Item Category []*Item
ItemConfig struct { ItemConfig struct {
Show bool `json:"show"` Show bool `json:"show"`
@ -48,6 +50,10 @@ func NewItem(alias string) *Item {
} }
} }
func NewHomePageConfig() Categories {
return Categories(make(map[string]Category))
}
func (item *Item) IsEmpty() bool { func (item *Item) IsEmpty() bool {
return item == nil || item.IsUnset || item.ItemConfig == nil return item == nil || item.IsUnset || item.ItemConfig == nil
} }
@ -56,15 +62,11 @@ func (item *Item) GetOverride() *Item {
return overrideConfigInstance.GetOverride(item) return overrideConfigInstance.GetOverride(item)
} }
func NewHomePageConfig() Config { func (c *Categories) Clear() {
return Config(make(map[string]Category)) *c = make(Categories)
} }
func (c *Config) Clear() { func (c Categories) Add(item *Item) {
*c = make(Config)
}
func (c Config) Add(item *Item) {
if c[item.Category] == nil { if c[item.Category] == nil {
c[item.Category] = make(Category, 0) c[item.Category] = make(Category, 0)
} }

View file

@ -53,10 +53,6 @@ func GetOverrideConfig() *OverrideConfig {
return overrideConfigInstance return overrideConfigInstance
} }
func (c *OverrideConfig) UnmarshalJSON(data []byte) error {
return utils.DeserializeJSON(data, c)
}
func (c *OverrideConfig) OverrideItem(alias string, override *ItemConfig) { func (c *OverrideConfig) OverrideItem(alias string, override *ItemConfig) {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()

47
internal/notif/base.go Normal file
View file

@ -0,0 +1,47 @@
package notif
import (
"net/url"
"strings"
E "github.com/yusing/go-proxy/internal/error"
)
type ProviderBase struct {
Name string `json:"name" validate:"required"`
URL string `json:"url" validate:"url"`
Token string `json:"token"`
}
var (
ErrMissingToken = E.New("token is required")
ErrURLMissingScheme = E.New("url missing scheme, expect 'http://' or 'https://'")
)
// Validate implements the utils.CustomValidator interface.
func (base *ProviderBase) Validate() E.Error {
if base.Token == "" {
return ErrMissingToken
}
if !strings.HasPrefix(base.URL, "http://") && !strings.HasPrefix(base.URL, "https://") {
return ErrURLMissingScheme
}
u, err := url.Parse(base.URL)
if err != nil {
return E.Wrap(err)
}
base.URL = u.String()
return nil
}
func (base *ProviderBase) GetName() string {
return base.Name
}
func (base *ProviderBase) GetURL() string {
return base.URL
}
func (base *ProviderBase) GetToken() string {
return base.Token
}

54
internal/notif/config.go Normal file
View file

@ -0,0 +1,54 @@
package notif
import (
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/utils"
)
type NotificationConfig struct {
ProviderName string `json:"provider"`
Provider Provider `json:"-"`
}
var (
ErrMissingNotifProvider = E.New("missing notification provider")
ErrInvalidNotifProviderType = E.New("invalid notification provider type")
ErrUnknownNotifProvider = E.New("unknown notification provider")
)
// UnmarshalMap implements MapUnmarshaler.
func (cfg *NotificationConfig) UnmarshalMap(m map[string]any) (err E.Error) {
// extract provider name
providerName := m["provider"]
switch providerName := providerName.(type) {
case string:
cfg.ProviderName = providerName
default:
return ErrInvalidNotifProviderType
}
delete(m, "provider")
if cfg.ProviderName == "" {
return ErrMissingNotifProvider
}
// validate provider name and initialize provider
switch cfg.ProviderName {
case ProviderWebhook:
cfg.Provider = &Webhook{}
case ProviderGotify:
cfg.Provider = &GotifyClient{}
default:
return ErrUnknownNotifProvider.
Subject(cfg.ProviderName).
Withf("expect %s or %s", ProviderWebhook, ProviderGotify)
}
// unmarshal provider config
if err := utils.Deserialize(m, cfg.Provider); err != nil {
return err
}
// validate provider
return cfg.Provider.Validate()
}

View file

@ -0,0 +1,163 @@
package notif
import (
"net/http"
"testing"
"github.com/yusing/go-proxy/internal/utils"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestNotificationConfig(t *testing.T) {
tests := []struct {
name string
cfg map[string]any
expected Provider
wantErr bool
}{
{
name: "valid_webhook",
cfg: map[string]any{
"name": "test",
"provider": "webhook",
"template": "discord",
"url": "https://example.com",
},
expected: &Webhook{
ProviderBase: ProviderBase{
Name: "test",
URL: "https://example.com",
},
Template: "discord",
Method: http.MethodPost,
MIMEType: "application/json",
ColorMode: "dec",
Payload: discordPayload,
},
wantErr: false,
},
{
name: "valid_gotify",
cfg: map[string]any{
"name": "test",
"provider": "gotify",
"url": "https://example.com",
"token": "token",
},
expected: &GotifyClient{
ProviderBase: ProviderBase{
Name: "test",
URL: "https://example.com",
Token: "token",
},
},
wantErr: false,
},
{
name: "invalid_provider",
cfg: map[string]any{
"name": "test",
"provider": "invalid",
"url": "https://example.com",
},
wantErr: true,
},
{
name: "missing_url",
cfg: map[string]any{
"name": "test",
"provider": "webhook",
},
wantErr: true,
},
{
name: "missing_provider",
cfg: map[string]any{
"name": "test",
"provider": "webhook",
},
wantErr: true,
},
{
name: "gotify_missing_token",
cfg: map[string]any{
"name": "test",
"provider": "gotify",
"url": "https://example.com",
},
wantErr: true,
},
{
name: "webhook_missing_payload",
cfg: map[string]any{
"name": "test",
"provider": "webhook",
"url": "https://example.com",
},
wantErr: true,
},
{
name: "webhook_missing_url",
cfg: map[string]any{
"name": "test",
"provider": "webhook",
},
wantErr: true,
},
{
name: "webhook_invalid_template",
cfg: map[string]any{
"name": "test",
"provider": "webhook",
"url": "https://example.com",
"template": "invalid",
},
wantErr: true,
},
{
name: "webhook_invalid_json_payload",
cfg: map[string]any{
"name": "test",
"provider": "webhook",
"url": "https://example.com",
"mime_type": "application/json",
"payload": "invalid",
},
wantErr: true,
},
{
name: "webhook_empty_text_payload",
cfg: map[string]any{
"name": "test",
"provider": "webhook",
"url": "https://example.com",
"mime_type": "text/plain",
},
wantErr: true,
},
{
name: "webhook_invalid_method",
cfg: map[string]any{
"name": "test",
"provider": "webhook",
"url": "https://example.com",
"method": "invalid",
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var cfg NotificationConfig
provider := tt.cfg["provider"]
err := utils.Deserialize(tt.cfg, &cfg)
if tt.wantErr {
ExpectHasError(t, err)
} else {
ExpectNoError(t, err)
ExpectEqual(t, provider.(string), cfg.ProviderName)
ExpectDeepEqual(t, cfg.Provider, tt.expected)
}
})
}
}

View file

@ -2,13 +2,10 @@ package notif
import ( import (
"github.com/rs/zerolog" "github.com/rs/zerolog"
"github.com/yusing/go-proxy/internal/config/types"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/task" "github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/internal/utils"
F "github.com/yusing/go-proxy/internal/utils/functional" F "github.com/yusing/go-proxy/internal/utils/functional"
"github.com/yusing/go-proxy/internal/utils/strutils"
) )
type ( type (
@ -27,12 +24,6 @@ type (
var dispatcher *Dispatcher var dispatcher *Dispatcher
var (
ErrMissingNotifProvider = E.New("missing notification provider")
ErrInvalidNotifProviderType = E.New("invalid notification provider type")
ErrUnknownNotifProvider = E.New("unknown notification provider")
)
const dispatchErr = "notification dispatch error" const dispatchErr = "notification dispatch error"
func StartNotifDispatcher(parent task.Parent) *Dispatcher { func StartNotifDispatcher(parent task.Parent) *Dispatcher {
@ -57,29 +48,8 @@ func Notify(msg *LogMessage) {
} }
} }
func (disp *Dispatcher) RegisterProvider(cfg types.NotificationConfig) (Provider, E.Error) { func (disp *Dispatcher) RegisterProvider(cfg *NotificationConfig) {
providerName, ok := cfg["provider"] disp.providers.Add(cfg.Provider)
if !ok {
return nil, ErrMissingNotifProvider
}
switch providerName := providerName.(type) {
case string:
delete(cfg, "provider")
createFunc, ok := Providers[providerName]
if !ok {
return nil, ErrUnknownNotifProvider.
Subject(providerName).
Withf(strutils.DoYouMean(utils.NearestField(providerName, Providers)))
}
provider, err := createFunc(cfg)
if err == nil {
disp.providers.Add(provider)
}
return provider, err
default:
return nil, ErrInvalidNotifProviderType.Subjectf("%T", providerName)
}
} }
func (disp *Dispatcher) start() { func (disp *Dispatcher) start() {
@ -110,7 +80,7 @@ func (disp *Dispatcher) dispatch(msg *LogMessage) {
errs := E.NewBuilder(dispatchErr) errs := E.NewBuilder(dispatchErr)
disp.providers.RangeAllParallel(func(p Provider) { disp.providers.RangeAllParallel(func(p Provider) {
if err := notifyProvider(task.Context(), p, msg); err != nil { if err := notifyProvider(task.Context(), p, msg); err != nil {
errs.Add(E.PrependSubject(p.Name(), err)) errs.Add(E.PrependSubject(p.GetName(), err))
} }
}) })
if errs.HasError() { if errs.HasError() {

View file

@ -13,37 +13,24 @@ import (
type ( type (
GotifyClient struct { GotifyClient struct {
N string `json:"name" validate:"required"` ProviderBase
U string `json:"url" validate:"url"`
Tok string `json:"token" validate:"required"`
} }
GotifyMessage model.MessageExternal GotifyMessage model.MessageExternal
) )
const gotifyMsgEndpoint = "/message" const gotifyMsgEndpoint = "/message"
// Name implements Provider. func (client *GotifyClient) GetURL() string {
func (client *GotifyClient) Name() string { return client.URL + gotifyMsgEndpoint
return client.N
} }
// Method implements Provider. // GetMethod implements Provider.
func (client *GotifyClient) Method() string { func (client *GotifyClient) GetMethod() string {
return http.MethodPost return http.MethodPost
} }
// URL implements Provider. // GetMIMEType implements Provider.
func (client *GotifyClient) URL() string { func (client *GotifyClient) GetMIMEType() string {
return client.U + gotifyMsgEndpoint
}
// Token implements Provider.
func (client *GotifyClient) Token() string {
return client.Tok
}
// MIMEType implements Provider.
func (client *GotifyClient) MIMEType() string {
return "application/json" return "application/json"
} }

View file

@ -1,52 +0,0 @@
package notif
import (
"testing"
"github.com/yusing/go-proxy/internal/utils"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestGotifyValidation(t *testing.T) {
t.Parallel()
newGotify := Providers[ProviderGotify]
t.Run("valid", func(t *testing.T) {
t.Parallel()
_, err := newGotify(map[string]any{
"name": "test",
"url": "https://example.com",
"token": "token",
})
ExpectNoError(t, err)
})
t.Run("missing url", func(t *testing.T) {
t.Parallel()
_, err := newGotify(map[string]any{
"name": "test",
"token": "token",
})
ExpectError(t, utils.ErrValidationError, err)
})
t.Run("missing token", func(t *testing.T) {
t.Parallel()
_, err := newGotify(map[string]any{
"name": "test",
"url": "https://example.com",
})
ExpectError(t, utils.ErrValidationError, err)
})
t.Run("invalid url", func(t *testing.T) {
t.Parallel()
_, err := newGotify(map[string]any{
"name": "test",
"url": "example.com",
"token": "token",
})
ExpectError(t, utils.ErrValidationError, err)
})
}

View file

@ -2,22 +2,24 @@ package notif
import ( import (
"context" "context"
"fmt"
"io" "io"
"net/http" "net/http"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
gphttp "github.com/yusing/go-proxy/internal/net/http" gphttp "github.com/yusing/go-proxy/internal/net/http"
U "github.com/yusing/go-proxy/internal/utils" "github.com/yusing/go-proxy/internal/utils"
) )
type ( type (
Provider interface { Provider interface {
Name() string utils.CustomValidator
URL() string
Method() string GetName() string
Token() string GetURL() string
MIMEType() string GetToken() string
GetMethod() string
GetMIMEType() string
MakeBody(logMsg *LogMessage) (io.Reader, error) MakeBody(logMsg *LogMessage) (io.Reader, error)
makeRespError(resp *http.Response) error makeRespError(resp *http.Response) error
@ -31,47 +33,29 @@ const (
ProviderWebhook = "webhook" ProviderWebhook = "webhook"
) )
var Providers = map[string]ProviderCreateFunc{
ProviderGotify: newNotifProvider[*GotifyClient],
ProviderWebhook: newNotifProvider[*Webhook],
}
func newNotifProvider[T Provider](cfg map[string]any) (Provider, E.Error) {
var client T
err := U.Deserialize(cfg, &client)
if err != nil {
return nil, err.Subject(client.Name())
}
return client, nil
}
func formatError(p Provider, err error) error {
return fmt.Errorf("%s error: %w", p.Name(), err)
}
func notifyProvider(ctx context.Context, provider Provider, msg *LogMessage) error { func notifyProvider(ctx context.Context, provider Provider, msg *LogMessage) error {
body, err := provider.MakeBody(msg) body, err := provider.MakeBody(msg)
if err != nil { if err != nil {
return formatError(provider, err) return E.PrependSubject(provider.GetName(), err)
} }
req, err := http.NewRequestWithContext( req, err := http.NewRequestWithContext(
ctx, ctx,
http.MethodPost, http.MethodPost,
provider.URL(), provider.GetURL(),
body, body,
) )
if err != nil { if err != nil {
return formatError(provider, err) return E.PrependSubject(provider.GetName(), err)
} }
req.Header.Set("Content-Type", provider.MIMEType()) req.Header.Set("Content-Type", provider.GetMIMEType())
if provider.Token() != "" { if provider.GetToken() != "" {
req.Header.Set("Authorization", "Bearer "+provider.Token()) req.Header.Set("Authorization", "Bearer "+provider.GetToken())
} }
resp, err := http.DefaultClient.Do(req) resp, err := http.DefaultClient.Do(req)
if err != nil { if err != nil {
return formatError(provider, err) return E.PrependSubject(provider.GetName(), err)
} }
defer resp.Body.Close() defer resp.Body.Close()

View file

@ -8,19 +8,16 @@ import (
"net/http" "net/http"
"strings" "strings"
"github.com/go-playground/validator/v10" E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/utils"
) )
type Webhook struct { type Webhook struct {
N string `json:"name" validate:"required"` ProviderBase
U string `json:"url" validate:"url"` Template string `json:"template"`
Template string `json:"template" validate:"omitempty,oneof=discord"` Payload string `json:"payload"`
Payload string `json:"payload" validate:"jsonIfTemplateNotUsed"` Method string `json:"method"`
Tok string `json:"token"` MIMEType string `json:"mime_type"`
Meth string `json:"method" validate:"oneof=GET POST PUT"` ColorMode string `json:"color_mode"`
MIMETyp string `json:"mime_type"`
ColorM string `json:"color_mode" validate:"oneof=hex dec"`
} }
//go:embed templates/discord.json //go:embed templates/discord.json
@ -30,60 +27,65 @@ var webhookTemplates = map[string]string{
"discord": discordPayload, "discord": discordPayload,
} }
func DefaultValue() *Webhook { func (webhook *Webhook) Validate() E.Error {
return &Webhook{ if err := webhook.ProviderBase.Validate(); err != nil && !err.Is(ErrMissingToken) {
Meth: "POST", return err
ColorM: "hex",
MIMETyp: "application/json",
} }
}
func jsonIfTemplateNotUsed(fl validator.FieldLevel) bool { switch webhook.MIMEType {
template := fl.Parent().FieldByName("Template").String() case "":
if template != "" { webhook.MIMEType = "application/json"
return true case "application/json", "application/x-www-form-urlencoded", "text/plain":
}
payload := fl.Field().String()
return json.Valid([]byte(payload))
}
func init() {
utils.RegisterDefaultValueFactory(DefaultValue)
utils.MustRegisterValidation("jsonIfTemplateNotUsed", jsonIfTemplateNotUsed)
}
// Name implements Provider.
func (webhook *Webhook) Name() string {
return webhook.N
}
// Method implements Provider.
func (webhook *Webhook) Method() string {
return webhook.Meth
}
// URL implements Provider.
func (webhook *Webhook) URL() string {
return webhook.U
}
// Token implements Provider.
func (webhook *Webhook) Token() string {
return webhook.Tok
}
// MIMEType implements Provider.
func (webhook *Webhook) MIMEType() string {
return webhook.MIMETyp
}
func (webhook *Webhook) ColorMode() string {
switch webhook.Template {
case "discord":
return "dec"
default: default:
return webhook.ColorM return E.New("invalid mime_type, expect empty, 'application/json', 'application/x-www-form-urlencoded' or 'text/plain'")
} }
switch webhook.Template {
case "":
if webhook.MIMEType == "application/json" && !json.Valid([]byte(webhook.Payload)) {
return E.New("invalid payload, expect valid JSON")
}
if webhook.Payload == "" {
return E.New("invalid payload, expect non-empty")
}
case "discord":
webhook.ColorMode = "dec"
webhook.Method = http.MethodPost
webhook.MIMEType = "application/json"
if webhook.Payload == "" {
webhook.Payload = discordPayload
}
default:
return E.New("invalid template, expect empty or 'discord'")
}
switch webhook.Method {
case "":
webhook.Method = http.MethodPost
case http.MethodGet, http.MethodPost, http.MethodPut:
default:
return E.New("invalid method, expect empty, 'GET', 'POST' or 'PUT'")
}
switch webhook.ColorMode {
case "":
webhook.ColorMode = "hex"
case "hex", "dec":
default:
return E.New("invalid color_mode, expect empty, 'hex' or 'dec'")
}
return nil
}
// GetMethod implements Provider.
func (webhook *Webhook) GetMethod() string {
return webhook.Method
}
// GetMIMEType implements Provider.
func (webhook *Webhook) GetMIMEType() string {
return webhook.MIMEType
} }
// makeRespError implements Provider. // makeRespError implements Provider.
@ -108,7 +110,7 @@ func (webhook *Webhook) MakeBody(logMsg *LogMessage) (io.Reader, error) {
return nil, err return nil, err
} }
var color string var color string
if webhook.ColorMode() == "hex" { if webhook.ColorMode == "hex" {
color = logMsg.Color.HexString() color = logMsg.Color.HexString()
} else { } else {
color = logMsg.Color.DecString() color = logMsg.Color.DecString()

View file

@ -1,121 +0,0 @@
package notif
import (
"encoding/json"
"testing"
"github.com/yusing/go-proxy/internal/utils"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestWebhookValidation(t *testing.T) {
t.Parallel()
newWebhook := Providers[ProviderWebhook]
t.Run("valid", func(t *testing.T) {
t.Parallel()
_, err := newWebhook(map[string]any{
"name": "test",
"url": "https://example.com",
"payload": "{}",
})
ExpectNoError(t, err)
})
t.Run("valid template", func(t *testing.T) {
t.Parallel()
_, err := newWebhook(map[string]any{
"name": "test",
"url": "https://example.com",
"template": "discord",
})
ExpectNoError(t, err)
})
t.Run("missing url", func(t *testing.T) {
t.Parallel()
_, err := newWebhook(map[string]any{
"name": "test",
"payload": "{}",
})
ExpectError(t, utils.ErrValidationError, err)
})
t.Run("missing payload", func(t *testing.T) {
t.Parallel()
_, err := newWebhook(map[string]any{
"name": "test",
"url": "https://example.com",
})
ExpectError(t, utils.ErrValidationError, err)
})
t.Run("invalid url", func(t *testing.T) {
t.Parallel()
_, err := newWebhook(map[string]any{
"name": "test",
"url": "example.com",
"payload": "{}",
})
ExpectError(t, utils.ErrValidationError, err)
})
t.Run("invalid payload", func(t *testing.T) {
t.Parallel()
_, err := newWebhook(map[string]any{
"name": "test",
"url": "https://example.com",
"payload": "abcd",
})
ExpectError(t, utils.ErrValidationError, err)
})
t.Run("invalid method", func(t *testing.T) {
t.Parallel()
_, err := newWebhook(map[string]any{
"name": "test",
"url": "https://example.com",
"payload": "{}",
"method": "abcd",
})
ExpectError(t, utils.ErrValidationError, err)
})
t.Run("invalid template", func(t *testing.T) {
t.Parallel()
_, err := newWebhook(map[string]any{
"name": "test",
"url": "https://example.com",
"template": "abcd",
})
ExpectError(t, utils.ErrValidationError, err)
})
}
func TestWebhookBody(t *testing.T) {
t.Parallel()
var webhook Webhook
webhook.Payload = discordPayload
bodyReader, err := webhook.MakeBody(&LogMessage{
Title: "abc",
Extras: map[string]any{
"foo": "bar",
},
})
ExpectNoError(t, err)
var body struct {
Embeds []struct {
Title string `json:"title"`
Fields []struct {
Name string `json:"name"`
Value string `json:"value"`
} `json:"fields"`
} `json:"embeds"`
}
err = json.NewDecoder(bodyReader).Decode(&body)
ExpectNoError(t, err)
ExpectEqual(t, body.Embeds[0].Title, "abc")
fields := body.Embeds[0].Fields
ExpectEqual(t, fields[0].Name, "foo")
ExpectEqual(t, fields[0].Value, "bar")
}

View file

@ -57,7 +57,7 @@ func HomepageCategories() []string {
return categories return categories
} }
func HomepageConfig(useDefaultCategories bool, categoryFilter, providerFilter string) homepage.Config { func HomepageConfig(useDefaultCategories bool, categoryFilter, providerFilter string) homepage.Categories {
hpCfg := homepage.NewHomePageConfig() hpCfg := homepage.NewHomePageConfig()
routes.GetHTTPRoutes().RangeAll(func(alias string, r route.HTTPRoute) { routes.GetHTTPRoutes().RangeAll(func(alias string, r route.HTTPRoute) {

View file

@ -1,13 +1,10 @@
package utils package utils
// FIXME: some times [%d] is not in correct order
import ( import (
"encoding/json" "encoding/json"
"errors" "errors"
"os" "os"
"reflect" "reflect"
"runtime/debug"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@ -21,6 +18,10 @@ import (
type SerializedObject = map[string]any type SerializedObject = map[string]any
type MapUnmarshaller interface {
UnmarshalMap(m map[string]any) E.Error
}
var ( var (
ErrInvalidType = E.New("invalid type") ErrInvalidType = E.New("invalid type")
ErrNilValue = E.New("nil") ErrNilValue = E.New("nil")
@ -29,6 +30,8 @@ var (
ErrUnknownField = E.New("unknown field") ErrUnknownField = E.New("unknown field")
) )
var mapUnmarshalerType = reflect.TypeFor[MapUnmarshaller]()
var defaultValues = functional.NewMapOf[reflect.Type, func() any]() var defaultValues = functional.NewMapOf[reflect.Type, func() any]()
func RegisterDefaultValueFactory[T any](factory func() *T) { func RegisterDefaultValueFactory[T any](factory func() *T) {
@ -56,8 +59,9 @@ func extractFields(t reflect.Type) (all, anonymous []reflect.StructField) {
if t.Kind() != reflect.Struct { if t.Kind() != reflect.Struct {
return nil, nil return nil, nil
} }
var fields []reflect.StructField n := t.NumField()
for i := range t.NumField() { fields := make([]reflect.StructField, 0, n)
for i := range n {
field := t.Field(i) field := t.Field(i)
if !field.IsExported() { if !field.IsExported() {
continue continue
@ -74,31 +78,74 @@ func extractFields(t reflect.Type) (all, anonymous []reflect.StructField) {
return fields, anonymous return fields, anonymous
} }
func ValidateWithFieldTags(s any) E.Error {
errs := E.NewBuilder("validate error")
err := validate.Struct(s)
var valErrs validator.ValidationErrors
if errors.As(err, &valErrs) {
for _, e := range valErrs {
detail := e.ActualTag()
if e.Param() != "" {
detail += ":" + e.Param()
}
errs.Add(ErrValidationError.
Subject(e.Namespace()).
Withf("require %q", detail))
}
}
return errs.Error()
}
// Deserialize takes a SerializedObject and a target value, and assigns the values in the SerializedObject to the target value. // Deserialize takes a SerializedObject and a target value, and assigns the values in the SerializedObject to the target value.
// Deserialize ignores case differences between the field names in the SerializedObject and the target. // Deserialize ignores case differences between the field names in the SerializedObject and the target.
// //
// The target value must be a struct or a map[string]any. // The target value must be a struct or a map[string]any.
// If the target value is a struct, the SerializedObject will be deserialized into the struct fields and validate if needed. // If the target value is a struct , and implements the MapUnmarshaller interface,
// If the target value is a map[string]any, the SerializedObject will be deserialized into the map. // the UnmarshalMap method will be called.
//
// If the target value is a struct, but does not implements the MapUnmarshaller interface,
// the SerializedObject will be deserialized into the struct fields and validate if needed.
//
// If the target value is a map[string]any the SerializedObject will be deserialized into the map.
// //
// The function returns an error if the target value is not a struct or a map[string]any, or if there is an error during deserialization. // The function returns an error if the target value is not a struct or a map[string]any, or if there is an error during deserialization.
func Deserialize(src SerializedObject, dst any) E.Error { func Deserialize(src SerializedObject, dst any) E.Error {
if src == nil {
return E.Errorf("deserialize: src is %w", ErrNilValue)
}
if dst == nil {
return E.Errorf("deserialize: dst is %w\n%s", ErrNilValue, debug.Stack())
}
dstV := reflect.ValueOf(dst) dstV := reflect.ValueOf(dst)
dstT := dstV.Type() dstT := dstV.Type()
if src == nil {
if dstV.CanSet() {
dstV.Set(reflect.Zero(dstT))
return nil
}
return E.Errorf("deserialize: src is %w and dst is not settable", ErrNilValue)
}
if dstT.Implements(mapUnmarshalerType) {
for dstV.IsNil() {
switch dstT.Kind() {
case reflect.Struct:
dstV.Set(New(dstT))
case reflect.Map:
dstV.Set(reflect.MakeMap(dstT))
case reflect.Slice:
dstV.Set(reflect.MakeSlice(dstT, 0, 0))
case reflect.Ptr:
dstV.Set(reflect.New(dstT.Elem()))
default:
return E.Errorf("deserialize: %w for dst %s", ErrInvalidType, dstT.String())
}
dstV = dstV.Elem()
}
return dstV.Interface().(MapUnmarshaller).UnmarshalMap(src)
}
for dstT.Kind() == reflect.Ptr { for dstT.Kind() == reflect.Ptr {
if dstV.IsNil() { if dstV.IsNil() {
if dstV.CanSet() { if dstV.CanSet() {
dstV.Set(New(dstT.Elem())) dstV.Set(New(dstT.Elem()))
} else { } else {
return E.Errorf("deserialize: dst is %w\n%s", ErrNilValue, debug.Stack()) return E.Errorf("deserialize: dst is %w and not settable", ErrNilValue)
} }
} }
dstV = dstV.Elem() dstV = dstV.Elem()
@ -113,9 +160,8 @@ func Deserialize(src SerializedObject, dst any) E.Error {
switch dstV.Kind() { switch dstV.Kind() {
case reflect.Struct: case reflect.Struct:
needValidate := false hasValidateTag := false
mapping := make(map[string]reflect.Value) mapping := make(map[string]reflect.Value)
fieldName := make(map[string]string)
fields, anonymous := extractFields(dstT) fields, anonymous := extractFields(dstT)
for _, anon := range anonymous { for _, anon := range anonymous {
if field := dstV.FieldByName(anon.Name); field.Kind() == reflect.Ptr && field.IsNil() { if field := dstV.FieldByName(anon.Name); field.Kind() == reflect.Ptr && field.IsNil() {
@ -134,17 +180,15 @@ func Deserialize(src SerializedObject, dst any) E.Error {
} }
key = strutils.ToLowerNoSnake(key) key = strutils.ToLowerNoSnake(key)
mapping[key] = dstV.FieldByName(field.Name) mapping[key] = dstV.FieldByName(field.Name)
fieldName[field.Name] = key
if !needValidate { if !hasValidateTag {
_, needValidate = field.Tag.Lookup("validate") _, hasValidateTag = field.Tag.Lookup("validate")
} }
aliases, ok := field.Tag.Lookup("aliases") aliases, ok := field.Tag.Lookup("aliases")
if ok { if ok {
for _, alias := range strutils.CommaSeperatedList(aliases) { for _, alias := range strutils.CommaSeperatedList(aliases) {
mapping[alias] = dstV.FieldByName(field.Name) mapping[alias] = dstV.FieldByName(field.Name)
fieldName[field.Name] = alias
} }
} }
} }
@ -158,20 +202,10 @@ func Deserialize(src SerializedObject, dst any) E.Error {
errs.Add(ErrUnknownField.Subject(k).Withf(strutils.DoYouMean(NearestField(k, mapping)))) errs.Add(ErrUnknownField.Subject(k).Withf(strutils.DoYouMean(NearestField(k, mapping))))
} }
} }
if needValidate { if hasValidateTag {
err := validate.Struct(dstV.Interface()) errs.Add(ValidateWithFieldTags(dstV.Interface()))
var valErrs validator.ValidationErrors } else if validator, ok := dstV.Addr().Interface().(CustomValidator); ok {
if errors.As(err, &valErrs) { errs.Add(validator.Validate())
for _, e := range valErrs {
detail := e.ActualTag()
if e.Param() != "" {
detail += ":" + e.Param()
}
errs.Add(ErrValidationError.
Subject(e.StructNamespace()).
Withf("require %q", detail))
}
}
} }
return errs.Error() return errs.Error()
case reflect.Map: case reflect.Map:
@ -188,6 +222,9 @@ func Deserialize(src SerializedObject, dst any) E.Error {
errs.Add(err.Subject(k)) errs.Add(err.Subject(k))
} }
} }
if validator, ok := dstV.Addr().Interface().(CustomValidator); ok {
errs.Add(validator.Validate())
}
return errs.Error() return errs.Error()
default: default:
return ErrUnsupportedConversion.Subject("mapping to " + dstT.String()) return ErrUnsupportedConversion.Subject("mapping to " + dstT.String())
@ -421,14 +458,6 @@ func DeserializeYAMLMap[V any](data []byte) (_ functional.Map[string, V], err E.
return functional.NewMapFrom(m2), nil return functional.NewMapFrom(m2), nil
} }
func DeserializeJSON[T any](data []byte, target T) E.Error {
m := make(map[string]any)
if err := json.Unmarshal(data, &m); err != nil {
return E.From(err)
}
return Deserialize(m, target)
}
func loadSerialized[T any](path string, dst *T, deserialize func(data []byte, dst any) error) error { func loadSerialized[T any](path string, dst *T, deserialize func(data []byte, dst any) error) error {
data, err := os.ReadFile(path) data, err := os.ReadFile(path)
if err != nil { if err != nil {

View file

@ -35,6 +35,14 @@ func ExpectNoError(t *testing.T, err error) {
} }
} }
func ExpectHasError(t *testing.T, err error) {
t.Helper()
if errors.Is(err, nil) {
t.Error("expected err not nil")
t.FailNow()
}
}
func ExpectError(t *testing.T, expected error, err error) { func ExpectError(t *testing.T, expected error, err error) {
t.Helper() t.Helper()
if !errors.Is(err, expected) { if !errors.Is(err, expected) {

View file

@ -9,6 +9,10 @@ var validate = validator.New()
var ErrValidationError = E.New("validation error") var ErrValidationError = E.New("validation error")
type CustomValidator interface {
Validate() E.Error
}
func Validator() *validator.Validate { func Validator() *validator.Validate {
return validate return validate
} }