diff --git a/agent/pkg/agent/new_agent_test.go b/agent/pkg/agent/new_agent_test.go index f9b417d..9809a7d 100644 --- a/agent/pkg/agent/new_agent_test.go +++ b/agent/pkg/agent/new_agent_test.go @@ -8,59 +8,59 @@ import ( "net/http/httptest" "testing" - . "github.com/yusing/go-proxy/internal/utils/testing" + expect "github.com/yusing/go-proxy/internal/utils/testing" ) func TestNewAgent(t *testing.T) { ca, srv, client, err := NewAgent() - ExpectNoError(t, err) - ExpectTrue(t, ca != nil) - ExpectTrue(t, srv != nil) - ExpectTrue(t, client != nil) + expect.NoError(t, err) + expect.True(t, ca != nil) + expect.True(t, srv != nil) + expect.True(t, client != nil) } func TestPEMPair(t *testing.T) { ca, srv, client, err := NewAgent() - ExpectNoError(t, err) + expect.NoError(t, err) for i, p := range []*PEMPair{ca, srv, client} { t.Run(fmt.Sprintf("load-%d", i), func(t *testing.T) { var pp PEMPair err := pp.Load(p.String()) - ExpectNoError(t, err) - ExpectEqual(t, p.Cert, pp.Cert) - ExpectEqual(t, p.Key, pp.Key) + expect.NoError(t, err) + expect.Equal(t, p.Cert, pp.Cert) + expect.Equal(t, p.Key, pp.Key) }) } } func TestPEMPairToTLSCert(t *testing.T) { ca, srv, client, err := NewAgent() - ExpectNoError(t, err) + expect.NoError(t, err) for i, p := range []*PEMPair{ca, srv, client} { t.Run(fmt.Sprintf("toTLSCert-%d", i), func(t *testing.T) { cert, err := p.ToTLSCert() - ExpectNoError(t, err) - ExpectTrue(t, cert != nil) + expect.NoError(t, err) + expect.True(t, cert != nil) }) } } func TestServerClient(t *testing.T) { ca, srv, client, err := NewAgent() - ExpectNoError(t, err) + expect.NoError(t, err) srvTLS, err := srv.ToTLSCert() - ExpectNoError(t, err) - ExpectTrue(t, srvTLS != nil) + expect.NoError(t, err) + expect.True(t, srvTLS != nil) clientTLS, err := client.ToTLSCert() - ExpectNoError(t, err) - ExpectTrue(t, clientTLS != nil) + expect.NoError(t, err) + expect.True(t, clientTLS != nil) caPool := x509.NewCertPool() - ExpectTrue(t, caPool.AppendCertsFromPEM(ca.Cert)) + expect.True(t, caPool.AppendCertsFromPEM(ca.Cert)) srvTLSConfig := &tls.Config{ Certificates: []tls.Certificate{*srvTLS}, @@ -86,6 +86,6 @@ func TestServerClient(t *testing.T) { } resp, err := httpClient.Get(server.URL) - ExpectNoError(t, err) - ExpectEqual(t, resp.StatusCode, http.StatusOK) + expect.NoError(t, err) + expect.Equal(t, resp.StatusCode, http.StatusOK) } diff --git a/agent/pkg/certs/zip_test.go b/agent/pkg/certs/zip_test.go index 0ebe6ac..707c8e4 100644 --- a/agent/pkg/certs/zip_test.go +++ b/agent/pkg/certs/zip_test.go @@ -3,17 +3,17 @@ package certs import ( "testing" - . "github.com/yusing/go-proxy/internal/utils/testing" + expect "github.com/yusing/go-proxy/internal/utils/testing" ) func TestZipCert(t *testing.T) { ca, crt, key := []byte("test1"), []byte("test2"), []byte("test3") zipData, err := ZipCert(ca, crt, key) - ExpectNoError(t, err) + expect.NoError(t, err) ca2, crt2, key2, err := ExtractCert(zipData) - ExpectNoError(t, err) - ExpectEqual(t, ca, ca2) - ExpectEqual(t, crt, crt2) - ExpectEqual(t, key, key2) + expect.NoError(t, err) + expect.Equal(t, ca, ca2) + expect.Equal(t, crt, crt2) + expect.Equal(t, key, key2) } diff --git a/agent/pkg/handler/check_health_test.go b/agent/pkg/handler/check_health_test.go index d90d98e..8280834 100644 --- a/agent/pkg/handler/check_health_test.go +++ b/agent/pkg/handler/check_health_test.go @@ -10,10 +10,11 @@ import ( "github.com/yusing/go-proxy/pkg/json" - "github.com/stretchr/testify/require" "github.com/yusing/go-proxy/agent/pkg/agent" "github.com/yusing/go-proxy/agent/pkg/handler" "github.com/yusing/go-proxy/internal/watcher/health" + + expect "github.com/yusing/go-proxy/internal/utils/testing" ) func TestCheckHealthHTTP(t *testing.T) { @@ -79,12 +80,12 @@ func TestCheckHealthHTTP(t *testing.T) { request := httptest.NewRequest(http.MethodGet, agent.APIEndpointBase+agent.EndpointHealth+"?"+query.Encode(), nil) handler.CheckHealth(recorder, request) - require.Equal(t, recorder.Code, tt.expectedStatus) + expect.Equal(t, recorder.Code, tt.expectedStatus) if tt.expectedStatus == http.StatusOK { var result health.HealthCheckResult - require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &result)) - require.Equal(t, result.Healthy, tt.expectedHealthy) + expect.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &result)) + expect.Equal(t, result.Healthy, tt.expectedHealthy) } }) } @@ -124,32 +125,32 @@ func TestCheckHealthFileServer(t *testing.T) { request := httptest.NewRequest(http.MethodGet, agent.APIEndpointBase+agent.EndpointHealth+"?"+query.Encode(), nil) handler.CheckHealth(recorder, request) - require.Equal(t, recorder.Code, tt.expectedStatus) + expect.Equal(t, recorder.Code, tt.expectedStatus) var result health.HealthCheckResult - require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &result)) - require.Equal(t, result.Healthy, tt.expectedHealthy) - require.Equal(t, result.Detail, tt.expectedDetail) + expect.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &result)) + expect.Equal(t, result.Healthy, tt.expectedHealthy) + expect.Equal(t, result.Detail, tt.expectedDetail) }) } } func TestCheckHealthTCPUDP(t *testing.T) { tcp, err := net.Listen("tcp", "localhost:0") - require.NoError(t, err) + expect.NoError(t, err) go func() { conn, err := tcp.Accept() - require.NoError(t, err) + expect.NoError(t, err) conn.Close() }() udp, err := net.ListenPacket("udp", "localhost:0") - require.NoError(t, err) + expect.NoError(t, err) go func() { buf := make([]byte, 1024) n, addr, err := udp.ReadFrom(buf) - require.NoError(t, err) - require.Equal(t, string(buf[:n]), "ping") + expect.NoError(t, err) + expect.Equal(t, string(buf[:n]), "ping") _, _ = udp.WriteTo([]byte("pong"), addr) udp.Close() }() @@ -207,11 +208,11 @@ func TestCheckHealthTCPUDP(t *testing.T) { request := httptest.NewRequest(http.MethodGet, agent.APIEndpointBase+agent.EndpointHealth+"?"+query.Encode(), nil) handler.CheckHealth(recorder, request) - require.Equal(t, recorder.Code, tt.expectedStatus) + expect.Equal(t, recorder.Code, tt.expectedStatus) var result health.HealthCheckResult - require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &result)) - require.Equal(t, result.Healthy, tt.expectedHealthy) + expect.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &result)) + expect.Equal(t, result.Healthy, tt.expectedHealthy) }) } } diff --git a/internal/api/v1/auth/oidc_test.go b/internal/api/v1/auth/oidc_test.go index ad2a5e2..c43ac54 100644 --- a/internal/api/v1/auth/oidc_test.go +++ b/internal/api/v1/auth/oidc_test.go @@ -17,7 +17,7 @@ import ( "github.com/yusing/go-proxy/internal/common" "golang.org/x/oauth2" - . "github.com/yusing/go-proxy/internal/utils/testing" + expect "github.com/yusing/go-proxy/internal/utils/testing" ) // setupMockOIDC configures mock OIDC provider for testing. @@ -75,7 +75,7 @@ func (j *provider) SignClaims(t *testing.T, claims jwt.Claims) string { token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) token.Header["kid"] = keyID signed, err := token.SignedString(j.key) - ExpectNoError(t, err) + expect.NoError(t, err) return signed } @@ -84,7 +84,7 @@ func setupProvider(t *testing.T) *provider { // Generate an RSA key pair for the test. privKey, err := rsa.GenerateKey(rand.Reader, 2048) - ExpectNoError(t, err) + expect.NoError(t, err) // Build the matching public JWK that will be served by the endpoint. jwk := buildRSAJWK(t, &privKey.PublicKey, keyID) @@ -227,12 +227,12 @@ func TestOIDCCallbackHandler(t *testing.T) { } if tt.wantStatus == http.StatusTemporaryRedirect { - setCookie := Must(http.ParseSetCookie(w.Header().Get("Set-Cookie"))) - ExpectEqual(t, setCookie.Name, defaultAuth.TokenCookieName()) - ExpectTrue(t, setCookie.Value != "") - ExpectEqual(t, setCookie.Path, "/") - ExpectEqual(t, setCookie.SameSite, http.SameSiteLaxMode) - ExpectEqual(t, setCookie.HttpOnly, true) + setCookie := expect.Must(http.ParseSetCookie(w.Header().Get("Set-Cookie"))) + expect.Equal(t, setCookie.Name, defaultAuth.TokenCookieName()) + expect.True(t, setCookie.Value != "") + expect.Equal(t, setCookie.Path, "/") + expect.Equal(t, setCookie.SameSite, http.SameSiteLaxMode) + expect.Equal(t, setCookie.HttpOnly, true) } }) } @@ -245,7 +245,7 @@ func TestInitOIDC(t *testing.T) { mux := http.NewServeMux() mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") - ExpectNoError(t, json.NewEncoder(w).Encode(discoveryDocument(t, server))) + expect.NoError(t, json.NewEncoder(w).Encode(discoveryDocument(t, server))) }) server = httptest.NewServer(mux) t.Cleanup(server.Close) @@ -446,9 +446,9 @@ func TestCheckToken(t *testing.T) { // Call CheckToken and verify the result. err := auth.CheckToken(req) if tc.wantErr == nil { - ExpectNoError(t, err) + expect.NoError(t, err) } else { - ExpectError(t, tc.wantErr, err) + expect.ErrorIs(t, tc.wantErr, err) } }) } diff --git a/internal/api/v1/auth/userpass_test.go b/internal/api/v1/auth/userpass_test.go index cd958ed..be4ae8f 100644 --- a/internal/api/v1/auth/userpass_test.go +++ b/internal/api/v1/auth/userpass_test.go @@ -10,14 +10,14 @@ import ( "github.com/yusing/go-proxy/pkg/json" - . "github.com/yusing/go-proxy/internal/utils/testing" + expect "github.com/yusing/go-proxy/internal/utils/testing" "golang.org/x/crypto/bcrypt" ) func newMockUserPassAuth() *UserPassAuth { return &UserPassAuth{ username: "username", - pwdHash: Must(bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost)), + pwdHash: expect.Must(bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost)), secret: []byte("abcdefghijklmnopqrstuvwxyz"), tokenTTL: time.Hour, } @@ -26,17 +26,17 @@ func newMockUserPassAuth() *UserPassAuth { func TestUserPassValidateCredentials(t *testing.T) { auth := newMockUserPassAuth() err := auth.validatePassword("username", "password") - ExpectNoError(t, err) + expect.NoError(t, err) err = auth.validatePassword("username", "wrong-password") - ExpectError(t, ErrInvalidPassword, err) + expect.ErrorIs(t, ErrInvalidPassword, err) err = auth.validatePassword("wrong-username", "password") - ExpectError(t, ErrInvalidUsername, err) + expect.ErrorIs(t, ErrInvalidUsername, err) } func TestUserPassCheckToken(t *testing.T) { auth := newMockUserPassAuth() token, err := auth.NewToken() - ExpectNoError(t, err) + expect.NoError(t, err) tests := []struct { token string wantErr bool @@ -61,9 +61,9 @@ func TestUserPassCheckToken(t *testing.T) { } err = auth.CheckToken(req) if tt.wantErr { - ExpectTrue(t, err != nil) + expect.True(t, err != nil) } else { - ExpectNoError(t, err) + expect.NoError(t, err) } } } @@ -97,20 +97,20 @@ func TestUserPassLoginCallbackHandler(t *testing.T) { w := httptest.NewRecorder() req := &http.Request{ Host: "app.example.com", - Body: io.NopCloser(bytes.NewReader(Must(json.Marshal(tt.creds)))), + Body: io.NopCloser(bytes.NewReader(expect.Must(json.Marshal(tt.creds)))), } auth.LoginCallbackHandler(w, req) if tt.wantErr { - ExpectEqual(t, w.Code, http.StatusUnauthorized) + expect.Equal(t, w.Code, http.StatusUnauthorized) } else { - setCookie := Must(http.ParseSetCookie(w.Header().Get("Set-Cookie"))) - ExpectTrue(t, setCookie.Name == auth.TokenCookieName()) - ExpectTrue(t, setCookie.Value != "") - ExpectEqual(t, setCookie.Domain, "example.com") - ExpectEqual(t, setCookie.Path, "/") - ExpectEqual(t, setCookie.SameSite, http.SameSiteLaxMode) - ExpectEqual(t, setCookie.HttpOnly, true) - ExpectEqual(t, w.Code, http.StatusOK) + setCookie := expect.Must(http.ParseSetCookie(w.Header().Get("Set-Cookie"))) + expect.True(t, setCookie.Name == auth.TokenCookieName()) + expect.True(t, setCookie.Value != "") + expect.Equal(t, setCookie.Domain, "example.com") + expect.Equal(t, setCookie.Path, "/") + expect.Equal(t, setCookie.SameSite, http.SameSiteLaxMode) + expect.Equal(t, setCookie.HttpOnly, true) + expect.Equal(t, w.Code, http.StatusOK) } } } diff --git a/internal/autocert/provider_test/ovh_test.go b/internal/autocert/provider_test/ovh_test.go index b59ee29..87f074f 100644 --- a/internal/autocert/provider_test/ovh_test.go +++ b/internal/autocert/provider_test/ovh_test.go @@ -6,7 +6,8 @@ import ( "github.com/go-acme/lego/v4/providers/dns/ovh" "github.com/goccy/go-yaml" "github.com/yusing/go-proxy/internal/utils" - . "github.com/yusing/go-proxy/internal/utils/testing" + + expect "github.com/yusing/go-proxy/internal/utils/testing" ) // type Config struct { @@ -44,7 +45,7 @@ oauth2_config: } testYaml = testYaml[1:] // remove first \n opt := make(map[string]any) - ExpectNoError(t, yaml.Unmarshal([]byte(testYaml), &opt)) - ExpectNoError(t, utils.MapUnmarshalValidate(opt, cfg)) - ExpectEqual(t, cfg, cfgExpected) + expect.NoError(t, yaml.Unmarshal([]byte(testYaml), &opt)) + expect.NoError(t, utils.MapUnmarshalValidate(opt, cfg)) + expect.Equal(t, cfg, cfgExpected) } diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 75f040f..be8feda 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -6,13 +6,12 @@ import ( "testing" "github.com/goccy/go-yaml" - "github.com/stretchr/testify/assert" "github.com/yusing/go-proxy/agent/pkg/agent" "github.com/yusing/go-proxy/internal/common" config "github.com/yusing/go-proxy/internal/config/types" "github.com/yusing/go-proxy/internal/route/provider" "github.com/yusing/go-proxy/internal/utils" - . "github.com/yusing/go-proxy/internal/utils/testing" + expect "github.com/yusing/go-proxy/internal/utils/testing" ) func TestFileProviderValidate(t *testing.T) { @@ -57,10 +56,10 @@ func TestFileProviderValidate(t *testing.T) { if tt.init != nil { for _, filename := range tt.filenames { filepath := path.Join(common.ConfigDir, filename) - assert.NoError(t, tt.init(filepath)) + expect.NoError(t, tt.init(filepath)) } } - err := utils.UnmarshalValidateYAML(Must(yaml.Marshal(map[string]any{ + err := utils.UnmarshalValidateYAML(expect.Must(yaml.Marshal(map[string]any{ "providers": map[string]any{ "include": tt.filenames, }, @@ -68,13 +67,13 @@ func TestFileProviderValidate(t *testing.T) { if tt.cleanup != nil { for _, filename := range tt.filenames { filepath := path.Join(common.ConfigDir, filename) - assert.NoError(t, tt.cleanup(filepath)) + expect.NoError(t, tt.cleanup(filepath)) } } if tt.expectedErrorContains != "" { - assert.ErrorContains(t, err, tt.expectedErrorContains) + expect.ErrorContains(t, err, tt.expectedErrorContains) } else { - assert.NoError(t, err) + expect.NoError(t, err) } }) } @@ -129,9 +128,9 @@ func TestLoadRouteProviders(t *testing.T) { t.Run(tt.name, func(t *testing.T) { err := utils.Validate(tt.providers) if tt.expectedError { - assert.ErrorContains(t, err, "unique") + expect.ErrorContains(t, err, "unique") } else { - assert.NoError(t, err) + expect.NoError(t, err) } }) } @@ -142,9 +141,9 @@ func TestProviderNameUniqueness(t *testing.T) { docker := provider.NewDockerProvider("routes", "unix:///var/run/docker.sock") agent := provider.NewAgentProvider(agent.TestAgentConfig("routes", "192.168.1.100:8080")) - assert.True(t, file.String() != docker.String()) - assert.True(t, file.String() != agent.String()) - assert.True(t, docker.String() != agent.String()) + expect.NotEqual(t, file.String(), docker.String()) + expect.NotEqual(t, file.String(), agent.String()) + expect.NotEqual(t, docker.String(), agent.String()) } func TestFileProviderNameFromFilename(t *testing.T) { @@ -160,7 +159,7 @@ func TestFileProviderNameFromFilename(t *testing.T) { for _, tt := range tests { t.Run(tt.filename, func(t *testing.T) { p := provider.NewFileProvider(tt.filename) - assert.Equal(t, tt.expectedName, p.ShortName()) + expect.Equal(t, tt.expectedName, p.ShortName()) }) } } @@ -179,7 +178,7 @@ func TestDockerProviderString(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { p := provider.NewDockerProvider(tt.name, tt.dockerHost) - assert.Equal(t, tt.expected, p.String()) + expect.Equal(t, tt.expected, p.String()) }) } } @@ -196,7 +195,7 @@ func TestExplicitOnlyProvider(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { p := provider.NewDockerProvider(tt.name, "unix:///var/run/docker.sock") - assert.Equal(t, tt.expectedFlag, p.IsExplicitOnly()) + expect.Equal(t, tt.expectedFlag, p.IsExplicitOnly()) }) } } diff --git a/internal/docker/container_test.go b/internal/docker/container_test.go index 753ebac..e9c1930 100644 --- a/internal/docker/container_test.go +++ b/internal/docker/container_test.go @@ -4,7 +4,8 @@ import ( "testing" "github.com/docker/docker/api/types/container" - . "github.com/yusing/go-proxy/internal/utils/testing" + + expect "github.com/yusing/go-proxy/internal/utils/testing" ) func TestContainerExplicit(t *testing.T) { @@ -37,7 +38,7 @@ func TestContainerExplicit(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := FromDocker(&container.Summary{Names: []string{"test"}, State: "test", Labels: tt.labels}, "") - ExpectEqual(t, c.IsExplicit, tt.isExplicit) + expect.Equal(t, c.IsExplicit, tt.isExplicit) }) } } diff --git a/internal/entrypoint/entrypoint_test.go b/internal/entrypoint/entrypoint_test.go index 6b43278..0b5261f 100644 --- a/internal/entrypoint/entrypoint_test.go +++ b/internal/entrypoint/entrypoint_test.go @@ -5,7 +5,8 @@ import ( "github.com/yusing/go-proxy/internal/route" "github.com/yusing/go-proxy/internal/route/routes" - . "github.com/yusing/go-proxy/internal/utils/testing" + + expect "github.com/yusing/go-proxy/internal/utils/testing" ) var ep = NewEntrypoint() @@ -29,15 +30,15 @@ func run(t *testing.T, match []string, noMatch []string) { t.Run(test, func(t *testing.T) { r := addRoute(test) found, err := ep.findRouteFunc(test) - ExpectNoError(t, err) - ExpectTrue(t, found == r) + expect.NoError(t, err) + expect.True(t, found == r) }) } for _, test := range noMatch { t.Run(test, func(t *testing.T) { _, err := ep.findRouteFunc(test) - ExpectError(t, ErrNoSuchRoute, err) + expect.ErrorIs(t, ErrNoSuchRoute, err) }) } } diff --git a/internal/gperr/builder_test.go b/internal/gperr/builder_test.go index 04aa326..00958c1 100644 --- a/internal/gperr/builder_test.go +++ b/internal/gperr/builder_test.go @@ -2,18 +2,17 @@ package gperr_test import ( "context" - "errors" "io" "testing" . "github.com/yusing/go-proxy/internal/gperr" - . "github.com/yusing/go-proxy/internal/utils/testing" + expect "github.com/yusing/go-proxy/internal/utils/testing" ) func TestBuilderEmpty(t *testing.T) { eb := NewBuilder("foo") - ExpectTrue(t, errors.Is(eb.Error(), nil)) - ExpectFalse(t, eb.HasError()) + expect.NoError(t, eb.Error()) + expect.False(t, eb.HasError()) } func TestBuilderAddNil(t *testing.T) { @@ -26,17 +25,17 @@ func TestBuilderAddNil(t *testing.T) { eb.Add(err) } eb.AddRange(nil, nil, err) - ExpectFalse(t, eb.HasError()) - ExpectTrue(t, eb.Error() == nil) + expect.False(t, eb.HasError()) + expect.NoError(t, eb.Error()) } func TestBuilderIs(t *testing.T) { eb := NewBuilder("foo") eb.Add(context.Canceled) eb.Add(io.ErrShortBuffer) - ExpectTrue(t, eb.HasError()) - ExpectError(t, io.ErrShortBuffer, eb.Error()) - ExpectError(t, context.Canceled, eb.Error()) + expect.True(t, eb.HasError()) + expect.ErrorIs(t, io.ErrShortBuffer, eb.Error()) + expect.ErrorIs(t, context.Canceled, eb.Error()) } func TestBuilderNested(t *testing.T) { @@ -51,5 +50,5 @@ func TestBuilderNested(t *testing.T) { • Inner: 2 • Action 2 • Inner: 3` - ExpectEqual(t, got, expected) + expect.Equal(t, got, expected) } diff --git a/internal/gperr/error_test.go b/internal/gperr/error_test.go index 1421a0a..a14527f 100644 --- a/internal/gperr/error_test.go +++ b/internal/gperr/error_test.go @@ -6,11 +6,11 @@ import ( "testing" "github.com/yusing/go-proxy/internal/utils/strutils/ansi" - . "github.com/yusing/go-proxy/internal/utils/testing" + expect "github.com/yusing/go-proxy/internal/utils/testing" ) func TestBaseString(t *testing.T) { - ExpectEqual(t, New("error").Error(), "error") + expect.Equal(t, New("error").Error(), "error") } func TestBaseWithSubject(t *testing.T) { @@ -18,13 +18,13 @@ func TestBaseWithSubject(t *testing.T) { withSubject := err.Subject("foo") withSubjectf := err.Subjectf("%s %s", "foo", "bar") - ExpectError(t, err, withSubject) - ExpectEqual(t, ansi.StripANSI(withSubject.Error()), "foo: error") - ExpectTrue(t, withSubject.Is(err)) + expect.ErrorIs(t, err, withSubject) + expect.Equal(t, ansi.StripANSI(withSubject.Error()), "foo: error") + expect.True(t, withSubject.Is(err)) - ExpectError(t, err, withSubjectf) - ExpectEqual(t, ansi.StripANSI(withSubjectf.Error()), "foo bar: error") - ExpectTrue(t, withSubjectf.Is(err)) + expect.ErrorIs(t, err, withSubjectf) + expect.Equal(t, ansi.StripANSI(withSubjectf.Error()), "foo bar: error") + expect.True(t, withSubjectf.Is(err)) } func TestBaseWithExtra(t *testing.T) { @@ -32,22 +32,22 @@ func TestBaseWithExtra(t *testing.T) { extra := New("bar").Subject("baz") withExtra := err.With(extra) - ExpectTrue(t, withExtra.Is(extra)) - ExpectTrue(t, withExtra.Is(err)) + expect.True(t, withExtra.Is(extra)) + expect.True(t, withExtra.Is(err)) - ExpectTrue(t, errors.Is(withExtra, extra)) - ExpectTrue(t, errors.Is(withExtra, err)) + expect.True(t, errors.Is(withExtra, extra)) + expect.True(t, errors.Is(withExtra, err)) - ExpectTrue(t, strings.Contains(withExtra.Error(), err.Error())) - ExpectTrue(t, strings.Contains(withExtra.Error(), extra.Error())) - ExpectTrue(t, strings.Contains(withExtra.Error(), "baz")) + expect.True(t, strings.Contains(withExtra.Error(), err.Error())) + expect.True(t, strings.Contains(withExtra.Error(), extra.Error())) + expect.True(t, strings.Contains(withExtra.Error(), "baz")) } func TestBaseUnwrap(t *testing.T) { err := errors.New("err") wrapped := Wrap(err) - ExpectError(t, err, errors.Unwrap(wrapped)) + expect.ErrorIs(t, err, errors.Unwrap(wrapped)) } func TestNestedUnwrap(t *testing.T) { @@ -56,24 +56,24 @@ func TestNestedUnwrap(t *testing.T) { wrapped := Wrap(err).Subject("foo").With(err2.Subject("bar")) unwrapper, ok := wrapped.(interface{ Unwrap() []error }) - ExpectTrue(t, ok) + expect.True(t, ok) - ExpectError(t, err, wrapped) - ExpectError(t, err2, wrapped) - ExpectEqual(t, len(unwrapper.Unwrap()), 2) + expect.ErrorIs(t, err, wrapped) + expect.ErrorIs(t, err2, wrapped) + expect.Equal(t, len(unwrapper.Unwrap()), 2) } func TestErrorIs(t *testing.T) { from := errors.New("error") err := Wrap(from) - ExpectError(t, from, err) + expect.ErrorIs(t, from, err) - ExpectTrue(t, err.Is(from)) - ExpectFalse(t, err.Is(New("error"))) + expect.True(t, err.Is(from)) + expect.False(t, err.Is(New("error"))) - ExpectTrue(t, errors.Is(err.Subject("foo"), from)) - ExpectTrue(t, errors.Is(err.Withf("foo"), from)) - ExpectTrue(t, errors.Is(err.Subject("foo").Withf("bar"), from)) + expect.True(t, errors.Is(err.Subject("foo"), from)) + expect.True(t, errors.Is(err.Withf("foo"), from)) + expect.True(t, errors.Is(err.Subject("foo").Withf("bar"), from)) } func TestErrorImmutability(t *testing.T) { @@ -83,14 +83,14 @@ func TestErrorImmutability(t *testing.T) { for range 3 { // t.Logf("%d: %v %T %s", i, errors.Unwrap(err), err, err) _ = err.Subject("foo") - ExpectFalse(t, strings.Contains(err.Error(), "foo")) + expect.False(t, strings.Contains(err.Error(), "foo")) _ = err.With(err2) - ExpectFalse(t, strings.Contains(err.Error(), "extra")) - ExpectFalse(t, err.Is(err2)) + expect.False(t, strings.Contains(err.Error(), "extra")) + expect.False(t, err.Is(err2)) err = err.Subject("bar").Withf("baz") - ExpectTrue(t, err != nil) + expect.True(t, err != nil) } } @@ -100,24 +100,24 @@ func TestErrorWith(t *testing.T) { err3 := err1.With(err2) - ExpectTrue(t, err3.Is(err1)) - ExpectTrue(t, err3.Is(err2)) + expect.True(t, err3.Is(err1)) + expect.True(t, err3.Is(err2)) _ = err2.Subject("foo") - ExpectTrue(t, err3.Is(err1)) - ExpectTrue(t, err3.Is(err2)) + expect.True(t, err3.Is(err1)) + expect.True(t, err3.Is(err2)) // check if err3 is affected by err2.Subject - ExpectFalse(t, strings.Contains(err3.Error(), "foo")) + expect.False(t, strings.Contains(err3.Error(), "foo")) } func TestErrorStringSimple(t *testing.T) { errFailure := New("generic failure") ne := errFailure.Subject("foo bar") - ExpectEqual(t, ansi.StripANSI(ne.Error()), "foo bar: generic failure") + expect.Equal(t, ansi.StripANSI(ne.Error()), "foo bar: generic failure") ne = ne.Subject("baz") - ExpectEqual(t, ansi.StripANSI(ne.Error()), "baz > foo bar: generic failure") + expect.Equal(t, ansi.StripANSI(ne.Error()), "baz > foo bar: generic failure") } func TestErrorStringNested(t *testing.T) { @@ -154,5 +154,5 @@ func TestErrorStringNested(t *testing.T) { • action 3 > inner3: generic failure • 3 • 3` - ExpectEqual(t, ansi.StripANSI(ne.Error()), want) + expect.Equal(t, ansi.StripANSI(ne.Error()), want) } diff --git a/internal/gperr/multiline_test.go b/internal/gperr/multiline_test.go index d77a405..af87b85 100644 --- a/internal/gperr/multiline_test.go +++ b/internal/gperr/multiline_test.go @@ -4,7 +4,7 @@ import ( "net" "testing" - "github.com/stretchr/testify/require" + expect "github.com/yusing/go-proxy/internal/utils/testing" ) func TestWrapMultiline(t *testing.T) { @@ -25,5 +25,5 @@ func TestPrependSubjectMultiline(t *testing.T) { builder := NewBuilder() builder.Add(multiline) - require.Equal(t, len(builder.errs), len(multiline.Extras), builder.errs) + expect.Equal(t, len(multiline.Extras), len(builder.errs)) } diff --git a/internal/gperr/subject.go b/internal/gperr/subject.go index 6167cca..db7392e 100644 --- a/internal/gperr/subject.go +++ b/internal/gperr/subject.go @@ -5,8 +5,8 @@ import ( "slices" "strings" - "github.com/yusing/go-proxy/pkg/json" "github.com/yusing/go-proxy/internal/utils/strutils/ansi" + "github.com/yusing/go-proxy/pkg/json" ) //nolint:errname diff --git a/internal/homepage/homepage_test.go b/internal/homepage/homepage_test.go index b7836de..a7d766a 100644 --- a/internal/homepage/homepage_test.go +++ b/internal/homepage/homepage_test.go @@ -3,7 +3,7 @@ package homepage import ( "testing" - . "github.com/yusing/go-proxy/internal/utils/testing" + expect "github.com/yusing/go-proxy/internal/utils/testing" ) func TestOverrideItem(t *testing.T) { @@ -32,5 +32,5 @@ func TestOverrideItem(t *testing.T) { overrides := GetOverrideConfig() overrides.OverrideItem(a.Alias, want) got := a.GetOverride(a.Alias) - ExpectEqual(t, got, want) + expect.Equal(t, got, want) } diff --git a/internal/homepage/icon_url_test.go b/internal/homepage/icon_url_test.go index 580ebaa..1b655fc 100644 --- a/internal/homepage/icon_url_test.go +++ b/internal/homepage/icon_url_test.go @@ -3,7 +3,7 @@ package homepage import ( "testing" - . "github.com/yusing/go-proxy/internal/utils/testing" + expect "github.com/yusing/go-proxy/internal/utils/testing" ) func TestIconURL(t *testing.T) { @@ -114,11 +114,11 @@ func TestIconURL(t *testing.T) { u := &IconURL{} err := u.Parse(tc.input) if tc.wantErr { - ExpectError(t, ErrInvalidIconURL, err) + expect.ErrorIs(t, ErrInvalidIconURL, err) } else { tc.wantValue.FullValue = tc.input - ExpectNoError(t, err) - ExpectEqual(t, u, tc.wantValue) + expect.NoError(t, err) + expect.Equal(t, u, tc.wantValue) } }) } diff --git a/internal/idlewatcher/types/config_test.go b/internal/idlewatcher/types/config_test.go index dacd412..b9f3861 100644 --- a/internal/idlewatcher/types/config_test.go +++ b/internal/idlewatcher/types/config_test.go @@ -3,7 +3,7 @@ package idlewatcher import ( "testing" - . "github.com/yusing/go-proxy/internal/utils/testing" + expect "github.com/yusing/go-proxy/internal/utils/testing" ) func TestValidateStartEndpoint(t *testing.T) { @@ -38,7 +38,7 @@ func TestValidateStartEndpoint(t *testing.T) { cfg := Config{StartEndpoint: tc.input} err := cfg.validateStartEndpoint() if err == nil { - ExpectEqual(t, cfg.StartEndpoint, tc.input) + expect.Equal(t, cfg.StartEndpoint, tc.input) } if (err != nil) != tc.wantErr { t.Errorf("validateStartEndpoint() error = %v, wantErr %t", err, tc.wantErr) diff --git a/internal/metrics/systeminfo/system_info_test.go b/internal/metrics/systeminfo/system_info_test.go index 25740a9..d3ecaa5 100644 --- a/internal/metrics/systeminfo/system_info_test.go +++ b/internal/metrics/systeminfo/system_info_test.go @@ -6,8 +6,9 @@ import ( "testing" "github.com/shirou/gopsutil/v4/sensors" - . "github.com/yusing/go-proxy/internal/utils/testing" "github.com/yusing/go-proxy/pkg/json" + + expect "github.com/yusing/go-proxy/internal/utils/testing" ) func TestExcludeDisks(t *testing.T) { @@ -72,7 +73,7 @@ func TestExcludeDisks(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := shouldExcludeDisk(tt.name) - ExpectEqual(t, result, tt.shouldExclude) + expect.Equal(t, result, tt.shouldExclude) }) } } @@ -147,21 +148,21 @@ var testInfo = &SystemInfo{ func TestSystemInfo(t *testing.T) { // Test marshaling data, err := json.Marshal(testInfo) - ExpectNoError(t, err) + expect.NoError(t, err) // Test unmarshaling back var decoded SystemInfo err = json.Unmarshal(data, &decoded) - ExpectNoError(t, err) + expect.NoError(t, err) // Compare original and decoded - ExpectEqual(t, decoded.Timestamp, testInfo.Timestamp) - ExpectEqual(t, *decoded.CPUAverage, *testInfo.CPUAverage) - ExpectEqual(t, decoded.Memory, testInfo.Memory) - ExpectEqual(t, decoded.Disks, testInfo.Disks) - ExpectEqual(t, decoded.DisksIO, testInfo.DisksIO) - ExpectEqual(t, decoded.Network, testInfo.Network) - ExpectEqual(t, decoded.Sensors, testInfo.Sensors) + expect.Equal(t, decoded.Timestamp, testInfo.Timestamp) + expect.Equal(t, *decoded.CPUAverage, *testInfo.CPUAverage) + expect.Equal(t, decoded.Memory, testInfo.Memory) + expect.Equal(t, decoded.Disks, testInfo.Disks) + expect.Equal(t, decoded.DisksIO, testInfo.DisksIO) + expect.Equal(t, decoded.Network, testInfo.Network) + expect.Equal(t, decoded.Sensors, testInfo.Sensors) // Test nil fields nilInfo := &SystemInfo{ @@ -169,18 +170,18 @@ func TestSystemInfo(t *testing.T) { } data, err = json.Marshal(nilInfo) - ExpectNoError(t, err) + expect.NoError(t, err) var decodedNil SystemInfo err = json.Unmarshal(data, &decodedNil) - ExpectNoError(t, err) + expect.NoError(t, err) - ExpectEqual(t, decodedNil.Timestamp, nilInfo.Timestamp) - ExpectTrue(t, decodedNil.CPUAverage == nil) - ExpectTrue(t, decodedNil.Memory == nil) - ExpectTrue(t, decodedNil.Disks == nil) - ExpectTrue(t, decodedNil.Network == nil) - ExpectTrue(t, decodedNil.Sensors == nil) + expect.Equal(t, decodedNil.Timestamp, nilInfo.Timestamp) + expect.True(t, decodedNil.CPUAverage == nil) + expect.True(t, decodedNil.Memory == nil) + expect.True(t, decodedNil.Disks == nil) + expect.True(t, decodedNil.Network == nil) + expect.True(t, decodedNil.Sensors == nil) } func TestSerialize(t *testing.T) { @@ -193,13 +194,13 @@ func TestSerialize(t *testing.T) { _, result := aggregate(entries, url.Values{"aggregate": []string{query}}) s := result.MarshalJSONTo(nil) var v []map[string]any - ExpectNoError(t, json.Unmarshal(s, &v)) - ExpectEqual(t, len(v), len(result)) + expect.NoError(t, json.Unmarshal(s, &v)) + expect.Equal(t, len(v), len(result)) for i, m := range v { for k, v := range m { // some int64 values are converted to float64 on json.Unmarshal vv := reflect.ValueOf(result[i][k]) - ExpectEqual(t, reflect.ValueOf(v).Convert(vv.Type()).Interface(), vv.Interface()) + expect.Equal(t, reflect.ValueOf(v).Convert(vv.Type()).Interface(), vv.Interface()) } } }) diff --git a/internal/net/gphttp/accesslog/access_logger_test.go b/internal/net/gphttp/accesslog/access_logger_test.go index 805799e..c4382dd 100644 --- a/internal/net/gphttp/accesslog/access_logger_test.go +++ b/internal/net/gphttp/accesslog/access_logger_test.go @@ -12,7 +12,7 @@ import ( . "github.com/yusing/go-proxy/internal/net/gphttp/accesslog" "github.com/yusing/go-proxy/internal/task" - . "github.com/yusing/go-proxy/internal/utils/testing" + expect "github.com/yusing/go-proxy/internal/utils/testing" ) const ( @@ -30,7 +30,7 @@ const ( var ( testTask = task.RootTask("test", false) - testURL = Must(url.Parse("http://" + host + uri)) + testURL = expect.Must(url.Parse("http://" + host + uri)) req = &http.Request{ RemoteAddr: remote, Method: method, @@ -69,7 +69,7 @@ func TestAccessLoggerCommon(t *testing.T) { config := DefaultConfig() config.Format = FormatCommon ts, log := fmtLog(config) - ExpectEqual(t, log, + expect.Equal(t, log, fmt.Sprintf("%s %s - - [%s] \"%s %s %s\" %d %d", host, remote, ts, method, uri, proto, status, contentLength, ), @@ -80,7 +80,7 @@ func TestAccessLoggerCombined(t *testing.T) { config := DefaultConfig() config.Format = FormatCombined ts, log := fmtLog(config) - ExpectEqual(t, log, + expect.Equal(t, log, fmt.Sprintf("%s %s - - [%s] \"%s %s %s\" %d %d \"%s\" \"%s\"", host, remote, ts, method, uri, proto, status, contentLength, referer, ua, ), @@ -92,7 +92,7 @@ func TestAccessLoggerRedactQuery(t *testing.T) { config.Format = FormatCommon config.Fields.Query.Default = FieldModeRedact ts, log := fmtLog(config) - ExpectEqual(t, log, + expect.Equal(t, log, fmt.Sprintf("%s %s - - [%s] \"%s %s %s\" %d %d", host, remote, ts, method, uriRedacted, proto, status, contentLength, ), @@ -124,27 +124,27 @@ func getJSONEntry(t *testing.T, config *Config) JSONLogEntry { var entry JSONLogEntry _, log := fmtLog(config) err := json.Unmarshal([]byte(log), &entry) - ExpectNoError(t, err) + expect.NoError(t, err) return entry } func TestAccessLoggerJSON(t *testing.T) { config := DefaultConfig() entry := getJSONEntry(t, config) - ExpectEqual(t, entry.IP, remote) - ExpectEqual(t, entry.Method, method) - ExpectEqual(t, entry.Scheme, "http") - ExpectEqual(t, entry.Host, testURL.Host) - ExpectEqual(t, entry.URI, testURL.RequestURI()) - ExpectEqual(t, entry.Protocol, proto) - ExpectEqual(t, entry.Status, status) - ExpectEqual(t, entry.ContentType, "text/plain") - ExpectEqual(t, entry.Size, contentLength) - ExpectEqual(t, entry.Referer, referer) - ExpectEqual(t, entry.UserAgent, ua) - ExpectEqual(t, len(entry.Headers), 0) - ExpectEqual(t, len(entry.Cookies), 0) + expect.Equal(t, entry.IP, remote) + expect.Equal(t, entry.Method, method) + expect.Equal(t, entry.Scheme, "http") + expect.Equal(t, entry.Host, testURL.Host) + expect.Equal(t, entry.URI, testURL.RequestURI()) + expect.Equal(t, entry.Protocol, proto) + expect.Equal(t, entry.Status, status) + expect.Equal(t, entry.ContentType, "text/plain") + expect.Equal(t, entry.Size, contentLength) + expect.Equal(t, entry.Referer, referer) + expect.Equal(t, entry.UserAgent, ua) + expect.Equal(t, len(entry.Headers), 0) + expect.Equal(t, len(entry.Cookies), 0) if status >= 400 { - ExpectEqual(t, entry.Error, http.StatusText(status)) + expect.Equal(t, entry.Error, http.StatusText(status)) } } diff --git a/internal/net/gphttp/accesslog/config_test.go b/internal/net/gphttp/accesslog/config_test.go index c91e600..4110b00 100644 --- a/internal/net/gphttp/accesslog/config_test.go +++ b/internal/net/gphttp/accesslog/config_test.go @@ -6,7 +6,7 @@ import ( "github.com/yusing/go-proxy/internal/docker" . "github.com/yusing/go-proxy/internal/net/gphttp/accesslog" "github.com/yusing/go-proxy/internal/utils" - . "github.com/yusing/go-proxy/internal/utils/testing" + expect "github.com/yusing/go-proxy/internal/utils/testing" ) func TestNewConfig(t *testing.T) { @@ -27,27 +27,27 @@ func TestNewConfig(t *testing.T) { "proxy.fields.cookies.config.foo": "keep", } parsed, err := docker.ParseLabels(labels) - ExpectNoError(t, err) + expect.NoError(t, err) var config Config err = utils.MapUnmarshalValidate(parsed, &config) - ExpectNoError(t, err) + expect.NoError(t, err) - ExpectEqual(t, config.BufferSize, 10) - ExpectEqual(t, config.Format, FormatCombined) - ExpectEqual(t, config.Path, "/tmp/access.log") - ExpectEqual(t, config.Filters.StatusCodes.Values, []*StatusCodeRange{{Start: 200, End: 299}}) - ExpectEqual(t, len(config.Filters.Method.Values), 2) - ExpectEqual(t, config.Filters.Method.Values, []HTTPMethod{"GET", "POST"}) - ExpectEqual(t, len(config.Filters.Headers.Values), 2) - ExpectEqual(t, config.Filters.Headers.Values, []*HTTPHeader{{Key: "foo", Value: "bar"}, {Key: "baz", Value: ""}}) - ExpectTrue(t, config.Filters.Headers.Negative) - ExpectEqual(t, len(config.Filters.CIDR.Values), 1) - ExpectEqual(t, config.Filters.CIDR.Values[0].String(), "192.168.10.0/24") - ExpectEqual(t, config.Fields.Headers.Default, FieldModeKeep) - ExpectEqual(t, config.Fields.Headers.Config["foo"], FieldModeRedact) - ExpectEqual(t, config.Fields.Query.Default, FieldModeDrop) - ExpectEqual(t, config.Fields.Query.Config["foo"], FieldModeKeep) - ExpectEqual(t, config.Fields.Cookies.Default, FieldModeRedact) - ExpectEqual(t, config.Fields.Cookies.Config["foo"], FieldModeKeep) + expect.Equal(t, config.BufferSize, 10) + expect.Equal(t, config.Format, FormatCombined) + expect.Equal(t, config.Path, "/tmp/access.log") + expect.Equal(t, config.Filters.StatusCodes.Values, []*StatusCodeRange{{Start: 200, End: 299}}) + expect.Equal(t, len(config.Filters.Method.Values), 2) + expect.Equal(t, config.Filters.Method.Values, []HTTPMethod{"GET", "POST"}) + expect.Equal(t, len(config.Filters.Headers.Values), 2) + expect.Equal(t, config.Filters.Headers.Values, []*HTTPHeader{{Key: "foo", Value: "bar"}, {Key: "baz", Value: ""}}) + expect.True(t, config.Filters.Headers.Negative) + expect.Equal(t, len(config.Filters.CIDR.Values), 1) + expect.Equal(t, config.Filters.CIDR.Values[0].String(), "192.168.10.0/24") + expect.Equal(t, config.Fields.Headers.Default, FieldModeKeep) + expect.Equal(t, config.Fields.Headers.Config["foo"], FieldModeRedact) + expect.Equal(t, config.Fields.Query.Default, FieldModeDrop) + expect.Equal(t, config.Fields.Query.Config["foo"], FieldModeKeep) + expect.Equal(t, config.Fields.Cookies.Default, FieldModeRedact) + expect.Equal(t, config.Fields.Cookies.Config["foo"], FieldModeKeep) } diff --git a/internal/net/gphttp/accesslog/fields_test.go b/internal/net/gphttp/accesslog/fields_test.go index f482734..34ae5b9 100644 --- a/internal/net/gphttp/accesslog/fields_test.go +++ b/internal/net/gphttp/accesslog/fields_test.go @@ -4,7 +4,7 @@ import ( "testing" . "github.com/yusing/go-proxy/internal/net/gphttp/accesslog" - . "github.com/yusing/go-proxy/internal/utils/testing" + expect "github.com/yusing/go-proxy/internal/utils/testing" ) // Cookie header should be removed, @@ -15,7 +15,7 @@ func TestAccessLoggerJSONKeepHeaders(t *testing.T) { entry := getJSONEntry(t, config) for k, v := range req.Header { if k != "Cookie" { - ExpectEqual(t, entry.Headers[k], v) + expect.Equal(t, entry.Headers[k], v) } } @@ -24,8 +24,8 @@ func TestAccessLoggerJSONKeepHeaders(t *testing.T) { "User-Agent": FieldModeDrop, } entry = getJSONEntry(t, config) - ExpectEqual(t, entry.Headers["Referer"], []string{RedactedValue}) - ExpectEqual(t, entry.Headers["User-Agent"], nil) + expect.Equal(t, entry.Headers["Referer"], []string{RedactedValue}) + expect.Equal(t, entry.Headers["User-Agent"], nil) } func TestAccessLoggerJSONDropHeaders(t *testing.T) { @@ -33,7 +33,7 @@ func TestAccessLoggerJSONDropHeaders(t *testing.T) { config.Fields.Headers.Default = FieldModeDrop entry := getJSONEntry(t, config) for k := range req.Header { - ExpectEqual(t, entry.Headers[k], nil) + expect.Equal(t, entry.Headers[k], nil) } config.Fields.Headers.Config = map[string]FieldMode{ @@ -41,18 +41,18 @@ func TestAccessLoggerJSONDropHeaders(t *testing.T) { "User-Agent": FieldModeRedact, } entry = getJSONEntry(t, config) - ExpectEqual(t, entry.Headers["Referer"], []string{req.Header.Get("Referer")}) - ExpectEqual(t, entry.Headers["User-Agent"], []string{RedactedValue}) + expect.Equal(t, entry.Headers["Referer"], []string{req.Header.Get("Referer")}) + expect.Equal(t, entry.Headers["User-Agent"], []string{RedactedValue}) } func TestAccessLoggerJSONRedactHeaders(t *testing.T) { config := DefaultConfig() config.Fields.Headers.Default = FieldModeRedact entry := getJSONEntry(t, config) - ExpectEqual(t, len(entry.Headers["Cookie"]), 0) + expect.Equal(t, len(entry.Headers["Cookie"]), 0) for k := range req.Header { if k != "Cookie" { - ExpectEqual(t, entry.Headers[k], []string{RedactedValue}) + expect.Equal(t, entry.Headers[k], []string{RedactedValue}) } } } @@ -62,9 +62,9 @@ func TestAccessLoggerJSONKeepCookies(t *testing.T) { config.Fields.Headers.Default = FieldModeKeep config.Fields.Cookies.Default = FieldModeKeep entry := getJSONEntry(t, config) - ExpectEqual(t, len(entry.Headers["Cookie"]), 0) + expect.Equal(t, len(entry.Headers["Cookie"]), 0) for _, cookie := range req.Cookies() { - ExpectEqual(t, entry.Cookies[cookie.Name], cookie.Value) + expect.Equal(t, entry.Cookies[cookie.Name], cookie.Value) } } @@ -73,9 +73,9 @@ func TestAccessLoggerJSONRedactCookies(t *testing.T) { config.Fields.Headers.Default = FieldModeKeep config.Fields.Cookies.Default = FieldModeRedact entry := getJSONEntry(t, config) - ExpectEqual(t, len(entry.Headers["Cookie"]), 0) + expect.Equal(t, len(entry.Headers["Cookie"]), 0) for _, cookie := range req.Cookies() { - ExpectEqual(t, entry.Cookies[cookie.Name], RedactedValue) + expect.Equal(t, entry.Cookies[cookie.Name], RedactedValue) } } @@ -83,14 +83,14 @@ func TestAccessLoggerJSONDropQuery(t *testing.T) { config := DefaultConfig() config.Fields.Query.Default = FieldModeDrop entry := getJSONEntry(t, config) - ExpectEqual(t, entry.Query["foo"], nil) - ExpectEqual(t, entry.Query["bar"], nil) + expect.Equal(t, entry.Query["foo"], nil) + expect.Equal(t, entry.Query["bar"], nil) } func TestAccessLoggerJSONRedactQuery(t *testing.T) { config := DefaultConfig() config.Fields.Query.Default = FieldModeRedact entry := getJSONEntry(t, config) - ExpectEqual(t, entry.Query["foo"], []string{RedactedValue}) - ExpectEqual(t, entry.Query["bar"], []string{RedactedValue}) + expect.Equal(t, entry.Query["foo"], []string{RedactedValue}) + expect.Equal(t, entry.Query["bar"], []string{RedactedValue}) } diff --git a/internal/net/gphttp/accesslog/file_logger_test.go b/internal/net/gphttp/accesslog/file_logger_test.go index 5159d01..ad450a7 100644 --- a/internal/net/gphttp/accesslog/file_logger_test.go +++ b/internal/net/gphttp/accesslog/file_logger_test.go @@ -6,7 +6,7 @@ import ( "sync" "testing" - . "github.com/yusing/go-proxy/internal/utils/testing" + expect "github.com/yusing/go-proxy/internal/utils/testing" "github.com/yusing/go-proxy/internal/task" ) @@ -22,10 +22,10 @@ func TestConcurrentFileLoggersShareSameAccessLogIO(t *testing.T) { // make test log file file, err := os.Create(cfg.Path) - ExpectNoError(t, err) + expect.NoError(t, err) file.Close() t.Cleanup(func() { - ExpectNoError(t, os.Remove(cfg.Path)) + expect.NoError(t, os.Remove(cfg.Path)) }) for i := range loggerCount { @@ -33,7 +33,7 @@ func TestConcurrentFileLoggersShareSameAccessLogIO(t *testing.T) { go func(index int) { defer wg.Done() file, err := newFileIO(cfg.Path) - ExpectNoError(t, err) + expect.NoError(t, err) accessLogIOs[index] = file }(i) } @@ -42,7 +42,7 @@ func TestConcurrentFileLoggersShareSameAccessLogIO(t *testing.T) { firstIO := accessLogIOs[0] for _, io := range accessLogIOs { - ExpectEqual(t, io, firstIO) + expect.Equal(t, io, firstIO) } } @@ -78,7 +78,7 @@ func TestConcurrentAccessLoggerLogAndFlush(t *testing.T) { expected := loggerCount * logCountPerLogger actual := file.LineCount() - ExpectEqual(t, actual, expected) + expect.Equal(t, actual, expected) } func parallelLog(logger *AccessLogger, req *http.Request, resp *http.Response, n int) { diff --git a/internal/net/gphttp/accesslog/filter_test.go b/internal/net/gphttp/accesslog/filter_test.go index 5d8e8c8..dd51c89 100644 --- a/internal/net/gphttp/accesslog/filter_test.go +++ b/internal/net/gphttp/accesslog/filter_test.go @@ -7,7 +7,7 @@ import ( . "github.com/yusing/go-proxy/internal/net/gphttp/accesslog" "github.com/yusing/go-proxy/internal/utils/strutils" - . "github.com/yusing/go-proxy/internal/utils/testing" + expect "github.com/yusing/go-proxy/internal/utils/testing" ) func TestStatusCodeFilter(t *testing.T) { @@ -16,20 +16,20 @@ func TestStatusCodeFilter(t *testing.T) { } t.Run("positive", func(t *testing.T) { filter := &LogFilter[*StatusCodeRange]{} - ExpectTrue(t, filter.CheckKeep(nil, nil)) + expect.True(t, filter.CheckKeep(nil, nil)) // keep any 2xx 3xx (inclusive) filter.Values = values - ExpectFalse(t, filter.CheckKeep(nil, &http.Response{ + expect.False(t, filter.CheckKeep(nil, &http.Response{ StatusCode: http.StatusForbidden, })) - ExpectTrue(t, filter.CheckKeep(nil, &http.Response{ + expect.True(t, filter.CheckKeep(nil, &http.Response{ StatusCode: http.StatusOK, })) - ExpectTrue(t, filter.CheckKeep(nil, &http.Response{ + expect.True(t, filter.CheckKeep(nil, &http.Response{ StatusCode: http.StatusMultipleChoices, })) - ExpectTrue(t, filter.CheckKeep(nil, &http.Response{ + expect.True(t, filter.CheckKeep(nil, &http.Response{ StatusCode: http.StatusPermanentRedirect, })) }) @@ -38,20 +38,20 @@ func TestStatusCodeFilter(t *testing.T) { filter := &LogFilter[*StatusCodeRange]{ Negative: true, } - ExpectFalse(t, filter.CheckKeep(nil, nil)) + expect.False(t, filter.CheckKeep(nil, nil)) // drop any 2xx 3xx (inclusive) filter.Values = values - ExpectTrue(t, filter.CheckKeep(nil, &http.Response{ + expect.True(t, filter.CheckKeep(nil, &http.Response{ StatusCode: http.StatusForbidden, })) - ExpectFalse(t, filter.CheckKeep(nil, &http.Response{ + expect.False(t, filter.CheckKeep(nil, &http.Response{ StatusCode: http.StatusOK, })) - ExpectFalse(t, filter.CheckKeep(nil, &http.Response{ + expect.False(t, filter.CheckKeep(nil, &http.Response{ StatusCode: http.StatusMultipleChoices, })) - ExpectFalse(t, filter.CheckKeep(nil, &http.Response{ + expect.False(t, filter.CheckKeep(nil, &http.Response{ StatusCode: http.StatusPermanentRedirect, })) }) @@ -60,19 +60,19 @@ func TestStatusCodeFilter(t *testing.T) { func TestMethodFilter(t *testing.T) { t.Run("positive", func(t *testing.T) { filter := &LogFilter[HTTPMethod]{} - ExpectTrue(t, filter.CheckKeep(&http.Request{ + expect.True(t, filter.CheckKeep(&http.Request{ Method: http.MethodGet, }, nil)) - ExpectTrue(t, filter.CheckKeep(&http.Request{ + expect.True(t, filter.CheckKeep(&http.Request{ Method: http.MethodPost, }, nil)) // keep get only filter.Values = []HTTPMethod{http.MethodGet} - ExpectTrue(t, filter.CheckKeep(&http.Request{ + expect.True(t, filter.CheckKeep(&http.Request{ Method: http.MethodGet, }, nil)) - ExpectFalse(t, filter.CheckKeep(&http.Request{ + expect.False(t, filter.CheckKeep(&http.Request{ Method: http.MethodPost, }, nil)) }) @@ -81,19 +81,19 @@ func TestMethodFilter(t *testing.T) { filter := &LogFilter[HTTPMethod]{ Negative: true, } - ExpectFalse(t, filter.CheckKeep(&http.Request{ + expect.False(t, filter.CheckKeep(&http.Request{ Method: http.MethodGet, }, nil)) - ExpectFalse(t, filter.CheckKeep(&http.Request{ + expect.False(t, filter.CheckKeep(&http.Request{ Method: http.MethodPost, }, nil)) // drop post only filter.Values = []HTTPMethod{http.MethodPost} - ExpectFalse(t, filter.CheckKeep(&http.Request{ + expect.False(t, filter.CheckKeep(&http.Request{ Method: http.MethodPost, }, nil)) - ExpectTrue(t, filter.CheckKeep(&http.Request{ + expect.True(t, filter.CheckKeep(&http.Request{ Method: http.MethodGet, }, nil)) }) @@ -113,45 +113,45 @@ func TestHeaderFilter(t *testing.T) { headerFoo := []*HTTPHeader{ strutils.MustParse[*HTTPHeader]("Foo"), } - ExpectEqual(t, headerFoo[0].Key, "Foo") - ExpectEqual(t, headerFoo[0].Value, "") + expect.Equal(t, headerFoo[0].Key, "Foo") + expect.Equal(t, headerFoo[0].Value, "") headerFooBar := []*HTTPHeader{ strutils.MustParse[*HTTPHeader]("Foo=bar"), } - ExpectEqual(t, headerFooBar[0].Key, "Foo") - ExpectEqual(t, headerFooBar[0].Value, "bar") + expect.Equal(t, headerFooBar[0].Key, "Foo") + expect.Equal(t, headerFooBar[0].Value, "bar") t.Run("positive", func(t *testing.T) { filter := &LogFilter[*HTTPHeader]{} - ExpectTrue(t, filter.CheckKeep(fooBar, nil)) - ExpectTrue(t, filter.CheckKeep(fooBaz, nil)) + expect.True(t, filter.CheckKeep(fooBar, nil)) + expect.True(t, filter.CheckKeep(fooBaz, nil)) // keep any foo filter.Values = headerFoo - ExpectTrue(t, filter.CheckKeep(fooBar, nil)) - ExpectTrue(t, filter.CheckKeep(fooBaz, nil)) + expect.True(t, filter.CheckKeep(fooBar, nil)) + expect.True(t, filter.CheckKeep(fooBaz, nil)) // keep foo == bar filter.Values = headerFooBar - ExpectTrue(t, filter.CheckKeep(fooBar, nil)) - ExpectFalse(t, filter.CheckKeep(fooBaz, nil)) + expect.True(t, filter.CheckKeep(fooBar, nil)) + expect.False(t, filter.CheckKeep(fooBaz, nil)) }) t.Run("negative", func(t *testing.T) { filter := &LogFilter[*HTTPHeader]{ Negative: true, } - ExpectFalse(t, filter.CheckKeep(fooBar, nil)) - ExpectFalse(t, filter.CheckKeep(fooBaz, nil)) + expect.False(t, filter.CheckKeep(fooBar, nil)) + expect.False(t, filter.CheckKeep(fooBaz, nil)) // drop any foo filter.Values = headerFoo - ExpectFalse(t, filter.CheckKeep(fooBar, nil)) - ExpectFalse(t, filter.CheckKeep(fooBaz, nil)) + expect.False(t, filter.CheckKeep(fooBar, nil)) + expect.False(t, filter.CheckKeep(fooBaz, nil)) // drop foo == bar filter.Values = headerFooBar - ExpectFalse(t, filter.CheckKeep(fooBar, nil)) - ExpectTrue(t, filter.CheckKeep(fooBaz, nil)) + expect.False(t, filter.CheckKeep(fooBar, nil)) + expect.True(t, filter.CheckKeep(fooBaz, nil)) }) } @@ -160,7 +160,7 @@ func TestCIDRFilter(t *testing.T) { IP: net.ParseIP("192.168.10.0"), Mask: net.CIDRMask(24, 32), }} - ExpectEqual(t, cidr[0].String(), "192.168.10.0/24") + expect.Equal(t, cidr[0].String(), "192.168.10.0/24") inCIDR := &http.Request{ RemoteAddr: "192.168.10.1", } @@ -170,21 +170,21 @@ func TestCIDRFilter(t *testing.T) { t.Run("positive", func(t *testing.T) { filter := &LogFilter[*CIDR]{} - ExpectTrue(t, filter.CheckKeep(inCIDR, nil)) - ExpectTrue(t, filter.CheckKeep(notInCIDR, nil)) + expect.True(t, filter.CheckKeep(inCIDR, nil)) + expect.True(t, filter.CheckKeep(notInCIDR, nil)) filter.Values = cidr - ExpectTrue(t, filter.CheckKeep(inCIDR, nil)) - ExpectFalse(t, filter.CheckKeep(notInCIDR, nil)) + expect.True(t, filter.CheckKeep(inCIDR, nil)) + expect.False(t, filter.CheckKeep(notInCIDR, nil)) }) t.Run("negative", func(t *testing.T) { filter := &LogFilter[*CIDR]{Negative: true} - ExpectFalse(t, filter.CheckKeep(inCIDR, nil)) - ExpectFalse(t, filter.CheckKeep(notInCIDR, nil)) + expect.False(t, filter.CheckKeep(inCIDR, nil)) + expect.False(t, filter.CheckKeep(notInCIDR, nil)) filter.Values = cidr - ExpectFalse(t, filter.CheckKeep(inCIDR, nil)) - ExpectTrue(t, filter.CheckKeep(notInCIDR, nil)) + expect.False(t, filter.CheckKeep(inCIDR, nil)) + expect.True(t, filter.CheckKeep(notInCIDR, nil)) }) } diff --git a/internal/net/gphttp/accesslog/retention_test.go b/internal/net/gphttp/accesslog/retention_test.go index 2a2eb98..125039e 100644 --- a/internal/net/gphttp/accesslog/retention_test.go +++ b/internal/net/gphttp/accesslog/retention_test.go @@ -4,7 +4,7 @@ import ( "testing" . "github.com/yusing/go-proxy/internal/net/gphttp/accesslog" - . "github.com/yusing/go-proxy/internal/utils/testing" + expect "github.com/yusing/go-proxy/internal/utils/testing" ) func TestParseRetention(t *testing.T) { @@ -24,9 +24,9 @@ func TestParseRetention(t *testing.T) { r := &Retention{} err := r.Parse(test.input) if !test.shouldErr { - ExpectNoError(t, err) + expect.NoError(t, err) } else { - ExpectEqual(t, r, test.expected) + expect.Equal(t, r, test.expected) } }) } diff --git a/internal/net/gphttp/accesslog/rotate_test.go b/internal/net/gphttp/accesslog/rotate_test.go index 727a3cb..2fda5f2 100644 --- a/internal/net/gphttp/accesslog/rotate_test.go +++ b/internal/net/gphttp/accesslog/rotate_test.go @@ -8,7 +8,7 @@ import ( . "github.com/yusing/go-proxy/internal/net/gphttp/accesslog" "github.com/yusing/go-proxy/internal/task" "github.com/yusing/go-proxy/internal/utils/strutils" - . "github.com/yusing/go-proxy/internal/utils/testing" + expect "github.com/yusing/go-proxy/internal/utils/testing" ) func TestParseLogTime(t *testing.T) { @@ -26,7 +26,7 @@ func TestParseLogTime(t *testing.T) { for _, test := range tests { t.Run(test, func(t *testing.T) { actual := ParseLogTime([]byte(test)) - ExpectTrue(t, actual.Equal(testTime)) + expect.True(t, actual.Equal(testTime)) }) } } @@ -43,16 +43,16 @@ func TestRetentionCommonFormat(t *testing.T) { logger.Flush() // test.Finish(nil) - ExpectEqual(t, logger.Config().Retention, nil) - ExpectTrue(t, file.Len() > 0) - ExpectEqual(t, file.LineCount(), 10) + expect.Equal(t, logger.Config().Retention, nil) + expect.True(t, file.Len() > 0) + expect.Equal(t, file.LineCount(), 10) t.Run("keep last", func(t *testing.T) { logger.Config().Retention = strutils.MustParse[*Retention]("last 5") - ExpectEqual(t, logger.Config().Retention.Days, 0) - ExpectEqual(t, logger.Config().Retention.Last, 5) - ExpectNoError(t, logger.Rotate()) - ExpectEqual(t, file.LineCount(), 5) + expect.Equal(t, logger.Config().Retention.Days, 0) + expect.Equal(t, logger.Config().Retention.Last, 5) + expect.NoError(t, logger.Rotate()) + expect.Equal(t, file.LineCount(), 5) }) _ = file.Truncate(0) @@ -65,14 +65,14 @@ func TestRetentionCommonFormat(t *testing.T) { logger.Log(req, resp) } logger.Flush() - ExpectEqual(t, file.LineCount(), 10) + expect.Equal(t, file.LineCount(), 10) t.Run("keep days", func(t *testing.T) { logger.Config().Retention = strutils.MustParse[*Retention]("3 days") - ExpectEqual(t, logger.Config().Retention.Days, 3) - ExpectEqual(t, logger.Config().Retention.Last, 0) - ExpectNoError(t, logger.Rotate()) - ExpectEqual(t, file.LineCount(), 3) + expect.Equal(t, logger.Config().Retention.Days, 3) + expect.Equal(t, logger.Config().Retention.Last, 0) + expect.NoError(t, logger.Rotate()) + expect.Equal(t, file.LineCount(), 3) rotated := string(file.Content()) _ = file.Truncate(0) for i := range 3 { @@ -81,6 +81,6 @@ func TestRetentionCommonFormat(t *testing.T) { } logger.Log(req, resp) } - ExpectEqual(t, rotated, string(file.Content())) + expect.Equal(t, rotated, string(file.Content())) }) } diff --git a/internal/net/gphttp/content_type_test.go b/internal/net/gphttp/content_type_test.go index f5bba69..dc3b3cf 100644 --- a/internal/net/gphttp/content_type_test.go +++ b/internal/net/gphttp/content_type_test.go @@ -4,38 +4,38 @@ import ( "net/http" "testing" - . "github.com/yusing/go-proxy/internal/utils/testing" + expect "github.com/yusing/go-proxy/internal/utils/testing" ) func TestContentTypes(t *testing.T) { - ExpectTrue(t, GetContentType(http.Header{"Content-Type": {"text/html"}}).IsHTML()) - ExpectTrue(t, GetContentType(http.Header{"Content-Type": {"text/html; charset=utf-8"}}).IsHTML()) - ExpectTrue(t, GetContentType(http.Header{"Content-Type": {"application/xhtml+xml"}}).IsHTML()) - ExpectFalse(t, GetContentType(http.Header{"Content-Type": {"text/plain"}}).IsHTML()) + expect.True(t, GetContentType(http.Header{"Content-Type": {"text/html"}}).IsHTML()) + expect.True(t, GetContentType(http.Header{"Content-Type": {"text/html; charset=utf-8"}}).IsHTML()) + expect.True(t, GetContentType(http.Header{"Content-Type": {"application/xhtml+xml"}}).IsHTML()) + expect.False(t, GetContentType(http.Header{"Content-Type": {"text/plain"}}).IsHTML()) - ExpectTrue(t, GetContentType(http.Header{"Content-Type": {"application/json"}}).IsJSON()) - ExpectTrue(t, GetContentType(http.Header{"Content-Type": {"application/json; charset=utf-8"}}).IsJSON()) - ExpectFalse(t, GetContentType(http.Header{"Content-Type": {"text/html"}}).IsJSON()) + expect.True(t, GetContentType(http.Header{"Content-Type": {"application/json"}}).IsJSON()) + expect.True(t, GetContentType(http.Header{"Content-Type": {"application/json; charset=utf-8"}}).IsJSON()) + expect.False(t, GetContentType(http.Header{"Content-Type": {"text/html"}}).IsJSON()) - ExpectTrue(t, GetContentType(http.Header{"Content-Type": {"text/plain"}}).IsPlainText()) - ExpectTrue(t, GetContentType(http.Header{"Content-Type": {"text/plain; charset=utf-8"}}).IsPlainText()) - ExpectFalse(t, GetContentType(http.Header{"Content-Type": {"text/html"}}).IsPlainText()) + expect.True(t, GetContentType(http.Header{"Content-Type": {"text/plain"}}).IsPlainText()) + expect.True(t, GetContentType(http.Header{"Content-Type": {"text/plain; charset=utf-8"}}).IsPlainText()) + expect.False(t, GetContentType(http.Header{"Content-Type": {"text/html"}}).IsPlainText()) } func TestAcceptContentTypes(t *testing.T) { - ExpectTrue(t, GetAccept(http.Header{"Accept": {"text/html", "text/plain"}}).AcceptPlainText()) - ExpectTrue(t, GetAccept(http.Header{"Accept": {"text/html", "text/plain; charset=utf-8"}}).AcceptPlainText()) - ExpectTrue(t, GetAccept(http.Header{"Accept": {"text/html", "text/plain"}}).AcceptHTML()) - ExpectTrue(t, GetAccept(http.Header{"Accept": {"application/json"}}).AcceptJSON()) - ExpectTrue(t, GetAccept(http.Header{"Accept": {"*/*"}}).AcceptPlainText()) - ExpectTrue(t, GetAccept(http.Header{"Accept": {"*/*"}}).AcceptHTML()) - ExpectTrue(t, GetAccept(http.Header{"Accept": {"*/*"}}).AcceptJSON()) - ExpectTrue(t, GetAccept(http.Header{"Accept": {"text/*"}}).AcceptPlainText()) - ExpectTrue(t, GetAccept(http.Header{"Accept": {"text/*"}}).AcceptHTML()) + expect.True(t, GetAccept(http.Header{"Accept": {"text/html", "text/plain"}}).AcceptPlainText()) + expect.True(t, GetAccept(http.Header{"Accept": {"text/html", "text/plain; charset=utf-8"}}).AcceptPlainText()) + expect.True(t, GetAccept(http.Header{"Accept": {"text/html", "text/plain"}}).AcceptHTML()) + expect.True(t, GetAccept(http.Header{"Accept": {"application/json"}}).AcceptJSON()) + expect.True(t, GetAccept(http.Header{"Accept": {"*/*"}}).AcceptPlainText()) + expect.True(t, GetAccept(http.Header{"Accept": {"*/*"}}).AcceptHTML()) + expect.True(t, GetAccept(http.Header{"Accept": {"*/*"}}).AcceptJSON()) + expect.True(t, GetAccept(http.Header{"Accept": {"text/*"}}).AcceptPlainText()) + expect.True(t, GetAccept(http.Header{"Accept": {"text/*"}}).AcceptHTML()) - ExpectFalse(t, GetAccept(http.Header{"Accept": {"text/plain"}}).AcceptHTML()) - ExpectFalse(t, GetAccept(http.Header{"Accept": {"text/plain; charset=utf-8"}}).AcceptHTML()) - ExpectFalse(t, GetAccept(http.Header{"Accept": {"text/html"}}).AcceptPlainText()) - ExpectFalse(t, GetAccept(http.Header{"Accept": {"text/html"}}).AcceptJSON()) - ExpectFalse(t, GetAccept(http.Header{"Accept": {"text/*"}}).AcceptJSON()) + expect.False(t, GetAccept(http.Header{"Accept": {"text/plain"}}).AcceptHTML()) + expect.False(t, GetAccept(http.Header{"Accept": {"text/plain; charset=utf-8"}}).AcceptHTML()) + expect.False(t, GetAccept(http.Header{"Accept": {"text/html"}}).AcceptPlainText()) + expect.False(t, GetAccept(http.Header{"Accept": {"text/html"}}).AcceptJSON()) + expect.False(t, GetAccept(http.Header{"Accept": {"text/*"}}).AcceptJSON()) } diff --git a/internal/net/gphttp/loadbalancer/loadbalancer_test.go b/internal/net/gphttp/loadbalancer/loadbalancer_test.go index 03f2bfc..6413dd3 100644 --- a/internal/net/gphttp/loadbalancer/loadbalancer_test.go +++ b/internal/net/gphttp/loadbalancer/loadbalancer_test.go @@ -4,7 +4,7 @@ import ( "testing" "github.com/yusing/go-proxy/internal/net/gphttp/loadbalancer/types" - . "github.com/yusing/go-proxy/internal/utils/testing" + expect "github.com/yusing/go-proxy/internal/utils/testing" ) func TestRebalance(t *testing.T) { @@ -15,7 +15,7 @@ func TestRebalance(t *testing.T) { lb.AddServer(types.TestNewServer(0)) } lb.rebalance() - ExpectEqual(t, lb.sumWeight, maxWeight) + expect.Equal(t, lb.sumWeight, maxWeight) }) t.Run("less", func(t *testing.T) { lb := New(new(types.Config)) @@ -26,7 +26,7 @@ func TestRebalance(t *testing.T) { lb.AddServer(types.TestNewServer(float64(maxWeight) * .1)) lb.rebalance() // t.Logf("%s", U.Must(json.MarshalIndent(lb.pool, "", " "))) - ExpectEqual(t, lb.sumWeight, maxWeight) + expect.Equal(t, lb.sumWeight, maxWeight) }) t.Run("more", func(t *testing.T) { lb := New(new(types.Config)) @@ -39,6 +39,6 @@ func TestRebalance(t *testing.T) { lb.AddServer(types.TestNewServer(float64(maxWeight) * .1)) lb.rebalance() // t.Logf("%s", U.Must(json.MarshalIndent(lb.pool, "", " "))) - ExpectEqual(t, lb.sumWeight, maxWeight) + expect.Equal(t, lb.sumWeight, maxWeight) }) } diff --git a/internal/net/gphttp/middleware/cidr_whitelist_test.go b/internal/net/gphttp/middleware/cidr_whitelist_test.go index a8c7cee..e6dd123 100644 --- a/internal/net/gphttp/middleware/cidr_whitelist_test.go +++ b/internal/net/gphttp/middleware/cidr_whitelist_test.go @@ -9,7 +9,7 @@ import ( "github.com/yusing/go-proxy/internal/gperr" "github.com/yusing/go-proxy/internal/utils" - . "github.com/yusing/go-proxy/internal/utils/testing" + expect "github.com/yusing/go-proxy/internal/utils/testing" ) //go:embed test_data/cidr_whitelist_test.yml @@ -23,32 +23,32 @@ func TestCIDRWhitelistValidation(t *testing.T) { "allow": []string{"192.168.2.100/32"}, "message": testMessage, }) - ExpectNoError(t, err) + expect.NoError(t, err) _, err = CIDRWhiteList.New(OptionsRaw{ "allow": []string{"192.168.2.100/32"}, "message": testMessage, "status": 403, }) - ExpectNoError(t, err) + expect.NoError(t, err) _, err = CIDRWhiteList.New(OptionsRaw{ "allow": []string{"192.168.2.100/32"}, "message": testMessage, "status_code": 403, }) - ExpectNoError(t, err) + expect.NoError(t, err) }) t.Run("missing allow", func(t *testing.T) { _, err := CIDRWhiteList.New(OptionsRaw{ "message": testMessage, }) - ExpectError(t, utils.ErrValidationError, err) + expect.ErrorIs(t, utils.ErrValidationError, err) }) t.Run("invalid cidr", func(t *testing.T) { _, err := CIDRWhiteList.New(OptionsRaw{ "allow": []string{"192.168.2.100/123"}, "message": testMessage, }) - ExpectErrorT[*net.ParseError](t, err) + expect.ErrorT[*net.ParseError](t, err) }) t.Run("invalid status code", func(t *testing.T) { _, err := CIDRWhiteList.New(OptionsRaw{ @@ -56,14 +56,14 @@ func TestCIDRWhitelistValidation(t *testing.T) { "status_code": 600, "message": testMessage, }) - ExpectError(t, utils.ErrValidationError, err) + expect.ErrorIs(t, utils.ErrValidationError, err) }) } func TestCIDRWhitelist(t *testing.T) { errs := gperr.NewBuilder("") mids := BuildMiddlewaresFromYAML("", testCIDRWhitelistCompose, errs) - ExpectNoError(t, errs.Error()) + expect.NoError(t, errs.Error()) deny = mids["deny@file"] accept = mids["accept@file"] if deny == nil || accept == nil { @@ -74,9 +74,9 @@ func TestCIDRWhitelist(t *testing.T) { t.Parallel() for range 10 { result, err := newMiddlewareTest(deny, nil) - ExpectNoError(t, err) - ExpectEqual(t, result.ResponseStatus, cidrWhitelistDefaults.StatusCode) - ExpectEqual(t, strings.TrimSpace(string(result.Data)), cidrWhitelistDefaults.Message) + expect.NoError(t, err) + expect.Equal(t, result.ResponseStatus, cidrWhitelistDefaults.StatusCode) + expect.Equal(t, strings.TrimSpace(string(result.Data)), cidrWhitelistDefaults.Message) } }) @@ -84,8 +84,8 @@ func TestCIDRWhitelist(t *testing.T) { t.Parallel() for range 10 { result, err := newMiddlewareTest(accept, nil) - ExpectNoError(t, err) - ExpectEqual(t, result.ResponseStatus, http.StatusOK) + expect.NoError(t, err) + expect.Equal(t, result.ResponseStatus, http.StatusOK) } }) } diff --git a/internal/net/gphttp/middleware/middleware_builder_test.go b/internal/net/gphttp/middleware/middleware_builder_test.go index 3036ca6..59822c6 100644 --- a/internal/net/gphttp/middleware/middleware_builder_test.go +++ b/internal/net/gphttp/middleware/middleware_builder_test.go @@ -7,7 +7,7 @@ import ( "github.com/yusing/go-proxy/pkg/json" "github.com/yusing/go-proxy/internal/gperr" - . "github.com/yusing/go-proxy/internal/utils/testing" + expect "github.com/yusing/go-proxy/internal/utils/testing" ) //go:embed test_data/middleware_compose.yml @@ -16,7 +16,7 @@ var testMiddlewareCompose []byte func TestBuild(t *testing.T) { errs := gperr.NewBuilder("") middlewares := BuildMiddlewaresFromYAML("", testMiddlewareCompose, errs) - ExpectNoError(t, errs.Error()) + expect.NoError(t, errs.Error()) json.Marshal(middlewares) // t.Log(string(data)) // TODO: test diff --git a/internal/net/gphttp/middleware/middleware_test.go b/internal/net/gphttp/middleware/middleware_test.go index 5b6e521..9f09452 100644 --- a/internal/net/gphttp/middleware/middleware_test.go +++ b/internal/net/gphttp/middleware/middleware_test.go @@ -6,7 +6,7 @@ import ( "strings" "testing" - . "github.com/yusing/go-proxy/internal/utils/testing" + expect "github.com/yusing/go-proxy/internal/utils/testing" ) type testPriority struct { @@ -28,10 +28,10 @@ func TestMiddlewarePriority(t *testing.T) { "priority": p, "value": i, }) - ExpectNoError(t, err) + expect.NoError(t, err) chain[i] = mid } res, err := newMiddlewaresTest(chain, nil) - ExpectNoError(t, err) - ExpectEqual(t, strings.Join(res.ResponseHeaders["Test-Value"], ","), "3,0,1,2") + expect.NoError(t, err) + expect.Equal(t, strings.Join(res.ResponseHeaders["Test-Value"], ","), "3,0,1,2") } diff --git a/internal/net/gphttp/middleware/modify_request_test.go b/internal/net/gphttp/middleware/modify_request_test.go index aa67d6d..50a582d 100644 --- a/internal/net/gphttp/middleware/modify_request_test.go +++ b/internal/net/gphttp/middleware/modify_request_test.go @@ -8,7 +8,7 @@ import ( "slices" "testing" - . "github.com/yusing/go-proxy/internal/utils/testing" + expect "github.com/yusing/go-proxy/internal/utils/testing" ) func TestModifyRequest(t *testing.T) { @@ -44,15 +44,15 @@ func TestModifyRequest(t *testing.T) { t.Run("set_options", func(t *testing.T) { mr, err := ModifyRequest.New(opts) - ExpectNoError(t, err) - ExpectEqual(t, mr.impl.(*modifyRequest).SetHeaders, opts["set_headers"].(map[string]string)) - ExpectEqual(t, mr.impl.(*modifyRequest).AddHeaders, opts["add_headers"].(map[string]string)) - ExpectEqual(t, mr.impl.(*modifyRequest).HideHeaders, opts["hide_headers"].([]string)) + expect.NoError(t, err) + expect.Equal(t, mr.impl.(*modifyRequest).SetHeaders, opts["set_headers"].(map[string]string)) + expect.Equal(t, mr.impl.(*modifyRequest).AddHeaders, opts["add_headers"].(map[string]string)) + expect.Equal(t, mr.impl.(*modifyRequest).HideHeaders, opts["hide_headers"].([]string)) }) t.Run("request_headers", func(t *testing.T) { - reqURL := Must(url.Parse("https://my.app/?arg_1=b")) - upstreamURL := Must(url.Parse("http://test.example.com")) + reqURL := expect.Must(url.Parse("https://my.app/?arg_1=b")) + upstreamURL := expect.Must(url.Parse("http://test.example.com")) result, err := newMiddlewareTest(ModifyRequest, &testArgs{ middlewareOpt: opts, reqURL: reqURL, @@ -62,38 +62,38 @@ func TestModifyRequest(t *testing.T) { "Content-Type": []string{"application/json"}, }, }) - ExpectNoError(t, err) - ExpectEqual(t, result.RequestHeaders.Get("User-Agent"), "go-proxy/v0.5.0") - ExpectEqual(t, result.RequestHeaders.Get("Host"), "test.example.com") - ExpectTrue(t, slices.Contains(result.RequestHeaders.Values("Accept-Encoding"), "test-value")) - ExpectEqual(t, result.RequestHeaders.Get("Accept"), "") + expect.NoError(t, err) + expect.Equal(t, result.RequestHeaders.Get("User-Agent"), "go-proxy/v0.5.0") + expect.Equal(t, result.RequestHeaders.Get("Host"), "test.example.com") + expect.True(t, slices.Contains(result.RequestHeaders.Values("Accept-Encoding"), "test-value")) + expect.Equal(t, result.RequestHeaders.Get("Accept"), "") - ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Method"), "GET") - ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Scheme"), reqURL.Scheme) - ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Host"), reqURL.Hostname()) - ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Port"), reqURL.Port()) - ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Addr"), reqURL.Host) - ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Path"), reqURL.Path) - ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Query"), reqURL.RawQuery) - ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Url"), reqURL.String()) - ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Uri"), reqURL.RequestURI()) - ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Content-Type"), "application/json") - ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Content-Length"), "100") + expect.Equal(t, result.RequestHeaders.Get("X-Test-Req-Method"), "GET") + expect.Equal(t, result.RequestHeaders.Get("X-Test-Req-Scheme"), reqURL.Scheme) + expect.Equal(t, result.RequestHeaders.Get("X-Test-Req-Host"), reqURL.Hostname()) + expect.Equal(t, result.RequestHeaders.Get("X-Test-Req-Port"), reqURL.Port()) + expect.Equal(t, result.RequestHeaders.Get("X-Test-Req-Addr"), reqURL.Host) + expect.Equal(t, result.RequestHeaders.Get("X-Test-Req-Path"), reqURL.Path) + expect.Equal(t, result.RequestHeaders.Get("X-Test-Req-Query"), reqURL.RawQuery) + expect.Equal(t, result.RequestHeaders.Get("X-Test-Req-Url"), reqURL.String()) + expect.Equal(t, result.RequestHeaders.Get("X-Test-Req-Uri"), reqURL.RequestURI()) + expect.Equal(t, result.RequestHeaders.Get("X-Test-Req-Content-Type"), "application/json") + expect.Equal(t, result.RequestHeaders.Get("X-Test-Req-Content-Length"), "100") remoteHost, remotePort, _ := net.SplitHostPort(result.RemoteAddr) - ExpectEqual(t, result.RequestHeaders.Get("X-Test-Remote-Host"), remoteHost) - ExpectEqual(t, result.RequestHeaders.Get("X-Test-Remote-Port"), remotePort) - ExpectEqual(t, result.RequestHeaders.Get("X-Test-Remote-Addr"), result.RemoteAddr) + expect.Equal(t, result.RequestHeaders.Get("X-Test-Remote-Host"), remoteHost) + expect.Equal(t, result.RequestHeaders.Get("X-Test-Remote-Port"), remotePort) + expect.Equal(t, result.RequestHeaders.Get("X-Test-Remote-Addr"), result.RemoteAddr) - ExpectEqual(t, result.RequestHeaders.Get("X-Test-Upstream-Scheme"), upstreamURL.Scheme) - ExpectEqual(t, result.RequestHeaders.Get("X-Test-Upstream-Host"), upstreamURL.Hostname()) - ExpectEqual(t, result.RequestHeaders.Get("X-Test-Upstream-Port"), upstreamURL.Port()) - ExpectEqual(t, result.RequestHeaders.Get("X-Test-Upstream-Addr"), upstreamURL.Host) - ExpectEqual(t, result.RequestHeaders.Get("X-Test-Upstream-Url"), upstreamURL.String()) + expect.Equal(t, result.RequestHeaders.Get("X-Test-Upstream-Scheme"), upstreamURL.Scheme) + expect.Equal(t, result.RequestHeaders.Get("X-Test-Upstream-Host"), upstreamURL.Hostname()) + expect.Equal(t, result.RequestHeaders.Get("X-Test-Upstream-Port"), upstreamURL.Port()) + expect.Equal(t, result.RequestHeaders.Get("X-Test-Upstream-Addr"), upstreamURL.Host) + expect.Equal(t, result.RequestHeaders.Get("X-Test-Upstream-Url"), upstreamURL.String()) - ExpectEqual(t, result.RequestHeaders.Get("X-Test-Header-Content-Type"), "application/json") + expect.Equal(t, result.RequestHeaders.Get("X-Test-Header-Content-Type"), "application/json") - ExpectEqual(t, result.RequestHeaders.Get("X-Test-Arg-Arg_1"), "b") + expect.Equal(t, result.RequestHeaders.Get("X-Test-Arg-Arg_1"), "b") }) t.Run("add_prefix", func(t *testing.T) { @@ -128,8 +128,8 @@ func TestModifyRequest(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - reqURL := Must(url.Parse("https://my.app" + tt.path)) - upstreamURL := Must(url.Parse(tt.upstreamURL)) + reqURL := expect.Must(url.Parse("https://my.app" + tt.path)) + upstreamURL := expect.Must(url.Parse(tt.upstreamURL)) opts["add_prefix"] = tt.addPrefix result, err := newMiddlewareTest(ModifyRequest, &testArgs{ @@ -137,8 +137,8 @@ func TestModifyRequest(t *testing.T) { reqURL: reqURL, upstreamURL: upstreamURL, }) - ExpectNoError(t, err) - ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Path"), tt.expectedPath) + expect.NoError(t, err) + expect.Equal(t, result.RequestHeaders.Get("X-Test-Req-Path"), tt.expectedPath) }) } }) diff --git a/internal/net/gphttp/middleware/modify_response_test.go b/internal/net/gphttp/middleware/modify_response_test.go index 9f426b2..3d09a61 100644 --- a/internal/net/gphttp/middleware/modify_response_test.go +++ b/internal/net/gphttp/middleware/modify_response_test.go @@ -8,7 +8,7 @@ import ( "slices" "testing" - . "github.com/yusing/go-proxy/internal/utils/testing" + expect "github.com/yusing/go-proxy/internal/utils/testing" ) func TestModifyResponse(t *testing.T) { @@ -47,15 +47,15 @@ func TestModifyResponse(t *testing.T) { t.Run("set_options", func(t *testing.T) { mr, err := ModifyResponse.New(opts) - ExpectNoError(t, err) - ExpectEqual(t, mr.impl.(*modifyResponse).SetHeaders, opts["set_headers"].(map[string]string)) - ExpectEqual(t, mr.impl.(*modifyResponse).AddHeaders, opts["add_headers"].(map[string]string)) - ExpectEqual(t, mr.impl.(*modifyResponse).HideHeaders, opts["hide_headers"].([]string)) + expect.NoError(t, err) + expect.Equal(t, mr.impl.(*modifyResponse).SetHeaders, opts["set_headers"].(map[string]string)) + expect.Equal(t, mr.impl.(*modifyResponse).AddHeaders, opts["add_headers"].(map[string]string)) + expect.Equal(t, mr.impl.(*modifyResponse).HideHeaders, opts["hide_headers"].([]string)) }) t.Run("response_headers", func(t *testing.T) { - reqURL := Must(url.Parse("https://my.app/?arg_1=b")) - upstreamURL := Must(url.Parse("http://test.example.com")) + reqURL := expect.Must(url.Parse("https://my.app/?arg_1=b")) + upstreamURL := expect.Must(url.Parse("http://test.example.com")) result, err := newMiddlewareTest(ModifyResponse, &testArgs{ middlewareOpt: opts, reqURL: reqURL, @@ -70,39 +70,39 @@ func TestModifyResponse(t *testing.T) { respBody: bytes.Repeat([]byte("a"), 50), respStatus: http.StatusOK, }) - ExpectNoError(t, err) - ExpectTrue(t, slices.Contains(result.ResponseHeaders.Values("Accept-Encoding"), "test-value")) - ExpectEqual(t, result.ResponseHeaders.Get("Accept"), "") + expect.NoError(t, err) + expect.True(t, slices.Contains(result.ResponseHeaders.Values("Accept-Encoding"), "test-value")) + expect.Equal(t, result.ResponseHeaders.Get("Accept"), "") - ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Resp-Status"), "200") - ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Resp-Content-Type"), "application/json") - ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Resp-Content-Length"), "50") - ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Resp-Header-Content-Type"), "application/json") + expect.Equal(t, result.ResponseHeaders.Get("X-Test-Resp-Status"), "200") + expect.Equal(t, result.ResponseHeaders.Get("X-Test-Resp-Content-Type"), "application/json") + expect.Equal(t, result.ResponseHeaders.Get("X-Test-Resp-Content-Length"), "50") + expect.Equal(t, result.ResponseHeaders.Get("X-Test-Resp-Header-Content-Type"), "application/json") - ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Method"), http.MethodGet) - ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Scheme"), reqURL.Scheme) - ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Host"), reqURL.Hostname()) - ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Port"), reqURL.Port()) - ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Addr"), reqURL.Host) - ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Path"), reqURL.Path) - ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Query"), reqURL.RawQuery) - ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Url"), reqURL.String()) - ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Uri"), reqURL.RequestURI()) - ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Content-Type"), "application/json") - ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Content-Length"), "100") + expect.Equal(t, result.ResponseHeaders.Get("X-Test-Req-Method"), http.MethodGet) + expect.Equal(t, result.ResponseHeaders.Get("X-Test-Req-Scheme"), reqURL.Scheme) + expect.Equal(t, result.ResponseHeaders.Get("X-Test-Req-Host"), reqURL.Hostname()) + expect.Equal(t, result.ResponseHeaders.Get("X-Test-Req-Port"), reqURL.Port()) + expect.Equal(t, result.ResponseHeaders.Get("X-Test-Req-Addr"), reqURL.Host) + expect.Equal(t, result.ResponseHeaders.Get("X-Test-Req-Path"), reqURL.Path) + expect.Equal(t, result.ResponseHeaders.Get("X-Test-Req-Query"), reqURL.RawQuery) + expect.Equal(t, result.ResponseHeaders.Get("X-Test-Req-Url"), reqURL.String()) + expect.Equal(t, result.ResponseHeaders.Get("X-Test-Req-Uri"), reqURL.RequestURI()) + expect.Equal(t, result.ResponseHeaders.Get("X-Test-Req-Content-Type"), "application/json") + expect.Equal(t, result.ResponseHeaders.Get("X-Test-Req-Content-Length"), "100") remoteHost, remotePort, _ := net.SplitHostPort(result.RemoteAddr) - ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Remote-Host"), remoteHost) - ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Remote-Port"), remotePort) - ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Remote-Addr"), result.RemoteAddr) + expect.Equal(t, result.ResponseHeaders.Get("X-Test-Remote-Host"), remoteHost) + expect.Equal(t, result.ResponseHeaders.Get("X-Test-Remote-Port"), remotePort) + expect.Equal(t, result.ResponseHeaders.Get("X-Test-Remote-Addr"), result.RemoteAddr) - ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Upstream-Scheme"), upstreamURL.Scheme) - ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Upstream-Host"), upstreamURL.Hostname()) - ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Upstream-Port"), upstreamURL.Port()) - ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Upstream-Addr"), upstreamURL.Host) - ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Upstream-Url"), upstreamURL.String()) + expect.Equal(t, result.ResponseHeaders.Get("X-Test-Upstream-Scheme"), upstreamURL.Scheme) + expect.Equal(t, result.ResponseHeaders.Get("X-Test-Upstream-Host"), upstreamURL.Hostname()) + expect.Equal(t, result.ResponseHeaders.Get("X-Test-Upstream-Port"), upstreamURL.Port()) + expect.Equal(t, result.ResponseHeaders.Get("X-Test-Upstream-Addr"), upstreamURL.Host) + expect.Equal(t, result.ResponseHeaders.Get("X-Test-Upstream-Url"), upstreamURL.String()) - ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Header-Content-Type"), "application/json") - ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Arg-Arg_1"), "b") + expect.Equal(t, result.ResponseHeaders.Get("X-Test-Header-Content-Type"), "application/json") + expect.Equal(t, result.ResponseHeaders.Get("X-Test-Arg-Arg_1"), "b") }) } diff --git a/internal/net/gphttp/middleware/rate_limit_test.go b/internal/net/gphttp/middleware/rate_limit_test.go index 1264997..b472946 100644 --- a/internal/net/gphttp/middleware/rate_limit_test.go +++ b/internal/net/gphttp/middleware/rate_limit_test.go @@ -4,7 +4,7 @@ import ( "net/http" "testing" - . "github.com/yusing/go-proxy/internal/utils/testing" + expect "github.com/yusing/go-proxy/internal/utils/testing" ) func TestRateLimit(t *testing.T) { @@ -15,13 +15,13 @@ func TestRateLimit(t *testing.T) { } rl, err := RateLimiter.New(opts) - ExpectNoError(t, err) + expect.NoError(t, err) for range 10 { result, err := newMiddlewareTest(rl, nil) - ExpectNoError(t, err) - ExpectEqual(t, result.ResponseStatus, http.StatusOK) + expect.NoError(t, err) + expect.Equal(t, result.ResponseStatus, http.StatusOK) } result, err := newMiddlewareTest(rl, nil) - ExpectNoError(t, err) - ExpectEqual(t, result.ResponseStatus, http.StatusTooManyRequests) + expect.NoError(t, err) + expect.Equal(t, result.ResponseStatus, http.StatusTooManyRequests) } diff --git a/internal/net/gphttp/middleware/real_ip_test.go b/internal/net/gphttp/middleware/real_ip_test.go index a70a182..419cd50 100644 --- a/internal/net/gphttp/middleware/real_ip_test.go +++ b/internal/net/gphttp/middleware/real_ip_test.go @@ -7,7 +7,7 @@ import ( "testing" "github.com/yusing/go-proxy/internal/net/gphttp/httpheaders" - . "github.com/yusing/go-proxy/internal/utils/testing" + expect "github.com/yusing/go-proxy/internal/utils/testing" ) func TestSetRealIPOpts(t *testing.T) { @@ -40,11 +40,11 @@ func TestSetRealIPOpts(t *testing.T) { } ri, err := RealIP.New(opts) - ExpectNoError(t, err) - ExpectEqual(t, ri.impl.(*realIP).Header, optExpected.Header) - ExpectEqual(t, ri.impl.(*realIP).Recursive, optExpected.Recursive) + expect.NoError(t, err) + expect.Equal(t, ri.impl.(*realIP).Header, optExpected.Header) + expect.Equal(t, ri.impl.(*realIP).Recursive, optExpected.Recursive) for i, CIDR := range ri.impl.(*realIP).From { - ExpectEqual(t, CIDR.String(), optExpected.From[i].String()) + expect.Equal(t, CIDR.String(), optExpected.From[i].String()) } } @@ -61,16 +61,16 @@ func TestSetRealIP(t *testing.T) { "set_headers": map[string]string{testHeader: testRealIP}, } realip, err := RealIP.New(opts) - ExpectNoError(t, err) + expect.NoError(t, err) mr, err := ModifyRequest.New(optsMr) - ExpectNoError(t, err) + expect.NoError(t, err) mid := NewMiddlewareChain("test", []*Middleware{mr, realip}) result, err := newMiddlewareTest(mid, nil) - ExpectNoError(t, err) + expect.NoError(t, err) t.Log(traces) - ExpectEqual(t, result.ResponseStatus, http.StatusOK) - ExpectEqual(t, strings.Split(result.RemoteAddr, ":")[0], testRealIP) + expect.Equal(t, result.ResponseStatus, http.StatusOK) + expect.Equal(t, strings.Split(result.RemoteAddr, ":")[0], testRealIP) } diff --git a/internal/net/gphttp/middleware/redirect_http_test.go b/internal/net/gphttp/middleware/redirect_http_test.go index 82dfb7c..18374a1 100644 --- a/internal/net/gphttp/middleware/redirect_http_test.go +++ b/internal/net/gphttp/middleware/redirect_http_test.go @@ -5,22 +5,22 @@ import ( "net/url" "testing" - . "github.com/yusing/go-proxy/internal/utils/testing" + expect "github.com/yusing/go-proxy/internal/utils/testing" ) func TestRedirectToHTTPs(t *testing.T) { result, err := newMiddlewareTest(RedirectHTTP, &testArgs{ - reqURL: Must(url.Parse("http://example.com")), + reqURL: expect.Must(url.Parse("http://example.com")), }) - ExpectNoError(t, err) - ExpectEqual(t, result.ResponseStatus, http.StatusPermanentRedirect) - ExpectEqual(t, result.ResponseHeaders.Get("Location"), "https://example.com") + expect.NoError(t, err) + expect.Equal(t, result.ResponseStatus, http.StatusPermanentRedirect) + expect.Equal(t, result.ResponseHeaders.Get("Location"), "https://example.com") } func TestNoRedirect(t *testing.T) { result, err := newMiddlewareTest(RedirectHTTP, &testArgs{ - reqURL: Must(url.Parse("https://example.com")), + reqURL: expect.Must(url.Parse("https://example.com")), }) - ExpectNoError(t, err) - ExpectEqual(t, result.ResponseStatus, http.StatusOK) + expect.NoError(t, err) + expect.Equal(t, result.ResponseStatus, http.StatusOK) } diff --git a/internal/net/gphttp/middleware/test_utils.go b/internal/net/gphttp/middleware/test_utils.go index 3448734..888e4ab 100644 --- a/internal/net/gphttp/middleware/test_utils.go +++ b/internal/net/gphttp/middleware/test_utils.go @@ -8,11 +8,11 @@ import ( "net/http/httptest" "net/url" - "github.com/yusing/go-proxy/pkg/json" "github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/gperr" "github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy" - . "github.com/yusing/go-proxy/internal/utils/testing" + expect "github.com/yusing/go-proxy/internal/utils/testing" + "github.com/yusing/go-proxy/pkg/json" ) //go:embed test_data/sample_headers.json @@ -96,13 +96,13 @@ type testArgs struct { func (args *testArgs) setDefaults() { if args.reqURL == nil { - args.reqURL = Must(url.Parse("https://example.com")) + args.reqURL = expect.Must(url.Parse("https://example.com")) } if args.reqMethod == "" { args.reqMethod = http.MethodGet } if args.upstreamURL == nil { - args.upstreamURL = Must(url.Parse("https://10.0.0.1:8443")) // dummy url, no actual effect + args.upstreamURL = expect.Must(url.Parse("https://10.0.0.1:8443")) // dummy url, no actual effect } if args.respHeaders == nil { args.respHeaders = http.Header{} diff --git a/internal/net/ping_test.go b/internal/net/ping_test.go index ca6f7f2..e172e62 100644 --- a/internal/net/ping_test.go +++ b/internal/net/ping_test.go @@ -7,7 +7,7 @@ import ( "os" "testing" - "github.com/stretchr/testify/require" + expect "github.com/yusing/go-proxy/internal/utils/testing" ) func TestPing(t *testing.T) { @@ -17,7 +17,7 @@ func TestPing(t *testing.T) { if errors.Is(err, os.ErrPermission) { t.Skip("permission denied") } - require.NoError(t, err) - require.True(t, ok) + expect.NoError(t, err) + expect.True(t, ok) }) } diff --git a/internal/notif/config_test.go b/internal/notif/config_test.go index 1ade43e..86e2e88 100644 --- a/internal/notif/config_test.go +++ b/internal/notif/config_test.go @@ -5,7 +5,7 @@ import ( "testing" "github.com/yusing/go-proxy/internal/utils" - . "github.com/yusing/go-proxy/internal/utils/testing" + expect "github.com/yusing/go-proxy/internal/utils/testing" ) func TestNotificationConfig(t *testing.T) { @@ -152,11 +152,11 @@ func TestNotificationConfig(t *testing.T) { provider := tt.cfg["provider"] err := utils.MapUnmarshalValidate(tt.cfg, &cfg) if tt.wantErr { - ExpectHasError(t, err) + expect.HasError(t, err) } else { - ExpectNoError(t, err) - ExpectEqual(t, provider.(string), cfg.ProviderName) - ExpectEqual(t, cfg.Provider, tt.expected) + expect.NoError(t, err) + expect.Equal(t, provider.(string), cfg.ProviderName) + expect.Equal(t, cfg.Provider, tt.expected) } }) } diff --git a/internal/route/fileserver_test.go b/internal/route/fileserver_test.go index f93e523..62ce7ee 100644 --- a/internal/route/fileserver_test.go +++ b/internal/route/fileserver_test.go @@ -10,7 +10,7 @@ import ( "path/filepath" "testing" - . "github.com/yusing/go-proxy/internal/utils/testing" + expect "github.com/yusing/go-proxy/internal/utils/testing" ) func TestPathTraversalAttack(t *testing.T) { @@ -108,7 +108,7 @@ func TestPathTraversalAttack(t *testing.T) { t.Errorf("Expected status 404 or 400, got %d in url %s", resp.StatusCode, u.String()) } - u = Must(url.Parse(ts.URL + "/" + p)) + u = expect.Must(url.Parse(ts.URL + "/" + p)) resp, err = http.DefaultClient.Do(&http.Request{ Method: http.MethodGet, URL: u, diff --git a/internal/route/provider/docker_labels_test.go b/internal/route/provider/docker_labels_test.go index 4d5b221..ad864b1 100644 --- a/internal/route/provider/docker_labels_test.go +++ b/internal/route/provider/docker_labels_test.go @@ -3,10 +3,10 @@ package provider import ( "testing" - "github.com/docker/docker/api/types" + "github.com/docker/docker/api/types/container" "github.com/goccy/go-yaml" "github.com/yusing/go-proxy/internal/docker" - . "github.com/yusing/go-proxy/internal/utils/testing" + expect "github.com/yusing/go-proxy/internal/utils/testing" _ "embed" ) @@ -18,19 +18,19 @@ func TestParseDockerLabels(t *testing.T) { var provider DockerProvider labels := make(map[string]string) - ExpectNoError(t, yaml.Unmarshal(testDockerLabelsYAML, &labels)) + expect.NoError(t, yaml.Unmarshal(testDockerLabelsYAML, &labels)) routes, err := provider.routesFromContainerLabels( - docker.FromDocker(&types.Container{ + docker.FromDocker(&container.Summary{ Names: []string{"container"}, Labels: labels, State: "running", - Ports: []types.Port{ + Ports: []container.Port{ {Type: "tcp", PrivatePort: 1234, PublicPort: 1234}, }, }, "/var/run/docker.sock"), ) - ExpectNoError(t, err) - ExpectTrue(t, routes.Contains("app")) - ExpectTrue(t, routes.Contains("app1")) + expect.NoError(t, err) + expect.True(t, routes.Contains("app")) + expect.True(t, routes.Contains("app1")) } diff --git a/internal/route/provider/docker_test.go b/internal/route/provider/docker_test.go index 8c79950..929df17 100644 --- a/internal/route/provider/docker_test.go +++ b/internal/route/provider/docker_test.go @@ -10,7 +10,7 @@ import ( D "github.com/yusing/go-proxy/internal/docker" "github.com/yusing/go-proxy/internal/route" T "github.com/yusing/go-proxy/internal/route/types" - . "github.com/yusing/go-proxy/internal/utils/testing" + expect "github.com/yusing/go-proxy/internal/utils/testing" ) var dummyNames = []string{"/a"} @@ -30,7 +30,7 @@ func makeRoutes(cont *container.Summary, dockerHostIP ...string) route.Routes { } cont.ID = "test" p.name = "test" - routes := Must(p.routesFromContainerLabels(D.FromDocker(cont, host))) + routes := expect.Must(p.routesFromContainerLabels(D.FromDocker(cont, host))) for _, r := range routes { r.Finalize() } @@ -39,7 +39,7 @@ func makeRoutes(cont *container.Summary, dockerHostIP ...string) route.Routes { func TestExplicitOnly(t *testing.T) { p := NewDockerProvider("a!", "") - ExpectTrue(t, p.IsExplicitOnly()) + expect.True(t, p.IsExplicitOnly()) } func TestApplyLabel(t *testing.T) { @@ -87,46 +87,46 @@ func TestApplyLabel(t *testing.T) { }) a, ok := entries["a"] - ExpectTrue(t, ok) + expect.True(t, ok) b, ok := entries["b"] - ExpectTrue(t, ok) + expect.True(t, ok) - ExpectEqual(t, a.Scheme, "https") - ExpectEqual(t, b.Scheme, "https") + expect.Equal(t, a.Scheme, "https") + expect.Equal(t, b.Scheme, "https") - ExpectEqual(t, a.Host, "app") - ExpectEqual(t, b.Host, "app") + expect.Equal(t, a.Host, "app") + expect.Equal(t, b.Host, "app") - ExpectEqual(t, a.Port.Proxy, 4567) - ExpectEqual(t, b.Port.Proxy, 4567) + expect.Equal(t, a.Port.Proxy, 4567) + expect.Equal(t, b.Port.Proxy, 4567) - ExpectTrue(t, a.NoTLSVerify) - ExpectTrue(t, b.NoTLSVerify) + expect.True(t, a.NoTLSVerify) + expect.True(t, b.NoTLSVerify) - ExpectEqual(t, a.PathPatterns, pathPatternsExpect) - ExpectEqual(t, len(b.PathPatterns), 0) + expect.Equal(t, a.PathPatterns, pathPatternsExpect) + expect.Equal(t, len(b.PathPatterns), 0) - ExpectEqual(t, a.Middlewares, middlewaresExpect) - ExpectEqual(t, len(b.Middlewares), 0) + expect.Equal(t, a.Middlewares, middlewaresExpect) + expect.Equal(t, len(b.Middlewares), 0) - ExpectEqual(t, a.Container.IdlewatcherConfig.IdleTimeout, 0) - ExpectEqual(t, b.Container.IdlewatcherConfig.IdleTimeout, 0) - ExpectEqual(t, a.Container.IdlewatcherConfig.StopTimeout, time.Hour) - ExpectEqual(t, b.Container.IdlewatcherConfig.StopTimeout, time.Hour) - ExpectEqual(t, a.Container.IdlewatcherConfig.StopMethod, "stop") - ExpectEqual(t, b.Container.IdlewatcherConfig.StopMethod, "stop") - ExpectEqual(t, a.Container.IdlewatcherConfig.WakeTimeout, 10*time.Second) - ExpectEqual(t, b.Container.IdlewatcherConfig.WakeTimeout, 10*time.Second) - ExpectEqual(t, a.Container.IdlewatcherConfig.StopSignal, "SIGTERM") - ExpectEqual(t, b.Container.IdlewatcherConfig.StopSignal, "SIGTERM") + expect.Equal(t, a.Container.IdlewatcherConfig.IdleTimeout, 0) + expect.Equal(t, b.Container.IdlewatcherConfig.IdleTimeout, 0) + expect.Equal(t, a.Container.IdlewatcherConfig.StopTimeout, time.Hour) + expect.Equal(t, b.Container.IdlewatcherConfig.StopTimeout, time.Hour) + expect.Equal(t, a.Container.IdlewatcherConfig.StopMethod, "stop") + expect.Equal(t, b.Container.IdlewatcherConfig.StopMethod, "stop") + expect.Equal(t, a.Container.IdlewatcherConfig.WakeTimeout, 10*time.Second) + expect.Equal(t, b.Container.IdlewatcherConfig.WakeTimeout, 10*time.Second) + expect.Equal(t, a.Container.IdlewatcherConfig.StopSignal, "SIGTERM") + expect.Equal(t, b.Container.IdlewatcherConfig.StopSignal, "SIGTERM") - ExpectEqual(t, a.Homepage.Show, true) - ExpectEqual(t, a.Homepage.Icon.Value, "png/adguard-home.png") - ExpectEqual(t, a.Homepage.Icon.Extra.FileType, "png") - ExpectEqual(t, a.Homepage.Icon.Extra.Name, "adguard-home") + expect.Equal(t, a.Homepage.Show, true) + expect.Equal(t, a.Homepage.Icon.Value, "png/adguard-home.png") + expect.Equal(t, a.Homepage.Icon.Extra.FileType, "png") + expect.Equal(t, a.Homepage.Icon.Extra.Name, "adguard-home") - ExpectEqual(t, a.HealthCheck.Path, "/ping") - ExpectEqual(t, a.HealthCheck.Interval, 10*time.Second) + expect.Equal(t, a.HealthCheck.Path, "/ping") + expect.Equal(t, a.HealthCheck.Interval, 10*time.Second) } func TestApplyLabelWithAlias(t *testing.T) { @@ -142,18 +142,18 @@ func TestApplyLabelWithAlias(t *testing.T) { }, }) a, ok := entries["a"] - ExpectTrue(t, ok) + expect.True(t, ok) b, ok := entries["b"] - ExpectTrue(t, ok) + expect.True(t, ok) c, ok := entries["c"] - ExpectTrue(t, ok) + expect.True(t, ok) - ExpectEqual(t, a.Scheme, "http") - ExpectEqual(t, a.Port.Proxy, 3333) - ExpectEqual(t, a.NoTLSVerify, true) - ExpectEqual(t, b.Scheme, "http") - ExpectEqual(t, b.Port.Proxy, 1234) - ExpectEqual(t, c.Scheme, "https") + expect.Equal(t, a.Scheme, "http") + expect.Equal(t, a.Port.Proxy, 3333) + expect.Equal(t, a.NoTLSVerify, true) + expect.Equal(t, b.Scheme, "http") + expect.Equal(t, b.Port.Proxy, 1234) + expect.Equal(t, c.Scheme, "https") } func TestApplyLabelWithRef(t *testing.T) { @@ -170,18 +170,18 @@ func TestApplyLabelWithRef(t *testing.T) { }, }) a, ok := entries["a"] - ExpectTrue(t, ok) + expect.True(t, ok) b, ok := entries["b"] - ExpectTrue(t, ok) + expect.True(t, ok) c, ok := entries["c"] - ExpectTrue(t, ok) + expect.True(t, ok) - ExpectEqual(t, a.Scheme, "http") - ExpectEqual(t, a.Host, "localhost") - ExpectEqual(t, a.Port.Proxy, 4444) - ExpectEqual(t, b.Port.Proxy, 9999) - ExpectEqual(t, c.Scheme, "https") - ExpectEqual(t, c.Port.Proxy, 1111) + expect.Equal(t, a.Scheme, "http") + expect.Equal(t, a.Host, "localhost") + expect.Equal(t, a.Port.Proxy, 4444) + expect.Equal(t, b.Port.Proxy, 9999) + expect.Equal(t, c.Scheme, "https") + expect.Equal(t, c.Port.Proxy, 1111) } func TestApplyLabelWithRefIndexError(t *testing.T) { @@ -197,7 +197,7 @@ func TestApplyLabelWithRefIndexError(t *testing.T) { }, "") var p DockerProvider _, err := p.routesFromContainerLabels(c) - ExpectError(t, ErrAliasRefIndexOutOfRange, err) + expect.ErrorIs(t, ErrAliasRefIndexOutOfRange, err) c = D.FromDocker(&container.Summary{ Names: dummyNames, @@ -208,7 +208,7 @@ func TestApplyLabelWithRefIndexError(t *testing.T) { }, }, "") _, err = p.routesFromContainerLabels(c) - ExpectError(t, ErrAliasRefIndexOutOfRange, err) + expect.ErrorIs(t, ErrAliasRefIndexOutOfRange, err) } func TestDynamicAliases(t *testing.T) { @@ -224,14 +224,14 @@ func TestDynamicAliases(t *testing.T) { entries := makeRoutes(c) r, ok := entries["app1"] - ExpectTrue(t, ok) - ExpectEqual(t, r.Scheme, "http") - ExpectEqual(t, r.Port.Proxy, 1234) + expect.True(t, ok) + expect.Equal(t, r.Scheme, "http") + expect.Equal(t, r.Port.Proxy, 1234) r, ok = entries["app1_backend"] - ExpectTrue(t, ok) - ExpectEqual(t, r.Scheme, "http") - ExpectEqual(t, r.Port.Proxy, 5678) + expect.True(t, ok) + expect.Equal(t, r.Scheme, "http") + expect.Equal(t, r.Port.Proxy, 5678) } func TestDisableHealthCheck(t *testing.T) { @@ -244,24 +244,24 @@ func TestDisableHealthCheck(t *testing.T) { }, } r, ok := makeRoutes(c)["a"] - ExpectTrue(t, ok) - ExpectFalse(t, r.UseHealthCheck()) + expect.True(t, ok) + expect.False(t, r.UseHealthCheck()) } func TestPublicIPLocalhost(t *testing.T) { c := &container.Summary{Names: dummyNames, State: "running"} r, ok := makeRoutes(c)["a"] - ExpectTrue(t, ok) - ExpectEqual(t, r.Container.PublicHostname, "127.0.0.1") - ExpectEqual(t, r.Host, r.Container.PublicHostname) + expect.True(t, ok) + expect.Equal(t, r.Container.PublicHostname, "127.0.0.1") + expect.Equal(t, r.Host, r.Container.PublicHostname) } func TestPublicIPRemote(t *testing.T) { c := &container.Summary{Names: dummyNames, State: "running"} raw, ok := makeRoutes(c, testIP)["a"] - ExpectTrue(t, ok) - ExpectEqual(t, raw.Container.PublicHostname, testIP) - ExpectEqual(t, raw.Host, raw.Container.PublicHostname) + expect.True(t, ok) + expect.Equal(t, raw.Container.PublicHostname, testIP) + expect.Equal(t, raw.Host, raw.Container.PublicHostname) } func TestPrivateIPLocalhost(t *testing.T) { @@ -276,9 +276,9 @@ func TestPrivateIPLocalhost(t *testing.T) { }, } r, ok := makeRoutes(c)["a"] - ExpectTrue(t, ok) - ExpectEqual(t, r.Container.PrivateHostname, testDockerIP) - ExpectEqual(t, r.Host, r.Container.PrivateHostname) + expect.True(t, ok) + expect.Equal(t, r.Container.PrivateHostname, testDockerIP) + expect.Equal(t, r.Host, r.Container.PrivateHostname) } func TestPrivateIPRemote(t *testing.T) { @@ -294,10 +294,10 @@ func TestPrivateIPRemote(t *testing.T) { }, } r, ok := makeRoutes(c, testIP)["a"] - ExpectTrue(t, ok) - ExpectEqual(t, r.Container.PrivateHostname, "") - ExpectEqual(t, r.Container.PublicHostname, testIP) - ExpectEqual(t, r.Host, r.Container.PublicHostname) + expect.True(t, ok) + expect.Equal(t, r.Container.PrivateHostname, "") + expect.Equal(t, r.Container.PublicHostname, testIP) + expect.Equal(t, r.Host, r.Container.PublicHostname) } func TestStreamDefaultValues(t *testing.T) { @@ -321,22 +321,22 @@ func TestStreamDefaultValues(t *testing.T) { t.Run("local", func(t *testing.T) { r, ok := makeRoutes(cont)["a"] - ExpectTrue(t, ok) - ExpectNoError(t, r.Validate()) - ExpectEqual(t, r.Scheme, T.Scheme("udp")) - ExpectEqual(t, r.TargetURL().Hostname(), privIP) - ExpectEqual(t, r.Port.Listening, 0) - ExpectEqual(t, r.Port.Proxy, int(privPort)) + expect.True(t, ok) + expect.NoError(t, r.Validate()) + expect.Equal(t, r.Scheme, T.Scheme("udp")) + expect.Equal(t, r.TargetURL().Hostname(), privIP) + expect.Equal(t, r.Port.Listening, 0) + expect.Equal(t, r.Port.Proxy, int(privPort)) }) t.Run("remote", func(t *testing.T) { r, ok := makeRoutes(cont, testIP)["a"] - ExpectTrue(t, ok) - ExpectNoError(t, r.Validate()) - ExpectEqual(t, r.Scheme, T.Scheme("udp")) - ExpectEqual(t, r.TargetURL().Hostname(), testIP) - ExpectEqual(t, r.Port.Listening, 0) - ExpectEqual(t, r.Port.Proxy, int(pubPort)) + expect.True(t, ok) + expect.NoError(t, r.Validate()) + expect.Equal(t, r.Scheme, T.Scheme("udp")) + expect.Equal(t, r.TargetURL().Hostname(), testIP) + expect.Equal(t, r.Port.Listening, 0) + expect.Equal(t, r.Port.Proxy, int(pubPort)) }) } @@ -349,8 +349,8 @@ func TestExplicitExclude(t *testing.T) { "proxy.a.no_tls_verify": "true", }, }, "")["a"] - ExpectTrue(t, ok) - ExpectTrue(t, r.ShouldExclude()) + expect.True(t, ok) + expect.True(t, r.ShouldExclude()) } func TestImplicitExcludeDatabase(t *testing.T) { @@ -361,8 +361,8 @@ func TestImplicitExcludeDatabase(t *testing.T) { {Source: "/data", Destination: "/var/lib/postgresql/data"}, }, })["a"] - ExpectTrue(t, ok) - ExpectTrue(t, r.ShouldExclude()) + expect.True(t, ok) + expect.True(t, r.ShouldExclude()) }) t.Run("exposed port detection", func(t *testing.T) { r, ok := makeRoutes(&container.Summary{ @@ -371,7 +371,7 @@ func TestImplicitExcludeDatabase(t *testing.T) { {Type: "tcp", PrivatePort: 5432, PublicPort: 5432}, }, })["a"] - ExpectTrue(t, ok) - ExpectTrue(t, r.ShouldExclude()) + expect.True(t, ok) + expect.True(t, r.ShouldExclude()) }) } diff --git a/internal/route/provider/file_test.go b/internal/route/provider/file_test.go index 756095c..7653167 100644 --- a/internal/route/provider/file_test.go +++ b/internal/route/provider/file_test.go @@ -5,7 +5,7 @@ import ( _ "embed" - . "github.com/yusing/go-proxy/internal/utils/testing" + expect "github.com/yusing/go-proxy/internal/utils/testing" ) //go:embed all_fields.yaml @@ -13,5 +13,5 @@ var testAllFieldsYAML []byte func TestFile(t *testing.T) { _, err := validate(testAllFieldsYAML) - ExpectNoError(t, err) + expect.NoError(t, err) } diff --git a/internal/route/route_test.go b/internal/route/route_test.go index ee68711..14bcf30 100644 --- a/internal/route/route_test.go +++ b/internal/route/route_test.go @@ -3,11 +3,11 @@ package route import ( "testing" - "github.com/stretchr/testify/require" "github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/docker" loadbalance "github.com/yusing/go-proxy/internal/net/gphttp/loadbalancer/types" route "github.com/yusing/go-proxy/internal/route/types" + expect "github.com/yusing/go-proxy/internal/utils/testing" "github.com/yusing/go-proxy/internal/watcher/health" ) @@ -20,8 +20,8 @@ func TestRouteValidate(t *testing.T) { Port: route.Port{Proxy: common.ProxyHTTPPort}, } err := r.Validate() - require.Error(t, err, "Validate should return error for localhost with reserved port") - require.Contains(t, err.Error(), "reserved for godoxy") + expect.HasError(t, err, "Validate should return error for localhost with reserved port") + expect.ErrorContains(t, err, "reserved for godoxy") }) t.Run("ListeningPortWithHTTP", func(t *testing.T) { @@ -32,8 +32,8 @@ func TestRouteValidate(t *testing.T) { Port: route.Port{Proxy: 80, Listening: 1234}, } err := r.Validate() - require.Error(t, err, "Validate should return error for HTTP scheme with listening port") - require.Contains(t, err.Error(), "unexpected listening port") + expect.HasError(t, err, "Validate should return error for HTTP scheme with listening port") + expect.ErrorContains(t, err, "unexpected listening port") }) t.Run("DisabledHealthCheckWithLoadBalancer", func(t *testing.T) { @@ -50,8 +50,8 @@ func TestRouteValidate(t *testing.T) { }, // Minimal LoadBalance config with non-empty Link will be checked by UseLoadBalance } err := r.Validate() - require.Error(t, err, "Validate should return error for disabled healthcheck with loadbalancer") - require.Contains(t, err.Error(), "cannot disable healthcheck") + expect.HasError(t, err, "Validate should return error for disabled healthcheck with loadbalancer") + expect.ErrorContains(t, err, "cannot disable healthcheck") }) t.Run("FileServerScheme", func(t *testing.T) { @@ -63,8 +63,8 @@ func TestRouteValidate(t *testing.T) { Root: "/tmp", // Root is required for file server } err := r.Validate() - require.NoError(t, err, "Validate should not return error for valid file server route") - require.NotNil(t, r.impl, "Impl should be initialized") + expect.NoError(t, err, "Validate should not return error for valid file server route") + expect.NotNil(t, r.impl, "Impl should be initialized") }) t.Run("HTTPScheme", func(t *testing.T) { @@ -75,8 +75,8 @@ func TestRouteValidate(t *testing.T) { Port: route.Port{Proxy: 80}, } err := r.Validate() - require.NoError(t, err, "Validate should not return error for valid HTTP route") - require.NotNil(t, r.impl, "Impl should be initialized") + expect.NoError(t, err, "Validate should not return error for valid HTTP route") + expect.NotNil(t, r.impl, "Impl should be initialized") }) t.Run("TCPScheme", func(t *testing.T) { @@ -87,8 +87,8 @@ func TestRouteValidate(t *testing.T) { Port: route.Port{Proxy: 80, Listening: 8080}, } err := r.Validate() - require.NoError(t, err, "Validate should not return error for valid TCP route") - require.NotNil(t, r.impl, "Impl should be initialized") + expect.NoError(t, err, "Validate should not return error for valid TCP route") + expect.NotNil(t, r.impl, "Impl should be initialized") }) t.Run("DockerContainer", func(t *testing.T) { @@ -107,8 +107,8 @@ func TestRouteValidate(t *testing.T) { }, } err := r.Validate() - require.NoError(t, err, "Validate should not return error for valid docker container route") - require.NotNil(t, r.ProxyURL, "ProxyURL should be set") + expect.NoError(t, err, "Validate should not return error for valid docker container route") + expect.NotNil(t, r.ProxyURL, "ProxyURL should be set") }) t.Run("InvalidScheme", func(t *testing.T) { @@ -118,7 +118,7 @@ func TestRouteValidate(t *testing.T) { Host: "example.com", Port: route.Port{Proxy: 80}, } - require.Panics(t, func() { + expect.Panics(t, func() { _ = r.Validate() }, "Validate should panic for invalid scheme") }) @@ -131,8 +131,8 @@ func TestRouteValidate(t *testing.T) { Port: route.Port{Proxy: 80}, } err := r.Validate() - require.NoError(t, err) - require.NotNil(t, r.ProxyURL) - require.NotNil(t, r.HealthCheck) + expect.NoError(t, err) + expect.NotNil(t, r.ProxyURL) + expect.NotNil(t, r.HealthCheck) }) } diff --git a/internal/route/rules/do_test.go b/internal/route/rules/do_test.go index 5b1d69b..8c8b9c5 100644 --- a/internal/route/rules/do_test.go +++ b/internal/route/rules/do_test.go @@ -3,7 +3,7 @@ package rules import ( "testing" - . "github.com/yusing/go-proxy/internal/utils/testing" + expect "github.com/yusing/go-proxy/internal/utils/testing" ) func TestParseCommands(t *testing.T) { @@ -126,9 +126,9 @@ func TestParseCommands(t *testing.T) { cmd := Command{} err := cmd.Parse(tt.input) if tt.wantErr != nil { - ExpectError(t, tt.wantErr, err) + expect.ErrorIs(t, tt.wantErr, err) } else { - ExpectNoError(t, err) + expect.NoError(t, err) } }) } diff --git a/internal/route/rules/on_test.go b/internal/route/rules/on_test.go index fb60ffc..d7100db 100644 --- a/internal/route/rules/on_test.go +++ b/internal/route/rules/on_test.go @@ -8,7 +8,7 @@ import ( "testing" "github.com/yusing/go-proxy/internal/gperr" - . "github.com/yusing/go-proxy/internal/utils/testing" + expect "github.com/yusing/go-proxy/internal/utils/testing" "golang.org/x/crypto/bcrypt" ) @@ -120,9 +120,9 @@ func TestParseOn(t *testing.T) { on := &RuleOn{} err := on.Parse(tt.input) if tt.wantErr != nil { - ExpectError(t, tt.wantErr, err) + expect.ErrorIs(t, tt.wantErr, err) } else { - ExpectNoError(t, err) + expect.NoError(t, err) } }) } @@ -212,7 +212,7 @@ func TestOnCorrectness(t *testing.T) { }, { name: "basic_auth_correct", - checker: "basic_auth user " + string(Must(bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost))), + checker: "basic_auth user " + string(expect.Must(bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost))), input: &http.Request{ Header: http.Header{ "Authorization": {"Basic " + base64.StdEncoding.EncodeToString([]byte("user:password"))}, // "user:password" @@ -222,7 +222,7 @@ func TestOnCorrectness(t *testing.T) { }, { name: "basic_auth_incorrect", - checker: "basic_auth user " + string(Must(bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost))), + checker: "basic_auth user " + string(expect.Must(bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost))), input: &http.Request{ Header: http.Header{ "Authorization": {"Basic " + base64.StdEncoding.EncodeToString([]byte("user:incorrect"))}, // "user:wrong" @@ -269,7 +269,7 @@ func TestOnCorrectness(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { on, err := parseOn(tt.checker) - ExpectNoError(t, err) + expect.NoError(t, err) got := on.Check(Cache{}, tt.input) if tt.want != got { t.Errorf("want %v, got %v", tt.want, got) diff --git a/internal/route/rules/parser_test.go b/internal/route/rules/parser_test.go index b5743ae..2d40e20 100644 --- a/internal/route/rules/parser_test.go +++ b/internal/route/rules/parser_test.go @@ -5,7 +5,7 @@ import ( "testing" "github.com/yusing/go-proxy/internal/gperr" - . "github.com/yusing/go-proxy/internal/utils/testing" + expect "github.com/yusing/go-proxy/internal/utils/testing" ) func TestParser(t *testing.T) { @@ -68,15 +68,15 @@ func TestParser(t *testing.T) { t.Run(tt.name, func(t *testing.T) { subject, args, err := parse(tt.input) if tt.wantErr != nil { - ExpectError(t, tt.wantErr, err) + expect.ErrorIs(t, tt.wantErr, err) return } // t.Log(subject, args, err) - ExpectNoError(t, err) - ExpectEqual(t, subject, tt.subject) - ExpectEqual(t, len(args), len(tt.args)) + expect.NoError(t, err) + expect.Equal(t, subject, tt.subject) + expect.Equal(t, len(args), len(tt.args)) for i, arg := range args { - ExpectEqual(t, arg, tt.args[i]) + expect.Equal(t, arg, tt.args[i]) } }) } @@ -89,7 +89,7 @@ func TestParser(t *testing.T) { for i, test := range tests { t.Run(strconv.Itoa(i), func(t *testing.T) { _, _, err := parse(test) - ExpectError(t, ErrUnterminatedQuotes, err) + expect.ErrorIs(t, ErrUnterminatedQuotes, err) }) } }) diff --git a/internal/route/rules/rules_test.go b/internal/route/rules/rules_test.go index aa8e8d1..d863fc0 100644 --- a/internal/route/rules/rules_test.go +++ b/internal/route/rules/rules_test.go @@ -4,7 +4,7 @@ import ( "testing" "github.com/yusing/go-proxy/internal/utils" - . "github.com/yusing/go-proxy/internal/utils/testing" + expect "github.com/yusing/go-proxy/internal/utils/testing" ) func TestParseRule(t *testing.T) { @@ -29,18 +29,18 @@ func TestParseRule(t *testing.T) { Rules Rules } err := utils.MapUnmarshalValidate(utils.SerializedObject{"rules": test}, &rules) - ExpectNoError(t, err) - ExpectEqual(t, len(rules.Rules), len(test)) - ExpectEqual(t, rules.Rules[0].Name, "test") - ExpectEqual(t, rules.Rules[0].On.String(), "method POST") - ExpectEqual(t, rules.Rules[0].Do.String(), "error 403 Forbidden") + expect.NoError(t, err) + expect.Equal(t, len(rules.Rules), len(test)) + expect.Equal(t, rules.Rules[0].Name, "test") + expect.Equal(t, rules.Rules[0].On.String(), "method POST") + expect.Equal(t, rules.Rules[0].Do.String(), "error 403 Forbidden") - ExpectEqual(t, rules.Rules[1].Name, "auth") - ExpectEqual(t, rules.Rules[1].On.String(), `basic_auth "username" "password" | basic_auth username2 "password2" | basic_auth "username3" "password3"`) - ExpectEqual(t, rules.Rules[1].Do.String(), "bypass") + expect.Equal(t, rules.Rules[1].Name, "auth") + expect.Equal(t, rules.Rules[1].On.String(), `basic_auth "username" "password" | basic_auth username2 "password2" | basic_auth "username3" "password3"`) + expect.Equal(t, rules.Rules[1].Do.String(), "bypass") - ExpectEqual(t, rules.Rules[2].Name, "default") - ExpectEqual(t, rules.Rules[2].Do.String(), "require_basic_auth any_realm") + expect.Equal(t, rules.Rules[2].Name, "default") + expect.Equal(t, rules.Rules[2].Do.String(), "require_basic_auth any_realm") } // TODO: real tests. diff --git a/internal/route/types/http_config_test.go b/internal/route/types/http_config_test.go index 3040818..c6ea8cf 100644 --- a/internal/route/types/http_config_test.go +++ b/internal/route/types/http_config_test.go @@ -7,7 +7,7 @@ import ( . "github.com/yusing/go-proxy/internal/route" route "github.com/yusing/go-proxy/internal/route/types" "github.com/yusing/go-proxy/internal/utils" - . "github.com/yusing/go-proxy/internal/utils/testing" + expect "github.com/yusing/go-proxy/internal/utils/testing" ) func TestHTTPConfigDeserialize(t *testing.T) { @@ -42,9 +42,9 @@ func TestHTTPConfigDeserialize(t *testing.T) { tt.input["host"] = "internal" err := utils.MapUnmarshalValidate(tt.input, &cfg) if err != nil { - ExpectNoError(t, err) + expect.NoError(t, err) } - ExpectEqual(t, cfg.HTTPConfig, tt.expected) + expect.Equal(t, cfg.HTTPConfig, tt.expected) }) } } diff --git a/internal/task/task_test.go b/internal/task/task_test.go index 4168b0e..a367c7f 100644 --- a/internal/task/task_test.go +++ b/internal/task/task_test.go @@ -6,7 +6,7 @@ import ( "testing" "time" - . "github.com/yusing/go-proxy/internal/utils/testing" + expect "github.com/yusing/go-proxy/internal/utils/testing" ) func testTask() *Task { @@ -35,7 +35,7 @@ func TestChildTaskCancellation(t *testing.T) { select { case <-child.Context().Done(): - ExpectError(t, context.Canceled, child.Context().Err()) + expect.ErrorIs(t, context.Canceled, child.Context().Err()) default: t.Fatal("subTask context was not canceled as expected") } @@ -55,10 +55,10 @@ func TestTaskOnCancelOnFinished(t *testing.T) { shouldTrueOnFinish = true }) - ExpectFalse(t, shouldTrueOnFinish) + expect.False(t, shouldTrueOnFinish) task.Finish(nil) - ExpectTrue(t, shouldTrueOnCancel) - ExpectTrue(t, shouldTrueOnFinish) + expect.True(t, shouldTrueOnCancel) + expect.True(t, shouldTrueOnFinish) } func TestCommonFlowWithGracefulShutdown(t *testing.T) { @@ -83,19 +83,19 @@ func TestCommonFlowWithGracefulShutdown(t *testing.T) { } }() - ExpectNoError(t, GracefulShutdown(1*time.Second)) - ExpectTrue(t, finished) + expect.NoError(t, GracefulShutdown(1*time.Second)) + expect.True(t, finished) <-root.finished - ExpectError(t, context.Canceled, task.Context().Err()) - ExpectError(t, ErrProgramExiting, task.FinishCause()) + expect.ErrorIs(t, context.Canceled, task.Context().Err()) + expect.ErrorIs(t, ErrProgramExiting, task.FinishCause()) } func TestTimeoutOnGracefulShutdown(t *testing.T) { t.Cleanup(testCleanup) _ = testTask() - ExpectError(t, context.DeadlineExceeded, GracefulShutdown(time.Millisecond)) + expect.ErrorIs(t, context.DeadlineExceeded, GracefulShutdown(time.Millisecond)) } func TestFinishMultipleCalls(t *testing.T) { diff --git a/internal/utils/atomic/value.go b/internal/utils/atomic/value.go index d05f80b..ddd0034 100644 --- a/internal/utils/atomic/value.go +++ b/internal/utils/atomic/value.go @@ -33,3 +33,12 @@ func (a *Value[T]) Swap(v T) T { func (a *Value[T]) MarshalJSONTo(buf []byte) []byte { return json.MarshalTo(a.Load(), buf) } + +func (a *Value[T]) UnmarshalJSON(data []byte) error { + var v T + err := json.Unmarshal(data, &v) + if err == nil { + a.Store(v) + } + return err +} diff --git a/internal/utils/functional/map_test.go b/internal/utils/functional/map_test.go index 9715210..e5670e2 100644 --- a/internal/utils/functional/map_test.go +++ b/internal/utils/functional/map_test.go @@ -4,7 +4,7 @@ import ( "testing" . "github.com/yusing/go-proxy/internal/utils/functional" - . "github.com/yusing/go-proxy/internal/utils/testing" + expect "github.com/yusing/go-proxy/internal/utils/testing" ) func TestNewMapFrom(t *testing.T) { @@ -13,8 +13,8 @@ func TestNewMapFrom(t *testing.T) { "b": 2, "c": 3, }) - ExpectEqual(t, m.Size(), 3) - ExpectTrue(t, m.Has("a")) - ExpectTrue(t, m.Has("b")) - ExpectTrue(t, m.Has("c")) + expect.Equal(t, m.Size(), 3) + expect.True(t, m.Has("a")) + expect.True(t, m.Has("b")) + expect.True(t, m.Has("c")) } diff --git a/internal/utils/ref_count_test.go b/internal/utils/ref_count_test.go index d6e64cd..1548a4a 100644 --- a/internal/utils/ref_count_test.go +++ b/internal/utils/ref_count_test.go @@ -5,7 +5,7 @@ import ( "testing" "time" - . "github.com/yusing/go-proxy/internal/utils/testing" + expect "github.com/yusing/go-proxy/internal/utils/testing" ) func TestRefCounterAddSub(t *testing.T) { @@ -23,7 +23,7 @@ func TestRefCounterAddSub(t *testing.T) { } wg.Wait() - ExpectEqual(t, int(rc.refCount), 0) + expect.Equal(t, int(rc.refCount), 0) select { case <-rc.Zero(): @@ -48,7 +48,7 @@ func TestRefCounterMultipleAddSub(t *testing.T) { }() } wg.Wait() - ExpectEqual(t, int(rc.refCount), numAdds+1) + expect.Equal(t, int(rc.refCount), numAdds+1) wg.Add(numSubs) for range numSubs { @@ -58,7 +58,7 @@ func TestRefCounterMultipleAddSub(t *testing.T) { }() } wg.Wait() - ExpectEqual(t, int(rc.refCount), numAdds+1-numSubs) + expect.Equal(t, int(rc.refCount), numAdds+1-numSubs) rc.Sub() select { diff --git a/internal/utils/serialization_test.go b/internal/utils/serialization_test.go index 56907b8..c9ba144 100644 --- a/internal/utils/serialization_test.go +++ b/internal/utils/serialization_test.go @@ -9,7 +9,7 @@ import ( "testing" "github.com/goccy/go-yaml" - . "github.com/yusing/go-proxy/internal/utils/testing" + expect "github.com/yusing/go-proxy/internal/utils/testing" ) func TestUnmarshal(t *testing.T) { @@ -44,8 +44,8 @@ func TestUnmarshal(t *testing.T) { t.Run("unmarshal", func(t *testing.T) { var s2 S err := MapUnmarshalValidate(testStructSerialized, &s2) - ExpectNoError(t, err) - ExpectEqualValues(t, s2, testStruct) + expect.NoError(t, err) + expect.Values(t, s2, testStruct) }) } @@ -64,16 +64,16 @@ func TestUnmarshalAnonymousField(t *testing.T) { // all, anon := extractFields(reflect.TypeOf(s2)) // t.Fatalf("anon %v, all %v", anon, all) err := MapUnmarshalValidate(map[string]any{"a": 1, "b": 2, "c": 3}, &s) - ExpectNoError(t, err) - ExpectEqualValues(t, s.A, 1) - ExpectEqualValues(t, s.B, 2) - ExpectEqualValues(t, s.C, 3) + expect.NoError(t, err) + expect.Values(t, s.A, 1) + expect.Values(t, s.B, 2) + expect.Values(t, s.C, 3) err = MapUnmarshalValidate(map[string]any{"a": 1, "b": 2, "c": 3}, &s2) - ExpectNoError(t, err) - ExpectEqualValues(t, s2.A, 1) - ExpectEqualValues(t, s2.B, 2) - ExpectEqualValues(t, s2.C, 3) + expect.NoError(t, err) + expect.Values(t, s2.A, 1) + expect.Values(t, s2.B, 2) + expect.Values(t, s2.C, 3) } func TestStringIntConvert(t *testing.T) { @@ -93,13 +93,13 @@ func TestStringIntConvert(t *testing.T) { field := refl.Elem().Field(i) t.Run(fmt.Sprintf("field_%s", field.Type().Name()), func(t *testing.T) { ok, err := ConvertString("127", field) - ExpectTrue(t, ok) - ExpectNoError(t, err) - ExpectEqualValues(t, field.Interface(), 127) + expect.True(t, ok) + expect.NoError(t, err) + expect.Values(t, field.Interface(), 127) err = Convert(reflect.ValueOf(uint8(64)), field) - ExpectNoError(t, err) - ExpectEqualValues(t, field.Interface(), 64) + expect.NoError(t, err) + expect.Values(t, field.Interface(), 64) }) } } @@ -123,26 +123,26 @@ func (c *testType) Parse(v string) (err error) { func TestConvertor(t *testing.T) { t.Run("valid", func(t *testing.T) { m := new(testModel) - ExpectNoError(t, MapUnmarshalValidate(map[string]any{"Test": "123"}, m)) + expect.NoError(t, MapUnmarshalValidate(map[string]any{"Test": "123"}, m)) - ExpectEqualValues(t, m.Test.foo, 123) - ExpectEqualValues(t, m.Test.bar, "123") + expect.Values(t, m.Test.foo, 123) + expect.Values(t, m.Test.bar, "123") }) t.Run("int_to_string", func(t *testing.T) { m := new(testModel) - ExpectNoError(t, MapUnmarshalValidate(map[string]any{"Test": "123"}, m)) + expect.NoError(t, MapUnmarshalValidate(map[string]any{"Test": "123"}, m)) - ExpectEqualValues(t, m.Test.foo, 123) - ExpectEqualValues(t, m.Test.bar, "123") + expect.Values(t, m.Test.foo, 123) + expect.Values(t, m.Test.bar, "123") - ExpectNoError(t, MapUnmarshalValidate(map[string]any{"Baz": 456}, m)) - ExpectEqualValues(t, m.Baz, "456") + expect.NoError(t, MapUnmarshalValidate(map[string]any{"Baz": 456}, m)) + expect.Values(t, m.Baz, "456") }) t.Run("invalid", func(t *testing.T) { m := new(testModel) - ExpectError(t, ErrUnsupportedConversion, MapUnmarshalValidate(map[string]any{"Test": struct{}{}}, m)) + expect.ErrorIs(t, ErrUnsupportedConversion, MapUnmarshalValidate(map[string]any{"Test": struct{}{}}, m)) }) } @@ -150,23 +150,23 @@ func TestStringToSlice(t *testing.T) { t.Run("comma_separated", func(t *testing.T) { dst := make([]string, 0) convertible, err := ConvertString("a,b,c", reflect.ValueOf(&dst)) - ExpectTrue(t, convertible) - ExpectNoError(t, err) - ExpectEqualValues(t, dst, []string{"a", "b", "c"}) + expect.True(t, convertible) + expect.NoError(t, err) + expect.Values(t, dst, []string{"a", "b", "c"}) }) t.Run("yaml-like", func(t *testing.T) { dst := make([]string, 0) convertible, err := ConvertString("- a\n- b\n- c", reflect.ValueOf(&dst)) - ExpectTrue(t, convertible) - ExpectNoError(t, err) - ExpectEqualValues(t, dst, []string{"a", "b", "c"}) + expect.True(t, convertible) + expect.NoError(t, err) + expect.Values(t, dst, []string{"a", "b", "c"}) }) t.Run("single-line-yaml-like", func(t *testing.T) { dst := make([]string, 0) convertible, err := ConvertString("- a", reflect.ValueOf(&dst)) - ExpectTrue(t, convertible) - ExpectNoError(t, err) - ExpectEqualValues(t, dst, []string{"a"}) + expect.True(t, convertible) + expect.NoError(t, err) + expect.Values(t, dst, []string{"a"}) }) } @@ -188,9 +188,9 @@ func TestStringToMap(t *testing.T) { t.Run("yaml-like", func(t *testing.T) { dst := make(map[string]string) convertible, err := ConvertString(" a: b\n c: d", reflect.ValueOf(&dst)) - ExpectTrue(t, convertible) - ExpectNoError(t, err) - ExpectEqualValues(t, dst, map[string]string{"a": "b", "c": "d"}) + expect.True(t, convertible) + expect.NoError(t, err) + expect.Values(t, dst, map[string]string{"a": "b", "c": "d"}) }) } @@ -216,10 +216,10 @@ func TestStringToStruct(t *testing.T) { t.Run("yaml-like simple", func(t *testing.T) { var dst T convertible, err := ConvertString(" A: a\n B: 123", reflect.ValueOf(&dst)) - ExpectTrue(t, convertible) - ExpectNoError(t, err) - ExpectEqualValues(t, dst.A, "a") - ExpectEqualValues(t, dst.B, 123) + expect.True(t, convertible) + expect.NoError(t, err) + expect.Values(t, dst.A, "a") + expect.Values(t, dst.B, 123) }) type T2 struct { @@ -229,8 +229,8 @@ func TestStringToStruct(t *testing.T) { t.Run("yaml-like complex", func(t *testing.T) { var dst T2 convertible, err := ConvertString(" URL: http://example.com\n CIDR: 1.2.3.0/24", reflect.ValueOf(&dst)) - ExpectTrue(t, convertible) - ExpectNoError(t, err) + expect.True(t, convertible) + expect.NoError(t, err) }) } diff --git a/internal/utils/slices_test.go b/internal/utils/slices_test.go index b37b0bb..1d79594 100644 --- a/internal/utils/slices_test.go +++ b/internal/utils/slices_test.go @@ -5,7 +5,7 @@ import ( "strings" "testing" - . "github.com/yusing/go-proxy/internal/utils/testing" + expect "github.com/yusing/go-proxy/internal/utils/testing" ) func TestIntersect(t *testing.T) { @@ -19,7 +19,7 @@ func TestIntersect(t *testing.T) { result := Intersect(slice1, slice2) slices.Sort(result) slices.Sort(want) - ExpectEqual(t, result, want) + expect.Equal(t, result, want) }) t.Run("intersection", func(t *testing.T) { var ( @@ -30,7 +30,7 @@ func TestIntersect(t *testing.T) { result := Intersect(slice1, slice2) slices.Sort(result) slices.Sort(want) - ExpectEqual(t, result, want) + expect.Equal(t, result, want) }) }) t.Run("ints", func(t *testing.T) { @@ -43,7 +43,7 @@ func TestIntersect(t *testing.T) { result := Intersect(slice1, slice2) slices.Sort(result) slices.Sort(want) - ExpectEqual(t, result, want) + expect.Equal(t, result, want) }) t.Run("intersection", func(t *testing.T) { var ( @@ -54,7 +54,7 @@ func TestIntersect(t *testing.T) { result := Intersect(slice1, slice2) slices.Sort(result) slices.Sort(want) - ExpectEqual(t, result, want) + expect.Equal(t, result, want) }) }) t.Run("complex", func(t *testing.T) { @@ -75,7 +75,7 @@ func TestIntersect(t *testing.T) { slices.SortFunc(want, func(i T, j T) int { return strings.Compare(i.A, j.A) }) - ExpectEqual(t, result, want) + expect.Equal(t, result, want) }) t.Run("intersection", func(t *testing.T) { var ( @@ -90,7 +90,7 @@ func TestIntersect(t *testing.T) { slices.SortFunc(want, func(i T, j T) int { return strings.Compare(i.A, j.A) }) - ExpectEqual(t, result, want) + expect.Equal(t, result, want) }) }) } diff --git a/internal/utils/strutils/format_test.go b/internal/utils/strutils/format_test.go index f782e05..aeb6ee6 100644 --- a/internal/utils/strutils/format_test.go +++ b/internal/utils/strutils/format_test.go @@ -5,11 +5,11 @@ import ( "time" . "github.com/yusing/go-proxy/internal/utils/strutils" - . "github.com/yusing/go-proxy/internal/utils/testing" + expect "github.com/yusing/go-proxy/internal/utils/testing" ) func TestFormatTime(t *testing.T) { - now := Must(time.Parse(time.RFC3339, "2021-06-15T12:30:30Z")) + now := expect.Must(time.Parse(time.RFC3339, "2021-06-15T12:30:30Z")) tests := []struct { name string @@ -84,9 +84,9 @@ func TestFormatTime(t *testing.T) { result := FormatTimeWithReference(tt.time, now) if tt.expectedLength > 0 { - ExpectEqual(t, len(result), tt.expectedLength, result) + expect.Equal(t, len(result), tt.expectedLength, result) } else { - ExpectEqual(t, result, tt.expected) + expect.Equal(t, result, tt.expected) } }) } @@ -163,7 +163,7 @@ func TestFormatDuration(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := FormatDuration(tt.duration) - ExpectEqual(t, result, tt.expected) + expect.Equal(t, result, tt.expected) }) } } @@ -193,7 +193,7 @@ func TestFormatLastSeen(t *testing.T) { result := FormatLastSeen(tt.time) if tt.name == "zero time" { - ExpectEqual(t, result, tt.expected) + expect.Equal(t, result, tt.expected) } else { // Just make sure it's not "never", the actual formatting is tested in TestFormatTime if result == "never" { diff --git a/internal/utils/strutils/split_join_test.go b/internal/utils/strutils/split_join_test.go index 0c86213..4d1ea8f 100644 --- a/internal/utils/strutils/split_join_test.go +++ b/internal/utils/strutils/split_join_test.go @@ -5,7 +5,7 @@ import ( "testing" . "github.com/yusing/go-proxy/internal/utils/strutils" - . "github.com/yusing/go-proxy/internal/utils/testing" + expect "github.com/yusing/go-proxy/internal/utils/testing" ) var alphaNumeric = func() string { @@ -31,8 +31,8 @@ func TestSplit(t *testing.T) { for sep, rsep := range tests { t.Run(sep, func(t *testing.T) { expected := strings.Split(alphaNumeric, sep) - ExpectEqual(t, SplitRune(alphaNumeric, rsep), expected) - ExpectEqual(t, JoinRune(expected, rsep), alphaNumeric) + expect.Equal(t, SplitRune(alphaNumeric, rsep), expected) + expect.Equal(t, JoinRune(expected, rsep), alphaNumeric) }) } } diff --git a/internal/utils/strutils/url_test.go b/internal/utils/strutils/url_test.go index 8143991..756418c 100644 --- a/internal/utils/strutils/url_test.go +++ b/internal/utils/strutils/url_test.go @@ -3,7 +3,7 @@ package strutils import ( "testing" - "github.com/stretchr/testify/require" + expect "github.com/yusing/go-proxy/internal/utils/testing" ) func TestSanitizeURI(t *testing.T) { @@ -57,7 +57,7 @@ func TestSanitizeURI(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := SanitizeURI(tt.input) - require.Equal(t, tt.expected, result) + expect.Equal(t, result, tt.expected) }) } } diff --git a/internal/utils/testing/expect.go b/internal/utils/testing/expect.go new file mode 100644 index 0000000..6f7a17d --- /dev/null +++ b/internal/utils/testing/expect.go @@ -0,0 +1,71 @@ +package expect + +import ( + "os" + "testing" + + "github.com/stretchr/testify/require" + "github.com/yusing/go-proxy/internal/common" +) + +func init() { + if common.IsTest { + os.Args = append([]string{os.Args[0], "-test.v"}, os.Args[1:]...) + } +} + +func Must[Result any](r Result, err error) Result { + if err != nil { + panic(err) + } + return r +} + +var ( + NoError = require.NoError + HasError = require.Error + True = require.True + False = require.False + Nil = require.Nil + NotNil = require.NotNil + ErrorContains = require.ErrorContains + Panics = require.Panics +) + +func ErrorIs(t *testing.T, expected error, err error, msgAndArgs ...any) { + t.Helper() + require.ErrorIs(t, err, expected, msgAndArgs...) +} + +func ErrorT[T error](t *testing.T, err error, msgAndArgs ...any) { + t.Helper() + var errAs T + require.ErrorAs(t, err, &errAs, msgAndArgs...) +} + +func Equal[T any](t *testing.T, got T, want T, msgAndArgs ...any) { + t.Helper() + require.Equal(t, want, got, msgAndArgs...) +} + +func NotEqual[T any](t *testing.T, got T, want T, msgAndArgs ...any) { + t.Helper() + require.NotEqual(t, want, got, msgAndArgs...) +} + +func Values(t *testing.T, got any, want any, msgAndArgs ...any) { + t.Helper() + require.EqualValues(t, want, got, msgAndArgs...) +} + +func Contains[T any](t *testing.T, got T, wants []T, msgAndArgs ...any) { + t.Helper() + require.Contains(t, wants, got, msgAndArgs...) +} + +func Type[T any](t *testing.T, got any, msgAndArgs ...any) (_ T) { + t.Helper() + _, ok := got.(T) + require.True(t, ok, msgAndArgs...) + return got.(T) +} diff --git a/internal/utils/testing/testing.go b/internal/utils/testing/testing.go deleted file mode 100644 index 7187e15..0000000 --- a/internal/utils/testing/testing.go +++ /dev/null @@ -1,75 +0,0 @@ -package utils - -import ( - "os" - "testing" - - "github.com/stretchr/testify/require" - "github.com/yusing/go-proxy/internal/common" -) - -func init() { - if common.IsTest { - os.Args = append([]string{os.Args[0], "-test.v"}, os.Args[1:]...) - } -} - -func Must[Result any](r Result, err error) Result { - if err != nil { - panic(err) - } - return r -} - -func ExpectNoError(t *testing.T, err error, msgAndArgs ...any) { - t.Helper() - require.NoError(t, err, msgAndArgs...) -} - -func ExpectHasError(t *testing.T, err error, msgAndArgs ...any) { - t.Helper() - require.Error(t, err, msgAndArgs...) -} - -func ExpectError(t *testing.T, expected error, err error, msgAndArgs ...any) { - t.Helper() - require.ErrorIs(t, err, expected, msgAndArgs...) -} - -func ExpectErrorT[T error](t *testing.T, err error, msgAndArgs ...any) { - t.Helper() - var errAs T - require.ErrorAs(t, err, &errAs, msgAndArgs...) -} - -func ExpectEqual[T any](t *testing.T, got T, want T, msgAndArgs ...any) { - t.Helper() - require.Equal(t, want, got, msgAndArgs...) -} - -func ExpectEqualValues(t *testing.T, got any, want any, msgAndArgs ...any) { - t.Helper() - require.EqualValues(t, want, got, msgAndArgs...) -} - -func ExpectContains[T any](t *testing.T, got T, wants []T, msgAndArgs ...any) { - t.Helper() - require.Contains(t, wants, got, msgAndArgs...) -} - -func ExpectTrue(t *testing.T, got bool, msgAndArgs ...any) { - t.Helper() - require.True(t, got, msgAndArgs...) -} - -func ExpectFalse(t *testing.T, got bool, msgAndArgs ...any) { - t.Helper() - require.False(t, got, msgAndArgs...) -} - -func ExpectType[T any](t *testing.T, got any, msgAndArgs ...any) (_ T) { - t.Helper() - _, ok := got.(T) - require.True(t, ok, msgAndArgs...) - return got.(T) -} diff --git a/pkg/json/marshal_test.go b/pkg/json/marshal_test.go index 7325103..fa46a03 100644 --- a/pkg/json/marshal_test.go +++ b/pkg/json/marshal_test.go @@ -8,9 +8,8 @@ import ( "testing" "github.com/bytedance/sonic" - "github.com/stretchr/testify/require" "github.com/yusing/go-proxy/internal/utils/strutils" - . "github.com/yusing/go-proxy/internal/utils/testing" + expect "github.com/yusing/go-proxy/internal/utils/testing" . "github.com/yusing/go-proxy/pkg/json" ) @@ -350,7 +349,7 @@ func TestMarshal(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result, _ := Marshal(tt.input) - require.Equal(t, tt.expected, string(result)) + expect.Equal(t, string(result), tt.expected) }) } @@ -468,7 +467,7 @@ func TestWithTestStruct(t *testing.T) { t.Fatalf("Unmarshal error: %v", err) } - ExpectEqual(t, unmarshalCustom, unmarshalStdlib) + expect.Equal(t, unmarshalCustom, unmarshalStdlib) } func BenchmarkMarshalSimpleStdLib(b *testing.B) {