refactor: use stretchr/testify, replace ExpectBytesEqual and ExpectDeepEqual with ExpectEqual in tests

This commit is contained in:
yusing 2025-03-28 08:45:06 +08:00
parent 2f476603d3
commit 232f720e77
18 changed files with 65 additions and 132 deletions

View file

@ -28,8 +28,8 @@ func TestPEMPair(t *testing.T) {
var pp PEMPair
err := pp.Load(p.String())
ExpectNoError(t, err)
ExpectBytesEqual(t, p.Cert, pp.Cert)
ExpectBytesEqual(t, p.Key, pp.Key)
ExpectEqual(t, p.Cert, pp.Cert)
ExpectEqual(t, p.Key, pp.Key)
})
}
}

View file

@ -13,7 +13,7 @@ func TestZipCert(t *testing.T) {
ca2, crt2, key2, err := ExtractCert(zipData)
ExpectNoError(t, err)
ExpectBytesEqual(t, ca, ca2)
ExpectBytesEqual(t, crt, crt2)
ExpectBytesEqual(t, key, key2)
ExpectEqual(t, ca, ca2)
ExpectEqual(t, crt, crt2)
ExpectEqual(t, key, key2)
}

View file

@ -46,5 +46,5 @@ oauth2_config:
opt := make(map[string]any)
ExpectNoError(t, yaml.Unmarshal([]byte(testYaml), opt))
ExpectNoError(t, U.Deserialize(opt, cfg))
ExpectDeepEqual(t, cfg, cfgExpected)
ExpectEqual(t, cfg, cfgExpected)
}

View file

@ -32,5 +32,5 @@ func TestOverrideItem(t *testing.T) {
overrides := GetOverrideConfig()
overrides.OverrideItem(a.Alias, want)
got := a.GetOverride(a.Alias)
ExpectDeepEqual(t, got, want)
ExpectEqual(t, got, want)
}

View file

@ -118,7 +118,7 @@ func TestIconURL(t *testing.T) {
} else {
tc.wantValue.FullValue = tc.input
ExpectNoError(t, err)
ExpectDeepEqual(t, u, tc.wantValue)
ExpectEqual(t, u, tc.wantValue)
}
})
}

View file

@ -90,11 +90,11 @@ func TestSystemInfo(t *testing.T) {
// Compare original and decoded
ExpectEqual(t, decoded.Timestamp, testInfo.Timestamp)
ExpectEqual(t, *decoded.CPUAverage, *testInfo.CPUAverage)
ExpectDeepEqual(t, decoded.Memory, testInfo.Memory)
ExpectDeepEqual(t, decoded.Disks, testInfo.Disks)
ExpectDeepEqual(t, decoded.DisksIO, testInfo.DisksIO)
ExpectDeepEqual(t, decoded.Network, testInfo.Network)
ExpectDeepEqual(t, decoded.Sensors, testInfo.Sensors)
ExpectEqual(t, decoded.Memory, testInfo.Memory)
ExpectEqual(t, decoded.Disks, testInfo.Disks)
ExpectEqual(t, decoded.DisksIO, testInfo.DisksIO)
ExpectEqual(t, decoded.Network, testInfo.Network)
ExpectEqual(t, decoded.Sensors, testInfo.Sensors)
// Test nil fields
nilInfo := &SystemInfo{
@ -108,7 +108,7 @@ func TestSystemInfo(t *testing.T) {
err = json.Unmarshal(data, &decodedNil)
ExpectNoError(t, err)
ExpectDeepEqual(t, decodedNil.Timestamp, nilInfo.Timestamp)
ExpectEqual(t, decodedNil.Timestamp, nilInfo.Timestamp)
ExpectTrue(t, decodedNil.CPUAverage == nil)
ExpectTrue(t, decodedNil.Memory == nil)
ExpectTrue(t, decodedNil.Disks == nil)

View file

@ -36,11 +36,11 @@ func TestNewConfig(t *testing.T) {
ExpectEqual(t, config.BufferSize, 10)
ExpectEqual(t, config.Format, FormatCombined)
ExpectEqual(t, config.Path, "/tmp/access.log")
ExpectDeepEqual(t, config.Filters.StatusCodes.Values, []*StatusCodeRange{{Start: 200, End: 299}})
ExpectEqual(t, config.Filters.StatusCodes.Values, []*StatusCodeRange{{Start: 200, End: 299}})
ExpectEqual(t, len(config.Filters.Method.Values), 2)
ExpectDeepEqual(t, config.Filters.Method.Values, []HTTPMethod{"GET", "POST"})
ExpectEqual(t, config.Filters.Method.Values, []HTTPMethod{"GET", "POST"})
ExpectEqual(t, len(config.Filters.Headers.Values), 2)
ExpectDeepEqual(t, config.Filters.Headers.Values, []*HTTPHeader{{Key: "foo", Value: "bar"}, {Key: "baz", Value: ""}})
ExpectEqual(t, config.Filters.Headers.Values, []*HTTPHeader{{Key: "foo", Value: "bar"}, {Key: "baz", Value: ""}})
ExpectTrue(t, config.Filters.Headers.Negative)
ExpectEqual(t, len(config.Filters.CIDR.Values), 1)
ExpectEqual(t, config.Filters.CIDR.Values[0].String(), "192.168.10.0/24")

View file

@ -15,7 +15,7 @@ func TestAccessLoggerJSONKeepHeaders(t *testing.T) {
entry := getJSONEntry(t, config)
for k, v := range req.Header {
if k != "Cookie" {
ExpectDeepEqual(t, entry.Headers[k], v)
ExpectEqual(t, entry.Headers[k], v)
}
}
@ -24,8 +24,8 @@ func TestAccessLoggerJSONKeepHeaders(t *testing.T) {
"User-Agent": FieldModeDrop,
}
entry = getJSONEntry(t, config)
ExpectDeepEqual(t, entry.Headers["Referer"], []string{RedactedValue})
ExpectDeepEqual(t, entry.Headers["User-Agent"], nil)
ExpectEqual(t, entry.Headers["Referer"], []string{RedactedValue})
ExpectEqual(t, entry.Headers["User-Agent"], nil)
}
func TestAccessLoggerJSONDropHeaders(t *testing.T) {
@ -33,7 +33,7 @@ func TestAccessLoggerJSONDropHeaders(t *testing.T) {
config.Fields.Headers.Default = FieldModeDrop
entry := getJSONEntry(t, config)
for k := range req.Header {
ExpectDeepEqual(t, entry.Headers[k], nil)
ExpectEqual(t, entry.Headers[k], nil)
}
config.Fields.Headers.Config = map[string]FieldMode{
@ -41,8 +41,8 @@ func TestAccessLoggerJSONDropHeaders(t *testing.T) {
"User-Agent": FieldModeRedact,
}
entry = getJSONEntry(t, config)
ExpectDeepEqual(t, entry.Headers["Referer"], []string{req.Header.Get("Referer")})
ExpectDeepEqual(t, entry.Headers["User-Agent"], []string{RedactedValue})
ExpectEqual(t, entry.Headers["Referer"], []string{req.Header.Get("Referer")})
ExpectEqual(t, entry.Headers["User-Agent"], []string{RedactedValue})
}
func TestAccessLoggerJSONRedactHeaders(t *testing.T) {
@ -52,7 +52,7 @@ func TestAccessLoggerJSONRedactHeaders(t *testing.T) {
ExpectEqual(t, len(entry.Headers["Cookie"]), 0)
for k := range req.Header {
if k != "Cookie" {
ExpectDeepEqual(t, entry.Headers[k], []string{RedactedValue})
ExpectEqual(t, entry.Headers[k], []string{RedactedValue})
}
}
}
@ -83,14 +83,14 @@ func TestAccessLoggerJSONDropQuery(t *testing.T) {
config := DefaultConfig()
config.Fields.Query.Default = FieldModeDrop
entry := getJSONEntry(t, config)
ExpectDeepEqual(t, entry.Query["foo"], nil)
ExpectDeepEqual(t, entry.Query["bar"], nil)
ExpectEqual(t, entry.Query["foo"], nil)
ExpectEqual(t, entry.Query["bar"], nil)
}
func TestAccessLoggerJSONRedactQuery(t *testing.T) {
config := DefaultConfig()
config.Fields.Query.Default = FieldModeRedact
entry := getJSONEntry(t, config)
ExpectDeepEqual(t, entry.Query["foo"], []string{RedactedValue})
ExpectDeepEqual(t, entry.Query["bar"], []string{RedactedValue})
ExpectEqual(t, entry.Query["foo"], []string{RedactedValue})
ExpectEqual(t, entry.Query["bar"], []string{RedactedValue})
}

View file

@ -26,7 +26,7 @@ func TestParseRetention(t *testing.T) {
if !test.shouldErr {
ExpectNoError(t, err)
} else {
ExpectDeepEqual(t, r, test.expected)
ExpectEqual(t, r, test.expected)
}
})
}

View file

@ -45,9 +45,9 @@ func TestModifyRequest(t *testing.T) {
t.Run("set_options", func(t *testing.T) {
mr, err := ModifyRequest.New(opts)
ExpectNoError(t, err)
ExpectDeepEqual(t, mr.impl.(*modifyRequest).SetHeaders, opts["set_headers"].(map[string]string))
ExpectDeepEqual(t, mr.impl.(*modifyRequest).AddHeaders, opts["add_headers"].(map[string]string))
ExpectDeepEqual(t, mr.impl.(*modifyRequest).HideHeaders, opts["hide_headers"].([]string))
ExpectEqual(t, mr.impl.(*modifyRequest).SetHeaders, opts["set_headers"].(map[string]string))
ExpectEqual(t, mr.impl.(*modifyRequest).AddHeaders, opts["add_headers"].(map[string]string))
ExpectEqual(t, mr.impl.(*modifyRequest).HideHeaders, opts["hide_headers"].([]string))
})
t.Run("request_headers", func(t *testing.T) {

View file

@ -48,9 +48,9 @@ func TestModifyResponse(t *testing.T) {
t.Run("set_options", func(t *testing.T) {
mr, err := ModifyResponse.New(opts)
ExpectNoError(t, err)
ExpectDeepEqual(t, mr.impl.(*modifyResponse).SetHeaders, opts["set_headers"].(map[string]string))
ExpectDeepEqual(t, mr.impl.(*modifyResponse).AddHeaders, opts["add_headers"].(map[string]string))
ExpectDeepEqual(t, mr.impl.(*modifyResponse).HideHeaders, opts["hide_headers"].([]string))
ExpectEqual(t, mr.impl.(*modifyResponse).SetHeaders, opts["set_headers"].(map[string]string))
ExpectEqual(t, mr.impl.(*modifyResponse).AddHeaders, opts["add_headers"].(map[string]string))
ExpectEqual(t, mr.impl.(*modifyResponse).HideHeaders, opts["hide_headers"].([]string))
})
t.Run("response_headers", func(t *testing.T) {

View file

@ -156,7 +156,7 @@ func TestNotificationConfig(t *testing.T) {
} else {
ExpectNoError(t, err)
ExpectEqual(t, provider.(string), cfg.ProviderName)
ExpectDeepEqual(t, cfg.Provider, tt.expected)
ExpectEqual(t, cfg.Provider, tt.expected)
}
})
}

View file

@ -104,10 +104,10 @@ func TestApplyLabel(t *testing.T) {
ExpectTrue(t, a.NoTLSVerify)
ExpectTrue(t, b.NoTLSVerify)
ExpectDeepEqual(t, a.PathPatterns, pathPatternsExpect)
ExpectEqual(t, a.PathPatterns, pathPatternsExpect)
ExpectEqual(t, len(b.PathPatterns), 0)
ExpectDeepEqual(t, a.Middlewares, middlewaresExpect)
ExpectEqual(t, a.Middlewares, middlewaresExpect)
ExpectEqual(t, len(b.Middlewares), 0)
ExpectEqual(t, a.Container.IdleTimeout, "")

View file

@ -43,7 +43,7 @@ func TestHTTPConfigDeserialize(t *testing.T) {
if err != nil {
ExpectNoError(t, err)
}
ExpectDeepEqual(t, cfg.HTTPConfig, tt.expected)
ExpectEqual(t, cfg.HTTPConfig, tt.expected)
})
}
}

View file

@ -42,7 +42,7 @@ func TestDeserialize(t *testing.T) {
var s2 S
err := Deserialize(testStructSerialized, &s2)
ExpectNoError(t, err)
ExpectDeepEqual(t, s2, testStruct)
ExpectEqual(t, s2, testStruct)
})
}
@ -177,21 +177,21 @@ func TestStringToSlice(t *testing.T) {
convertible, err := ConvertString("a,b,c", reflect.ValueOf(&dst))
ExpectTrue(t, convertible)
ExpectNoError(t, err)
ExpectDeepEqual(t, dst, []string{"a", "b", "c"})
ExpectEqual(t, dst, []string{"a", "b", "c"})
})
t.Run("yaml-like", func(t *testing.T) {
dst := make([]string, 0)
convertible, err := ConvertString("- a\n- b\n- c", reflect.ValueOf(&dst))
ExpectTrue(t, convertible)
ExpectNoError(t, err)
ExpectDeepEqual(t, dst, []string{"a", "b", "c"})
ExpectEqual(t, dst, []string{"a", "b", "c"})
})
t.Run("single-line-yaml-like", func(t *testing.T) {
dst := make([]string, 0)
convertible, err := ConvertString("- a", reflect.ValueOf(&dst))
ExpectTrue(t, convertible)
ExpectNoError(t, err)
ExpectDeepEqual(t, dst, []string{"a"})
ExpectEqual(t, dst, []string{"a"})
})
}
@ -215,7 +215,7 @@ func TestStringToMap(t *testing.T) {
convertible, err := ConvertString(" a: b\n c: d", reflect.ValueOf(&dst))
ExpectTrue(t, convertible)
ExpectNoError(t, err)
ExpectDeepEqual(t, dst, map[string]string{"a": "b", "c": "d"})
ExpectEqual(t, dst, map[string]string{"a": "b", "c": "d"})
})
}
@ -242,7 +242,7 @@ func TestStringToStruct(t *testing.T) {
convertible, err := ConvertString(" A: a\n B: 123", reflect.ValueOf(&dst))
ExpectTrue(t, convertible)
ExpectNoError(t, err)
ExpectDeepEqual(t, dst, struct {
ExpectEqual(t, dst, struct {
A string
B int
}{"a", 123})

View file

@ -5,7 +5,7 @@ import (
"strings"
"testing"
utils "github.com/yusing/go-proxy/internal/utils/testing"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestIntersect(t *testing.T) {
@ -19,7 +19,7 @@ func TestIntersect(t *testing.T) {
result := Intersect(slice1, slice2)
slices.Sort(result)
slices.Sort(want)
utils.ExpectDeepEqual(t, result, want)
ExpectEqual(t, result, want)
})
t.Run("intersection", func(t *testing.T) {
var (
@ -30,7 +30,7 @@ func TestIntersect(t *testing.T) {
result := Intersect(slice1, slice2)
slices.Sort(result)
slices.Sort(want)
utils.ExpectDeepEqual(t, result, want)
ExpectEqual(t, result, want)
})
})
t.Run("ints", func(t *testing.T) {
@ -43,7 +43,7 @@ func TestIntersect(t *testing.T) {
result := Intersect(slice1, slice2)
slices.Sort(result)
slices.Sort(want)
utils.ExpectDeepEqual(t, result, want)
ExpectEqual(t, result, want)
})
t.Run("intersection", func(t *testing.T) {
var (
@ -54,7 +54,7 @@ func TestIntersect(t *testing.T) {
result := Intersect(slice1, slice2)
slices.Sort(result)
slices.Sort(want)
utils.ExpectDeepEqual(t, result, want)
ExpectEqual(t, result, want)
})
})
t.Run("complex", func(t *testing.T) {
@ -75,7 +75,7 @@ func TestIntersect(t *testing.T) {
slices.SortFunc(want, func(i T, j T) int {
return strings.Compare(i.A, j.A)
})
utils.ExpectDeepEqual(t, result, want)
ExpectEqual(t, result, want)
})
t.Run("intersection", func(t *testing.T) {
var (
@ -90,7 +90,7 @@ func TestIntersect(t *testing.T) {
slices.SortFunc(want, func(i T, j T) int {
return strings.Compare(i.A, j.A)
})
utils.ExpectDeepEqual(t, result, want)
ExpectEqual(t, result, want)
})
})
}

View file

@ -31,7 +31,7 @@ func TestSplit(t *testing.T) {
for sep, rsep := range tests {
t.Run(sep, func(t *testing.T) {
expected := strings.Split(alphaNumeric, sep)
ExpectDeepEqual(t, SplitRune(alphaNumeric, rsep), expected)
ExpectEqual(t, SplitRune(alphaNumeric, rsep), expected)
ExpectEqual(t, JoinRune(expected, rsep), alphaNumeric)
})
}

View file

@ -1,14 +1,11 @@
package utils
import (
"bytes"
"errors"
"os"
"reflect"
"testing"
"github.com/stretchr/testify/require"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/utils/strutils/ansi"
)
func init() {
@ -24,114 +21,50 @@ func Must[Result any](r Result, err error) Result {
return r
}
func fmtError(err error) string {
if err == nil {
return "<nil>"
}
return ansi.StripANSI(err.Error())
}
func ExpectNoError(t *testing.T, err error) {
t.Helper()
if err != nil {
t.Errorf("expected err=nil, got %s", fmtError(err))
t.FailNow()
}
require.NoError(t, err)
}
func ExpectHasError(t *testing.T, err error) {
t.Helper()
if errors.Is(err, nil) {
t.Error("expected err not nil")
t.FailNow()
}
require.Error(t, err)
}
func ExpectError(t *testing.T, expected error, err error) {
t.Helper()
if !errors.Is(err, expected) {
t.Errorf("expected err %s, got %s", expected, fmtError(err))
t.FailNow()
}
}
func ExpectError2(t *testing.T, input any, expected error, err error) {
t.Helper()
if !errors.Is(err, expected) {
t.Errorf("%v: expected err %s, got %s", input, expected, fmtError(err))
t.FailNow()
}
require.ErrorIs(t, err, expected)
}
func ExpectErrorT[T error](t *testing.T, err error) {
t.Helper()
var errAs T
if !errors.As(err, &errAs) {
t.Errorf("expected err %T, got %s", errAs, fmtError(err))
t.FailNow()
}
require.ErrorAs(t, err, &errAs)
}
func ExpectEqual[T comparable](t *testing.T, got T, want T) {
func ExpectEqual[T any](t *testing.T, got T, want T) {
t.Helper()
if gotStr, ok := any(got).(string); ok {
ExpectDeepEqual(t, ansi.StripANSI(gotStr), any(want).(string))
return
}
if got != want {
t.Errorf("expected:\n%v, got\n%v", want, got)
t.FailNow()
}
require.EqualValues(t, got, want)
}
func ExpectEqualAny[T comparable](t *testing.T, got T, wants []T) {
func ExpectContains[T any](t *testing.T, got T, wants []T) {
t.Helper()
for _, want := range wants {
if got == want {
return
}
}
t.Errorf("expected any of:\n%v, got\n%v", wants, got)
t.FailNow()
}
func ExpectDeepEqual[T any](t *testing.T, got T, want T) {
t.Helper()
if !reflect.DeepEqual(got, want) {
t.Errorf("expected:\n%v, got\n%v", want, got)
t.FailNow()
}
}
func ExpectBytesEqual(t *testing.T, got []byte, want []byte) {
t.Helper()
if !bytes.Equal(got, want) {
t.Errorf("expected:\n%v, got\n%v", want, got)
t.FailNow()
}
require.Contains(t, wants, got)
}
func ExpectTrue(t *testing.T, got bool) {
t.Helper()
if !got {
t.Error("expected true")
t.FailNow()
}
require.True(t, got)
}
func ExpectFalse(t *testing.T, got bool) {
t.Helper()
if got {
t.Error("expected false")
t.FailNow()
}
require.False(t, got)
}
func ExpectType[T any](t *testing.T, got any) (_ T) {
t.Helper()
_, ok := got.(T)
if !ok {
t.Fatalf("expected type %s, got %T", reflect.TypeFor[T](), got)
}
require.True(t, ok)
return got.(T)
}