mirror of
https://github.com/yusing/godoxy.git
synced 2025-07-06 06:24:03 +02:00
Cleaned up some validation code, stricter validation
This commit is contained in:
parent
254224c0e8
commit
1586610a44
23 changed files with 590 additions and 468 deletions
|
@ -75,6 +75,38 @@ func GetFileContent(w http.ResponseWriter, r *http.Request) {
|
||||||
U.WriteBody(w, content)
|
U.WriteBody(w, content)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func validateFile(fileType FileType, content []byte) error {
|
||||||
|
switch fileType {
|
||||||
|
case FileTypeConfig:
|
||||||
|
return config.Validate(content)
|
||||||
|
case FileTypeMiddleware:
|
||||||
|
errs := E.NewBuilder("middleware errors")
|
||||||
|
middleware.BuildMiddlewaresFromYAML("", content, errs)
|
||||||
|
return errs.Error()
|
||||||
|
}
|
||||||
|
return provider.Validate(content)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ValidateFile(w http.ResponseWriter, r *http.Request) {
|
||||||
|
fileType := FileType(r.PathValue("type"))
|
||||||
|
if !fileType.IsValid() {
|
||||||
|
U.RespondError(w, U.ErrInvalidKey("type"), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
content, err := io.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
U.HandleErr(w, r, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
r.Body.Close()
|
||||||
|
err = validateFile(fileType, content)
|
||||||
|
if err != nil {
|
||||||
|
U.RespondError(w, err, http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
func SetFileContent(w http.ResponseWriter, r *http.Request) {
|
func SetFileContent(w http.ResponseWriter, r *http.Request) {
|
||||||
fileType, filename, err := getArgs(r)
|
fileType, filename, err := getArgs(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -87,19 +119,7 @@ func SetFileContent(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var valErr E.Error
|
if valErr := validateFile(fileType, content); valErr != nil {
|
||||||
switch fileType {
|
|
||||||
case FileTypeConfig:
|
|
||||||
valErr = config.Validate(content)
|
|
||||||
case FileTypeMiddleware:
|
|
||||||
errs := E.NewBuilder("middleware errors")
|
|
||||||
middleware.BuildMiddlewaresFromYAML(filename, content, errs)
|
|
||||||
valErr = errs.Error()
|
|
||||||
default:
|
|
||||||
valErr = provider.Validate(content)
|
|
||||||
}
|
|
||||||
|
|
||||||
if valErr != nil {
|
|
||||||
U.RespondError(w, valErr, http.StatusBadRequest)
|
U.RespondError(w, valErr, http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,6 +6,7 @@ import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"os"
|
"os"
|
||||||
|
"regexp"
|
||||||
|
|
||||||
"github.com/go-acme/lego/v4/certcrypto"
|
"github.com/go-acme/lego/v4/certcrypto"
|
||||||
"github.com/go-acme/lego/v4/lego"
|
"github.com/go-acme/lego/v4/lego"
|
||||||
|
@ -13,63 +14,89 @@ import (
|
||||||
"github.com/yusing/go-proxy/internal/logging"
|
"github.com/yusing/go-proxy/internal/logging"
|
||||||
"github.com/yusing/go-proxy/internal/utils"
|
"github.com/yusing/go-proxy/internal/utils"
|
||||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||||
|
|
||||||
"github.com/yusing/go-proxy/internal/config/types"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Config types.AutoCertConfig
|
type (
|
||||||
|
AutocertConfig struct {
|
||||||
|
Email string `json:"email,omitempty"`
|
||||||
|
Domains []string `json:"domains,omitempty"`
|
||||||
|
CertPath string `json:"cert_path,omitempty"`
|
||||||
|
KeyPath string `json:"key_path,omitempty"`
|
||||||
|
ACMEKeyPath string `json:"acme_key_path,omitempty"`
|
||||||
|
Provider string `json:"provider,omitempty"`
|
||||||
|
Options ProviderOpt `json:"options,omitempty"`
|
||||||
|
}
|
||||||
|
ProviderOpt map[string]any
|
||||||
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrMissingDomain = E.New("missing field 'domains'")
|
ErrMissingDomain = E.New("missing field 'domains'")
|
||||||
ErrMissingEmail = E.New("missing field 'email'")
|
ErrMissingEmail = E.New("missing field 'email'")
|
||||||
ErrMissingProvider = E.New("missing field 'provider'")
|
ErrMissingProvider = E.New("missing field 'provider'")
|
||||||
|
ErrInvalidDomain = E.New("invalid domain")
|
||||||
ErrUnknownProvider = E.New("unknown provider")
|
ErrUnknownProvider = E.New("unknown provider")
|
||||||
)
|
)
|
||||||
|
|
||||||
func NewConfig(cfg *types.AutoCertConfig) *Config {
|
var domainOrWildcardRE = regexp.MustCompile(`^\*?([^.]+\.)+[^.]+$`)
|
||||||
|
|
||||||
|
// Validate implements the utils.CustomValidator interface.
|
||||||
|
func (cfg *AutocertConfig) Validate() E.Error {
|
||||||
if cfg == nil {
|
if cfg == nil {
|
||||||
cfg = new(types.AutoCertConfig)
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if cfg.Provider == "" {
|
||||||
|
cfg.Provider = ProviderLocal
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
b := E.NewBuilder("autocert errors")
|
||||||
|
if cfg.Provider != ProviderLocal {
|
||||||
|
if len(cfg.Domains) == 0 {
|
||||||
|
b.Add(ErrMissingDomain)
|
||||||
|
}
|
||||||
|
if cfg.Email == "" {
|
||||||
|
b.Add(ErrMissingEmail)
|
||||||
|
}
|
||||||
|
for i, d := range cfg.Domains {
|
||||||
|
if !domainOrWildcardRE.MatchString(d) {
|
||||||
|
b.Add(ErrInvalidDomain.Subjectf("domains[%d]", i))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// check if provider is implemented
|
||||||
|
providerConstructor, ok := providersGenMap[cfg.Provider]
|
||||||
|
if !ok {
|
||||||
|
b.Add(ErrUnknownProvider.
|
||||||
|
Subject(cfg.Provider).
|
||||||
|
Withf(strutils.DoYouMean(utils.NearestField(cfg.Provider, providersGenMap))))
|
||||||
|
} else {
|
||||||
|
_, err := providerConstructor(cfg.Options)
|
||||||
|
if err != nil {
|
||||||
|
b.Add(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return b.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cfg *AutocertConfig) GetProvider() (*Provider, E.Error) {
|
||||||
|
if cfg == nil {
|
||||||
|
cfg = new(AutocertConfig)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := cfg.Validate(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
if cfg.CertPath == "" {
|
if cfg.CertPath == "" {
|
||||||
cfg.CertPath = CertFileDefault
|
cfg.CertPath = CertFileDefault
|
||||||
}
|
}
|
||||||
if cfg.KeyPath == "" {
|
if cfg.KeyPath == "" {
|
||||||
cfg.KeyPath = KeyFileDefault
|
cfg.KeyPath = KeyFileDefault
|
||||||
}
|
}
|
||||||
if cfg.Provider == "" {
|
|
||||||
cfg.Provider = ProviderLocal
|
|
||||||
}
|
|
||||||
if cfg.ACMEKeyPath == "" {
|
if cfg.ACMEKeyPath == "" {
|
||||||
cfg.ACMEKeyPath = ACMEKeyFileDefault
|
cfg.ACMEKeyPath = ACMEKeyFileDefault
|
||||||
}
|
}
|
||||||
return (*Config)(cfg)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cfg *Config) GetProvider() (*Provider, E.Error) {
|
|
||||||
b := E.NewBuilder("autocert errors")
|
|
||||||
|
|
||||||
if cfg.Provider != ProviderLocal {
|
|
||||||
if len(cfg.Domains) == 0 {
|
|
||||||
b.Add(ErrMissingDomain)
|
|
||||||
}
|
|
||||||
if cfg.Provider == "" {
|
|
||||||
b.Add(ErrMissingProvider)
|
|
||||||
}
|
|
||||||
if cfg.Email == "" {
|
|
||||||
b.Add(ErrMissingEmail)
|
|
||||||
}
|
|
||||||
// check if provider is implemented
|
|
||||||
_, ok := providersGenMap[cfg.Provider]
|
|
||||||
if !ok {
|
|
||||||
b.Add(ErrUnknownProvider.
|
|
||||||
Subject(cfg.Provider).
|
|
||||||
Withf(strutils.DoYouMean(utils.NearestField(cfg.Provider, providersGenMap))))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if b.HasError() {
|
|
||||||
return nil, b.Error()
|
|
||||||
}
|
|
||||||
|
|
||||||
var privKey *ecdsa.PrivateKey
|
var privKey *ecdsa.PrivateKey
|
||||||
var err error
|
var err error
|
||||||
|
@ -103,7 +130,7 @@ func (cfg *Config) GetProvider() (*Provider, E.Error) {
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cfg *Config) loadACMEKey() (*ecdsa.PrivateKey, error) {
|
func (cfg *AutocertConfig) loadACMEKey() (*ecdsa.PrivateKey, error) {
|
||||||
data, err := os.ReadFile(cfg.ACMEKeyPath)
|
data, err := os.ReadFile(cfg.ACMEKeyPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -111,7 +138,7 @@ func (cfg *Config) loadACMEKey() (*ecdsa.PrivateKey, error) {
|
||||||
return x509.ParseECPrivateKey(data)
|
return x509.ParseECPrivateKey(data)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cfg *Config) saveACMEKey(key *ecdsa.PrivateKey) error {
|
func (cfg *AutocertConfig) saveACMEKey(key *ecdsa.PrivateKey) error {
|
||||||
data, err := x509.MarshalECPrivateKey(key)
|
data, err := x509.MarshalECPrivateKey(key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
|
@ -14,7 +14,6 @@ import (
|
||||||
"github.com/go-acme/lego/v4/challenge"
|
"github.com/go-acme/lego/v4/challenge"
|
||||||
"github.com/go-acme/lego/v4/lego"
|
"github.com/go-acme/lego/v4/lego"
|
||||||
"github.com/go-acme/lego/v4/registration"
|
"github.com/go-acme/lego/v4/registration"
|
||||||
"github.com/yusing/go-proxy/internal/config/types"
|
|
||||||
E "github.com/yusing/go-proxy/internal/error"
|
E "github.com/yusing/go-proxy/internal/error"
|
||||||
"github.com/yusing/go-proxy/internal/logging"
|
"github.com/yusing/go-proxy/internal/logging"
|
||||||
"github.com/yusing/go-proxy/internal/task"
|
"github.com/yusing/go-proxy/internal/task"
|
||||||
|
@ -24,7 +23,7 @@ import (
|
||||||
|
|
||||||
type (
|
type (
|
||||||
Provider struct {
|
Provider struct {
|
||||||
cfg *Config
|
cfg *AutocertConfig
|
||||||
user *User
|
user *User
|
||||||
legoCfg *lego.Config
|
legoCfg *lego.Config
|
||||||
client *lego.Client
|
client *lego.Client
|
||||||
|
@ -33,7 +32,7 @@ type (
|
||||||
tlsCert *tls.Certificate
|
tlsCert *tls.Certificate
|
||||||
certExpiries CertExpiries
|
certExpiries CertExpiries
|
||||||
}
|
}
|
||||||
ProviderGenerator func(types.AutocertProviderOpt) (challenge.Provider, E.Error)
|
ProviderGenerator func(ProviderOpt) (challenge.Provider, E.Error)
|
||||||
|
|
||||||
CertExpiries map[string]time.Time
|
CertExpiries map[string]time.Time
|
||||||
)
|
)
|
||||||
|
@ -313,7 +312,7 @@ func providerGenerator[CT any, PT challenge.Provider](
|
||||||
defaultCfg func() *CT,
|
defaultCfg func() *CT,
|
||||||
newProvider func(*CT) (PT, error),
|
newProvider func(*CT) (PT, error),
|
||||||
) ProviderGenerator {
|
) ProviderGenerator {
|
||||||
return func(opt types.AutocertProviderOpt) (challenge.Provider, E.Error) {
|
return func(opt ProviderOpt) (challenge.Provider, E.Error) {
|
||||||
cfg := defaultCfg()
|
cfg := defaultCfg()
|
||||||
err := U.Deserialize(opt, cfg)
|
err := U.Deserialize(opt, cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -15,9 +15,11 @@ type User struct {
|
||||||
func (u *User) GetEmail() string {
|
func (u *User) GetEmail() string {
|
||||||
return u.Email
|
return u.Email
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *User) GetRegistration() *registration.Resource {
|
func (u *User) GetRegistration() *registration.Resource {
|
||||||
return u.Registration
|
return u.Registration
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *User) GetPrivateKey() crypto.PrivateKey {
|
func (u *User) GetPrivateKey() crypto.PrivateKey {
|
||||||
return u.key
|
return u.key
|
||||||
}
|
}
|
||||||
|
|
|
@ -234,7 +234,7 @@ func (cfg *Config) load() E.Error {
|
||||||
errs := E.NewBuilder(errMsg)
|
errs := E.NewBuilder(errMsg)
|
||||||
errs.Add(cfg.entrypoint.SetMiddlewares(model.Entrypoint.Middlewares))
|
errs.Add(cfg.entrypoint.SetMiddlewares(model.Entrypoint.Middlewares))
|
||||||
errs.Add(cfg.entrypoint.SetAccessLogger(cfg.task, model.Entrypoint.AccessLog))
|
errs.Add(cfg.entrypoint.SetAccessLogger(cfg.task, model.Entrypoint.AccessLog))
|
||||||
errs.Add(cfg.initNotification(model.Providers.Notification))
|
cfg.initNotification(model.Providers.Notification)
|
||||||
errs.Add(cfg.initAutoCert(model.AutoCert))
|
errs.Add(cfg.initAutoCert(model.AutoCert))
|
||||||
errs.Add(cfg.loadRouteProviders(&model.Providers))
|
errs.Add(cfg.loadRouteProviders(&model.Providers))
|
||||||
|
|
||||||
|
@ -249,28 +249,22 @@ func (cfg *Config) load() E.Error {
|
||||||
return errs.Error()
|
return errs.Error()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cfg *Config) initNotification(notifCfg []types.NotificationConfig) (err E.Error) {
|
func (cfg *Config) initNotification(notifCfg []notif.NotificationConfig) {
|
||||||
if len(notifCfg) == 0 {
|
if len(notifCfg) == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
dispatcher := notif.StartNotifDispatcher(cfg.task)
|
dispatcher := notif.StartNotifDispatcher(cfg.task)
|
||||||
errs := E.NewBuilder("notification providers load errors")
|
for _, notifier := range notifCfg {
|
||||||
for i, notifier := range notifCfg {
|
dispatcher.RegisterProvider(¬ifier)
|
||||||
_, err := dispatcher.RegisterProvider(notifier)
|
|
||||||
if err == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
errs.Add(err.Subjectf("[%d]", i))
|
|
||||||
}
|
}
|
||||||
return errs.Error()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cfg *Config) initAutoCert(autocertCfg *types.AutoCertConfig) (err E.Error) {
|
func (cfg *Config) initAutoCert(autocertCfg *autocert.AutocertConfig) (err E.Error) {
|
||||||
if cfg.autocertProvider != nil {
|
if cfg.autocertProvider != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
cfg.autocertProvider, err = autocert.NewConfig(autocertCfg).GetProvider()
|
cfg.autocertProvider, err = autocertCfg.GetProvider()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,14 +0,0 @@
|
||||||
package types
|
|
||||||
|
|
||||||
type (
|
|
||||||
AutoCertConfig struct {
|
|
||||||
Email string `json:"email,omitempty" validate:"email"`
|
|
||||||
Domains []string `json:"domains,omitempty"`
|
|
||||||
CertPath string `json:"cert_path,omitempty" validate:"omitempty,filepath"`
|
|
||||||
KeyPath string `json:"key_path,omitempty" validate:"omitempty,filepath"`
|
|
||||||
ACMEKeyPath string `json:"acme_key_path,omitempty" validate:"omitempty,filepath"`
|
|
||||||
Provider string `json:"provider,omitempty"`
|
|
||||||
Options AutocertProviderOpt `json:"options,omitempty"`
|
|
||||||
}
|
|
||||||
AutocertProviderOpt map[string]any
|
|
||||||
)
|
|
|
@ -2,8 +2,12 @@ package types
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"regexp"
|
||||||
|
|
||||||
|
"github.com/go-playground/validator/v10"
|
||||||
|
"github.com/yusing/go-proxy/internal/autocert"
|
||||||
"github.com/yusing/go-proxy/internal/net/http/accesslog"
|
"github.com/yusing/go-proxy/internal/net/http/accesslog"
|
||||||
|
"github.com/yusing/go-proxy/internal/notif"
|
||||||
"github.com/yusing/go-proxy/internal/utils"
|
"github.com/yusing/go-proxy/internal/utils"
|
||||||
|
|
||||||
E "github.com/yusing/go-proxy/internal/error"
|
E "github.com/yusing/go-proxy/internal/error"
|
||||||
|
@ -11,23 +15,22 @@ import (
|
||||||
|
|
||||||
type (
|
type (
|
||||||
Config struct {
|
Config struct {
|
||||||
AutoCert *AutoCertConfig `json:"autocert" validate:"omitempty"`
|
AutoCert *autocert.AutocertConfig `json:"autocert"`
|
||||||
Entrypoint Entrypoint `json:"entrypoint"`
|
Entrypoint Entrypoint `json:"entrypoint"`
|
||||||
Providers Providers `json:"providers"`
|
Providers Providers `json:"providers"`
|
||||||
MatchDomains []string `json:"match_domains" validate:"dive,fqdn"`
|
MatchDomains []string `json:"match_domains" validate:"domain_name"`
|
||||||
Homepage HomepageConfig `json:"homepage"`
|
Homepage HomepageConfig `json:"homepage"`
|
||||||
TimeoutShutdown int `json:"timeout_shutdown" validate:"gte=0"`
|
TimeoutShutdown int `json:"timeout_shutdown" validate:"gte=0"`
|
||||||
}
|
}
|
||||||
Providers struct {
|
Providers struct {
|
||||||
Files []string `json:"include" validate:"dive,filepath"`
|
Files []string `json:"include" validate:"dive,filepath"`
|
||||||
Docker map[string]string `json:"docker" validate:"dive,unix_addr|url"`
|
Docker map[string]string `json:"docker" validate:"dive,unix_addr|url"`
|
||||||
Notification []NotificationConfig `json:"notification"`
|
Notification []notif.NotificationConfig `json:"notification"`
|
||||||
}
|
}
|
||||||
Entrypoint struct {
|
Entrypoint struct {
|
||||||
Middlewares []map[string]any `json:"middlewares"`
|
Middlewares []map[string]any `json:"middlewares"`
|
||||||
AccessLog *accesslog.Config `json:"access_log" validate:"omitempty"`
|
AccessLog *accesslog.Config `json:"access_log" validate:"omitempty"`
|
||||||
}
|
}
|
||||||
NotificationConfig map[string]any
|
|
||||||
|
|
||||||
ConfigInstance interface {
|
ConfigInstance interface {
|
||||||
Value() *Config
|
Value() *Config
|
||||||
|
@ -52,6 +55,17 @@ func Validate(data []byte) E.Error {
|
||||||
return utils.DeserializeYAML(data, &model)
|
return utils.DeserializeYAML(data, &model)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var matchDomainsRegex = regexp.MustCompile(`^[^\.]?([\w\d\-_]\.?)+[^\.]?$`)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
utils.RegisterDefaultValueFactory(DefaultConfig)
|
utils.RegisterDefaultValueFactory(DefaultConfig)
|
||||||
|
utils.MustRegisterValidation("domain_name", func(fl validator.FieldLevel) bool {
|
||||||
|
domains := fl.Field().Interface().([]string)
|
||||||
|
for _, domain := range domains {
|
||||||
|
if !matchDomainsRegex.MatchString(domain) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,6 +19,13 @@ func Errorf(format string, args ...any) Error {
|
||||||
return &baseError{fmt.Errorf(format, args...)}
|
return &baseError{fmt.Errorf(format, args...)}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Wrap(err error, message ...string) Error {
|
||||||
|
if len(message) == 0 || message[0] == "" {
|
||||||
|
return From(err)
|
||||||
|
}
|
||||||
|
return Errorf("%w: %s", err, message[0])
|
||||||
|
}
|
||||||
|
|
||||||
func From(err error) Error {
|
func From(err error) Error {
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -1,11 +1,13 @@
|
||||||
package homepage
|
package homepage
|
||||||
|
|
||||||
import "github.com/yusing/go-proxy/internal/utils"
|
import (
|
||||||
|
"github.com/yusing/go-proxy/internal/utils"
|
||||||
|
)
|
||||||
|
|
||||||
type (
|
type (
|
||||||
//nolint:recvcheck
|
//nolint:recvcheck
|
||||||
Config map[string]Category
|
Categories map[string]Category
|
||||||
Category []*Item
|
Category []*Item
|
||||||
|
|
||||||
ItemConfig struct {
|
ItemConfig struct {
|
||||||
Show bool `json:"show"`
|
Show bool `json:"show"`
|
||||||
|
@ -48,6 +50,10 @@ func NewItem(alias string) *Item {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func NewHomePageConfig() Categories {
|
||||||
|
return Categories(make(map[string]Category))
|
||||||
|
}
|
||||||
|
|
||||||
func (item *Item) IsEmpty() bool {
|
func (item *Item) IsEmpty() bool {
|
||||||
return item == nil || item.IsUnset || item.ItemConfig == nil
|
return item == nil || item.IsUnset || item.ItemConfig == nil
|
||||||
}
|
}
|
||||||
|
@ -56,15 +62,11 @@ func (item *Item) GetOverride() *Item {
|
||||||
return overrideConfigInstance.GetOverride(item)
|
return overrideConfigInstance.GetOverride(item)
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewHomePageConfig() Config {
|
func (c *Categories) Clear() {
|
||||||
return Config(make(map[string]Category))
|
*c = make(Categories)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Config) Clear() {
|
func (c Categories) Add(item *Item) {
|
||||||
*c = make(Config)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c Config) Add(item *Item) {
|
|
||||||
if c[item.Category] == nil {
|
if c[item.Category] == nil {
|
||||||
c[item.Category] = make(Category, 0)
|
c[item.Category] = make(Category, 0)
|
||||||
}
|
}
|
||||||
|
|
|
@ -53,10 +53,6 @@ func GetOverrideConfig() *OverrideConfig {
|
||||||
return overrideConfigInstance
|
return overrideConfigInstance
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *OverrideConfig) UnmarshalJSON(data []byte) error {
|
|
||||||
return utils.DeserializeJSON(data, c)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *OverrideConfig) OverrideItem(alias string, override *ItemConfig) {
|
func (c *OverrideConfig) OverrideItem(alias string, override *ItemConfig) {
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
defer c.mu.Unlock()
|
defer c.mu.Unlock()
|
||||||
|
|
47
internal/notif/base.go
Normal file
47
internal/notif/base.go
Normal file
|
@ -0,0 +1,47 @@
|
||||||
|
package notif
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
E "github.com/yusing/go-proxy/internal/error"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ProviderBase struct {
|
||||||
|
Name string `json:"name" validate:"required"`
|
||||||
|
URL string `json:"url" validate:"url"`
|
||||||
|
Token string `json:"token"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrMissingToken = E.New("token is required")
|
||||||
|
ErrURLMissingScheme = E.New("url missing scheme, expect 'http://' or 'https://'")
|
||||||
|
)
|
||||||
|
|
||||||
|
// Validate implements the utils.CustomValidator interface.
|
||||||
|
func (base *ProviderBase) Validate() E.Error {
|
||||||
|
if base.Token == "" {
|
||||||
|
return ErrMissingToken
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(base.URL, "http://") && !strings.HasPrefix(base.URL, "https://") {
|
||||||
|
return ErrURLMissingScheme
|
||||||
|
}
|
||||||
|
u, err := url.Parse(base.URL)
|
||||||
|
if err != nil {
|
||||||
|
return E.Wrap(err)
|
||||||
|
}
|
||||||
|
base.URL = u.String()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (base *ProviderBase) GetName() string {
|
||||||
|
return base.Name
|
||||||
|
}
|
||||||
|
|
||||||
|
func (base *ProviderBase) GetURL() string {
|
||||||
|
return base.URL
|
||||||
|
}
|
||||||
|
|
||||||
|
func (base *ProviderBase) GetToken() string {
|
||||||
|
return base.Token
|
||||||
|
}
|
54
internal/notif/config.go
Normal file
54
internal/notif/config.go
Normal file
|
@ -0,0 +1,54 @@
|
||||||
|
package notif
|
||||||
|
|
||||||
|
import (
|
||||||
|
E "github.com/yusing/go-proxy/internal/error"
|
||||||
|
"github.com/yusing/go-proxy/internal/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
type NotificationConfig struct {
|
||||||
|
ProviderName string `json:"provider"`
|
||||||
|
Provider Provider `json:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrMissingNotifProvider = E.New("missing notification provider")
|
||||||
|
ErrInvalidNotifProviderType = E.New("invalid notification provider type")
|
||||||
|
ErrUnknownNotifProvider = E.New("unknown notification provider")
|
||||||
|
)
|
||||||
|
|
||||||
|
// UnmarshalMap implements MapUnmarshaler.
|
||||||
|
func (cfg *NotificationConfig) UnmarshalMap(m map[string]any) (err E.Error) {
|
||||||
|
// extract provider name
|
||||||
|
providerName := m["provider"]
|
||||||
|
switch providerName := providerName.(type) {
|
||||||
|
case string:
|
||||||
|
cfg.ProviderName = providerName
|
||||||
|
default:
|
||||||
|
return ErrInvalidNotifProviderType
|
||||||
|
}
|
||||||
|
delete(m, "provider")
|
||||||
|
|
||||||
|
if cfg.ProviderName == "" {
|
||||||
|
return ErrMissingNotifProvider
|
||||||
|
}
|
||||||
|
|
||||||
|
// validate provider name and initialize provider
|
||||||
|
switch cfg.ProviderName {
|
||||||
|
case ProviderWebhook:
|
||||||
|
cfg.Provider = &Webhook{}
|
||||||
|
case ProviderGotify:
|
||||||
|
cfg.Provider = &GotifyClient{}
|
||||||
|
default:
|
||||||
|
return ErrUnknownNotifProvider.
|
||||||
|
Subject(cfg.ProviderName).
|
||||||
|
Withf("expect %s or %s", ProviderWebhook, ProviderGotify)
|
||||||
|
}
|
||||||
|
|
||||||
|
// unmarshal provider config
|
||||||
|
if err := utils.Deserialize(m, cfg.Provider); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// validate provider
|
||||||
|
return cfg.Provider.Validate()
|
||||||
|
}
|
163
internal/notif/config_test.go
Normal file
163
internal/notif/config_test.go
Normal file
|
@ -0,0 +1,163 @@
|
||||||
|
package notif
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/yusing/go-proxy/internal/utils"
|
||||||
|
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNotificationConfig(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
cfg map[string]any
|
||||||
|
expected Provider
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid_webhook",
|
||||||
|
cfg: map[string]any{
|
||||||
|
"name": "test",
|
||||||
|
"provider": "webhook",
|
||||||
|
"template": "discord",
|
||||||
|
"url": "https://example.com",
|
||||||
|
},
|
||||||
|
expected: &Webhook{
|
||||||
|
ProviderBase: ProviderBase{
|
||||||
|
Name: "test",
|
||||||
|
URL: "https://example.com",
|
||||||
|
},
|
||||||
|
Template: "discord",
|
||||||
|
Method: http.MethodPost,
|
||||||
|
MIMEType: "application/json",
|
||||||
|
ColorMode: "dec",
|
||||||
|
Payload: discordPayload,
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid_gotify",
|
||||||
|
cfg: map[string]any{
|
||||||
|
"name": "test",
|
||||||
|
"provider": "gotify",
|
||||||
|
"url": "https://example.com",
|
||||||
|
"token": "token",
|
||||||
|
},
|
||||||
|
expected: &GotifyClient{
|
||||||
|
ProviderBase: ProviderBase{
|
||||||
|
Name: "test",
|
||||||
|
URL: "https://example.com",
|
||||||
|
Token: "token",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid_provider",
|
||||||
|
cfg: map[string]any{
|
||||||
|
"name": "test",
|
||||||
|
"provider": "invalid",
|
||||||
|
"url": "https://example.com",
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing_url",
|
||||||
|
cfg: map[string]any{
|
||||||
|
"name": "test",
|
||||||
|
"provider": "webhook",
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing_provider",
|
||||||
|
cfg: map[string]any{
|
||||||
|
"name": "test",
|
||||||
|
"provider": "webhook",
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "gotify_missing_token",
|
||||||
|
cfg: map[string]any{
|
||||||
|
"name": "test",
|
||||||
|
"provider": "gotify",
|
||||||
|
"url": "https://example.com",
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "webhook_missing_payload",
|
||||||
|
cfg: map[string]any{
|
||||||
|
"name": "test",
|
||||||
|
"provider": "webhook",
|
||||||
|
"url": "https://example.com",
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "webhook_missing_url",
|
||||||
|
cfg: map[string]any{
|
||||||
|
"name": "test",
|
||||||
|
"provider": "webhook",
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "webhook_invalid_template",
|
||||||
|
cfg: map[string]any{
|
||||||
|
"name": "test",
|
||||||
|
"provider": "webhook",
|
||||||
|
"url": "https://example.com",
|
||||||
|
"template": "invalid",
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "webhook_invalid_json_payload",
|
||||||
|
cfg: map[string]any{
|
||||||
|
"name": "test",
|
||||||
|
"provider": "webhook",
|
||||||
|
"url": "https://example.com",
|
||||||
|
"mime_type": "application/json",
|
||||||
|
"payload": "invalid",
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "webhook_empty_text_payload",
|
||||||
|
cfg: map[string]any{
|
||||||
|
"name": "test",
|
||||||
|
"provider": "webhook",
|
||||||
|
"url": "https://example.com",
|
||||||
|
"mime_type": "text/plain",
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "webhook_invalid_method",
|
||||||
|
cfg: map[string]any{
|
||||||
|
"name": "test",
|
||||||
|
"provider": "webhook",
|
||||||
|
"url": "https://example.com",
|
||||||
|
"method": "invalid",
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
var cfg NotificationConfig
|
||||||
|
provider := tt.cfg["provider"]
|
||||||
|
err := utils.Deserialize(tt.cfg, &cfg)
|
||||||
|
if tt.wantErr {
|
||||||
|
ExpectHasError(t, err)
|
||||||
|
} else {
|
||||||
|
ExpectNoError(t, err)
|
||||||
|
ExpectEqual(t, provider.(string), cfg.ProviderName)
|
||||||
|
ExpectDeepEqual(t, cfg.Provider, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -2,13 +2,10 @@ package notif
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
"github.com/yusing/go-proxy/internal/config/types"
|
|
||||||
E "github.com/yusing/go-proxy/internal/error"
|
E "github.com/yusing/go-proxy/internal/error"
|
||||||
"github.com/yusing/go-proxy/internal/logging"
|
"github.com/yusing/go-proxy/internal/logging"
|
||||||
"github.com/yusing/go-proxy/internal/task"
|
"github.com/yusing/go-proxy/internal/task"
|
||||||
"github.com/yusing/go-proxy/internal/utils"
|
|
||||||
F "github.com/yusing/go-proxy/internal/utils/functional"
|
F "github.com/yusing/go-proxy/internal/utils/functional"
|
||||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type (
|
type (
|
||||||
|
@ -27,12 +24,6 @@ type (
|
||||||
|
|
||||||
var dispatcher *Dispatcher
|
var dispatcher *Dispatcher
|
||||||
|
|
||||||
var (
|
|
||||||
ErrMissingNotifProvider = E.New("missing notification provider")
|
|
||||||
ErrInvalidNotifProviderType = E.New("invalid notification provider type")
|
|
||||||
ErrUnknownNotifProvider = E.New("unknown notification provider")
|
|
||||||
)
|
|
||||||
|
|
||||||
const dispatchErr = "notification dispatch error"
|
const dispatchErr = "notification dispatch error"
|
||||||
|
|
||||||
func StartNotifDispatcher(parent task.Parent) *Dispatcher {
|
func StartNotifDispatcher(parent task.Parent) *Dispatcher {
|
||||||
|
@ -57,29 +48,8 @@ func Notify(msg *LogMessage) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (disp *Dispatcher) RegisterProvider(cfg types.NotificationConfig) (Provider, E.Error) {
|
func (disp *Dispatcher) RegisterProvider(cfg *NotificationConfig) {
|
||||||
providerName, ok := cfg["provider"]
|
disp.providers.Add(cfg.Provider)
|
||||||
if !ok {
|
|
||||||
return nil, ErrMissingNotifProvider
|
|
||||||
}
|
|
||||||
switch providerName := providerName.(type) {
|
|
||||||
case string:
|
|
||||||
delete(cfg, "provider")
|
|
||||||
createFunc, ok := Providers[providerName]
|
|
||||||
if !ok {
|
|
||||||
return nil, ErrUnknownNotifProvider.
|
|
||||||
Subject(providerName).
|
|
||||||
Withf(strutils.DoYouMean(utils.NearestField(providerName, Providers)))
|
|
||||||
}
|
|
||||||
|
|
||||||
provider, err := createFunc(cfg)
|
|
||||||
if err == nil {
|
|
||||||
disp.providers.Add(provider)
|
|
||||||
}
|
|
||||||
return provider, err
|
|
||||||
default:
|
|
||||||
return nil, ErrInvalidNotifProviderType.Subjectf("%T", providerName)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (disp *Dispatcher) start() {
|
func (disp *Dispatcher) start() {
|
||||||
|
@ -110,7 +80,7 @@ func (disp *Dispatcher) dispatch(msg *LogMessage) {
|
||||||
errs := E.NewBuilder(dispatchErr)
|
errs := E.NewBuilder(dispatchErr)
|
||||||
disp.providers.RangeAllParallel(func(p Provider) {
|
disp.providers.RangeAllParallel(func(p Provider) {
|
||||||
if err := notifyProvider(task.Context(), p, msg); err != nil {
|
if err := notifyProvider(task.Context(), p, msg); err != nil {
|
||||||
errs.Add(E.PrependSubject(p.Name(), err))
|
errs.Add(E.PrependSubject(p.GetName(), err))
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
if errs.HasError() {
|
if errs.HasError() {
|
||||||
|
|
|
@ -13,37 +13,24 @@ import (
|
||||||
|
|
||||||
type (
|
type (
|
||||||
GotifyClient struct {
|
GotifyClient struct {
|
||||||
N string `json:"name" validate:"required"`
|
ProviderBase
|
||||||
U string `json:"url" validate:"url"`
|
|
||||||
Tok string `json:"token" validate:"required"`
|
|
||||||
}
|
}
|
||||||
GotifyMessage model.MessageExternal
|
GotifyMessage model.MessageExternal
|
||||||
)
|
)
|
||||||
|
|
||||||
const gotifyMsgEndpoint = "/message"
|
const gotifyMsgEndpoint = "/message"
|
||||||
|
|
||||||
// Name implements Provider.
|
func (client *GotifyClient) GetURL() string {
|
||||||
func (client *GotifyClient) Name() string {
|
return client.URL + gotifyMsgEndpoint
|
||||||
return client.N
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Method implements Provider.
|
// GetMethod implements Provider.
|
||||||
func (client *GotifyClient) Method() string {
|
func (client *GotifyClient) GetMethod() string {
|
||||||
return http.MethodPost
|
return http.MethodPost
|
||||||
}
|
}
|
||||||
|
|
||||||
// URL implements Provider.
|
// GetMIMEType implements Provider.
|
||||||
func (client *GotifyClient) URL() string {
|
func (client *GotifyClient) GetMIMEType() string {
|
||||||
return client.U + gotifyMsgEndpoint
|
|
||||||
}
|
|
||||||
|
|
||||||
// Token implements Provider.
|
|
||||||
func (client *GotifyClient) Token() string {
|
|
||||||
return client.Tok
|
|
||||||
}
|
|
||||||
|
|
||||||
// MIMEType implements Provider.
|
|
||||||
func (client *GotifyClient) MIMEType() string {
|
|
||||||
return "application/json"
|
return "application/json"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,52 +0,0 @@
|
||||||
package notif
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/yusing/go-proxy/internal/utils"
|
|
||||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestGotifyValidation(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
newGotify := Providers[ProviderGotify]
|
|
||||||
|
|
||||||
t.Run("valid", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
_, err := newGotify(map[string]any{
|
|
||||||
"name": "test",
|
|
||||||
"url": "https://example.com",
|
|
||||||
"token": "token",
|
|
||||||
})
|
|
||||||
ExpectNoError(t, err)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("missing url", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
_, err := newGotify(map[string]any{
|
|
||||||
"name": "test",
|
|
||||||
"token": "token",
|
|
||||||
})
|
|
||||||
ExpectError(t, utils.ErrValidationError, err)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("missing token", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
_, err := newGotify(map[string]any{
|
|
||||||
"name": "test",
|
|
||||||
"url": "https://example.com",
|
|
||||||
})
|
|
||||||
ExpectError(t, utils.ErrValidationError, err)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("invalid url", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
_, err := newGotify(map[string]any{
|
|
||||||
"name": "test",
|
|
||||||
"url": "example.com",
|
|
||||||
"token": "token",
|
|
||||||
})
|
|
||||||
ExpectError(t, utils.ErrValidationError, err)
|
|
||||||
})
|
|
||||||
}
|
|
|
@ -2,22 +2,24 @@ package notif
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
E "github.com/yusing/go-proxy/internal/error"
|
E "github.com/yusing/go-proxy/internal/error"
|
||||||
gphttp "github.com/yusing/go-proxy/internal/net/http"
|
gphttp "github.com/yusing/go-proxy/internal/net/http"
|
||||||
U "github.com/yusing/go-proxy/internal/utils"
|
"github.com/yusing/go-proxy/internal/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
type (
|
type (
|
||||||
Provider interface {
|
Provider interface {
|
||||||
Name() string
|
utils.CustomValidator
|
||||||
URL() string
|
|
||||||
Method() string
|
GetName() string
|
||||||
Token() string
|
GetURL() string
|
||||||
MIMEType() string
|
GetToken() string
|
||||||
|
GetMethod() string
|
||||||
|
GetMIMEType() string
|
||||||
|
|
||||||
MakeBody(logMsg *LogMessage) (io.Reader, error)
|
MakeBody(logMsg *LogMessage) (io.Reader, error)
|
||||||
|
|
||||||
makeRespError(resp *http.Response) error
|
makeRespError(resp *http.Response) error
|
||||||
|
@ -31,47 +33,29 @@ const (
|
||||||
ProviderWebhook = "webhook"
|
ProviderWebhook = "webhook"
|
||||||
)
|
)
|
||||||
|
|
||||||
var Providers = map[string]ProviderCreateFunc{
|
|
||||||
ProviderGotify: newNotifProvider[*GotifyClient],
|
|
||||||
ProviderWebhook: newNotifProvider[*Webhook],
|
|
||||||
}
|
|
||||||
|
|
||||||
func newNotifProvider[T Provider](cfg map[string]any) (Provider, E.Error) {
|
|
||||||
var client T
|
|
||||||
err := U.Deserialize(cfg, &client)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err.Subject(client.Name())
|
|
||||||
}
|
|
||||||
return client, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func formatError(p Provider, err error) error {
|
|
||||||
return fmt.Errorf("%s error: %w", p.Name(), err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func notifyProvider(ctx context.Context, provider Provider, msg *LogMessage) error {
|
func notifyProvider(ctx context.Context, provider Provider, msg *LogMessage) error {
|
||||||
body, err := provider.MakeBody(msg)
|
body, err := provider.MakeBody(msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return formatError(provider, err)
|
return E.PrependSubject(provider.GetName(), err)
|
||||||
}
|
}
|
||||||
req, err := http.NewRequestWithContext(
|
req, err := http.NewRequestWithContext(
|
||||||
ctx,
|
ctx,
|
||||||
http.MethodPost,
|
http.MethodPost,
|
||||||
provider.URL(),
|
provider.GetURL(),
|
||||||
body,
|
body,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return formatError(provider, err)
|
return E.PrependSubject(provider.GetName(), err)
|
||||||
}
|
}
|
||||||
|
|
||||||
req.Header.Set("Content-Type", provider.MIMEType())
|
req.Header.Set("Content-Type", provider.GetMIMEType())
|
||||||
if provider.Token() != "" {
|
if provider.GetToken() != "" {
|
||||||
req.Header.Set("Authorization", "Bearer "+provider.Token())
|
req.Header.Set("Authorization", "Bearer "+provider.GetToken())
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := http.DefaultClient.Do(req)
|
resp, err := http.DefaultClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return formatError(provider, err)
|
return E.PrependSubject(provider.GetName(), err)
|
||||||
}
|
}
|
||||||
|
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
|
@ -8,19 +8,16 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/go-playground/validator/v10"
|
E "github.com/yusing/go-proxy/internal/error"
|
||||||
"github.com/yusing/go-proxy/internal/utils"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Webhook struct {
|
type Webhook struct {
|
||||||
N string `json:"name" validate:"required"`
|
ProviderBase
|
||||||
U string `json:"url" validate:"url"`
|
Template string `json:"template"`
|
||||||
Template string `json:"template" validate:"omitempty,oneof=discord"`
|
Payload string `json:"payload"`
|
||||||
Payload string `json:"payload" validate:"jsonIfTemplateNotUsed"`
|
Method string `json:"method"`
|
||||||
Tok string `json:"token"`
|
MIMEType string `json:"mime_type"`
|
||||||
Meth string `json:"method" validate:"oneof=GET POST PUT"`
|
ColorMode string `json:"color_mode"`
|
||||||
MIMETyp string `json:"mime_type"`
|
|
||||||
ColorM string `json:"color_mode" validate:"oneof=hex dec"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//go:embed templates/discord.json
|
//go:embed templates/discord.json
|
||||||
|
@ -30,60 +27,65 @@ var webhookTemplates = map[string]string{
|
||||||
"discord": discordPayload,
|
"discord": discordPayload,
|
||||||
}
|
}
|
||||||
|
|
||||||
func DefaultValue() *Webhook {
|
func (webhook *Webhook) Validate() E.Error {
|
||||||
return &Webhook{
|
if err := webhook.ProviderBase.Validate(); err != nil && !err.Is(ErrMissingToken) {
|
||||||
Meth: "POST",
|
return err
|
||||||
ColorM: "hex",
|
|
||||||
MIMETyp: "application/json",
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
func jsonIfTemplateNotUsed(fl validator.FieldLevel) bool {
|
switch webhook.MIMEType {
|
||||||
template := fl.Parent().FieldByName("Template").String()
|
case "":
|
||||||
if template != "" {
|
webhook.MIMEType = "application/json"
|
||||||
return true
|
case "application/json", "application/x-www-form-urlencoded", "text/plain":
|
||||||
}
|
|
||||||
payload := fl.Field().String()
|
|
||||||
return json.Valid([]byte(payload))
|
|
||||||
}
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
utils.RegisterDefaultValueFactory(DefaultValue)
|
|
||||||
utils.MustRegisterValidation("jsonIfTemplateNotUsed", jsonIfTemplateNotUsed)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Name implements Provider.
|
|
||||||
func (webhook *Webhook) Name() string {
|
|
||||||
return webhook.N
|
|
||||||
}
|
|
||||||
|
|
||||||
// Method implements Provider.
|
|
||||||
func (webhook *Webhook) Method() string {
|
|
||||||
return webhook.Meth
|
|
||||||
}
|
|
||||||
|
|
||||||
// URL implements Provider.
|
|
||||||
func (webhook *Webhook) URL() string {
|
|
||||||
return webhook.U
|
|
||||||
}
|
|
||||||
|
|
||||||
// Token implements Provider.
|
|
||||||
func (webhook *Webhook) Token() string {
|
|
||||||
return webhook.Tok
|
|
||||||
}
|
|
||||||
|
|
||||||
// MIMEType implements Provider.
|
|
||||||
func (webhook *Webhook) MIMEType() string {
|
|
||||||
return webhook.MIMETyp
|
|
||||||
}
|
|
||||||
|
|
||||||
func (webhook *Webhook) ColorMode() string {
|
|
||||||
switch webhook.Template {
|
|
||||||
case "discord":
|
|
||||||
return "dec"
|
|
||||||
default:
|
default:
|
||||||
return webhook.ColorM
|
return E.New("invalid mime_type, expect empty, 'application/json', 'application/x-www-form-urlencoded' or 'text/plain'")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
switch webhook.Template {
|
||||||
|
case "":
|
||||||
|
if webhook.MIMEType == "application/json" && !json.Valid([]byte(webhook.Payload)) {
|
||||||
|
return E.New("invalid payload, expect valid JSON")
|
||||||
|
}
|
||||||
|
if webhook.Payload == "" {
|
||||||
|
return E.New("invalid payload, expect non-empty")
|
||||||
|
}
|
||||||
|
case "discord":
|
||||||
|
webhook.ColorMode = "dec"
|
||||||
|
webhook.Method = http.MethodPost
|
||||||
|
webhook.MIMEType = "application/json"
|
||||||
|
if webhook.Payload == "" {
|
||||||
|
webhook.Payload = discordPayload
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return E.New("invalid template, expect empty or 'discord'")
|
||||||
|
}
|
||||||
|
|
||||||
|
switch webhook.Method {
|
||||||
|
case "":
|
||||||
|
webhook.Method = http.MethodPost
|
||||||
|
case http.MethodGet, http.MethodPost, http.MethodPut:
|
||||||
|
default:
|
||||||
|
return E.New("invalid method, expect empty, 'GET', 'POST' or 'PUT'")
|
||||||
|
}
|
||||||
|
|
||||||
|
switch webhook.ColorMode {
|
||||||
|
case "":
|
||||||
|
webhook.ColorMode = "hex"
|
||||||
|
case "hex", "dec":
|
||||||
|
default:
|
||||||
|
return E.New("invalid color_mode, expect empty, 'hex' or 'dec'")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMethod implements Provider.
|
||||||
|
func (webhook *Webhook) GetMethod() string {
|
||||||
|
return webhook.Method
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMIMEType implements Provider.
|
||||||
|
func (webhook *Webhook) GetMIMEType() string {
|
||||||
|
return webhook.MIMEType
|
||||||
}
|
}
|
||||||
|
|
||||||
// makeRespError implements Provider.
|
// makeRespError implements Provider.
|
||||||
|
@ -108,7 +110,7 @@ func (webhook *Webhook) MakeBody(logMsg *LogMessage) (io.Reader, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
var color string
|
var color string
|
||||||
if webhook.ColorMode() == "hex" {
|
if webhook.ColorMode == "hex" {
|
||||||
color = logMsg.Color.HexString()
|
color = logMsg.Color.HexString()
|
||||||
} else {
|
} else {
|
||||||
color = logMsg.Color.DecString()
|
color = logMsg.Color.DecString()
|
||||||
|
|
|
@ -1,121 +0,0 @@
|
||||||
package notif
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/yusing/go-proxy/internal/utils"
|
|
||||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestWebhookValidation(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
newWebhook := Providers[ProviderWebhook]
|
|
||||||
|
|
||||||
t.Run("valid", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
_, err := newWebhook(map[string]any{
|
|
||||||
"name": "test",
|
|
||||||
"url": "https://example.com",
|
|
||||||
"payload": "{}",
|
|
||||||
})
|
|
||||||
ExpectNoError(t, err)
|
|
||||||
})
|
|
||||||
t.Run("valid template", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
_, err := newWebhook(map[string]any{
|
|
||||||
"name": "test",
|
|
||||||
"url": "https://example.com",
|
|
||||||
"template": "discord",
|
|
||||||
})
|
|
||||||
ExpectNoError(t, err)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("missing url", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
_, err := newWebhook(map[string]any{
|
|
||||||
"name": "test",
|
|
||||||
"payload": "{}",
|
|
||||||
})
|
|
||||||
ExpectError(t, utils.ErrValidationError, err)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("missing payload", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
_, err := newWebhook(map[string]any{
|
|
||||||
"name": "test",
|
|
||||||
"url": "https://example.com",
|
|
||||||
})
|
|
||||||
ExpectError(t, utils.ErrValidationError, err)
|
|
||||||
})
|
|
||||||
t.Run("invalid url", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
_, err := newWebhook(map[string]any{
|
|
||||||
"name": "test",
|
|
||||||
"url": "example.com",
|
|
||||||
"payload": "{}",
|
|
||||||
})
|
|
||||||
ExpectError(t, utils.ErrValidationError, err)
|
|
||||||
})
|
|
||||||
t.Run("invalid payload", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
_, err := newWebhook(map[string]any{
|
|
||||||
"name": "test",
|
|
||||||
"url": "https://example.com",
|
|
||||||
"payload": "abcd",
|
|
||||||
})
|
|
||||||
ExpectError(t, utils.ErrValidationError, err)
|
|
||||||
})
|
|
||||||
t.Run("invalid method", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
_, err := newWebhook(map[string]any{
|
|
||||||
"name": "test",
|
|
||||||
"url": "https://example.com",
|
|
||||||
"payload": "{}",
|
|
||||||
"method": "abcd",
|
|
||||||
})
|
|
||||||
ExpectError(t, utils.ErrValidationError, err)
|
|
||||||
})
|
|
||||||
t.Run("invalid template", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
_, err := newWebhook(map[string]any{
|
|
||||||
"name": "test",
|
|
||||||
"url": "https://example.com",
|
|
||||||
"template": "abcd",
|
|
||||||
})
|
|
||||||
ExpectError(t, utils.ErrValidationError, err)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestWebhookBody(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
var webhook Webhook
|
|
||||||
webhook.Payload = discordPayload
|
|
||||||
bodyReader, err := webhook.MakeBody(&LogMessage{
|
|
||||||
Title: "abc",
|
|
||||||
Extras: map[string]any{
|
|
||||||
"foo": "bar",
|
|
||||||
},
|
|
||||||
})
|
|
||||||
ExpectNoError(t, err)
|
|
||||||
|
|
||||||
var body struct {
|
|
||||||
Embeds []struct {
|
|
||||||
Title string `json:"title"`
|
|
||||||
Fields []struct {
|
|
||||||
Name string `json:"name"`
|
|
||||||
Value string `json:"value"`
|
|
||||||
} `json:"fields"`
|
|
||||||
} `json:"embeds"`
|
|
||||||
}
|
|
||||||
|
|
||||||
err = json.NewDecoder(bodyReader).Decode(&body)
|
|
||||||
ExpectNoError(t, err)
|
|
||||||
|
|
||||||
ExpectEqual(t, body.Embeds[0].Title, "abc")
|
|
||||||
fields := body.Embeds[0].Fields
|
|
||||||
ExpectEqual(t, fields[0].Name, "foo")
|
|
||||||
ExpectEqual(t, fields[0].Value, "bar")
|
|
||||||
}
|
|
|
@ -57,7 +57,7 @@ func HomepageCategories() []string {
|
||||||
return categories
|
return categories
|
||||||
}
|
}
|
||||||
|
|
||||||
func HomepageConfig(useDefaultCategories bool, categoryFilter, providerFilter string) homepage.Config {
|
func HomepageConfig(useDefaultCategories bool, categoryFilter, providerFilter string) homepage.Categories {
|
||||||
hpCfg := homepage.NewHomePageConfig()
|
hpCfg := homepage.NewHomePageConfig()
|
||||||
|
|
||||||
routes.GetHTTPRoutes().RangeAll(func(alias string, r route.HTTPRoute) {
|
routes.GetHTTPRoutes().RangeAll(func(alias string, r route.HTTPRoute) {
|
||||||
|
|
|
@ -1,13 +1,10 @@
|
||||||
package utils
|
package utils
|
||||||
|
|
||||||
// FIXME: some times [%d] is not in correct order
|
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"os"
|
"os"
|
||||||
"reflect"
|
"reflect"
|
||||||
"runtime/debug"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
@ -21,6 +18,10 @@ import (
|
||||||
|
|
||||||
type SerializedObject = map[string]any
|
type SerializedObject = map[string]any
|
||||||
|
|
||||||
|
type MapUnmarshaller interface {
|
||||||
|
UnmarshalMap(m map[string]any) E.Error
|
||||||
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrInvalidType = E.New("invalid type")
|
ErrInvalidType = E.New("invalid type")
|
||||||
ErrNilValue = E.New("nil")
|
ErrNilValue = E.New("nil")
|
||||||
|
@ -29,6 +30,8 @@ var (
|
||||||
ErrUnknownField = E.New("unknown field")
|
ErrUnknownField = E.New("unknown field")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var mapUnmarshalerType = reflect.TypeFor[MapUnmarshaller]()
|
||||||
|
|
||||||
var defaultValues = functional.NewMapOf[reflect.Type, func() any]()
|
var defaultValues = functional.NewMapOf[reflect.Type, func() any]()
|
||||||
|
|
||||||
func RegisterDefaultValueFactory[T any](factory func() *T) {
|
func RegisterDefaultValueFactory[T any](factory func() *T) {
|
||||||
|
@ -56,8 +59,9 @@ func extractFields(t reflect.Type) (all, anonymous []reflect.StructField) {
|
||||||
if t.Kind() != reflect.Struct {
|
if t.Kind() != reflect.Struct {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
var fields []reflect.StructField
|
n := t.NumField()
|
||||||
for i := range t.NumField() {
|
fields := make([]reflect.StructField, 0, n)
|
||||||
|
for i := range n {
|
||||||
field := t.Field(i)
|
field := t.Field(i)
|
||||||
if !field.IsExported() {
|
if !field.IsExported() {
|
||||||
continue
|
continue
|
||||||
|
@ -74,31 +78,74 @@ func extractFields(t reflect.Type) (all, anonymous []reflect.StructField) {
|
||||||
return fields, anonymous
|
return fields, anonymous
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ValidateWithFieldTags(s any) E.Error {
|
||||||
|
errs := E.NewBuilder("validate error")
|
||||||
|
err := validate.Struct(s)
|
||||||
|
var valErrs validator.ValidationErrors
|
||||||
|
if errors.As(err, &valErrs) {
|
||||||
|
for _, e := range valErrs {
|
||||||
|
detail := e.ActualTag()
|
||||||
|
if e.Param() != "" {
|
||||||
|
detail += ":" + e.Param()
|
||||||
|
}
|
||||||
|
errs.Add(ErrValidationError.
|
||||||
|
Subject(e.Namespace()).
|
||||||
|
Withf("require %q", detail))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return errs.Error()
|
||||||
|
}
|
||||||
|
|
||||||
// Deserialize takes a SerializedObject and a target value, and assigns the values in the SerializedObject to the target value.
|
// Deserialize takes a SerializedObject and a target value, and assigns the values in the SerializedObject to the target value.
|
||||||
// Deserialize ignores case differences between the field names in the SerializedObject and the target.
|
// Deserialize ignores case differences between the field names in the SerializedObject and the target.
|
||||||
//
|
//
|
||||||
// The target value must be a struct or a map[string]any.
|
// The target value must be a struct or a map[string]any.
|
||||||
// If the target value is a struct, the SerializedObject will be deserialized into the struct fields and validate if needed.
|
// If the target value is a struct , and implements the MapUnmarshaller interface,
|
||||||
// If the target value is a map[string]any, the SerializedObject will be deserialized into the map.
|
// the UnmarshalMap method will be called.
|
||||||
|
//
|
||||||
|
// If the target value is a struct, but does not implements the MapUnmarshaller interface,
|
||||||
|
// the SerializedObject will be deserialized into the struct fields and validate if needed.
|
||||||
|
//
|
||||||
|
// If the target value is a map[string]any the SerializedObject will be deserialized into the map.
|
||||||
//
|
//
|
||||||
// The function returns an error if the target value is not a struct or a map[string]any, or if there is an error during deserialization.
|
// The function returns an error if the target value is not a struct or a map[string]any, or if there is an error during deserialization.
|
||||||
func Deserialize(src SerializedObject, dst any) E.Error {
|
func Deserialize(src SerializedObject, dst any) E.Error {
|
||||||
if src == nil {
|
|
||||||
return E.Errorf("deserialize: src is %w", ErrNilValue)
|
|
||||||
}
|
|
||||||
if dst == nil {
|
|
||||||
return E.Errorf("deserialize: dst is %w\n%s", ErrNilValue, debug.Stack())
|
|
||||||
}
|
|
||||||
|
|
||||||
dstV := reflect.ValueOf(dst)
|
dstV := reflect.ValueOf(dst)
|
||||||
dstT := dstV.Type()
|
dstT := dstV.Type()
|
||||||
|
|
||||||
|
if src == nil {
|
||||||
|
if dstV.CanSet() {
|
||||||
|
dstV.Set(reflect.Zero(dstT))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return E.Errorf("deserialize: src is %w and dst is not settable", ErrNilValue)
|
||||||
|
}
|
||||||
|
|
||||||
|
if dstT.Implements(mapUnmarshalerType) {
|
||||||
|
for dstV.IsNil() {
|
||||||
|
switch dstT.Kind() {
|
||||||
|
case reflect.Struct:
|
||||||
|
dstV.Set(New(dstT))
|
||||||
|
case reflect.Map:
|
||||||
|
dstV.Set(reflect.MakeMap(dstT))
|
||||||
|
case reflect.Slice:
|
||||||
|
dstV.Set(reflect.MakeSlice(dstT, 0, 0))
|
||||||
|
case reflect.Ptr:
|
||||||
|
dstV.Set(reflect.New(dstT.Elem()))
|
||||||
|
default:
|
||||||
|
return E.Errorf("deserialize: %w for dst %s", ErrInvalidType, dstT.String())
|
||||||
|
}
|
||||||
|
dstV = dstV.Elem()
|
||||||
|
}
|
||||||
|
return dstV.Interface().(MapUnmarshaller).UnmarshalMap(src)
|
||||||
|
}
|
||||||
|
|
||||||
for dstT.Kind() == reflect.Ptr {
|
for dstT.Kind() == reflect.Ptr {
|
||||||
if dstV.IsNil() {
|
if dstV.IsNil() {
|
||||||
if dstV.CanSet() {
|
if dstV.CanSet() {
|
||||||
dstV.Set(New(dstT.Elem()))
|
dstV.Set(New(dstT.Elem()))
|
||||||
} else {
|
} else {
|
||||||
return E.Errorf("deserialize: dst is %w\n%s", ErrNilValue, debug.Stack())
|
return E.Errorf("deserialize: dst is %w and not settable", ErrNilValue)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
dstV = dstV.Elem()
|
dstV = dstV.Elem()
|
||||||
|
@ -113,9 +160,8 @@ func Deserialize(src SerializedObject, dst any) E.Error {
|
||||||
|
|
||||||
switch dstV.Kind() {
|
switch dstV.Kind() {
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
needValidate := false
|
hasValidateTag := false
|
||||||
mapping := make(map[string]reflect.Value)
|
mapping := make(map[string]reflect.Value)
|
||||||
fieldName := make(map[string]string)
|
|
||||||
fields, anonymous := extractFields(dstT)
|
fields, anonymous := extractFields(dstT)
|
||||||
for _, anon := range anonymous {
|
for _, anon := range anonymous {
|
||||||
if field := dstV.FieldByName(anon.Name); field.Kind() == reflect.Ptr && field.IsNil() {
|
if field := dstV.FieldByName(anon.Name); field.Kind() == reflect.Ptr && field.IsNil() {
|
||||||
|
@ -134,17 +180,15 @@ func Deserialize(src SerializedObject, dst any) E.Error {
|
||||||
}
|
}
|
||||||
key = strutils.ToLowerNoSnake(key)
|
key = strutils.ToLowerNoSnake(key)
|
||||||
mapping[key] = dstV.FieldByName(field.Name)
|
mapping[key] = dstV.FieldByName(field.Name)
|
||||||
fieldName[field.Name] = key
|
|
||||||
|
|
||||||
if !needValidate {
|
if !hasValidateTag {
|
||||||
_, needValidate = field.Tag.Lookup("validate")
|
_, hasValidateTag = field.Tag.Lookup("validate")
|
||||||
}
|
}
|
||||||
|
|
||||||
aliases, ok := field.Tag.Lookup("aliases")
|
aliases, ok := field.Tag.Lookup("aliases")
|
||||||
if ok {
|
if ok {
|
||||||
for _, alias := range strutils.CommaSeperatedList(aliases) {
|
for _, alias := range strutils.CommaSeperatedList(aliases) {
|
||||||
mapping[alias] = dstV.FieldByName(field.Name)
|
mapping[alias] = dstV.FieldByName(field.Name)
|
||||||
fieldName[field.Name] = alias
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -158,20 +202,10 @@ func Deserialize(src SerializedObject, dst any) E.Error {
|
||||||
errs.Add(ErrUnknownField.Subject(k).Withf(strutils.DoYouMean(NearestField(k, mapping))))
|
errs.Add(ErrUnknownField.Subject(k).Withf(strutils.DoYouMean(NearestField(k, mapping))))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if needValidate {
|
if hasValidateTag {
|
||||||
err := validate.Struct(dstV.Interface())
|
errs.Add(ValidateWithFieldTags(dstV.Interface()))
|
||||||
var valErrs validator.ValidationErrors
|
} else if validator, ok := dstV.Addr().Interface().(CustomValidator); ok {
|
||||||
if errors.As(err, &valErrs) {
|
errs.Add(validator.Validate())
|
||||||
for _, e := range valErrs {
|
|
||||||
detail := e.ActualTag()
|
|
||||||
if e.Param() != "" {
|
|
||||||
detail += ":" + e.Param()
|
|
||||||
}
|
|
||||||
errs.Add(ErrValidationError.
|
|
||||||
Subject(e.StructNamespace()).
|
|
||||||
Withf("require %q", detail))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return errs.Error()
|
return errs.Error()
|
||||||
case reflect.Map:
|
case reflect.Map:
|
||||||
|
@ -188,6 +222,9 @@ func Deserialize(src SerializedObject, dst any) E.Error {
|
||||||
errs.Add(err.Subject(k))
|
errs.Add(err.Subject(k))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if validator, ok := dstV.Addr().Interface().(CustomValidator); ok {
|
||||||
|
errs.Add(validator.Validate())
|
||||||
|
}
|
||||||
return errs.Error()
|
return errs.Error()
|
||||||
default:
|
default:
|
||||||
return ErrUnsupportedConversion.Subject("mapping to " + dstT.String())
|
return ErrUnsupportedConversion.Subject("mapping to " + dstT.String())
|
||||||
|
@ -421,14 +458,6 @@ func DeserializeYAMLMap[V any](data []byte) (_ functional.Map[string, V], err E.
|
||||||
return functional.NewMapFrom(m2), nil
|
return functional.NewMapFrom(m2), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func DeserializeJSON[T any](data []byte, target T) E.Error {
|
|
||||||
m := make(map[string]any)
|
|
||||||
if err := json.Unmarshal(data, &m); err != nil {
|
|
||||||
return E.From(err)
|
|
||||||
}
|
|
||||||
return Deserialize(m, target)
|
|
||||||
}
|
|
||||||
|
|
||||||
func loadSerialized[T any](path string, dst *T, deserialize func(data []byte, dst any) error) error {
|
func loadSerialized[T any](path string, dst *T, deserialize func(data []byte, dst any) error) error {
|
||||||
data, err := os.ReadFile(path)
|
data, err := os.ReadFile(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -35,6 +35,14 @@ func ExpectNoError(t *testing.T, err error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ExpectHasError(t *testing.T, err error) {
|
||||||
|
t.Helper()
|
||||||
|
if errors.Is(err, nil) {
|
||||||
|
t.Error("expected err not nil")
|
||||||
|
t.FailNow()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func ExpectError(t *testing.T, expected error, err error) {
|
func ExpectError(t *testing.T, expected error, err error) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
if !errors.Is(err, expected) {
|
if !errors.Is(err, expected) {
|
||||||
|
|
|
@ -9,6 +9,10 @@ var validate = validator.New()
|
||||||
|
|
||||||
var ErrValidationError = E.New("validation error")
|
var ErrValidationError = E.New("validation error")
|
||||||
|
|
||||||
|
type CustomValidator interface {
|
||||||
|
Validate() E.Error
|
||||||
|
}
|
||||||
|
|
||||||
func Validator() *validator.Validate {
|
func Validator() *validator.Validate {
|
||||||
return validate
|
return validate
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Reference in a new issue