Cleaned up some validation code, stricter validation

This commit is contained in:
yusing 2025-01-26 14:43:48 +08:00
parent 254224c0e8
commit 1586610a44
23 changed files with 590 additions and 468 deletions

View file

@ -75,6 +75,38 @@ func GetFileContent(w http.ResponseWriter, r *http.Request) {
U.WriteBody(w, content)
}
func validateFile(fileType FileType, content []byte) error {
switch fileType {
case FileTypeConfig:
return config.Validate(content)
case FileTypeMiddleware:
errs := E.NewBuilder("middleware errors")
middleware.BuildMiddlewaresFromYAML("", content, errs)
return errs.Error()
}
return provider.Validate(content)
}
func ValidateFile(w http.ResponseWriter, r *http.Request) {
fileType := FileType(r.PathValue("type"))
if !fileType.IsValid() {
U.RespondError(w, U.ErrInvalidKey("type"), http.StatusBadRequest)
return
}
content, err := io.ReadAll(r.Body)
if err != nil {
U.HandleErr(w, r, err)
return
}
r.Body.Close()
err = validateFile(fileType, content)
if err != nil {
U.RespondError(w, err, http.StatusBadRequest)
return
}
w.WriteHeader(http.StatusOK)
}
func SetFileContent(w http.ResponseWriter, r *http.Request) {
fileType, filename, err := getArgs(r)
if err != nil {
@ -87,19 +119,7 @@ func SetFileContent(w http.ResponseWriter, r *http.Request) {
return
}
var valErr E.Error
switch fileType {
case FileTypeConfig:
valErr = config.Validate(content)
case FileTypeMiddleware:
errs := E.NewBuilder("middleware errors")
middleware.BuildMiddlewaresFromYAML(filename, content, errs)
valErr = errs.Error()
default:
valErr = provider.Validate(content)
}
if valErr != nil {
if valErr := validateFile(fileType, content); valErr != nil {
U.RespondError(w, valErr, http.StatusBadRequest)
return
}

View file

@ -6,6 +6,7 @@ import (
"crypto/rand"
"crypto/x509"
"os"
"regexp"
"github.com/go-acme/lego/v4/certcrypto"
"github.com/go-acme/lego/v4/lego"
@ -13,63 +14,89 @@ import (
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/utils"
"github.com/yusing/go-proxy/internal/utils/strutils"
"github.com/yusing/go-proxy/internal/config/types"
)
type Config types.AutoCertConfig
type (
AutocertConfig struct {
Email string `json:"email,omitempty"`
Domains []string `json:"domains,omitempty"`
CertPath string `json:"cert_path,omitempty"`
KeyPath string `json:"key_path,omitempty"`
ACMEKeyPath string `json:"acme_key_path,omitempty"`
Provider string `json:"provider,omitempty"`
Options ProviderOpt `json:"options,omitempty"`
}
ProviderOpt map[string]any
)
var (
ErrMissingDomain = E.New("missing field 'domains'")
ErrMissingEmail = E.New("missing field 'email'")
ErrMissingProvider = E.New("missing field 'provider'")
ErrInvalidDomain = E.New("invalid domain")
ErrUnknownProvider = E.New("unknown provider")
)
func NewConfig(cfg *types.AutoCertConfig) *Config {
var domainOrWildcardRE = regexp.MustCompile(`^\*?([^.]+\.)+[^.]+$`)
// Validate implements the utils.CustomValidator interface.
func (cfg *AutocertConfig) Validate() E.Error {
if cfg == nil {
cfg = new(types.AutoCertConfig)
return nil
}
if cfg.Provider == "" {
cfg.Provider = ProviderLocal
return nil
}
b := E.NewBuilder("autocert errors")
if cfg.Provider != ProviderLocal {
if len(cfg.Domains) == 0 {
b.Add(ErrMissingDomain)
}
if cfg.Email == "" {
b.Add(ErrMissingEmail)
}
for i, d := range cfg.Domains {
if !domainOrWildcardRE.MatchString(d) {
b.Add(ErrInvalidDomain.Subjectf("domains[%d]", i))
}
}
// check if provider is implemented
providerConstructor, ok := providersGenMap[cfg.Provider]
if !ok {
b.Add(ErrUnknownProvider.
Subject(cfg.Provider).
Withf(strutils.DoYouMean(utils.NearestField(cfg.Provider, providersGenMap))))
} else {
_, err := providerConstructor(cfg.Options)
if err != nil {
b.Add(err)
}
}
}
return b.Error()
}
func (cfg *AutocertConfig) GetProvider() (*Provider, E.Error) {
if cfg == nil {
cfg = new(AutocertConfig)
}
if err := cfg.Validate(); err != nil {
return nil, err
}
if cfg.CertPath == "" {
cfg.CertPath = CertFileDefault
}
if cfg.KeyPath == "" {
cfg.KeyPath = KeyFileDefault
}
if cfg.Provider == "" {
cfg.Provider = ProviderLocal
}
if cfg.ACMEKeyPath == "" {
cfg.ACMEKeyPath = ACMEKeyFileDefault
}
return (*Config)(cfg)
}
func (cfg *Config) GetProvider() (*Provider, E.Error) {
b := E.NewBuilder("autocert errors")
if cfg.Provider != ProviderLocal {
if len(cfg.Domains) == 0 {
b.Add(ErrMissingDomain)
}
if cfg.Provider == "" {
b.Add(ErrMissingProvider)
}
if cfg.Email == "" {
b.Add(ErrMissingEmail)
}
// check if provider is implemented
_, ok := providersGenMap[cfg.Provider]
if !ok {
b.Add(ErrUnknownProvider.
Subject(cfg.Provider).
Withf(strutils.DoYouMean(utils.NearestField(cfg.Provider, providersGenMap))))
}
}
if b.HasError() {
return nil, b.Error()
}
var privKey *ecdsa.PrivateKey
var err error
@ -103,7 +130,7 @@ func (cfg *Config) GetProvider() (*Provider, E.Error) {
}, nil
}
func (cfg *Config) loadACMEKey() (*ecdsa.PrivateKey, error) {
func (cfg *AutocertConfig) loadACMEKey() (*ecdsa.PrivateKey, error) {
data, err := os.ReadFile(cfg.ACMEKeyPath)
if err != nil {
return nil, err
@ -111,7 +138,7 @@ func (cfg *Config) loadACMEKey() (*ecdsa.PrivateKey, error) {
return x509.ParseECPrivateKey(data)
}
func (cfg *Config) saveACMEKey(key *ecdsa.PrivateKey) error {
func (cfg *AutocertConfig) saveACMEKey(key *ecdsa.PrivateKey) error {
data, err := x509.MarshalECPrivateKey(key)
if err != nil {
return err

View file

@ -14,7 +14,6 @@ import (
"github.com/go-acme/lego/v4/challenge"
"github.com/go-acme/lego/v4/lego"
"github.com/go-acme/lego/v4/registration"
"github.com/yusing/go-proxy/internal/config/types"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/task"
@ -24,7 +23,7 @@ import (
type (
Provider struct {
cfg *Config
cfg *AutocertConfig
user *User
legoCfg *lego.Config
client *lego.Client
@ -33,7 +32,7 @@ type (
tlsCert *tls.Certificate
certExpiries CertExpiries
}
ProviderGenerator func(types.AutocertProviderOpt) (challenge.Provider, E.Error)
ProviderGenerator func(ProviderOpt) (challenge.Provider, E.Error)
CertExpiries map[string]time.Time
)
@ -313,7 +312,7 @@ func providerGenerator[CT any, PT challenge.Provider](
defaultCfg func() *CT,
newProvider func(*CT) (PT, error),
) ProviderGenerator {
return func(opt types.AutocertProviderOpt) (challenge.Provider, E.Error) {
return func(opt ProviderOpt) (challenge.Provider, E.Error) {
cfg := defaultCfg()
err := U.Deserialize(opt, cfg)
if err != nil {

View file

@ -15,9 +15,11 @@ type User struct {
func (u *User) GetEmail() string {
return u.Email
}
func (u *User) GetRegistration() *registration.Resource {
return u.Registration
}
func (u *User) GetPrivateKey() crypto.PrivateKey {
return u.key
}

View file

@ -234,7 +234,7 @@ func (cfg *Config) load() E.Error {
errs := E.NewBuilder(errMsg)
errs.Add(cfg.entrypoint.SetMiddlewares(model.Entrypoint.Middlewares))
errs.Add(cfg.entrypoint.SetAccessLogger(cfg.task, model.Entrypoint.AccessLog))
errs.Add(cfg.initNotification(model.Providers.Notification))
cfg.initNotification(model.Providers.Notification)
errs.Add(cfg.initAutoCert(model.AutoCert))
errs.Add(cfg.loadRouteProviders(&model.Providers))
@ -249,28 +249,22 @@ func (cfg *Config) load() E.Error {
return errs.Error()
}
func (cfg *Config) initNotification(notifCfg []types.NotificationConfig) (err E.Error) {
func (cfg *Config) initNotification(notifCfg []notif.NotificationConfig) {
if len(notifCfg) == 0 {
return
}
dispatcher := notif.StartNotifDispatcher(cfg.task)
errs := E.NewBuilder("notification providers load errors")
for i, notifier := range notifCfg {
_, err := dispatcher.RegisterProvider(notifier)
if err == nil {
continue
}
errs.Add(err.Subjectf("[%d]", i))
for _, notifier := range notifCfg {
dispatcher.RegisterProvider(&notifier)
}
return errs.Error()
}
func (cfg *Config) initAutoCert(autocertCfg *types.AutoCertConfig) (err E.Error) {
func (cfg *Config) initAutoCert(autocertCfg *autocert.AutocertConfig) (err E.Error) {
if cfg.autocertProvider != nil {
return
}
cfg.autocertProvider, err = autocert.NewConfig(autocertCfg).GetProvider()
cfg.autocertProvider, err = autocertCfg.GetProvider()
return
}

View file

@ -1,14 +0,0 @@
package types
type (
AutoCertConfig struct {
Email string `json:"email,omitempty" validate:"email"`
Domains []string `json:"domains,omitempty"`
CertPath string `json:"cert_path,omitempty" validate:"omitempty,filepath"`
KeyPath string `json:"key_path,omitempty" validate:"omitempty,filepath"`
ACMEKeyPath string `json:"acme_key_path,omitempty" validate:"omitempty,filepath"`
Provider string `json:"provider,omitempty"`
Options AutocertProviderOpt `json:"options,omitempty"`
}
AutocertProviderOpt map[string]any
)

View file

@ -2,8 +2,12 @@ package types
import (
"context"
"regexp"
"github.com/go-playground/validator/v10"
"github.com/yusing/go-proxy/internal/autocert"
"github.com/yusing/go-proxy/internal/net/http/accesslog"
"github.com/yusing/go-proxy/internal/notif"
"github.com/yusing/go-proxy/internal/utils"
E "github.com/yusing/go-proxy/internal/error"
@ -11,23 +15,22 @@ import (
type (
Config struct {
AutoCert *AutoCertConfig `json:"autocert" validate:"omitempty"`
Entrypoint Entrypoint `json:"entrypoint"`
Providers Providers `json:"providers"`
MatchDomains []string `json:"match_domains" validate:"dive,fqdn"`
Homepage HomepageConfig `json:"homepage"`
TimeoutShutdown int `json:"timeout_shutdown" validate:"gte=0"`
AutoCert *autocert.AutocertConfig `json:"autocert"`
Entrypoint Entrypoint `json:"entrypoint"`
Providers Providers `json:"providers"`
MatchDomains []string `json:"match_domains" validate:"domain_name"`
Homepage HomepageConfig `json:"homepage"`
TimeoutShutdown int `json:"timeout_shutdown" validate:"gte=0"`
}
Providers struct {
Files []string `json:"include" validate:"dive,filepath"`
Docker map[string]string `json:"docker" validate:"dive,unix_addr|url"`
Notification []NotificationConfig `json:"notification"`
Files []string `json:"include" validate:"dive,filepath"`
Docker map[string]string `json:"docker" validate:"dive,unix_addr|url"`
Notification []notif.NotificationConfig `json:"notification"`
}
Entrypoint struct {
Middlewares []map[string]any `json:"middlewares"`
AccessLog *accesslog.Config `json:"access_log" validate:"omitempty"`
}
NotificationConfig map[string]any
ConfigInstance interface {
Value() *Config
@ -52,6 +55,17 @@ func Validate(data []byte) E.Error {
return utils.DeserializeYAML(data, &model)
}
var matchDomainsRegex = regexp.MustCompile(`^[^\.]?([\w\d\-_]\.?)+[^\.]?$`)
func init() {
utils.RegisterDefaultValueFactory(DefaultConfig)
utils.MustRegisterValidation("domain_name", func(fl validator.FieldLevel) bool {
domains := fl.Field().Interface().([]string)
for _, domain := range domains {
if !matchDomainsRegex.MatchString(domain) {
return false
}
}
return true
})
}

View file

@ -19,6 +19,13 @@ func Errorf(format string, args ...any) Error {
return &baseError{fmt.Errorf(format, args...)}
}
func Wrap(err error, message ...string) Error {
if len(message) == 0 || message[0] == "" {
return From(err)
}
return Errorf("%w: %s", err, message[0])
}
func From(err error) Error {
if err == nil {
return nil

View file

@ -1,11 +1,13 @@
package homepage
import "github.com/yusing/go-proxy/internal/utils"
import (
"github.com/yusing/go-proxy/internal/utils"
)
type (
//nolint:recvcheck
Config map[string]Category
Category []*Item
Categories map[string]Category
Category []*Item
ItemConfig struct {
Show bool `json:"show"`
@ -48,6 +50,10 @@ func NewItem(alias string) *Item {
}
}
func NewHomePageConfig() Categories {
return Categories(make(map[string]Category))
}
func (item *Item) IsEmpty() bool {
return item == nil || item.IsUnset || item.ItemConfig == nil
}
@ -56,15 +62,11 @@ func (item *Item) GetOverride() *Item {
return overrideConfigInstance.GetOverride(item)
}
func NewHomePageConfig() Config {
return Config(make(map[string]Category))
func (c *Categories) Clear() {
*c = make(Categories)
}
func (c *Config) Clear() {
*c = make(Config)
}
func (c Config) Add(item *Item) {
func (c Categories) Add(item *Item) {
if c[item.Category] == nil {
c[item.Category] = make(Category, 0)
}

View file

@ -53,10 +53,6 @@ func GetOverrideConfig() *OverrideConfig {
return overrideConfigInstance
}
func (c *OverrideConfig) UnmarshalJSON(data []byte) error {
return utils.DeserializeJSON(data, c)
}
func (c *OverrideConfig) OverrideItem(alias string, override *ItemConfig) {
c.mu.Lock()
defer c.mu.Unlock()

47
internal/notif/base.go Normal file
View file

@ -0,0 +1,47 @@
package notif
import (
"net/url"
"strings"
E "github.com/yusing/go-proxy/internal/error"
)
type ProviderBase struct {
Name string `json:"name" validate:"required"`
URL string `json:"url" validate:"url"`
Token string `json:"token"`
}
var (
ErrMissingToken = E.New("token is required")
ErrURLMissingScheme = E.New("url missing scheme, expect 'http://' or 'https://'")
)
// Validate implements the utils.CustomValidator interface.
func (base *ProviderBase) Validate() E.Error {
if base.Token == "" {
return ErrMissingToken
}
if !strings.HasPrefix(base.URL, "http://") && !strings.HasPrefix(base.URL, "https://") {
return ErrURLMissingScheme
}
u, err := url.Parse(base.URL)
if err != nil {
return E.Wrap(err)
}
base.URL = u.String()
return nil
}
func (base *ProviderBase) GetName() string {
return base.Name
}
func (base *ProviderBase) GetURL() string {
return base.URL
}
func (base *ProviderBase) GetToken() string {
return base.Token
}

54
internal/notif/config.go Normal file
View file

@ -0,0 +1,54 @@
package notif
import (
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/utils"
)
type NotificationConfig struct {
ProviderName string `json:"provider"`
Provider Provider `json:"-"`
}
var (
ErrMissingNotifProvider = E.New("missing notification provider")
ErrInvalidNotifProviderType = E.New("invalid notification provider type")
ErrUnknownNotifProvider = E.New("unknown notification provider")
)
// UnmarshalMap implements MapUnmarshaler.
func (cfg *NotificationConfig) UnmarshalMap(m map[string]any) (err E.Error) {
// extract provider name
providerName := m["provider"]
switch providerName := providerName.(type) {
case string:
cfg.ProviderName = providerName
default:
return ErrInvalidNotifProviderType
}
delete(m, "provider")
if cfg.ProviderName == "" {
return ErrMissingNotifProvider
}
// validate provider name and initialize provider
switch cfg.ProviderName {
case ProviderWebhook:
cfg.Provider = &Webhook{}
case ProviderGotify:
cfg.Provider = &GotifyClient{}
default:
return ErrUnknownNotifProvider.
Subject(cfg.ProviderName).
Withf("expect %s or %s", ProviderWebhook, ProviderGotify)
}
// unmarshal provider config
if err := utils.Deserialize(m, cfg.Provider); err != nil {
return err
}
// validate provider
return cfg.Provider.Validate()
}

View file

@ -0,0 +1,163 @@
package notif
import (
"net/http"
"testing"
"github.com/yusing/go-proxy/internal/utils"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestNotificationConfig(t *testing.T) {
tests := []struct {
name string
cfg map[string]any
expected Provider
wantErr bool
}{
{
name: "valid_webhook",
cfg: map[string]any{
"name": "test",
"provider": "webhook",
"template": "discord",
"url": "https://example.com",
},
expected: &Webhook{
ProviderBase: ProviderBase{
Name: "test",
URL: "https://example.com",
},
Template: "discord",
Method: http.MethodPost,
MIMEType: "application/json",
ColorMode: "dec",
Payload: discordPayload,
},
wantErr: false,
},
{
name: "valid_gotify",
cfg: map[string]any{
"name": "test",
"provider": "gotify",
"url": "https://example.com",
"token": "token",
},
expected: &GotifyClient{
ProviderBase: ProviderBase{
Name: "test",
URL: "https://example.com",
Token: "token",
},
},
wantErr: false,
},
{
name: "invalid_provider",
cfg: map[string]any{
"name": "test",
"provider": "invalid",
"url": "https://example.com",
},
wantErr: true,
},
{
name: "missing_url",
cfg: map[string]any{
"name": "test",
"provider": "webhook",
},
wantErr: true,
},
{
name: "missing_provider",
cfg: map[string]any{
"name": "test",
"provider": "webhook",
},
wantErr: true,
},
{
name: "gotify_missing_token",
cfg: map[string]any{
"name": "test",
"provider": "gotify",
"url": "https://example.com",
},
wantErr: true,
},
{
name: "webhook_missing_payload",
cfg: map[string]any{
"name": "test",
"provider": "webhook",
"url": "https://example.com",
},
wantErr: true,
},
{
name: "webhook_missing_url",
cfg: map[string]any{
"name": "test",
"provider": "webhook",
},
wantErr: true,
},
{
name: "webhook_invalid_template",
cfg: map[string]any{
"name": "test",
"provider": "webhook",
"url": "https://example.com",
"template": "invalid",
},
wantErr: true,
},
{
name: "webhook_invalid_json_payload",
cfg: map[string]any{
"name": "test",
"provider": "webhook",
"url": "https://example.com",
"mime_type": "application/json",
"payload": "invalid",
},
wantErr: true,
},
{
name: "webhook_empty_text_payload",
cfg: map[string]any{
"name": "test",
"provider": "webhook",
"url": "https://example.com",
"mime_type": "text/plain",
},
wantErr: true,
},
{
name: "webhook_invalid_method",
cfg: map[string]any{
"name": "test",
"provider": "webhook",
"url": "https://example.com",
"method": "invalid",
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var cfg NotificationConfig
provider := tt.cfg["provider"]
err := utils.Deserialize(tt.cfg, &cfg)
if tt.wantErr {
ExpectHasError(t, err)
} else {
ExpectNoError(t, err)
ExpectEqual(t, provider.(string), cfg.ProviderName)
ExpectDeepEqual(t, cfg.Provider, tt.expected)
}
})
}
}

View file

@ -2,13 +2,10 @@ package notif
import (
"github.com/rs/zerolog"
"github.com/yusing/go-proxy/internal/config/types"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/internal/utils"
F "github.com/yusing/go-proxy/internal/utils/functional"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
type (
@ -27,12 +24,6 @@ type (
var dispatcher *Dispatcher
var (
ErrMissingNotifProvider = E.New("missing notification provider")
ErrInvalidNotifProviderType = E.New("invalid notification provider type")
ErrUnknownNotifProvider = E.New("unknown notification provider")
)
const dispatchErr = "notification dispatch error"
func StartNotifDispatcher(parent task.Parent) *Dispatcher {
@ -57,29 +48,8 @@ func Notify(msg *LogMessage) {
}
}
func (disp *Dispatcher) RegisterProvider(cfg types.NotificationConfig) (Provider, E.Error) {
providerName, ok := cfg["provider"]
if !ok {
return nil, ErrMissingNotifProvider
}
switch providerName := providerName.(type) {
case string:
delete(cfg, "provider")
createFunc, ok := Providers[providerName]
if !ok {
return nil, ErrUnknownNotifProvider.
Subject(providerName).
Withf(strutils.DoYouMean(utils.NearestField(providerName, Providers)))
}
provider, err := createFunc(cfg)
if err == nil {
disp.providers.Add(provider)
}
return provider, err
default:
return nil, ErrInvalidNotifProviderType.Subjectf("%T", providerName)
}
func (disp *Dispatcher) RegisterProvider(cfg *NotificationConfig) {
disp.providers.Add(cfg.Provider)
}
func (disp *Dispatcher) start() {
@ -110,7 +80,7 @@ func (disp *Dispatcher) dispatch(msg *LogMessage) {
errs := E.NewBuilder(dispatchErr)
disp.providers.RangeAllParallel(func(p Provider) {
if err := notifyProvider(task.Context(), p, msg); err != nil {
errs.Add(E.PrependSubject(p.Name(), err))
errs.Add(E.PrependSubject(p.GetName(), err))
}
})
if errs.HasError() {

View file

@ -13,37 +13,24 @@ import (
type (
GotifyClient struct {
N string `json:"name" validate:"required"`
U string `json:"url" validate:"url"`
Tok string `json:"token" validate:"required"`
ProviderBase
}
GotifyMessage model.MessageExternal
)
const gotifyMsgEndpoint = "/message"
// Name implements Provider.
func (client *GotifyClient) Name() string {
return client.N
func (client *GotifyClient) GetURL() string {
return client.URL + gotifyMsgEndpoint
}
// Method implements Provider.
func (client *GotifyClient) Method() string {
// GetMethod implements Provider.
func (client *GotifyClient) GetMethod() string {
return http.MethodPost
}
// URL implements Provider.
func (client *GotifyClient) URL() string {
return client.U + gotifyMsgEndpoint
}
// Token implements Provider.
func (client *GotifyClient) Token() string {
return client.Tok
}
// MIMEType implements Provider.
func (client *GotifyClient) MIMEType() string {
// GetMIMEType implements Provider.
func (client *GotifyClient) GetMIMEType() string {
return "application/json"
}

View file

@ -1,52 +0,0 @@
package notif
import (
"testing"
"github.com/yusing/go-proxy/internal/utils"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestGotifyValidation(t *testing.T) {
t.Parallel()
newGotify := Providers[ProviderGotify]
t.Run("valid", func(t *testing.T) {
t.Parallel()
_, err := newGotify(map[string]any{
"name": "test",
"url": "https://example.com",
"token": "token",
})
ExpectNoError(t, err)
})
t.Run("missing url", func(t *testing.T) {
t.Parallel()
_, err := newGotify(map[string]any{
"name": "test",
"token": "token",
})
ExpectError(t, utils.ErrValidationError, err)
})
t.Run("missing token", func(t *testing.T) {
t.Parallel()
_, err := newGotify(map[string]any{
"name": "test",
"url": "https://example.com",
})
ExpectError(t, utils.ErrValidationError, err)
})
t.Run("invalid url", func(t *testing.T) {
t.Parallel()
_, err := newGotify(map[string]any{
"name": "test",
"url": "example.com",
"token": "token",
})
ExpectError(t, utils.ErrValidationError, err)
})
}

View file

@ -2,22 +2,24 @@ package notif
import (
"context"
"fmt"
"io"
"net/http"
E "github.com/yusing/go-proxy/internal/error"
gphttp "github.com/yusing/go-proxy/internal/net/http"
U "github.com/yusing/go-proxy/internal/utils"
"github.com/yusing/go-proxy/internal/utils"
)
type (
Provider interface {
Name() string
URL() string
Method() string
Token() string
MIMEType() string
utils.CustomValidator
GetName() string
GetURL() string
GetToken() string
GetMethod() string
GetMIMEType() string
MakeBody(logMsg *LogMessage) (io.Reader, error)
makeRespError(resp *http.Response) error
@ -31,47 +33,29 @@ const (
ProviderWebhook = "webhook"
)
var Providers = map[string]ProviderCreateFunc{
ProviderGotify: newNotifProvider[*GotifyClient],
ProviderWebhook: newNotifProvider[*Webhook],
}
func newNotifProvider[T Provider](cfg map[string]any) (Provider, E.Error) {
var client T
err := U.Deserialize(cfg, &client)
if err != nil {
return nil, err.Subject(client.Name())
}
return client, nil
}
func formatError(p Provider, err error) error {
return fmt.Errorf("%s error: %w", p.Name(), err)
}
func notifyProvider(ctx context.Context, provider Provider, msg *LogMessage) error {
body, err := provider.MakeBody(msg)
if err != nil {
return formatError(provider, err)
return E.PrependSubject(provider.GetName(), err)
}
req, err := http.NewRequestWithContext(
ctx,
http.MethodPost,
provider.URL(),
provider.GetURL(),
body,
)
if err != nil {
return formatError(provider, err)
return E.PrependSubject(provider.GetName(), err)
}
req.Header.Set("Content-Type", provider.MIMEType())
if provider.Token() != "" {
req.Header.Set("Authorization", "Bearer "+provider.Token())
req.Header.Set("Content-Type", provider.GetMIMEType())
if provider.GetToken() != "" {
req.Header.Set("Authorization", "Bearer "+provider.GetToken())
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return formatError(provider, err)
return E.PrependSubject(provider.GetName(), err)
}
defer resp.Body.Close()

View file

@ -8,19 +8,16 @@ import (
"net/http"
"strings"
"github.com/go-playground/validator/v10"
"github.com/yusing/go-proxy/internal/utils"
E "github.com/yusing/go-proxy/internal/error"
)
type Webhook struct {
N string `json:"name" validate:"required"`
U string `json:"url" validate:"url"`
Template string `json:"template" validate:"omitempty,oneof=discord"`
Payload string `json:"payload" validate:"jsonIfTemplateNotUsed"`
Tok string `json:"token"`
Meth string `json:"method" validate:"oneof=GET POST PUT"`
MIMETyp string `json:"mime_type"`
ColorM string `json:"color_mode" validate:"oneof=hex dec"`
ProviderBase
Template string `json:"template"`
Payload string `json:"payload"`
Method string `json:"method"`
MIMEType string `json:"mime_type"`
ColorMode string `json:"color_mode"`
}
//go:embed templates/discord.json
@ -30,60 +27,65 @@ var webhookTemplates = map[string]string{
"discord": discordPayload,
}
func DefaultValue() *Webhook {
return &Webhook{
Meth: "POST",
ColorM: "hex",
MIMETyp: "application/json",
func (webhook *Webhook) Validate() E.Error {
if err := webhook.ProviderBase.Validate(); err != nil && !err.Is(ErrMissingToken) {
return err
}
}
func jsonIfTemplateNotUsed(fl validator.FieldLevel) bool {
template := fl.Parent().FieldByName("Template").String()
if template != "" {
return true
}
payload := fl.Field().String()
return json.Valid([]byte(payload))
}
func init() {
utils.RegisterDefaultValueFactory(DefaultValue)
utils.MustRegisterValidation("jsonIfTemplateNotUsed", jsonIfTemplateNotUsed)
}
// Name implements Provider.
func (webhook *Webhook) Name() string {
return webhook.N
}
// Method implements Provider.
func (webhook *Webhook) Method() string {
return webhook.Meth
}
// URL implements Provider.
func (webhook *Webhook) URL() string {
return webhook.U
}
// Token implements Provider.
func (webhook *Webhook) Token() string {
return webhook.Tok
}
// MIMEType implements Provider.
func (webhook *Webhook) MIMEType() string {
return webhook.MIMETyp
}
func (webhook *Webhook) ColorMode() string {
switch webhook.Template {
case "discord":
return "dec"
switch webhook.MIMEType {
case "":
webhook.MIMEType = "application/json"
case "application/json", "application/x-www-form-urlencoded", "text/plain":
default:
return webhook.ColorM
return E.New("invalid mime_type, expect empty, 'application/json', 'application/x-www-form-urlencoded' or 'text/plain'")
}
switch webhook.Template {
case "":
if webhook.MIMEType == "application/json" && !json.Valid([]byte(webhook.Payload)) {
return E.New("invalid payload, expect valid JSON")
}
if webhook.Payload == "" {
return E.New("invalid payload, expect non-empty")
}
case "discord":
webhook.ColorMode = "dec"
webhook.Method = http.MethodPost
webhook.MIMEType = "application/json"
if webhook.Payload == "" {
webhook.Payload = discordPayload
}
default:
return E.New("invalid template, expect empty or 'discord'")
}
switch webhook.Method {
case "":
webhook.Method = http.MethodPost
case http.MethodGet, http.MethodPost, http.MethodPut:
default:
return E.New("invalid method, expect empty, 'GET', 'POST' or 'PUT'")
}
switch webhook.ColorMode {
case "":
webhook.ColorMode = "hex"
case "hex", "dec":
default:
return E.New("invalid color_mode, expect empty, 'hex' or 'dec'")
}
return nil
}
// GetMethod implements Provider.
func (webhook *Webhook) GetMethod() string {
return webhook.Method
}
// GetMIMEType implements Provider.
func (webhook *Webhook) GetMIMEType() string {
return webhook.MIMEType
}
// makeRespError implements Provider.
@ -108,7 +110,7 @@ func (webhook *Webhook) MakeBody(logMsg *LogMessage) (io.Reader, error) {
return nil, err
}
var color string
if webhook.ColorMode() == "hex" {
if webhook.ColorMode == "hex" {
color = logMsg.Color.HexString()
} else {
color = logMsg.Color.DecString()

View file

@ -1,121 +0,0 @@
package notif
import (
"encoding/json"
"testing"
"github.com/yusing/go-proxy/internal/utils"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestWebhookValidation(t *testing.T) {
t.Parallel()
newWebhook := Providers[ProviderWebhook]
t.Run("valid", func(t *testing.T) {
t.Parallel()
_, err := newWebhook(map[string]any{
"name": "test",
"url": "https://example.com",
"payload": "{}",
})
ExpectNoError(t, err)
})
t.Run("valid template", func(t *testing.T) {
t.Parallel()
_, err := newWebhook(map[string]any{
"name": "test",
"url": "https://example.com",
"template": "discord",
})
ExpectNoError(t, err)
})
t.Run("missing url", func(t *testing.T) {
t.Parallel()
_, err := newWebhook(map[string]any{
"name": "test",
"payload": "{}",
})
ExpectError(t, utils.ErrValidationError, err)
})
t.Run("missing payload", func(t *testing.T) {
t.Parallel()
_, err := newWebhook(map[string]any{
"name": "test",
"url": "https://example.com",
})
ExpectError(t, utils.ErrValidationError, err)
})
t.Run("invalid url", func(t *testing.T) {
t.Parallel()
_, err := newWebhook(map[string]any{
"name": "test",
"url": "example.com",
"payload": "{}",
})
ExpectError(t, utils.ErrValidationError, err)
})
t.Run("invalid payload", func(t *testing.T) {
t.Parallel()
_, err := newWebhook(map[string]any{
"name": "test",
"url": "https://example.com",
"payload": "abcd",
})
ExpectError(t, utils.ErrValidationError, err)
})
t.Run("invalid method", func(t *testing.T) {
t.Parallel()
_, err := newWebhook(map[string]any{
"name": "test",
"url": "https://example.com",
"payload": "{}",
"method": "abcd",
})
ExpectError(t, utils.ErrValidationError, err)
})
t.Run("invalid template", func(t *testing.T) {
t.Parallel()
_, err := newWebhook(map[string]any{
"name": "test",
"url": "https://example.com",
"template": "abcd",
})
ExpectError(t, utils.ErrValidationError, err)
})
}
func TestWebhookBody(t *testing.T) {
t.Parallel()
var webhook Webhook
webhook.Payload = discordPayload
bodyReader, err := webhook.MakeBody(&LogMessage{
Title: "abc",
Extras: map[string]any{
"foo": "bar",
},
})
ExpectNoError(t, err)
var body struct {
Embeds []struct {
Title string `json:"title"`
Fields []struct {
Name string `json:"name"`
Value string `json:"value"`
} `json:"fields"`
} `json:"embeds"`
}
err = json.NewDecoder(bodyReader).Decode(&body)
ExpectNoError(t, err)
ExpectEqual(t, body.Embeds[0].Title, "abc")
fields := body.Embeds[0].Fields
ExpectEqual(t, fields[0].Name, "foo")
ExpectEqual(t, fields[0].Value, "bar")
}

View file

@ -57,7 +57,7 @@ func HomepageCategories() []string {
return categories
}
func HomepageConfig(useDefaultCategories bool, categoryFilter, providerFilter string) homepage.Config {
func HomepageConfig(useDefaultCategories bool, categoryFilter, providerFilter string) homepage.Categories {
hpCfg := homepage.NewHomePageConfig()
routes.GetHTTPRoutes().RangeAll(func(alias string, r route.HTTPRoute) {

View file

@ -1,13 +1,10 @@
package utils
// FIXME: some times [%d] is not in correct order
import (
"encoding/json"
"errors"
"os"
"reflect"
"runtime/debug"
"strconv"
"strings"
"time"
@ -21,6 +18,10 @@ import (
type SerializedObject = map[string]any
type MapUnmarshaller interface {
UnmarshalMap(m map[string]any) E.Error
}
var (
ErrInvalidType = E.New("invalid type")
ErrNilValue = E.New("nil")
@ -29,6 +30,8 @@ var (
ErrUnknownField = E.New("unknown field")
)
var mapUnmarshalerType = reflect.TypeFor[MapUnmarshaller]()
var defaultValues = functional.NewMapOf[reflect.Type, func() any]()
func RegisterDefaultValueFactory[T any](factory func() *T) {
@ -56,8 +59,9 @@ func extractFields(t reflect.Type) (all, anonymous []reflect.StructField) {
if t.Kind() != reflect.Struct {
return nil, nil
}
var fields []reflect.StructField
for i := range t.NumField() {
n := t.NumField()
fields := make([]reflect.StructField, 0, n)
for i := range n {
field := t.Field(i)
if !field.IsExported() {
continue
@ -74,31 +78,74 @@ func extractFields(t reflect.Type) (all, anonymous []reflect.StructField) {
return fields, anonymous
}
func ValidateWithFieldTags(s any) E.Error {
errs := E.NewBuilder("validate error")
err := validate.Struct(s)
var valErrs validator.ValidationErrors
if errors.As(err, &valErrs) {
for _, e := range valErrs {
detail := e.ActualTag()
if e.Param() != "" {
detail += ":" + e.Param()
}
errs.Add(ErrValidationError.
Subject(e.Namespace()).
Withf("require %q", detail))
}
}
return errs.Error()
}
// Deserialize takes a SerializedObject and a target value, and assigns the values in the SerializedObject to the target value.
// Deserialize ignores case differences between the field names in the SerializedObject and the target.
//
// The target value must be a struct or a map[string]any.
// If the target value is a struct, the SerializedObject will be deserialized into the struct fields and validate if needed.
// If the target value is a map[string]any, the SerializedObject will be deserialized into the map.
// If the target value is a struct , and implements the MapUnmarshaller interface,
// the UnmarshalMap method will be called.
//
// If the target value is a struct, but does not implements the MapUnmarshaller interface,
// the SerializedObject will be deserialized into the struct fields and validate if needed.
//
// If the target value is a map[string]any the SerializedObject will be deserialized into the map.
//
// The function returns an error if the target value is not a struct or a map[string]any, or if there is an error during deserialization.
func Deserialize(src SerializedObject, dst any) E.Error {
if src == nil {
return E.Errorf("deserialize: src is %w", ErrNilValue)
}
if dst == nil {
return E.Errorf("deserialize: dst is %w\n%s", ErrNilValue, debug.Stack())
}
dstV := reflect.ValueOf(dst)
dstT := dstV.Type()
if src == nil {
if dstV.CanSet() {
dstV.Set(reflect.Zero(dstT))
return nil
}
return E.Errorf("deserialize: src is %w and dst is not settable", ErrNilValue)
}
if dstT.Implements(mapUnmarshalerType) {
for dstV.IsNil() {
switch dstT.Kind() {
case reflect.Struct:
dstV.Set(New(dstT))
case reflect.Map:
dstV.Set(reflect.MakeMap(dstT))
case reflect.Slice:
dstV.Set(reflect.MakeSlice(dstT, 0, 0))
case reflect.Ptr:
dstV.Set(reflect.New(dstT.Elem()))
default:
return E.Errorf("deserialize: %w for dst %s", ErrInvalidType, dstT.String())
}
dstV = dstV.Elem()
}
return dstV.Interface().(MapUnmarshaller).UnmarshalMap(src)
}
for dstT.Kind() == reflect.Ptr {
if dstV.IsNil() {
if dstV.CanSet() {
dstV.Set(New(dstT.Elem()))
} else {
return E.Errorf("deserialize: dst is %w\n%s", ErrNilValue, debug.Stack())
return E.Errorf("deserialize: dst is %w and not settable", ErrNilValue)
}
}
dstV = dstV.Elem()
@ -113,9 +160,8 @@ func Deserialize(src SerializedObject, dst any) E.Error {
switch dstV.Kind() {
case reflect.Struct:
needValidate := false
hasValidateTag := false
mapping := make(map[string]reflect.Value)
fieldName := make(map[string]string)
fields, anonymous := extractFields(dstT)
for _, anon := range anonymous {
if field := dstV.FieldByName(anon.Name); field.Kind() == reflect.Ptr && field.IsNil() {
@ -134,17 +180,15 @@ func Deserialize(src SerializedObject, dst any) E.Error {
}
key = strutils.ToLowerNoSnake(key)
mapping[key] = dstV.FieldByName(field.Name)
fieldName[field.Name] = key
if !needValidate {
_, needValidate = field.Tag.Lookup("validate")
if !hasValidateTag {
_, hasValidateTag = field.Tag.Lookup("validate")
}
aliases, ok := field.Tag.Lookup("aliases")
if ok {
for _, alias := range strutils.CommaSeperatedList(aliases) {
mapping[alias] = dstV.FieldByName(field.Name)
fieldName[field.Name] = alias
}
}
}
@ -158,20 +202,10 @@ func Deserialize(src SerializedObject, dst any) E.Error {
errs.Add(ErrUnknownField.Subject(k).Withf(strutils.DoYouMean(NearestField(k, mapping))))
}
}
if needValidate {
err := validate.Struct(dstV.Interface())
var valErrs validator.ValidationErrors
if errors.As(err, &valErrs) {
for _, e := range valErrs {
detail := e.ActualTag()
if e.Param() != "" {
detail += ":" + e.Param()
}
errs.Add(ErrValidationError.
Subject(e.StructNamespace()).
Withf("require %q", detail))
}
}
if hasValidateTag {
errs.Add(ValidateWithFieldTags(dstV.Interface()))
} else if validator, ok := dstV.Addr().Interface().(CustomValidator); ok {
errs.Add(validator.Validate())
}
return errs.Error()
case reflect.Map:
@ -188,6 +222,9 @@ func Deserialize(src SerializedObject, dst any) E.Error {
errs.Add(err.Subject(k))
}
}
if validator, ok := dstV.Addr().Interface().(CustomValidator); ok {
errs.Add(validator.Validate())
}
return errs.Error()
default:
return ErrUnsupportedConversion.Subject("mapping to " + dstT.String())
@ -421,14 +458,6 @@ func DeserializeYAMLMap[V any](data []byte) (_ functional.Map[string, V], err E.
return functional.NewMapFrom(m2), nil
}
func DeserializeJSON[T any](data []byte, target T) E.Error {
m := make(map[string]any)
if err := json.Unmarshal(data, &m); err != nil {
return E.From(err)
}
return Deserialize(m, target)
}
func loadSerialized[T any](path string, dst *T, deserialize func(data []byte, dst any) error) error {
data, err := os.ReadFile(path)
if err != nil {

View file

@ -35,6 +35,14 @@ func ExpectNoError(t *testing.T, err error) {
}
}
func ExpectHasError(t *testing.T, err error) {
t.Helper()
if errors.Is(err, nil) {
t.Error("expected err not nil")
t.FailNow()
}
}
func ExpectError(t *testing.T, expected error, err error) {
t.Helper()
if !errors.Is(err, expected) {

View file

@ -9,6 +9,10 @@ var validate = validator.New()
var ErrValidationError = E.New("validation error")
type CustomValidator interface {
Validate() E.Error
}
func Validator() *validator.Validate {
return validate
}