refactor: notifications

This commit is contained in:
yusing 2025-05-02 05:51:15 +08:00
parent 28d9a72908
commit 69ee8495d8
11 changed files with 202 additions and 104 deletions

View file

@ -10,9 +10,10 @@ import (
) )
type ProviderBase struct { type ProviderBase struct {
Name string `json:"name" validate:"required"` Name string `json:"name" validate:"required"`
URL string `json:"url" validate:"url"` URL string `json:"url" validate:"url"`
Token string `json:"token"` Token string `json:"token"`
Format *LogFormat `json:"format"`
} }
var ( var (
@ -22,8 +23,8 @@ var (
// Validate implements the utils.CustomValidator interface. // Validate implements the utils.CustomValidator interface.
func (base *ProviderBase) Validate() gperr.Error { func (base *ProviderBase) Validate() gperr.Error {
if base.Token == "" { if base.Format == nil {
return ErrMissingToken base.Format = LogFormatMarkdown
} }
if !strings.HasPrefix(base.URL, "http://") && !strings.HasPrefix(base.URL, "https://") { if !strings.HasPrefix(base.URL, "http://") && !strings.HasPrefix(base.URL, "https://") {
return ErrURLMissingScheme return ErrURLMissingScheme

111
internal/notif/body.go Normal file
View file

@ -0,0 +1,111 @@
package notif
import (
"bytes"
"encoding/json"
"fmt"
"strings"
"github.com/yusing/go-proxy/internal/gperr"
)
type (
LogField struct {
Name string `json:"name"`
Value string `json:"value"`
}
LogFormat struct {
string
}
LogBody interface {
Format(format *LogFormat) ([]byte, error)
}
)
type (
FieldsBody []LogField
ListBody []string
MessageBody string
)
var (
LogFormatMarkdown = &LogFormat{"markdown"}
LogFormatPlain = &LogFormat{"plain"}
LogFormatRawJSON = &LogFormat{"json"} // internal use only
)
func MakeLogFields(fields ...LogField) LogBody {
return FieldsBody(fields)
}
func (f *LogFormat) Parse(format string) error {
switch format {
case "":
f.string = LogFormatMarkdown.string
case LogFormatPlain.string, LogFormatMarkdown.string:
f.string = format
default:
return gperr.Multiline().
Addf("invalid log format %s, supported formats:", format).
AddLines(
LogFormatPlain,
LogFormatMarkdown,
)
}
return nil
}
func (f FieldsBody) Format(format *LogFormat) ([]byte, error) {
switch format {
case LogFormatMarkdown:
var msg bytes.Buffer
for _, field := range f {
msg.WriteString("#### ")
msg.WriteString(field.Name)
msg.WriteRune('\n')
msg.WriteString(field.Value)
msg.WriteRune('\n')
}
return msg.Bytes(), nil
case LogFormatPlain:
var msg bytes.Buffer
for _, field := range f {
msg.WriteString(field.Name)
msg.WriteString(": ")
msg.WriteString(field.Value)
msg.WriteRune('\n')
}
return msg.Bytes(), nil
case LogFormatRawJSON:
return json.Marshal(f)
}
return nil, fmt.Errorf("unknown format: %v", format)
}
func (l ListBody) Format(format *LogFormat) ([]byte, error) {
switch format {
case LogFormatPlain:
return []byte(strings.Join(l, "\n")), nil
case LogFormatMarkdown:
var msg bytes.Buffer
for _, item := range l {
msg.WriteString("* ")
msg.WriteString(item)
msg.WriteRune('\n')
}
return msg.Bytes(), nil
case LogFormatRawJSON:
return json.Marshal(l)
}
return nil, fmt.Errorf("unknown format: %v", format)
}
func (m MessageBody) Format(format *LogFormat) ([]byte, error) {
switch format {
case LogFormatPlain, LogFormatMarkdown:
return []byte(m), nil
case LogFormatRawJSON:
return json.Marshal(m)
}
return nil, fmt.Errorf("unknown format: %v", format)
}

View file

@ -46,11 +46,5 @@ func (cfg *NotificationConfig) UnmarshalMap(m map[string]any) (err gperr.Error)
Withf("expect %s or %s", ProviderWebhook, ProviderGotify) Withf("expect %s or %s", ProviderWebhook, ProviderGotify)
} }
// unmarshal provider config return utils.MapUnmarshalValidate(m, cfg.Provider)
if err := utils.MapUnmarshalValidate(m, cfg.Provider); err != nil {
return err
}
// validate provider
return cfg.Provider.Validate()
} }

View file

@ -25,8 +25,9 @@ func TestNotificationConfig(t *testing.T) {
}, },
expected: &Webhook{ expected: &Webhook{
ProviderBase: ProviderBase{ ProviderBase: ProviderBase{
Name: "test", Name: "test",
URL: "https://example.com", URL: "https://example.com",
Format: LogFormatMarkdown,
}, },
Template: "discord", Template: "discord",
Method: http.MethodPost, Method: http.MethodPost,
@ -43,12 +44,32 @@ func TestNotificationConfig(t *testing.T) {
"provider": "gotify", "provider": "gotify",
"url": "https://example.com", "url": "https://example.com",
"token": "token", "token": "token",
"format": "plain",
}, },
expected: &GotifyClient{ expected: &GotifyClient{
ProviderBase: ProviderBase{ ProviderBase: ProviderBase{
Name: "test", Name: "test",
URL: "https://example.com", URL: "https://example.com",
Token: "token", Token: "token",
Format: LogFormatPlain,
},
},
wantErr: false,
},
{
name: "default_format",
cfg: map[string]any{
"name": "test",
"provider": "gotify",
"token": "token",
"url": "https://example.com",
},
expected: &GotifyClient{
ProviderBase: ProviderBase{
Name: "test",
URL: "https://example.com",
Token: "token",
Format: LogFormatMarkdown,
}, },
}, },
wantErr: false, wantErr: false,
@ -62,6 +83,16 @@ func TestNotificationConfig(t *testing.T) {
}, },
wantErr: true, wantErr: true,
}, },
{
name: "invalid_format",
cfg: map[string]any{
"name": "test",
"provider": "webhook",
"url": "https://example.com",
"format": "invalid",
},
wantErr: true,
},
{ {
name: "missing_url", name: "missing_url",
cfg: map[string]any{ cfg: map[string]any{

View file

@ -14,16 +14,11 @@ type (
logCh chan *LogMessage logCh chan *LogMessage
providers F.Set[Provider] providers F.Set[Provider]
} }
LogField struct {
Name string `json:"name"`
Value string `json:"value"`
}
LogFields []LogField
LogMessage struct { LogMessage struct {
Level zerolog.Level Level zerolog.Level
Title string Title string
Extras LogFields Body LogBody
Color Color Color Color
} }
) )
@ -53,7 +48,7 @@ func Notify(msg *LogMessage) {
} }
} }
func (f *LogFields) Add(name, value string) { func (f *FieldsBody) Add(name, value string) {
*f = append(*f, LogField{Name: name, Value: value}) *f = append(*f, LogField{Name: name, Value: value})
} }

View file

@ -1,26 +0,0 @@
package notif
import (
"bytes"
"encoding/json"
)
func formatMarkdown(extras LogFields) string {
msg := bytes.NewBufferString("")
for _, field := range extras {
msg.WriteString("#### ")
msg.WriteString(field.Name)
msg.WriteRune('\n')
msg.WriteString(field.Value)
msg.WriteRune('\n')
}
return msg.String()
}
func formatDiscord(extras LogFields) (string, error) {
fields, err := json.Marshal(extras)
if err != nil {
return "", err
}
return string(fields), nil
}

View file

@ -1,10 +1,8 @@
package notif package notif
import ( import (
"bytes"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io"
"net/http" "net/http"
"github.com/gotify/server/v2/model" "github.com/gotify/server/v2/model"
@ -24,8 +22,8 @@ func (client *GotifyClient) GetURL() string {
return client.URL + gotifyMsgEndpoint return client.URL + gotifyMsgEndpoint
} }
// MakeBody implements Provider. // MarshalMessage implements Provider.
func (client *GotifyClient) MakeBody(logMsg *LogMessage) (io.Reader, error) { func (client *GotifyClient) MarshalMessage(logMsg *LogMessage) ([]byte, error) {
var priority int var priority int
switch logMsg.Level { switch logMsg.Level {
@ -37,15 +35,23 @@ func (client *GotifyClient) MakeBody(logMsg *LogMessage) (io.Reader, error) {
priority = 8 priority = 8
} }
body, err := logMsg.Body.Format(client.Format)
if err != nil {
return nil, err
}
msg := &GotifyMessage{ msg := &GotifyMessage{
Title: logMsg.Title, Title: logMsg.Title,
Message: formatMarkdown(logMsg.Extras), Message: string(body),
Priority: &priority, Priority: &priority,
Extras: map[string]interface{}{ }
if client.Format == LogFormatMarkdown {
msg.Extras = map[string]interface{}{
"client::display": map[string]string{ "client::display": map[string]string{
"contentType": "text/markdown", "contentType": "text/markdown",
}, },
}, }
} }
data, err := json.Marshal(msg) data, err := json.Marshal(msg)
@ -53,7 +59,7 @@ func (client *GotifyClient) MakeBody(logMsg *LogMessage) (io.Reader, error) {
return nil, err return nil, err
} }
return bytes.NewReader(data), nil return data, nil
} }
// makeRespError implements Provider. // makeRespError implements Provider.

View file

@ -1,10 +1,7 @@
package notif package notif
import ( import (
"bytes"
"io"
"net/http" "net/http"
"strings"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"github.com/yusing/go-proxy/internal/gperr" "github.com/yusing/go-proxy/internal/gperr"
@ -13,18 +10,14 @@ import (
// See https://docs.ntfy.sh/publish // See https://docs.ntfy.sh/publish
type Ntfy struct { type Ntfy struct {
ProviderBase ProviderBase
Topic string `json:"topic"` Topic string `json:"topic"`
Style NtfyStyle `json:"style"`
} }
type NtfyStyle string // Validate implements the utils.CustomValidator interface.
const (
NtfyStyleMarkdown NtfyStyle = "markdown"
NtfyStylePlain NtfyStyle = "plain"
)
func (n *Ntfy) Validate() gperr.Error { func (n *Ntfy) Validate() gperr.Error {
if err := n.ProviderBase.Validate(); err != nil {
return err
}
if n.URL == "" { if n.URL == "" {
return gperr.New("url is required") return gperr.New("url is required")
} }
@ -34,16 +27,10 @@ func (n *Ntfy) Validate() gperr.Error {
if n.Topic[0] == '/' { if n.Topic[0] == '/' {
return gperr.New("topic should not start with a slash") return gperr.New("topic should not start with a slash")
} }
switch n.Style {
case "":
n.Style = NtfyStyleMarkdown
case NtfyStyleMarkdown, NtfyStylePlain:
default:
return gperr.Errorf("invalid style, expecting %q or %q, got %q", NtfyStyleMarkdown, NtfyStylePlain, n.Style)
}
return nil return nil
} }
// GetURL implements Provider.
func (n *Ntfy) GetURL() string { func (n *Ntfy) GetURL() string {
if n.URL[len(n.URL)-1] == '/' { if n.URL[len(n.URL)-1] == '/' {
return n.URL + n.Topic return n.URL + n.Topic
@ -51,23 +38,22 @@ func (n *Ntfy) GetURL() string {
return n.URL + "/" + n.Topic return n.URL + "/" + n.Topic
} }
// GetMIMEType implements Provider.
func (n *Ntfy) GetMIMEType() string { func (n *Ntfy) GetMIMEType() string {
return "" return ""
} }
// GetToken implements Provider.
func (n *Ntfy) GetToken() string { func (n *Ntfy) GetToken() string {
return n.Token return n.Token
} }
func (n *Ntfy) MakeBody(logMsg *LogMessage) (io.Reader, error) { // MarshalMessage implements Provider.
switch n.Style { func (n *Ntfy) MarshalMessage(logMsg *LogMessage) ([]byte, error) {
case NtfyStyleMarkdown: return logMsg.Body.Format(n.Format)
return strings.NewReader(formatMarkdown(logMsg.Extras)), nil
default:
return &bytes.Buffer{}, nil
}
} }
// SetHeaders implements Provider.
func (n *Ntfy) SetHeaders(logMsg *LogMessage, headers http.Header) { func (n *Ntfy) SetHeaders(logMsg *LogMessage, headers http.Header) {
headers.Set("Title", logMsg.Title) headers.Set("Title", logMsg.Title)
@ -83,7 +69,7 @@ func (n *Ntfy) SetHeaders(logMsg *LogMessage, headers http.Header) {
headers.Set("Priority", "min") headers.Set("Priority", "min")
} }
if n.Style == NtfyStyleMarkdown { if n.Format == LogFormatMarkdown {
headers.Set("Markdown", "yes") headers.Set("Markdown", "yes")
} }
} }

View file

@ -1,8 +1,8 @@
package notif package notif
import ( import (
"bytes"
"context" "context"
"io"
"net/http" "net/http"
"time" "time"
@ -21,7 +21,7 @@ type (
GetMethod() string GetMethod() string
GetMIMEType() string GetMIMEType() string
MakeBody(logMsg *LogMessage) (io.Reader, error) MarshalMessage(logMsg *LogMessage) ([]byte, error)
SetHeaders(logMsg *LogMessage, headers http.Header) SetHeaders(logMsg *LogMessage, headers http.Header)
makeRespError(resp *http.Response) error makeRespError(resp *http.Response) error
@ -37,7 +37,7 @@ const (
) )
func notifyProvider(ctx context.Context, provider Provider, msg *LogMessage) error { func notifyProvider(ctx context.Context, provider Provider, msg *LogMessage) error {
body, err := provider.MakeBody(msg) body, err := provider.MarshalMessage(msg)
if err != nil { if err != nil {
return gperr.PrependSubject(provider.GetName(), err) return gperr.PrependSubject(provider.GetName(), err)
} }
@ -49,7 +49,7 @@ func notifyProvider(ctx context.Context, provider Provider, msg *LogMessage) err
ctx, ctx,
http.MethodPost, http.MethodPost,
provider.GetURL(), provider.GetURL(),
body, bytes.NewReader(body),
) )
if err != nil { if err != nil {
return gperr.PrependSubject(provider.GetName(), err) return gperr.PrependSubject(provider.GetName(), err)

View file

@ -100,12 +100,12 @@ func (webhook *Webhook) makeRespError(resp *http.Response) error {
return fmt.Errorf("%s status %d", webhook.Name, resp.StatusCode) return fmt.Errorf("%s status %d", webhook.Name, resp.StatusCode)
} }
func (webhook *Webhook) MakeBody(logMsg *LogMessage) (io.Reader, error) { func (webhook *Webhook) MarshalMessage(logMsg *LogMessage) ([]byte, error) {
title, err := json.Marshal(logMsg.Title) title, err := json.Marshal(logMsg.Title)
if err != nil { if err != nil {
return nil, err return nil, err
} }
fields, err := formatDiscord(logMsg.Extras) fields, err := logMsg.Body.Format(LogFormatRawJSON)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -115,14 +115,14 @@ func (webhook *Webhook) MakeBody(logMsg *LogMessage) (io.Reader, error) {
} else { } else {
color = logMsg.Color.DecString() color = logMsg.Color.DecString()
} }
message, err := json.Marshal(formatMarkdown(logMsg.Extras)) message, err := logMsg.Body.Format(LogFormatMarkdown)
if err != nil { if err != nil {
return nil, err return nil, err
} }
plTempl := strings.NewReplacer( plTempl := strings.NewReplacer(
"$title", string(title), "$title", string(title),
"$message", string(message), "$message", string(message),
"$fields", fields, "$fields", string(fields),
"$color", color, "$color", color,
) )
var pl string var pl string
@ -132,5 +132,5 @@ func (webhook *Webhook) MakeBody(logMsg *LogMessage) (io.Reader, error) {
pl = webhook.Payload pl = webhook.Payload
} }
pl = plTempl.Replace(pl) pl = plTempl.Replace(pl)
return strings.NewReader(pl), nil return []byte(pl), nil
} }

View file

@ -222,7 +222,7 @@ func (mon *monitor) checkUpdateHealth() error {
status = health.StatusUnhealthy status = health.StatusUnhealthy
} }
if result.Healthy != (mon.status.Swap(status) == health.StatusHealthy) { if result.Healthy != (mon.status.Swap(status) == health.StatusHealthy) {
extras := notif.LogFields{ extras := notif.FieldsBody{
{Name: "Service Name", Value: mon.service}, {Name: "Service Name", Value: mon.service},
{Name: "Time", Value: strutils.FormatTime(time.Now())}, {Name: "Time", Value: strutils.FormatTime(time.Now())},
} }
@ -239,16 +239,16 @@ func (mon *monitor) checkUpdateHealth() error {
logger.Info().Msg("service is up") logger.Info().Msg("service is up")
extras.Add("Ping", fmt.Sprintf("%d ms", result.Latency.Milliseconds())) extras.Add("Ping", fmt.Sprintf("%d ms", result.Latency.Milliseconds()))
notif.Notify(&notif.LogMessage{ notif.Notify(&notif.LogMessage{
Title: "✅ Service is up ✅", Title: "✅ Service is up ✅",
Extras: extras, Body: extras,
Color: notif.ColorSuccess, Color: notif.ColorSuccess,
}) })
} else { } else {
logger.Warn().Msg("service went down") logger.Warn().Msg("service went down")
notif.Notify(&notif.LogMessage{ notif.Notify(&notif.LogMessage{
Title: "❌ Service went down ❌", Title: "❌ Service went down ❌",
Extras: extras, Body: extras,
Color: notif.ColorError, Color: notif.ColorError,
}) })
} }
} }