refactor: rename utils/testing to expect, replace all dot imports in tests

This commit is contained in:
yusing 2025-04-17 15:45:30 +08:00
parent aeb6a69e3d
commit 293bb80f0b
59 changed files with 778 additions and 771 deletions

View file

@ -8,59 +8,59 @@ import (
"net/http/httptest"
"testing"
. "github.com/yusing/go-proxy/internal/utils/testing"
expect "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestNewAgent(t *testing.T) {
ca, srv, client, err := NewAgent()
ExpectNoError(t, err)
ExpectTrue(t, ca != nil)
ExpectTrue(t, srv != nil)
ExpectTrue(t, client != nil)
expect.NoError(t, err)
expect.True(t, ca != nil)
expect.True(t, srv != nil)
expect.True(t, client != nil)
}
func TestPEMPair(t *testing.T) {
ca, srv, client, err := NewAgent()
ExpectNoError(t, err)
expect.NoError(t, err)
for i, p := range []*PEMPair{ca, srv, client} {
t.Run(fmt.Sprintf("load-%d", i), func(t *testing.T) {
var pp PEMPair
err := pp.Load(p.String())
ExpectNoError(t, err)
ExpectEqual(t, p.Cert, pp.Cert)
ExpectEqual(t, p.Key, pp.Key)
expect.NoError(t, err)
expect.Equal(t, p.Cert, pp.Cert)
expect.Equal(t, p.Key, pp.Key)
})
}
}
func TestPEMPairToTLSCert(t *testing.T) {
ca, srv, client, err := NewAgent()
ExpectNoError(t, err)
expect.NoError(t, err)
for i, p := range []*PEMPair{ca, srv, client} {
t.Run(fmt.Sprintf("toTLSCert-%d", i), func(t *testing.T) {
cert, err := p.ToTLSCert()
ExpectNoError(t, err)
ExpectTrue(t, cert != nil)
expect.NoError(t, err)
expect.True(t, cert != nil)
})
}
}
func TestServerClient(t *testing.T) {
ca, srv, client, err := NewAgent()
ExpectNoError(t, err)
expect.NoError(t, err)
srvTLS, err := srv.ToTLSCert()
ExpectNoError(t, err)
ExpectTrue(t, srvTLS != nil)
expect.NoError(t, err)
expect.True(t, srvTLS != nil)
clientTLS, err := client.ToTLSCert()
ExpectNoError(t, err)
ExpectTrue(t, clientTLS != nil)
expect.NoError(t, err)
expect.True(t, clientTLS != nil)
caPool := x509.NewCertPool()
ExpectTrue(t, caPool.AppendCertsFromPEM(ca.Cert))
expect.True(t, caPool.AppendCertsFromPEM(ca.Cert))
srvTLSConfig := &tls.Config{
Certificates: []tls.Certificate{*srvTLS},
@ -86,6 +86,6 @@ func TestServerClient(t *testing.T) {
}
resp, err := httpClient.Get(server.URL)
ExpectNoError(t, err)
ExpectEqual(t, resp.StatusCode, http.StatusOK)
expect.NoError(t, err)
expect.Equal(t, resp.StatusCode, http.StatusOK)
}

View file

@ -3,17 +3,17 @@ package certs
import (
"testing"
. "github.com/yusing/go-proxy/internal/utils/testing"
expect "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestZipCert(t *testing.T) {
ca, crt, key := []byte("test1"), []byte("test2"), []byte("test3")
zipData, err := ZipCert(ca, crt, key)
ExpectNoError(t, err)
expect.NoError(t, err)
ca2, crt2, key2, err := ExtractCert(zipData)
ExpectNoError(t, err)
ExpectEqual(t, ca, ca2)
ExpectEqual(t, crt, crt2)
ExpectEqual(t, key, key2)
expect.NoError(t, err)
expect.Equal(t, ca, ca2)
expect.Equal(t, crt, crt2)
expect.Equal(t, key, key2)
}

View file

@ -10,10 +10,11 @@ import (
"github.com/yusing/go-proxy/pkg/json"
"github.com/stretchr/testify/require"
"github.com/yusing/go-proxy/agent/pkg/agent"
"github.com/yusing/go-proxy/agent/pkg/handler"
"github.com/yusing/go-proxy/internal/watcher/health"
expect "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestCheckHealthHTTP(t *testing.T) {
@ -79,12 +80,12 @@ func TestCheckHealthHTTP(t *testing.T) {
request := httptest.NewRequest(http.MethodGet, agent.APIEndpointBase+agent.EndpointHealth+"?"+query.Encode(), nil)
handler.CheckHealth(recorder, request)
require.Equal(t, recorder.Code, tt.expectedStatus)
expect.Equal(t, recorder.Code, tt.expectedStatus)
if tt.expectedStatus == http.StatusOK {
var result health.HealthCheckResult
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &result))
require.Equal(t, result.Healthy, tt.expectedHealthy)
expect.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &result))
expect.Equal(t, result.Healthy, tt.expectedHealthy)
}
})
}
@ -124,32 +125,32 @@ func TestCheckHealthFileServer(t *testing.T) {
request := httptest.NewRequest(http.MethodGet, agent.APIEndpointBase+agent.EndpointHealth+"?"+query.Encode(), nil)
handler.CheckHealth(recorder, request)
require.Equal(t, recorder.Code, tt.expectedStatus)
expect.Equal(t, recorder.Code, tt.expectedStatus)
var result health.HealthCheckResult
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &result))
require.Equal(t, result.Healthy, tt.expectedHealthy)
require.Equal(t, result.Detail, tt.expectedDetail)
expect.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &result))
expect.Equal(t, result.Healthy, tt.expectedHealthy)
expect.Equal(t, result.Detail, tt.expectedDetail)
})
}
}
func TestCheckHealthTCPUDP(t *testing.T) {
tcp, err := net.Listen("tcp", "localhost:0")
require.NoError(t, err)
expect.NoError(t, err)
go func() {
conn, err := tcp.Accept()
require.NoError(t, err)
expect.NoError(t, err)
conn.Close()
}()
udp, err := net.ListenPacket("udp", "localhost:0")
require.NoError(t, err)
expect.NoError(t, err)
go func() {
buf := make([]byte, 1024)
n, addr, err := udp.ReadFrom(buf)
require.NoError(t, err)
require.Equal(t, string(buf[:n]), "ping")
expect.NoError(t, err)
expect.Equal(t, string(buf[:n]), "ping")
_, _ = udp.WriteTo([]byte("pong"), addr)
udp.Close()
}()
@ -207,11 +208,11 @@ func TestCheckHealthTCPUDP(t *testing.T) {
request := httptest.NewRequest(http.MethodGet, agent.APIEndpointBase+agent.EndpointHealth+"?"+query.Encode(), nil)
handler.CheckHealth(recorder, request)
require.Equal(t, recorder.Code, tt.expectedStatus)
expect.Equal(t, recorder.Code, tt.expectedStatus)
var result health.HealthCheckResult
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &result))
require.Equal(t, result.Healthy, tt.expectedHealthy)
expect.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &result))
expect.Equal(t, result.Healthy, tt.expectedHealthy)
})
}
}

View file

@ -17,7 +17,7 @@ import (
"github.com/yusing/go-proxy/internal/common"
"golang.org/x/oauth2"
. "github.com/yusing/go-proxy/internal/utils/testing"
expect "github.com/yusing/go-proxy/internal/utils/testing"
)
// setupMockOIDC configures mock OIDC provider for testing.
@ -75,7 +75,7 @@ func (j *provider) SignClaims(t *testing.T, claims jwt.Claims) string {
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
token.Header["kid"] = keyID
signed, err := token.SignedString(j.key)
ExpectNoError(t, err)
expect.NoError(t, err)
return signed
}
@ -84,7 +84,7 @@ func setupProvider(t *testing.T) *provider {
// Generate an RSA key pair for the test.
privKey, err := rsa.GenerateKey(rand.Reader, 2048)
ExpectNoError(t, err)
expect.NoError(t, err)
// Build the matching public JWK that will be served by the endpoint.
jwk := buildRSAJWK(t, &privKey.PublicKey, keyID)
@ -227,12 +227,12 @@ func TestOIDCCallbackHandler(t *testing.T) {
}
if tt.wantStatus == http.StatusTemporaryRedirect {
setCookie := Must(http.ParseSetCookie(w.Header().Get("Set-Cookie")))
ExpectEqual(t, setCookie.Name, defaultAuth.TokenCookieName())
ExpectTrue(t, setCookie.Value != "")
ExpectEqual(t, setCookie.Path, "/")
ExpectEqual(t, setCookie.SameSite, http.SameSiteLaxMode)
ExpectEqual(t, setCookie.HttpOnly, true)
setCookie := expect.Must(http.ParseSetCookie(w.Header().Get("Set-Cookie")))
expect.Equal(t, setCookie.Name, defaultAuth.TokenCookieName())
expect.True(t, setCookie.Value != "")
expect.Equal(t, setCookie.Path, "/")
expect.Equal(t, setCookie.SameSite, http.SameSiteLaxMode)
expect.Equal(t, setCookie.HttpOnly, true)
}
})
}
@ -245,7 +245,7 @@ func TestInitOIDC(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
ExpectNoError(t, json.NewEncoder(w).Encode(discoveryDocument(t, server)))
expect.NoError(t, json.NewEncoder(w).Encode(discoveryDocument(t, server)))
})
server = httptest.NewServer(mux)
t.Cleanup(server.Close)
@ -446,9 +446,9 @@ func TestCheckToken(t *testing.T) {
// Call CheckToken and verify the result.
err := auth.CheckToken(req)
if tc.wantErr == nil {
ExpectNoError(t, err)
expect.NoError(t, err)
} else {
ExpectError(t, tc.wantErr, err)
expect.ErrorIs(t, tc.wantErr, err)
}
})
}

View file

@ -10,14 +10,14 @@ import (
"github.com/yusing/go-proxy/pkg/json"
. "github.com/yusing/go-proxy/internal/utils/testing"
expect "github.com/yusing/go-proxy/internal/utils/testing"
"golang.org/x/crypto/bcrypt"
)
func newMockUserPassAuth() *UserPassAuth {
return &UserPassAuth{
username: "username",
pwdHash: Must(bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost)),
pwdHash: expect.Must(bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost)),
secret: []byte("abcdefghijklmnopqrstuvwxyz"),
tokenTTL: time.Hour,
}
@ -26,17 +26,17 @@ func newMockUserPassAuth() *UserPassAuth {
func TestUserPassValidateCredentials(t *testing.T) {
auth := newMockUserPassAuth()
err := auth.validatePassword("username", "password")
ExpectNoError(t, err)
expect.NoError(t, err)
err = auth.validatePassword("username", "wrong-password")
ExpectError(t, ErrInvalidPassword, err)
expect.ErrorIs(t, ErrInvalidPassword, err)
err = auth.validatePassword("wrong-username", "password")
ExpectError(t, ErrInvalidUsername, err)
expect.ErrorIs(t, ErrInvalidUsername, err)
}
func TestUserPassCheckToken(t *testing.T) {
auth := newMockUserPassAuth()
token, err := auth.NewToken()
ExpectNoError(t, err)
expect.NoError(t, err)
tests := []struct {
token string
wantErr bool
@ -61,9 +61,9 @@ func TestUserPassCheckToken(t *testing.T) {
}
err = auth.CheckToken(req)
if tt.wantErr {
ExpectTrue(t, err != nil)
expect.True(t, err != nil)
} else {
ExpectNoError(t, err)
expect.NoError(t, err)
}
}
}
@ -97,20 +97,20 @@ func TestUserPassLoginCallbackHandler(t *testing.T) {
w := httptest.NewRecorder()
req := &http.Request{
Host: "app.example.com",
Body: io.NopCloser(bytes.NewReader(Must(json.Marshal(tt.creds)))),
Body: io.NopCloser(bytes.NewReader(expect.Must(json.Marshal(tt.creds)))),
}
auth.LoginCallbackHandler(w, req)
if tt.wantErr {
ExpectEqual(t, w.Code, http.StatusUnauthorized)
expect.Equal(t, w.Code, http.StatusUnauthorized)
} else {
setCookie := Must(http.ParseSetCookie(w.Header().Get("Set-Cookie")))
ExpectTrue(t, setCookie.Name == auth.TokenCookieName())
ExpectTrue(t, setCookie.Value != "")
ExpectEqual(t, setCookie.Domain, "example.com")
ExpectEqual(t, setCookie.Path, "/")
ExpectEqual(t, setCookie.SameSite, http.SameSiteLaxMode)
ExpectEqual(t, setCookie.HttpOnly, true)
ExpectEqual(t, w.Code, http.StatusOK)
setCookie := expect.Must(http.ParseSetCookie(w.Header().Get("Set-Cookie")))
expect.True(t, setCookie.Name == auth.TokenCookieName())
expect.True(t, setCookie.Value != "")
expect.Equal(t, setCookie.Domain, "example.com")
expect.Equal(t, setCookie.Path, "/")
expect.Equal(t, setCookie.SameSite, http.SameSiteLaxMode)
expect.Equal(t, setCookie.HttpOnly, true)
expect.Equal(t, w.Code, http.StatusOK)
}
}
}

View file

@ -6,7 +6,8 @@ import (
"github.com/go-acme/lego/v4/providers/dns/ovh"
"github.com/goccy/go-yaml"
"github.com/yusing/go-proxy/internal/utils"
. "github.com/yusing/go-proxy/internal/utils/testing"
expect "github.com/yusing/go-proxy/internal/utils/testing"
)
// type Config struct {
@ -44,7 +45,7 @@ oauth2_config:
}
testYaml = testYaml[1:] // remove first \n
opt := make(map[string]any)
ExpectNoError(t, yaml.Unmarshal([]byte(testYaml), &opt))
ExpectNoError(t, utils.MapUnmarshalValidate(opt, cfg))
ExpectEqual(t, cfg, cfgExpected)
expect.NoError(t, yaml.Unmarshal([]byte(testYaml), &opt))
expect.NoError(t, utils.MapUnmarshalValidate(opt, cfg))
expect.Equal(t, cfg, cfgExpected)
}

View file

@ -6,13 +6,12 @@ import (
"testing"
"github.com/goccy/go-yaml"
"github.com/stretchr/testify/assert"
"github.com/yusing/go-proxy/agent/pkg/agent"
"github.com/yusing/go-proxy/internal/common"
config "github.com/yusing/go-proxy/internal/config/types"
"github.com/yusing/go-proxy/internal/route/provider"
"github.com/yusing/go-proxy/internal/utils"
. "github.com/yusing/go-proxy/internal/utils/testing"
expect "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestFileProviderValidate(t *testing.T) {
@ -57,10 +56,10 @@ func TestFileProviderValidate(t *testing.T) {
if tt.init != nil {
for _, filename := range tt.filenames {
filepath := path.Join(common.ConfigDir, filename)
assert.NoError(t, tt.init(filepath))
expect.NoError(t, tt.init(filepath))
}
}
err := utils.UnmarshalValidateYAML(Must(yaml.Marshal(map[string]any{
err := utils.UnmarshalValidateYAML(expect.Must(yaml.Marshal(map[string]any{
"providers": map[string]any{
"include": tt.filenames,
},
@ -68,13 +67,13 @@ func TestFileProviderValidate(t *testing.T) {
if tt.cleanup != nil {
for _, filename := range tt.filenames {
filepath := path.Join(common.ConfigDir, filename)
assert.NoError(t, tt.cleanup(filepath))
expect.NoError(t, tt.cleanup(filepath))
}
}
if tt.expectedErrorContains != "" {
assert.ErrorContains(t, err, tt.expectedErrorContains)
expect.ErrorContains(t, err, tt.expectedErrorContains)
} else {
assert.NoError(t, err)
expect.NoError(t, err)
}
})
}
@ -129,9 +128,9 @@ func TestLoadRouteProviders(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
err := utils.Validate(tt.providers)
if tt.expectedError {
assert.ErrorContains(t, err, "unique")
expect.ErrorContains(t, err, "unique")
} else {
assert.NoError(t, err)
expect.NoError(t, err)
}
})
}
@ -142,9 +141,9 @@ func TestProviderNameUniqueness(t *testing.T) {
docker := provider.NewDockerProvider("routes", "unix:///var/run/docker.sock")
agent := provider.NewAgentProvider(agent.TestAgentConfig("routes", "192.168.1.100:8080"))
assert.True(t, file.String() != docker.String())
assert.True(t, file.String() != agent.String())
assert.True(t, docker.String() != agent.String())
expect.NotEqual(t, file.String(), docker.String())
expect.NotEqual(t, file.String(), agent.String())
expect.NotEqual(t, docker.String(), agent.String())
}
func TestFileProviderNameFromFilename(t *testing.T) {
@ -160,7 +159,7 @@ func TestFileProviderNameFromFilename(t *testing.T) {
for _, tt := range tests {
t.Run(tt.filename, func(t *testing.T) {
p := provider.NewFileProvider(tt.filename)
assert.Equal(t, tt.expectedName, p.ShortName())
expect.Equal(t, tt.expectedName, p.ShortName())
})
}
}
@ -179,7 +178,7 @@ func TestDockerProviderString(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := provider.NewDockerProvider(tt.name, tt.dockerHost)
assert.Equal(t, tt.expected, p.String())
expect.Equal(t, tt.expected, p.String())
})
}
}
@ -196,7 +195,7 @@ func TestExplicitOnlyProvider(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := provider.NewDockerProvider(tt.name, "unix:///var/run/docker.sock")
assert.Equal(t, tt.expectedFlag, p.IsExplicitOnly())
expect.Equal(t, tt.expectedFlag, p.IsExplicitOnly())
})
}
}

View file

@ -4,7 +4,8 @@ import (
"testing"
"github.com/docker/docker/api/types/container"
. "github.com/yusing/go-proxy/internal/utils/testing"
expect "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestContainerExplicit(t *testing.T) {
@ -37,7 +38,7 @@ func TestContainerExplicit(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := FromDocker(&container.Summary{Names: []string{"test"}, State: "test", Labels: tt.labels}, "")
ExpectEqual(t, c.IsExplicit, tt.isExplicit)
expect.Equal(t, c.IsExplicit, tt.isExplicit)
})
}
}

View file

@ -5,7 +5,8 @@ import (
"github.com/yusing/go-proxy/internal/route"
"github.com/yusing/go-proxy/internal/route/routes"
. "github.com/yusing/go-proxy/internal/utils/testing"
expect "github.com/yusing/go-proxy/internal/utils/testing"
)
var ep = NewEntrypoint()
@ -29,15 +30,15 @@ func run(t *testing.T, match []string, noMatch []string) {
t.Run(test, func(t *testing.T) {
r := addRoute(test)
found, err := ep.findRouteFunc(test)
ExpectNoError(t, err)
ExpectTrue(t, found == r)
expect.NoError(t, err)
expect.True(t, found == r)
})
}
for _, test := range noMatch {
t.Run(test, func(t *testing.T) {
_, err := ep.findRouteFunc(test)
ExpectError(t, ErrNoSuchRoute, err)
expect.ErrorIs(t, ErrNoSuchRoute, err)
})
}
}

View file

@ -2,18 +2,17 @@ package gperr_test
import (
"context"
"errors"
"io"
"testing"
. "github.com/yusing/go-proxy/internal/gperr"
. "github.com/yusing/go-proxy/internal/utils/testing"
expect "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestBuilderEmpty(t *testing.T) {
eb := NewBuilder("foo")
ExpectTrue(t, errors.Is(eb.Error(), nil))
ExpectFalse(t, eb.HasError())
expect.NoError(t, eb.Error())
expect.False(t, eb.HasError())
}
func TestBuilderAddNil(t *testing.T) {
@ -26,17 +25,17 @@ func TestBuilderAddNil(t *testing.T) {
eb.Add(err)
}
eb.AddRange(nil, nil, err)
ExpectFalse(t, eb.HasError())
ExpectTrue(t, eb.Error() == nil)
expect.False(t, eb.HasError())
expect.NoError(t, eb.Error())
}
func TestBuilderIs(t *testing.T) {
eb := NewBuilder("foo")
eb.Add(context.Canceled)
eb.Add(io.ErrShortBuffer)
ExpectTrue(t, eb.HasError())
ExpectError(t, io.ErrShortBuffer, eb.Error())
ExpectError(t, context.Canceled, eb.Error())
expect.True(t, eb.HasError())
expect.ErrorIs(t, io.ErrShortBuffer, eb.Error())
expect.ErrorIs(t, context.Canceled, eb.Error())
}
func TestBuilderNested(t *testing.T) {
@ -51,5 +50,5 @@ func TestBuilderNested(t *testing.T) {
Inner: 2
Action 2
Inner: 3`
ExpectEqual(t, got, expected)
expect.Equal(t, got, expected)
}

View file

@ -6,11 +6,11 @@ import (
"testing"
"github.com/yusing/go-proxy/internal/utils/strutils/ansi"
. "github.com/yusing/go-proxy/internal/utils/testing"
expect "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestBaseString(t *testing.T) {
ExpectEqual(t, New("error").Error(), "error")
expect.Equal(t, New("error").Error(), "error")
}
func TestBaseWithSubject(t *testing.T) {
@ -18,13 +18,13 @@ func TestBaseWithSubject(t *testing.T) {
withSubject := err.Subject("foo")
withSubjectf := err.Subjectf("%s %s", "foo", "bar")
ExpectError(t, err, withSubject)
ExpectEqual(t, ansi.StripANSI(withSubject.Error()), "foo: error")
ExpectTrue(t, withSubject.Is(err))
expect.ErrorIs(t, err, withSubject)
expect.Equal(t, ansi.StripANSI(withSubject.Error()), "foo: error")
expect.True(t, withSubject.Is(err))
ExpectError(t, err, withSubjectf)
ExpectEqual(t, ansi.StripANSI(withSubjectf.Error()), "foo bar: error")
ExpectTrue(t, withSubjectf.Is(err))
expect.ErrorIs(t, err, withSubjectf)
expect.Equal(t, ansi.StripANSI(withSubjectf.Error()), "foo bar: error")
expect.True(t, withSubjectf.Is(err))
}
func TestBaseWithExtra(t *testing.T) {
@ -32,22 +32,22 @@ func TestBaseWithExtra(t *testing.T) {
extra := New("bar").Subject("baz")
withExtra := err.With(extra)
ExpectTrue(t, withExtra.Is(extra))
ExpectTrue(t, withExtra.Is(err))
expect.True(t, withExtra.Is(extra))
expect.True(t, withExtra.Is(err))
ExpectTrue(t, errors.Is(withExtra, extra))
ExpectTrue(t, errors.Is(withExtra, err))
expect.True(t, errors.Is(withExtra, extra))
expect.True(t, errors.Is(withExtra, err))
ExpectTrue(t, strings.Contains(withExtra.Error(), err.Error()))
ExpectTrue(t, strings.Contains(withExtra.Error(), extra.Error()))
ExpectTrue(t, strings.Contains(withExtra.Error(), "baz"))
expect.True(t, strings.Contains(withExtra.Error(), err.Error()))
expect.True(t, strings.Contains(withExtra.Error(), extra.Error()))
expect.True(t, strings.Contains(withExtra.Error(), "baz"))
}
func TestBaseUnwrap(t *testing.T) {
err := errors.New("err")
wrapped := Wrap(err)
ExpectError(t, err, errors.Unwrap(wrapped))
expect.ErrorIs(t, err, errors.Unwrap(wrapped))
}
func TestNestedUnwrap(t *testing.T) {
@ -56,24 +56,24 @@ func TestNestedUnwrap(t *testing.T) {
wrapped := Wrap(err).Subject("foo").With(err2.Subject("bar"))
unwrapper, ok := wrapped.(interface{ Unwrap() []error })
ExpectTrue(t, ok)
expect.True(t, ok)
ExpectError(t, err, wrapped)
ExpectError(t, err2, wrapped)
ExpectEqual(t, len(unwrapper.Unwrap()), 2)
expect.ErrorIs(t, err, wrapped)
expect.ErrorIs(t, err2, wrapped)
expect.Equal(t, len(unwrapper.Unwrap()), 2)
}
func TestErrorIs(t *testing.T) {
from := errors.New("error")
err := Wrap(from)
ExpectError(t, from, err)
expect.ErrorIs(t, from, err)
ExpectTrue(t, err.Is(from))
ExpectFalse(t, err.Is(New("error")))
expect.True(t, err.Is(from))
expect.False(t, err.Is(New("error")))
ExpectTrue(t, errors.Is(err.Subject("foo"), from))
ExpectTrue(t, errors.Is(err.Withf("foo"), from))
ExpectTrue(t, errors.Is(err.Subject("foo").Withf("bar"), from))
expect.True(t, errors.Is(err.Subject("foo"), from))
expect.True(t, errors.Is(err.Withf("foo"), from))
expect.True(t, errors.Is(err.Subject("foo").Withf("bar"), from))
}
func TestErrorImmutability(t *testing.T) {
@ -83,14 +83,14 @@ func TestErrorImmutability(t *testing.T) {
for range 3 {
// t.Logf("%d: %v %T %s", i, errors.Unwrap(err), err, err)
_ = err.Subject("foo")
ExpectFalse(t, strings.Contains(err.Error(), "foo"))
expect.False(t, strings.Contains(err.Error(), "foo"))
_ = err.With(err2)
ExpectFalse(t, strings.Contains(err.Error(), "extra"))
ExpectFalse(t, err.Is(err2))
expect.False(t, strings.Contains(err.Error(), "extra"))
expect.False(t, err.Is(err2))
err = err.Subject("bar").Withf("baz")
ExpectTrue(t, err != nil)
expect.True(t, err != nil)
}
}
@ -100,24 +100,24 @@ func TestErrorWith(t *testing.T) {
err3 := err1.With(err2)
ExpectTrue(t, err3.Is(err1))
ExpectTrue(t, err3.Is(err2))
expect.True(t, err3.Is(err1))
expect.True(t, err3.Is(err2))
_ = err2.Subject("foo")
ExpectTrue(t, err3.Is(err1))
ExpectTrue(t, err3.Is(err2))
expect.True(t, err3.Is(err1))
expect.True(t, err3.Is(err2))
// check if err3 is affected by err2.Subject
ExpectFalse(t, strings.Contains(err3.Error(), "foo"))
expect.False(t, strings.Contains(err3.Error(), "foo"))
}
func TestErrorStringSimple(t *testing.T) {
errFailure := New("generic failure")
ne := errFailure.Subject("foo bar")
ExpectEqual(t, ansi.StripANSI(ne.Error()), "foo bar: generic failure")
expect.Equal(t, ansi.StripANSI(ne.Error()), "foo bar: generic failure")
ne = ne.Subject("baz")
ExpectEqual(t, ansi.StripANSI(ne.Error()), "baz > foo bar: generic failure")
expect.Equal(t, ansi.StripANSI(ne.Error()), "baz > foo bar: generic failure")
}
func TestErrorStringNested(t *testing.T) {
@ -154,5 +154,5 @@ func TestErrorStringNested(t *testing.T) {
action 3 > inner3: generic failure
3
3`
ExpectEqual(t, ansi.StripANSI(ne.Error()), want)
expect.Equal(t, ansi.StripANSI(ne.Error()), want)
}

View file

@ -4,7 +4,7 @@ import (
"net"
"testing"
"github.com/stretchr/testify/require"
expect "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestWrapMultiline(t *testing.T) {
@ -25,5 +25,5 @@ func TestPrependSubjectMultiline(t *testing.T) {
builder := NewBuilder()
builder.Add(multiline)
require.Equal(t, len(builder.errs), len(multiline.Extras), builder.errs)
expect.Equal(t, len(multiline.Extras), len(builder.errs))
}

View file

@ -5,8 +5,8 @@ import (
"slices"
"strings"
"github.com/yusing/go-proxy/pkg/json"
"github.com/yusing/go-proxy/internal/utils/strutils/ansi"
"github.com/yusing/go-proxy/pkg/json"
)
//nolint:errname

View file

@ -3,7 +3,7 @@ package homepage
import (
"testing"
. "github.com/yusing/go-proxy/internal/utils/testing"
expect "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestOverrideItem(t *testing.T) {
@ -32,5 +32,5 @@ func TestOverrideItem(t *testing.T) {
overrides := GetOverrideConfig()
overrides.OverrideItem(a.Alias, want)
got := a.GetOverride(a.Alias)
ExpectEqual(t, got, want)
expect.Equal(t, got, want)
}

View file

@ -3,7 +3,7 @@ package homepage
import (
"testing"
. "github.com/yusing/go-proxy/internal/utils/testing"
expect "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestIconURL(t *testing.T) {
@ -114,11 +114,11 @@ func TestIconURL(t *testing.T) {
u := &IconURL{}
err := u.Parse(tc.input)
if tc.wantErr {
ExpectError(t, ErrInvalidIconURL, err)
expect.ErrorIs(t, ErrInvalidIconURL, err)
} else {
tc.wantValue.FullValue = tc.input
ExpectNoError(t, err)
ExpectEqual(t, u, tc.wantValue)
expect.NoError(t, err)
expect.Equal(t, u, tc.wantValue)
}
})
}

View file

@ -3,7 +3,7 @@ package idlewatcher
import (
"testing"
. "github.com/yusing/go-proxy/internal/utils/testing"
expect "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestValidateStartEndpoint(t *testing.T) {
@ -38,7 +38,7 @@ func TestValidateStartEndpoint(t *testing.T) {
cfg := Config{StartEndpoint: tc.input}
err := cfg.validateStartEndpoint()
if err == nil {
ExpectEqual(t, cfg.StartEndpoint, tc.input)
expect.Equal(t, cfg.StartEndpoint, tc.input)
}
if (err != nil) != tc.wantErr {
t.Errorf("validateStartEndpoint() error = %v, wantErr %t", err, tc.wantErr)

View file

@ -6,8 +6,9 @@ import (
"testing"
"github.com/shirou/gopsutil/v4/sensors"
. "github.com/yusing/go-proxy/internal/utils/testing"
"github.com/yusing/go-proxy/pkg/json"
expect "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestExcludeDisks(t *testing.T) {
@ -72,7 +73,7 @@ func TestExcludeDisks(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := shouldExcludeDisk(tt.name)
ExpectEqual(t, result, tt.shouldExclude)
expect.Equal(t, result, tt.shouldExclude)
})
}
}
@ -147,21 +148,21 @@ var testInfo = &SystemInfo{
func TestSystemInfo(t *testing.T) {
// Test marshaling
data, err := json.Marshal(testInfo)
ExpectNoError(t, err)
expect.NoError(t, err)
// Test unmarshaling back
var decoded SystemInfo
err = json.Unmarshal(data, &decoded)
ExpectNoError(t, err)
expect.NoError(t, err)
// Compare original and decoded
ExpectEqual(t, decoded.Timestamp, testInfo.Timestamp)
ExpectEqual(t, *decoded.CPUAverage, *testInfo.CPUAverage)
ExpectEqual(t, decoded.Memory, testInfo.Memory)
ExpectEqual(t, decoded.Disks, testInfo.Disks)
ExpectEqual(t, decoded.DisksIO, testInfo.DisksIO)
ExpectEqual(t, decoded.Network, testInfo.Network)
ExpectEqual(t, decoded.Sensors, testInfo.Sensors)
expect.Equal(t, decoded.Timestamp, testInfo.Timestamp)
expect.Equal(t, *decoded.CPUAverage, *testInfo.CPUAverage)
expect.Equal(t, decoded.Memory, testInfo.Memory)
expect.Equal(t, decoded.Disks, testInfo.Disks)
expect.Equal(t, decoded.DisksIO, testInfo.DisksIO)
expect.Equal(t, decoded.Network, testInfo.Network)
expect.Equal(t, decoded.Sensors, testInfo.Sensors)
// Test nil fields
nilInfo := &SystemInfo{
@ -169,18 +170,18 @@ func TestSystemInfo(t *testing.T) {
}
data, err = json.Marshal(nilInfo)
ExpectNoError(t, err)
expect.NoError(t, err)
var decodedNil SystemInfo
err = json.Unmarshal(data, &decodedNil)
ExpectNoError(t, err)
expect.NoError(t, err)
ExpectEqual(t, decodedNil.Timestamp, nilInfo.Timestamp)
ExpectTrue(t, decodedNil.CPUAverage == nil)
ExpectTrue(t, decodedNil.Memory == nil)
ExpectTrue(t, decodedNil.Disks == nil)
ExpectTrue(t, decodedNil.Network == nil)
ExpectTrue(t, decodedNil.Sensors == nil)
expect.Equal(t, decodedNil.Timestamp, nilInfo.Timestamp)
expect.True(t, decodedNil.CPUAverage == nil)
expect.True(t, decodedNil.Memory == nil)
expect.True(t, decodedNil.Disks == nil)
expect.True(t, decodedNil.Network == nil)
expect.True(t, decodedNil.Sensors == nil)
}
func TestSerialize(t *testing.T) {
@ -193,13 +194,13 @@ func TestSerialize(t *testing.T) {
_, result := aggregate(entries, url.Values{"aggregate": []string{query}})
s := result.MarshalJSONTo(nil)
var v []map[string]any
ExpectNoError(t, json.Unmarshal(s, &v))
ExpectEqual(t, len(v), len(result))
expect.NoError(t, json.Unmarshal(s, &v))
expect.Equal(t, len(v), len(result))
for i, m := range v {
for k, v := range m {
// some int64 values are converted to float64 on json.Unmarshal
vv := reflect.ValueOf(result[i][k])
ExpectEqual(t, reflect.ValueOf(v).Convert(vv.Type()).Interface(), vv.Interface())
expect.Equal(t, reflect.ValueOf(v).Convert(vv.Type()).Interface(), vv.Interface())
}
}
})

View file

@ -12,7 +12,7 @@ import (
. "github.com/yusing/go-proxy/internal/net/gphttp/accesslog"
"github.com/yusing/go-proxy/internal/task"
. "github.com/yusing/go-proxy/internal/utils/testing"
expect "github.com/yusing/go-proxy/internal/utils/testing"
)
const (
@ -30,7 +30,7 @@ const (
var (
testTask = task.RootTask("test", false)
testURL = Must(url.Parse("http://" + host + uri))
testURL = expect.Must(url.Parse("http://" + host + uri))
req = &http.Request{
RemoteAddr: remote,
Method: method,
@ -69,7 +69,7 @@ func TestAccessLoggerCommon(t *testing.T) {
config := DefaultConfig()
config.Format = FormatCommon
ts, log := fmtLog(config)
ExpectEqual(t, log,
expect.Equal(t, log,
fmt.Sprintf("%s %s - - [%s] \"%s %s %s\" %d %d",
host, remote, ts, method, uri, proto, status, contentLength,
),
@ -80,7 +80,7 @@ func TestAccessLoggerCombined(t *testing.T) {
config := DefaultConfig()
config.Format = FormatCombined
ts, log := fmtLog(config)
ExpectEqual(t, log,
expect.Equal(t, log,
fmt.Sprintf("%s %s - - [%s] \"%s %s %s\" %d %d \"%s\" \"%s\"",
host, remote, ts, method, uri, proto, status, contentLength, referer, ua,
),
@ -92,7 +92,7 @@ func TestAccessLoggerRedactQuery(t *testing.T) {
config.Format = FormatCommon
config.Fields.Query.Default = FieldModeRedact
ts, log := fmtLog(config)
ExpectEqual(t, log,
expect.Equal(t, log,
fmt.Sprintf("%s %s - - [%s] \"%s %s %s\" %d %d",
host, remote, ts, method, uriRedacted, proto, status, contentLength,
),
@ -124,27 +124,27 @@ func getJSONEntry(t *testing.T, config *Config) JSONLogEntry {
var entry JSONLogEntry
_, log := fmtLog(config)
err := json.Unmarshal([]byte(log), &entry)
ExpectNoError(t, err)
expect.NoError(t, err)
return entry
}
func TestAccessLoggerJSON(t *testing.T) {
config := DefaultConfig()
entry := getJSONEntry(t, config)
ExpectEqual(t, entry.IP, remote)
ExpectEqual(t, entry.Method, method)
ExpectEqual(t, entry.Scheme, "http")
ExpectEqual(t, entry.Host, testURL.Host)
ExpectEqual(t, entry.URI, testURL.RequestURI())
ExpectEqual(t, entry.Protocol, proto)
ExpectEqual(t, entry.Status, status)
ExpectEqual(t, entry.ContentType, "text/plain")
ExpectEqual(t, entry.Size, contentLength)
ExpectEqual(t, entry.Referer, referer)
ExpectEqual(t, entry.UserAgent, ua)
ExpectEqual(t, len(entry.Headers), 0)
ExpectEqual(t, len(entry.Cookies), 0)
expect.Equal(t, entry.IP, remote)
expect.Equal(t, entry.Method, method)
expect.Equal(t, entry.Scheme, "http")
expect.Equal(t, entry.Host, testURL.Host)
expect.Equal(t, entry.URI, testURL.RequestURI())
expect.Equal(t, entry.Protocol, proto)
expect.Equal(t, entry.Status, status)
expect.Equal(t, entry.ContentType, "text/plain")
expect.Equal(t, entry.Size, contentLength)
expect.Equal(t, entry.Referer, referer)
expect.Equal(t, entry.UserAgent, ua)
expect.Equal(t, len(entry.Headers), 0)
expect.Equal(t, len(entry.Cookies), 0)
if status >= 400 {
ExpectEqual(t, entry.Error, http.StatusText(status))
expect.Equal(t, entry.Error, http.StatusText(status))
}
}

View file

@ -6,7 +6,7 @@ import (
"github.com/yusing/go-proxy/internal/docker"
. "github.com/yusing/go-proxy/internal/net/gphttp/accesslog"
"github.com/yusing/go-proxy/internal/utils"
. "github.com/yusing/go-proxy/internal/utils/testing"
expect "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestNewConfig(t *testing.T) {
@ -27,27 +27,27 @@ func TestNewConfig(t *testing.T) {
"proxy.fields.cookies.config.foo": "keep",
}
parsed, err := docker.ParseLabels(labels)
ExpectNoError(t, err)
expect.NoError(t, err)
var config Config
err = utils.MapUnmarshalValidate(parsed, &config)
ExpectNoError(t, err)
expect.NoError(t, err)
ExpectEqual(t, config.BufferSize, 10)
ExpectEqual(t, config.Format, FormatCombined)
ExpectEqual(t, config.Path, "/tmp/access.log")
ExpectEqual(t, config.Filters.StatusCodes.Values, []*StatusCodeRange{{Start: 200, End: 299}})
ExpectEqual(t, len(config.Filters.Method.Values), 2)
ExpectEqual(t, config.Filters.Method.Values, []HTTPMethod{"GET", "POST"})
ExpectEqual(t, len(config.Filters.Headers.Values), 2)
ExpectEqual(t, config.Filters.Headers.Values, []*HTTPHeader{{Key: "foo", Value: "bar"}, {Key: "baz", Value: ""}})
ExpectTrue(t, config.Filters.Headers.Negative)
ExpectEqual(t, len(config.Filters.CIDR.Values), 1)
ExpectEqual(t, config.Filters.CIDR.Values[0].String(), "192.168.10.0/24")
ExpectEqual(t, config.Fields.Headers.Default, FieldModeKeep)
ExpectEqual(t, config.Fields.Headers.Config["foo"], FieldModeRedact)
ExpectEqual(t, config.Fields.Query.Default, FieldModeDrop)
ExpectEqual(t, config.Fields.Query.Config["foo"], FieldModeKeep)
ExpectEqual(t, config.Fields.Cookies.Default, FieldModeRedact)
ExpectEqual(t, config.Fields.Cookies.Config["foo"], FieldModeKeep)
expect.Equal(t, config.BufferSize, 10)
expect.Equal(t, config.Format, FormatCombined)
expect.Equal(t, config.Path, "/tmp/access.log")
expect.Equal(t, config.Filters.StatusCodes.Values, []*StatusCodeRange{{Start: 200, End: 299}})
expect.Equal(t, len(config.Filters.Method.Values), 2)
expect.Equal(t, config.Filters.Method.Values, []HTTPMethod{"GET", "POST"})
expect.Equal(t, len(config.Filters.Headers.Values), 2)
expect.Equal(t, config.Filters.Headers.Values, []*HTTPHeader{{Key: "foo", Value: "bar"}, {Key: "baz", Value: ""}})
expect.True(t, config.Filters.Headers.Negative)
expect.Equal(t, len(config.Filters.CIDR.Values), 1)
expect.Equal(t, config.Filters.CIDR.Values[0].String(), "192.168.10.0/24")
expect.Equal(t, config.Fields.Headers.Default, FieldModeKeep)
expect.Equal(t, config.Fields.Headers.Config["foo"], FieldModeRedact)
expect.Equal(t, config.Fields.Query.Default, FieldModeDrop)
expect.Equal(t, config.Fields.Query.Config["foo"], FieldModeKeep)
expect.Equal(t, config.Fields.Cookies.Default, FieldModeRedact)
expect.Equal(t, config.Fields.Cookies.Config["foo"], FieldModeKeep)
}

View file

@ -4,7 +4,7 @@ import (
"testing"
. "github.com/yusing/go-proxy/internal/net/gphttp/accesslog"
. "github.com/yusing/go-proxy/internal/utils/testing"
expect "github.com/yusing/go-proxy/internal/utils/testing"
)
// Cookie header should be removed,
@ -15,7 +15,7 @@ func TestAccessLoggerJSONKeepHeaders(t *testing.T) {
entry := getJSONEntry(t, config)
for k, v := range req.Header {
if k != "Cookie" {
ExpectEqual(t, entry.Headers[k], v)
expect.Equal(t, entry.Headers[k], v)
}
}
@ -24,8 +24,8 @@ func TestAccessLoggerJSONKeepHeaders(t *testing.T) {
"User-Agent": FieldModeDrop,
}
entry = getJSONEntry(t, config)
ExpectEqual(t, entry.Headers["Referer"], []string{RedactedValue})
ExpectEqual(t, entry.Headers["User-Agent"], nil)
expect.Equal(t, entry.Headers["Referer"], []string{RedactedValue})
expect.Equal(t, entry.Headers["User-Agent"], nil)
}
func TestAccessLoggerJSONDropHeaders(t *testing.T) {
@ -33,7 +33,7 @@ func TestAccessLoggerJSONDropHeaders(t *testing.T) {
config.Fields.Headers.Default = FieldModeDrop
entry := getJSONEntry(t, config)
for k := range req.Header {
ExpectEqual(t, entry.Headers[k], nil)
expect.Equal(t, entry.Headers[k], nil)
}
config.Fields.Headers.Config = map[string]FieldMode{
@ -41,18 +41,18 @@ func TestAccessLoggerJSONDropHeaders(t *testing.T) {
"User-Agent": FieldModeRedact,
}
entry = getJSONEntry(t, config)
ExpectEqual(t, entry.Headers["Referer"], []string{req.Header.Get("Referer")})
ExpectEqual(t, entry.Headers["User-Agent"], []string{RedactedValue})
expect.Equal(t, entry.Headers["Referer"], []string{req.Header.Get("Referer")})
expect.Equal(t, entry.Headers["User-Agent"], []string{RedactedValue})
}
func TestAccessLoggerJSONRedactHeaders(t *testing.T) {
config := DefaultConfig()
config.Fields.Headers.Default = FieldModeRedact
entry := getJSONEntry(t, config)
ExpectEqual(t, len(entry.Headers["Cookie"]), 0)
expect.Equal(t, len(entry.Headers["Cookie"]), 0)
for k := range req.Header {
if k != "Cookie" {
ExpectEqual(t, entry.Headers[k], []string{RedactedValue})
expect.Equal(t, entry.Headers[k], []string{RedactedValue})
}
}
}
@ -62,9 +62,9 @@ func TestAccessLoggerJSONKeepCookies(t *testing.T) {
config.Fields.Headers.Default = FieldModeKeep
config.Fields.Cookies.Default = FieldModeKeep
entry := getJSONEntry(t, config)
ExpectEqual(t, len(entry.Headers["Cookie"]), 0)
expect.Equal(t, len(entry.Headers["Cookie"]), 0)
for _, cookie := range req.Cookies() {
ExpectEqual(t, entry.Cookies[cookie.Name], cookie.Value)
expect.Equal(t, entry.Cookies[cookie.Name], cookie.Value)
}
}
@ -73,9 +73,9 @@ func TestAccessLoggerJSONRedactCookies(t *testing.T) {
config.Fields.Headers.Default = FieldModeKeep
config.Fields.Cookies.Default = FieldModeRedact
entry := getJSONEntry(t, config)
ExpectEqual(t, len(entry.Headers["Cookie"]), 0)
expect.Equal(t, len(entry.Headers["Cookie"]), 0)
for _, cookie := range req.Cookies() {
ExpectEqual(t, entry.Cookies[cookie.Name], RedactedValue)
expect.Equal(t, entry.Cookies[cookie.Name], RedactedValue)
}
}
@ -83,14 +83,14 @@ func TestAccessLoggerJSONDropQuery(t *testing.T) {
config := DefaultConfig()
config.Fields.Query.Default = FieldModeDrop
entry := getJSONEntry(t, config)
ExpectEqual(t, entry.Query["foo"], nil)
ExpectEqual(t, entry.Query["bar"], nil)
expect.Equal(t, entry.Query["foo"], nil)
expect.Equal(t, entry.Query["bar"], nil)
}
func TestAccessLoggerJSONRedactQuery(t *testing.T) {
config := DefaultConfig()
config.Fields.Query.Default = FieldModeRedact
entry := getJSONEntry(t, config)
ExpectEqual(t, entry.Query["foo"], []string{RedactedValue})
ExpectEqual(t, entry.Query["bar"], []string{RedactedValue})
expect.Equal(t, entry.Query["foo"], []string{RedactedValue})
expect.Equal(t, entry.Query["bar"], []string{RedactedValue})
}

View file

@ -6,7 +6,7 @@ import (
"sync"
"testing"
. "github.com/yusing/go-proxy/internal/utils/testing"
expect "github.com/yusing/go-proxy/internal/utils/testing"
"github.com/yusing/go-proxy/internal/task"
)
@ -22,10 +22,10 @@ func TestConcurrentFileLoggersShareSameAccessLogIO(t *testing.T) {
// make test log file
file, err := os.Create(cfg.Path)
ExpectNoError(t, err)
expect.NoError(t, err)
file.Close()
t.Cleanup(func() {
ExpectNoError(t, os.Remove(cfg.Path))
expect.NoError(t, os.Remove(cfg.Path))
})
for i := range loggerCount {
@ -33,7 +33,7 @@ func TestConcurrentFileLoggersShareSameAccessLogIO(t *testing.T) {
go func(index int) {
defer wg.Done()
file, err := newFileIO(cfg.Path)
ExpectNoError(t, err)
expect.NoError(t, err)
accessLogIOs[index] = file
}(i)
}
@ -42,7 +42,7 @@ func TestConcurrentFileLoggersShareSameAccessLogIO(t *testing.T) {
firstIO := accessLogIOs[0]
for _, io := range accessLogIOs {
ExpectEqual(t, io, firstIO)
expect.Equal(t, io, firstIO)
}
}
@ -78,7 +78,7 @@ func TestConcurrentAccessLoggerLogAndFlush(t *testing.T) {
expected := loggerCount * logCountPerLogger
actual := file.LineCount()
ExpectEqual(t, actual, expected)
expect.Equal(t, actual, expected)
}
func parallelLog(logger *AccessLogger, req *http.Request, resp *http.Response, n int) {

View file

@ -7,7 +7,7 @@ import (
. "github.com/yusing/go-proxy/internal/net/gphttp/accesslog"
"github.com/yusing/go-proxy/internal/utils/strutils"
. "github.com/yusing/go-proxy/internal/utils/testing"
expect "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestStatusCodeFilter(t *testing.T) {
@ -16,20 +16,20 @@ func TestStatusCodeFilter(t *testing.T) {
}
t.Run("positive", func(t *testing.T) {
filter := &LogFilter[*StatusCodeRange]{}
ExpectTrue(t, filter.CheckKeep(nil, nil))
expect.True(t, filter.CheckKeep(nil, nil))
// keep any 2xx 3xx (inclusive)
filter.Values = values
ExpectFalse(t, filter.CheckKeep(nil, &http.Response{
expect.False(t, filter.CheckKeep(nil, &http.Response{
StatusCode: http.StatusForbidden,
}))
ExpectTrue(t, filter.CheckKeep(nil, &http.Response{
expect.True(t, filter.CheckKeep(nil, &http.Response{
StatusCode: http.StatusOK,
}))
ExpectTrue(t, filter.CheckKeep(nil, &http.Response{
expect.True(t, filter.CheckKeep(nil, &http.Response{
StatusCode: http.StatusMultipleChoices,
}))
ExpectTrue(t, filter.CheckKeep(nil, &http.Response{
expect.True(t, filter.CheckKeep(nil, &http.Response{
StatusCode: http.StatusPermanentRedirect,
}))
})
@ -38,20 +38,20 @@ func TestStatusCodeFilter(t *testing.T) {
filter := &LogFilter[*StatusCodeRange]{
Negative: true,
}
ExpectFalse(t, filter.CheckKeep(nil, nil))
expect.False(t, filter.CheckKeep(nil, nil))
// drop any 2xx 3xx (inclusive)
filter.Values = values
ExpectTrue(t, filter.CheckKeep(nil, &http.Response{
expect.True(t, filter.CheckKeep(nil, &http.Response{
StatusCode: http.StatusForbidden,
}))
ExpectFalse(t, filter.CheckKeep(nil, &http.Response{
expect.False(t, filter.CheckKeep(nil, &http.Response{
StatusCode: http.StatusOK,
}))
ExpectFalse(t, filter.CheckKeep(nil, &http.Response{
expect.False(t, filter.CheckKeep(nil, &http.Response{
StatusCode: http.StatusMultipleChoices,
}))
ExpectFalse(t, filter.CheckKeep(nil, &http.Response{
expect.False(t, filter.CheckKeep(nil, &http.Response{
StatusCode: http.StatusPermanentRedirect,
}))
})
@ -60,19 +60,19 @@ func TestStatusCodeFilter(t *testing.T) {
func TestMethodFilter(t *testing.T) {
t.Run("positive", func(t *testing.T) {
filter := &LogFilter[HTTPMethod]{}
ExpectTrue(t, filter.CheckKeep(&http.Request{
expect.True(t, filter.CheckKeep(&http.Request{
Method: http.MethodGet,
}, nil))
ExpectTrue(t, filter.CheckKeep(&http.Request{
expect.True(t, filter.CheckKeep(&http.Request{
Method: http.MethodPost,
}, nil))
// keep get only
filter.Values = []HTTPMethod{http.MethodGet}
ExpectTrue(t, filter.CheckKeep(&http.Request{
expect.True(t, filter.CheckKeep(&http.Request{
Method: http.MethodGet,
}, nil))
ExpectFalse(t, filter.CheckKeep(&http.Request{
expect.False(t, filter.CheckKeep(&http.Request{
Method: http.MethodPost,
}, nil))
})
@ -81,19 +81,19 @@ func TestMethodFilter(t *testing.T) {
filter := &LogFilter[HTTPMethod]{
Negative: true,
}
ExpectFalse(t, filter.CheckKeep(&http.Request{
expect.False(t, filter.CheckKeep(&http.Request{
Method: http.MethodGet,
}, nil))
ExpectFalse(t, filter.CheckKeep(&http.Request{
expect.False(t, filter.CheckKeep(&http.Request{
Method: http.MethodPost,
}, nil))
// drop post only
filter.Values = []HTTPMethod{http.MethodPost}
ExpectFalse(t, filter.CheckKeep(&http.Request{
expect.False(t, filter.CheckKeep(&http.Request{
Method: http.MethodPost,
}, nil))
ExpectTrue(t, filter.CheckKeep(&http.Request{
expect.True(t, filter.CheckKeep(&http.Request{
Method: http.MethodGet,
}, nil))
})
@ -113,45 +113,45 @@ func TestHeaderFilter(t *testing.T) {
headerFoo := []*HTTPHeader{
strutils.MustParse[*HTTPHeader]("Foo"),
}
ExpectEqual(t, headerFoo[0].Key, "Foo")
ExpectEqual(t, headerFoo[0].Value, "")
expect.Equal(t, headerFoo[0].Key, "Foo")
expect.Equal(t, headerFoo[0].Value, "")
headerFooBar := []*HTTPHeader{
strutils.MustParse[*HTTPHeader]("Foo=bar"),
}
ExpectEqual(t, headerFooBar[0].Key, "Foo")
ExpectEqual(t, headerFooBar[0].Value, "bar")
expect.Equal(t, headerFooBar[0].Key, "Foo")
expect.Equal(t, headerFooBar[0].Value, "bar")
t.Run("positive", func(t *testing.T) {
filter := &LogFilter[*HTTPHeader]{}
ExpectTrue(t, filter.CheckKeep(fooBar, nil))
ExpectTrue(t, filter.CheckKeep(fooBaz, nil))
expect.True(t, filter.CheckKeep(fooBar, nil))
expect.True(t, filter.CheckKeep(fooBaz, nil))
// keep any foo
filter.Values = headerFoo
ExpectTrue(t, filter.CheckKeep(fooBar, nil))
ExpectTrue(t, filter.CheckKeep(fooBaz, nil))
expect.True(t, filter.CheckKeep(fooBar, nil))
expect.True(t, filter.CheckKeep(fooBaz, nil))
// keep foo == bar
filter.Values = headerFooBar
ExpectTrue(t, filter.CheckKeep(fooBar, nil))
ExpectFalse(t, filter.CheckKeep(fooBaz, nil))
expect.True(t, filter.CheckKeep(fooBar, nil))
expect.False(t, filter.CheckKeep(fooBaz, nil))
})
t.Run("negative", func(t *testing.T) {
filter := &LogFilter[*HTTPHeader]{
Negative: true,
}
ExpectFalse(t, filter.CheckKeep(fooBar, nil))
ExpectFalse(t, filter.CheckKeep(fooBaz, nil))
expect.False(t, filter.CheckKeep(fooBar, nil))
expect.False(t, filter.CheckKeep(fooBaz, nil))
// drop any foo
filter.Values = headerFoo
ExpectFalse(t, filter.CheckKeep(fooBar, nil))
ExpectFalse(t, filter.CheckKeep(fooBaz, nil))
expect.False(t, filter.CheckKeep(fooBar, nil))
expect.False(t, filter.CheckKeep(fooBaz, nil))
// drop foo == bar
filter.Values = headerFooBar
ExpectFalse(t, filter.CheckKeep(fooBar, nil))
ExpectTrue(t, filter.CheckKeep(fooBaz, nil))
expect.False(t, filter.CheckKeep(fooBar, nil))
expect.True(t, filter.CheckKeep(fooBaz, nil))
})
}
@ -160,7 +160,7 @@ func TestCIDRFilter(t *testing.T) {
IP: net.ParseIP("192.168.10.0"),
Mask: net.CIDRMask(24, 32),
}}
ExpectEqual(t, cidr[0].String(), "192.168.10.0/24")
expect.Equal(t, cidr[0].String(), "192.168.10.0/24")
inCIDR := &http.Request{
RemoteAddr: "192.168.10.1",
}
@ -170,21 +170,21 @@ func TestCIDRFilter(t *testing.T) {
t.Run("positive", func(t *testing.T) {
filter := &LogFilter[*CIDR]{}
ExpectTrue(t, filter.CheckKeep(inCIDR, nil))
ExpectTrue(t, filter.CheckKeep(notInCIDR, nil))
expect.True(t, filter.CheckKeep(inCIDR, nil))
expect.True(t, filter.CheckKeep(notInCIDR, nil))
filter.Values = cidr
ExpectTrue(t, filter.CheckKeep(inCIDR, nil))
ExpectFalse(t, filter.CheckKeep(notInCIDR, nil))
expect.True(t, filter.CheckKeep(inCIDR, nil))
expect.False(t, filter.CheckKeep(notInCIDR, nil))
})
t.Run("negative", func(t *testing.T) {
filter := &LogFilter[*CIDR]{Negative: true}
ExpectFalse(t, filter.CheckKeep(inCIDR, nil))
ExpectFalse(t, filter.CheckKeep(notInCIDR, nil))
expect.False(t, filter.CheckKeep(inCIDR, nil))
expect.False(t, filter.CheckKeep(notInCIDR, nil))
filter.Values = cidr
ExpectFalse(t, filter.CheckKeep(inCIDR, nil))
ExpectTrue(t, filter.CheckKeep(notInCIDR, nil))
expect.False(t, filter.CheckKeep(inCIDR, nil))
expect.True(t, filter.CheckKeep(notInCIDR, nil))
})
}

View file

@ -4,7 +4,7 @@ import (
"testing"
. "github.com/yusing/go-proxy/internal/net/gphttp/accesslog"
. "github.com/yusing/go-proxy/internal/utils/testing"
expect "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestParseRetention(t *testing.T) {
@ -24,9 +24,9 @@ func TestParseRetention(t *testing.T) {
r := &Retention{}
err := r.Parse(test.input)
if !test.shouldErr {
ExpectNoError(t, err)
expect.NoError(t, err)
} else {
ExpectEqual(t, r, test.expected)
expect.Equal(t, r, test.expected)
}
})
}

View file

@ -8,7 +8,7 @@ import (
. "github.com/yusing/go-proxy/internal/net/gphttp/accesslog"
"github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/internal/utils/strutils"
. "github.com/yusing/go-proxy/internal/utils/testing"
expect "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestParseLogTime(t *testing.T) {
@ -26,7 +26,7 @@ func TestParseLogTime(t *testing.T) {
for _, test := range tests {
t.Run(test, func(t *testing.T) {
actual := ParseLogTime([]byte(test))
ExpectTrue(t, actual.Equal(testTime))
expect.True(t, actual.Equal(testTime))
})
}
}
@ -43,16 +43,16 @@ func TestRetentionCommonFormat(t *testing.T) {
logger.Flush()
// test.Finish(nil)
ExpectEqual(t, logger.Config().Retention, nil)
ExpectTrue(t, file.Len() > 0)
ExpectEqual(t, file.LineCount(), 10)
expect.Equal(t, logger.Config().Retention, nil)
expect.True(t, file.Len() > 0)
expect.Equal(t, file.LineCount(), 10)
t.Run("keep last", func(t *testing.T) {
logger.Config().Retention = strutils.MustParse[*Retention]("last 5")
ExpectEqual(t, logger.Config().Retention.Days, 0)
ExpectEqual(t, logger.Config().Retention.Last, 5)
ExpectNoError(t, logger.Rotate())
ExpectEqual(t, file.LineCount(), 5)
expect.Equal(t, logger.Config().Retention.Days, 0)
expect.Equal(t, logger.Config().Retention.Last, 5)
expect.NoError(t, logger.Rotate())
expect.Equal(t, file.LineCount(), 5)
})
_ = file.Truncate(0)
@ -65,14 +65,14 @@ func TestRetentionCommonFormat(t *testing.T) {
logger.Log(req, resp)
}
logger.Flush()
ExpectEqual(t, file.LineCount(), 10)
expect.Equal(t, file.LineCount(), 10)
t.Run("keep days", func(t *testing.T) {
logger.Config().Retention = strutils.MustParse[*Retention]("3 days")
ExpectEqual(t, logger.Config().Retention.Days, 3)
ExpectEqual(t, logger.Config().Retention.Last, 0)
ExpectNoError(t, logger.Rotate())
ExpectEqual(t, file.LineCount(), 3)
expect.Equal(t, logger.Config().Retention.Days, 3)
expect.Equal(t, logger.Config().Retention.Last, 0)
expect.NoError(t, logger.Rotate())
expect.Equal(t, file.LineCount(), 3)
rotated := string(file.Content())
_ = file.Truncate(0)
for i := range 3 {
@ -81,6 +81,6 @@ func TestRetentionCommonFormat(t *testing.T) {
}
logger.Log(req, resp)
}
ExpectEqual(t, rotated, string(file.Content()))
expect.Equal(t, rotated, string(file.Content()))
})
}

View file

@ -4,38 +4,38 @@ import (
"net/http"
"testing"
. "github.com/yusing/go-proxy/internal/utils/testing"
expect "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestContentTypes(t *testing.T) {
ExpectTrue(t, GetContentType(http.Header{"Content-Type": {"text/html"}}).IsHTML())
ExpectTrue(t, GetContentType(http.Header{"Content-Type": {"text/html; charset=utf-8"}}).IsHTML())
ExpectTrue(t, GetContentType(http.Header{"Content-Type": {"application/xhtml+xml"}}).IsHTML())
ExpectFalse(t, GetContentType(http.Header{"Content-Type": {"text/plain"}}).IsHTML())
expect.True(t, GetContentType(http.Header{"Content-Type": {"text/html"}}).IsHTML())
expect.True(t, GetContentType(http.Header{"Content-Type": {"text/html; charset=utf-8"}}).IsHTML())
expect.True(t, GetContentType(http.Header{"Content-Type": {"application/xhtml+xml"}}).IsHTML())
expect.False(t, GetContentType(http.Header{"Content-Type": {"text/plain"}}).IsHTML())
ExpectTrue(t, GetContentType(http.Header{"Content-Type": {"application/json"}}).IsJSON())
ExpectTrue(t, GetContentType(http.Header{"Content-Type": {"application/json; charset=utf-8"}}).IsJSON())
ExpectFalse(t, GetContentType(http.Header{"Content-Type": {"text/html"}}).IsJSON())
expect.True(t, GetContentType(http.Header{"Content-Type": {"application/json"}}).IsJSON())
expect.True(t, GetContentType(http.Header{"Content-Type": {"application/json; charset=utf-8"}}).IsJSON())
expect.False(t, GetContentType(http.Header{"Content-Type": {"text/html"}}).IsJSON())
ExpectTrue(t, GetContentType(http.Header{"Content-Type": {"text/plain"}}).IsPlainText())
ExpectTrue(t, GetContentType(http.Header{"Content-Type": {"text/plain; charset=utf-8"}}).IsPlainText())
ExpectFalse(t, GetContentType(http.Header{"Content-Type": {"text/html"}}).IsPlainText())
expect.True(t, GetContentType(http.Header{"Content-Type": {"text/plain"}}).IsPlainText())
expect.True(t, GetContentType(http.Header{"Content-Type": {"text/plain; charset=utf-8"}}).IsPlainText())
expect.False(t, GetContentType(http.Header{"Content-Type": {"text/html"}}).IsPlainText())
}
func TestAcceptContentTypes(t *testing.T) {
ExpectTrue(t, GetAccept(http.Header{"Accept": {"text/html", "text/plain"}}).AcceptPlainText())
ExpectTrue(t, GetAccept(http.Header{"Accept": {"text/html", "text/plain; charset=utf-8"}}).AcceptPlainText())
ExpectTrue(t, GetAccept(http.Header{"Accept": {"text/html", "text/plain"}}).AcceptHTML())
ExpectTrue(t, GetAccept(http.Header{"Accept": {"application/json"}}).AcceptJSON())
ExpectTrue(t, GetAccept(http.Header{"Accept": {"*/*"}}).AcceptPlainText())
ExpectTrue(t, GetAccept(http.Header{"Accept": {"*/*"}}).AcceptHTML())
ExpectTrue(t, GetAccept(http.Header{"Accept": {"*/*"}}).AcceptJSON())
ExpectTrue(t, GetAccept(http.Header{"Accept": {"text/*"}}).AcceptPlainText())
ExpectTrue(t, GetAccept(http.Header{"Accept": {"text/*"}}).AcceptHTML())
expect.True(t, GetAccept(http.Header{"Accept": {"text/html", "text/plain"}}).AcceptPlainText())
expect.True(t, GetAccept(http.Header{"Accept": {"text/html", "text/plain; charset=utf-8"}}).AcceptPlainText())
expect.True(t, GetAccept(http.Header{"Accept": {"text/html", "text/plain"}}).AcceptHTML())
expect.True(t, GetAccept(http.Header{"Accept": {"application/json"}}).AcceptJSON())
expect.True(t, GetAccept(http.Header{"Accept": {"*/*"}}).AcceptPlainText())
expect.True(t, GetAccept(http.Header{"Accept": {"*/*"}}).AcceptHTML())
expect.True(t, GetAccept(http.Header{"Accept": {"*/*"}}).AcceptJSON())
expect.True(t, GetAccept(http.Header{"Accept": {"text/*"}}).AcceptPlainText())
expect.True(t, GetAccept(http.Header{"Accept": {"text/*"}}).AcceptHTML())
ExpectFalse(t, GetAccept(http.Header{"Accept": {"text/plain"}}).AcceptHTML())
ExpectFalse(t, GetAccept(http.Header{"Accept": {"text/plain; charset=utf-8"}}).AcceptHTML())
ExpectFalse(t, GetAccept(http.Header{"Accept": {"text/html"}}).AcceptPlainText())
ExpectFalse(t, GetAccept(http.Header{"Accept": {"text/html"}}).AcceptJSON())
ExpectFalse(t, GetAccept(http.Header{"Accept": {"text/*"}}).AcceptJSON())
expect.False(t, GetAccept(http.Header{"Accept": {"text/plain"}}).AcceptHTML())
expect.False(t, GetAccept(http.Header{"Accept": {"text/plain; charset=utf-8"}}).AcceptHTML())
expect.False(t, GetAccept(http.Header{"Accept": {"text/html"}}).AcceptPlainText())
expect.False(t, GetAccept(http.Header{"Accept": {"text/html"}}).AcceptJSON())
expect.False(t, GetAccept(http.Header{"Accept": {"text/*"}}).AcceptJSON())
}

View file

@ -4,7 +4,7 @@ import (
"testing"
"github.com/yusing/go-proxy/internal/net/gphttp/loadbalancer/types"
. "github.com/yusing/go-proxy/internal/utils/testing"
expect "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestRebalance(t *testing.T) {
@ -15,7 +15,7 @@ func TestRebalance(t *testing.T) {
lb.AddServer(types.TestNewServer(0))
}
lb.rebalance()
ExpectEqual(t, lb.sumWeight, maxWeight)
expect.Equal(t, lb.sumWeight, maxWeight)
})
t.Run("less", func(t *testing.T) {
lb := New(new(types.Config))
@ -26,7 +26,7 @@ func TestRebalance(t *testing.T) {
lb.AddServer(types.TestNewServer(float64(maxWeight) * .1))
lb.rebalance()
// t.Logf("%s", U.Must(json.MarshalIndent(lb.pool, "", " ")))
ExpectEqual(t, lb.sumWeight, maxWeight)
expect.Equal(t, lb.sumWeight, maxWeight)
})
t.Run("more", func(t *testing.T) {
lb := New(new(types.Config))
@ -39,6 +39,6 @@ func TestRebalance(t *testing.T) {
lb.AddServer(types.TestNewServer(float64(maxWeight) * .1))
lb.rebalance()
// t.Logf("%s", U.Must(json.MarshalIndent(lb.pool, "", " ")))
ExpectEqual(t, lb.sumWeight, maxWeight)
expect.Equal(t, lb.sumWeight, maxWeight)
})
}

View file

@ -9,7 +9,7 @@ import (
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/utils"
. "github.com/yusing/go-proxy/internal/utils/testing"
expect "github.com/yusing/go-proxy/internal/utils/testing"
)
//go:embed test_data/cidr_whitelist_test.yml
@ -23,32 +23,32 @@ func TestCIDRWhitelistValidation(t *testing.T) {
"allow": []string{"192.168.2.100/32"},
"message": testMessage,
})
ExpectNoError(t, err)
expect.NoError(t, err)
_, err = CIDRWhiteList.New(OptionsRaw{
"allow": []string{"192.168.2.100/32"},
"message": testMessage,
"status": 403,
})
ExpectNoError(t, err)
expect.NoError(t, err)
_, err = CIDRWhiteList.New(OptionsRaw{
"allow": []string{"192.168.2.100/32"},
"message": testMessage,
"status_code": 403,
})
ExpectNoError(t, err)
expect.NoError(t, err)
})
t.Run("missing allow", func(t *testing.T) {
_, err := CIDRWhiteList.New(OptionsRaw{
"message": testMessage,
})
ExpectError(t, utils.ErrValidationError, err)
expect.ErrorIs(t, utils.ErrValidationError, err)
})
t.Run("invalid cidr", func(t *testing.T) {
_, err := CIDRWhiteList.New(OptionsRaw{
"allow": []string{"192.168.2.100/123"},
"message": testMessage,
})
ExpectErrorT[*net.ParseError](t, err)
expect.ErrorT[*net.ParseError](t, err)
})
t.Run("invalid status code", func(t *testing.T) {
_, err := CIDRWhiteList.New(OptionsRaw{
@ -56,14 +56,14 @@ func TestCIDRWhitelistValidation(t *testing.T) {
"status_code": 600,
"message": testMessage,
})
ExpectError(t, utils.ErrValidationError, err)
expect.ErrorIs(t, utils.ErrValidationError, err)
})
}
func TestCIDRWhitelist(t *testing.T) {
errs := gperr.NewBuilder("")
mids := BuildMiddlewaresFromYAML("", testCIDRWhitelistCompose, errs)
ExpectNoError(t, errs.Error())
expect.NoError(t, errs.Error())
deny = mids["deny@file"]
accept = mids["accept@file"]
if deny == nil || accept == nil {
@ -74,9 +74,9 @@ func TestCIDRWhitelist(t *testing.T) {
t.Parallel()
for range 10 {
result, err := newMiddlewareTest(deny, nil)
ExpectNoError(t, err)
ExpectEqual(t, result.ResponseStatus, cidrWhitelistDefaults.StatusCode)
ExpectEqual(t, strings.TrimSpace(string(result.Data)), cidrWhitelistDefaults.Message)
expect.NoError(t, err)
expect.Equal(t, result.ResponseStatus, cidrWhitelistDefaults.StatusCode)
expect.Equal(t, strings.TrimSpace(string(result.Data)), cidrWhitelistDefaults.Message)
}
})
@ -84,8 +84,8 @@ func TestCIDRWhitelist(t *testing.T) {
t.Parallel()
for range 10 {
result, err := newMiddlewareTest(accept, nil)
ExpectNoError(t, err)
ExpectEqual(t, result.ResponseStatus, http.StatusOK)
expect.NoError(t, err)
expect.Equal(t, result.ResponseStatus, http.StatusOK)
}
})
}

View file

@ -7,7 +7,7 @@ import (
"github.com/yusing/go-proxy/pkg/json"
"github.com/yusing/go-proxy/internal/gperr"
. "github.com/yusing/go-proxy/internal/utils/testing"
expect "github.com/yusing/go-proxy/internal/utils/testing"
)
//go:embed test_data/middleware_compose.yml
@ -16,7 +16,7 @@ var testMiddlewareCompose []byte
func TestBuild(t *testing.T) {
errs := gperr.NewBuilder("")
middlewares := BuildMiddlewaresFromYAML("", testMiddlewareCompose, errs)
ExpectNoError(t, errs.Error())
expect.NoError(t, errs.Error())
json.Marshal(middlewares)
// t.Log(string(data))
// TODO: test

View file

@ -6,7 +6,7 @@ import (
"strings"
"testing"
. "github.com/yusing/go-proxy/internal/utils/testing"
expect "github.com/yusing/go-proxy/internal/utils/testing"
)
type testPriority struct {
@ -28,10 +28,10 @@ func TestMiddlewarePriority(t *testing.T) {
"priority": p,
"value": i,
})
ExpectNoError(t, err)
expect.NoError(t, err)
chain[i] = mid
}
res, err := newMiddlewaresTest(chain, nil)
ExpectNoError(t, err)
ExpectEqual(t, strings.Join(res.ResponseHeaders["Test-Value"], ","), "3,0,1,2")
expect.NoError(t, err)
expect.Equal(t, strings.Join(res.ResponseHeaders["Test-Value"], ","), "3,0,1,2")
}

View file

@ -8,7 +8,7 @@ import (
"slices"
"testing"
. "github.com/yusing/go-proxy/internal/utils/testing"
expect "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestModifyRequest(t *testing.T) {
@ -44,15 +44,15 @@ func TestModifyRequest(t *testing.T) {
t.Run("set_options", func(t *testing.T) {
mr, err := ModifyRequest.New(opts)
ExpectNoError(t, err)
ExpectEqual(t, mr.impl.(*modifyRequest).SetHeaders, opts["set_headers"].(map[string]string))
ExpectEqual(t, mr.impl.(*modifyRequest).AddHeaders, opts["add_headers"].(map[string]string))
ExpectEqual(t, mr.impl.(*modifyRequest).HideHeaders, opts["hide_headers"].([]string))
expect.NoError(t, err)
expect.Equal(t, mr.impl.(*modifyRequest).SetHeaders, opts["set_headers"].(map[string]string))
expect.Equal(t, mr.impl.(*modifyRequest).AddHeaders, opts["add_headers"].(map[string]string))
expect.Equal(t, mr.impl.(*modifyRequest).HideHeaders, opts["hide_headers"].([]string))
})
t.Run("request_headers", func(t *testing.T) {
reqURL := Must(url.Parse("https://my.app/?arg_1=b"))
upstreamURL := Must(url.Parse("http://test.example.com"))
reqURL := expect.Must(url.Parse("https://my.app/?arg_1=b"))
upstreamURL := expect.Must(url.Parse("http://test.example.com"))
result, err := newMiddlewareTest(ModifyRequest, &testArgs{
middlewareOpt: opts,
reqURL: reqURL,
@ -62,38 +62,38 @@ func TestModifyRequest(t *testing.T) {
"Content-Type": []string{"application/json"},
},
})
ExpectNoError(t, err)
ExpectEqual(t, result.RequestHeaders.Get("User-Agent"), "go-proxy/v0.5.0")
ExpectEqual(t, result.RequestHeaders.Get("Host"), "test.example.com")
ExpectTrue(t, slices.Contains(result.RequestHeaders.Values("Accept-Encoding"), "test-value"))
ExpectEqual(t, result.RequestHeaders.Get("Accept"), "")
expect.NoError(t, err)
expect.Equal(t, result.RequestHeaders.Get("User-Agent"), "go-proxy/v0.5.0")
expect.Equal(t, result.RequestHeaders.Get("Host"), "test.example.com")
expect.True(t, slices.Contains(result.RequestHeaders.Values("Accept-Encoding"), "test-value"))
expect.Equal(t, result.RequestHeaders.Get("Accept"), "")
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Method"), "GET")
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Scheme"), reqURL.Scheme)
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Host"), reqURL.Hostname())
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Port"), reqURL.Port())
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Addr"), reqURL.Host)
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Path"), reqURL.Path)
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Query"), reqURL.RawQuery)
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Url"), reqURL.String())
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Uri"), reqURL.RequestURI())
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Content-Type"), "application/json")
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Content-Length"), "100")
expect.Equal(t, result.RequestHeaders.Get("X-Test-Req-Method"), "GET")
expect.Equal(t, result.RequestHeaders.Get("X-Test-Req-Scheme"), reqURL.Scheme)
expect.Equal(t, result.RequestHeaders.Get("X-Test-Req-Host"), reqURL.Hostname())
expect.Equal(t, result.RequestHeaders.Get("X-Test-Req-Port"), reqURL.Port())
expect.Equal(t, result.RequestHeaders.Get("X-Test-Req-Addr"), reqURL.Host)
expect.Equal(t, result.RequestHeaders.Get("X-Test-Req-Path"), reqURL.Path)
expect.Equal(t, result.RequestHeaders.Get("X-Test-Req-Query"), reqURL.RawQuery)
expect.Equal(t, result.RequestHeaders.Get("X-Test-Req-Url"), reqURL.String())
expect.Equal(t, result.RequestHeaders.Get("X-Test-Req-Uri"), reqURL.RequestURI())
expect.Equal(t, result.RequestHeaders.Get("X-Test-Req-Content-Type"), "application/json")
expect.Equal(t, result.RequestHeaders.Get("X-Test-Req-Content-Length"), "100")
remoteHost, remotePort, _ := net.SplitHostPort(result.RemoteAddr)
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Remote-Host"), remoteHost)
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Remote-Port"), remotePort)
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Remote-Addr"), result.RemoteAddr)
expect.Equal(t, result.RequestHeaders.Get("X-Test-Remote-Host"), remoteHost)
expect.Equal(t, result.RequestHeaders.Get("X-Test-Remote-Port"), remotePort)
expect.Equal(t, result.RequestHeaders.Get("X-Test-Remote-Addr"), result.RemoteAddr)
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Upstream-Scheme"), upstreamURL.Scheme)
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Upstream-Host"), upstreamURL.Hostname())
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Upstream-Port"), upstreamURL.Port())
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Upstream-Addr"), upstreamURL.Host)
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Upstream-Url"), upstreamURL.String())
expect.Equal(t, result.RequestHeaders.Get("X-Test-Upstream-Scheme"), upstreamURL.Scheme)
expect.Equal(t, result.RequestHeaders.Get("X-Test-Upstream-Host"), upstreamURL.Hostname())
expect.Equal(t, result.RequestHeaders.Get("X-Test-Upstream-Port"), upstreamURL.Port())
expect.Equal(t, result.RequestHeaders.Get("X-Test-Upstream-Addr"), upstreamURL.Host)
expect.Equal(t, result.RequestHeaders.Get("X-Test-Upstream-Url"), upstreamURL.String())
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Header-Content-Type"), "application/json")
expect.Equal(t, result.RequestHeaders.Get("X-Test-Header-Content-Type"), "application/json")
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Arg-Arg_1"), "b")
expect.Equal(t, result.RequestHeaders.Get("X-Test-Arg-Arg_1"), "b")
})
t.Run("add_prefix", func(t *testing.T) {
@ -128,8 +128,8 @@ func TestModifyRequest(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
reqURL := Must(url.Parse("https://my.app" + tt.path))
upstreamURL := Must(url.Parse(tt.upstreamURL))
reqURL := expect.Must(url.Parse("https://my.app" + tt.path))
upstreamURL := expect.Must(url.Parse(tt.upstreamURL))
opts["add_prefix"] = tt.addPrefix
result, err := newMiddlewareTest(ModifyRequest, &testArgs{
@ -137,8 +137,8 @@ func TestModifyRequest(t *testing.T) {
reqURL: reqURL,
upstreamURL: upstreamURL,
})
ExpectNoError(t, err)
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Path"), tt.expectedPath)
expect.NoError(t, err)
expect.Equal(t, result.RequestHeaders.Get("X-Test-Req-Path"), tt.expectedPath)
})
}
})

View file

@ -8,7 +8,7 @@ import (
"slices"
"testing"
. "github.com/yusing/go-proxy/internal/utils/testing"
expect "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestModifyResponse(t *testing.T) {
@ -47,15 +47,15 @@ func TestModifyResponse(t *testing.T) {
t.Run("set_options", func(t *testing.T) {
mr, err := ModifyResponse.New(opts)
ExpectNoError(t, err)
ExpectEqual(t, mr.impl.(*modifyResponse).SetHeaders, opts["set_headers"].(map[string]string))
ExpectEqual(t, mr.impl.(*modifyResponse).AddHeaders, opts["add_headers"].(map[string]string))
ExpectEqual(t, mr.impl.(*modifyResponse).HideHeaders, opts["hide_headers"].([]string))
expect.NoError(t, err)
expect.Equal(t, mr.impl.(*modifyResponse).SetHeaders, opts["set_headers"].(map[string]string))
expect.Equal(t, mr.impl.(*modifyResponse).AddHeaders, opts["add_headers"].(map[string]string))
expect.Equal(t, mr.impl.(*modifyResponse).HideHeaders, opts["hide_headers"].([]string))
})
t.Run("response_headers", func(t *testing.T) {
reqURL := Must(url.Parse("https://my.app/?arg_1=b"))
upstreamURL := Must(url.Parse("http://test.example.com"))
reqURL := expect.Must(url.Parse("https://my.app/?arg_1=b"))
upstreamURL := expect.Must(url.Parse("http://test.example.com"))
result, err := newMiddlewareTest(ModifyResponse, &testArgs{
middlewareOpt: opts,
reqURL: reqURL,
@ -70,39 +70,39 @@ func TestModifyResponse(t *testing.T) {
respBody: bytes.Repeat([]byte("a"), 50),
respStatus: http.StatusOK,
})
ExpectNoError(t, err)
ExpectTrue(t, slices.Contains(result.ResponseHeaders.Values("Accept-Encoding"), "test-value"))
ExpectEqual(t, result.ResponseHeaders.Get("Accept"), "")
expect.NoError(t, err)
expect.True(t, slices.Contains(result.ResponseHeaders.Values("Accept-Encoding"), "test-value"))
expect.Equal(t, result.ResponseHeaders.Get("Accept"), "")
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Resp-Status"), "200")
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Resp-Content-Type"), "application/json")
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Resp-Content-Length"), "50")
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Resp-Header-Content-Type"), "application/json")
expect.Equal(t, result.ResponseHeaders.Get("X-Test-Resp-Status"), "200")
expect.Equal(t, result.ResponseHeaders.Get("X-Test-Resp-Content-Type"), "application/json")
expect.Equal(t, result.ResponseHeaders.Get("X-Test-Resp-Content-Length"), "50")
expect.Equal(t, result.ResponseHeaders.Get("X-Test-Resp-Header-Content-Type"), "application/json")
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Method"), http.MethodGet)
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Scheme"), reqURL.Scheme)
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Host"), reqURL.Hostname())
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Port"), reqURL.Port())
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Addr"), reqURL.Host)
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Path"), reqURL.Path)
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Query"), reqURL.RawQuery)
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Url"), reqURL.String())
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Uri"), reqURL.RequestURI())
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Content-Type"), "application/json")
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Content-Length"), "100")
expect.Equal(t, result.ResponseHeaders.Get("X-Test-Req-Method"), http.MethodGet)
expect.Equal(t, result.ResponseHeaders.Get("X-Test-Req-Scheme"), reqURL.Scheme)
expect.Equal(t, result.ResponseHeaders.Get("X-Test-Req-Host"), reqURL.Hostname())
expect.Equal(t, result.ResponseHeaders.Get("X-Test-Req-Port"), reqURL.Port())
expect.Equal(t, result.ResponseHeaders.Get("X-Test-Req-Addr"), reqURL.Host)
expect.Equal(t, result.ResponseHeaders.Get("X-Test-Req-Path"), reqURL.Path)
expect.Equal(t, result.ResponseHeaders.Get("X-Test-Req-Query"), reqURL.RawQuery)
expect.Equal(t, result.ResponseHeaders.Get("X-Test-Req-Url"), reqURL.String())
expect.Equal(t, result.ResponseHeaders.Get("X-Test-Req-Uri"), reqURL.RequestURI())
expect.Equal(t, result.ResponseHeaders.Get("X-Test-Req-Content-Type"), "application/json")
expect.Equal(t, result.ResponseHeaders.Get("X-Test-Req-Content-Length"), "100")
remoteHost, remotePort, _ := net.SplitHostPort(result.RemoteAddr)
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Remote-Host"), remoteHost)
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Remote-Port"), remotePort)
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Remote-Addr"), result.RemoteAddr)
expect.Equal(t, result.ResponseHeaders.Get("X-Test-Remote-Host"), remoteHost)
expect.Equal(t, result.ResponseHeaders.Get("X-Test-Remote-Port"), remotePort)
expect.Equal(t, result.ResponseHeaders.Get("X-Test-Remote-Addr"), result.RemoteAddr)
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Upstream-Scheme"), upstreamURL.Scheme)
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Upstream-Host"), upstreamURL.Hostname())
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Upstream-Port"), upstreamURL.Port())
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Upstream-Addr"), upstreamURL.Host)
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Upstream-Url"), upstreamURL.String())
expect.Equal(t, result.ResponseHeaders.Get("X-Test-Upstream-Scheme"), upstreamURL.Scheme)
expect.Equal(t, result.ResponseHeaders.Get("X-Test-Upstream-Host"), upstreamURL.Hostname())
expect.Equal(t, result.ResponseHeaders.Get("X-Test-Upstream-Port"), upstreamURL.Port())
expect.Equal(t, result.ResponseHeaders.Get("X-Test-Upstream-Addr"), upstreamURL.Host)
expect.Equal(t, result.ResponseHeaders.Get("X-Test-Upstream-Url"), upstreamURL.String())
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Header-Content-Type"), "application/json")
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Arg-Arg_1"), "b")
expect.Equal(t, result.ResponseHeaders.Get("X-Test-Header-Content-Type"), "application/json")
expect.Equal(t, result.ResponseHeaders.Get("X-Test-Arg-Arg_1"), "b")
})
}

View file

@ -4,7 +4,7 @@ import (
"net/http"
"testing"
. "github.com/yusing/go-proxy/internal/utils/testing"
expect "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestRateLimit(t *testing.T) {
@ -15,13 +15,13 @@ func TestRateLimit(t *testing.T) {
}
rl, err := RateLimiter.New(opts)
ExpectNoError(t, err)
expect.NoError(t, err)
for range 10 {
result, err := newMiddlewareTest(rl, nil)
ExpectNoError(t, err)
ExpectEqual(t, result.ResponseStatus, http.StatusOK)
expect.NoError(t, err)
expect.Equal(t, result.ResponseStatus, http.StatusOK)
}
result, err := newMiddlewareTest(rl, nil)
ExpectNoError(t, err)
ExpectEqual(t, result.ResponseStatus, http.StatusTooManyRequests)
expect.NoError(t, err)
expect.Equal(t, result.ResponseStatus, http.StatusTooManyRequests)
}

View file

@ -7,7 +7,7 @@ import (
"testing"
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
. "github.com/yusing/go-proxy/internal/utils/testing"
expect "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestSetRealIPOpts(t *testing.T) {
@ -40,11 +40,11 @@ func TestSetRealIPOpts(t *testing.T) {
}
ri, err := RealIP.New(opts)
ExpectNoError(t, err)
ExpectEqual(t, ri.impl.(*realIP).Header, optExpected.Header)
ExpectEqual(t, ri.impl.(*realIP).Recursive, optExpected.Recursive)
expect.NoError(t, err)
expect.Equal(t, ri.impl.(*realIP).Header, optExpected.Header)
expect.Equal(t, ri.impl.(*realIP).Recursive, optExpected.Recursive)
for i, CIDR := range ri.impl.(*realIP).From {
ExpectEqual(t, CIDR.String(), optExpected.From[i].String())
expect.Equal(t, CIDR.String(), optExpected.From[i].String())
}
}
@ -61,16 +61,16 @@ func TestSetRealIP(t *testing.T) {
"set_headers": map[string]string{testHeader: testRealIP},
}
realip, err := RealIP.New(opts)
ExpectNoError(t, err)
expect.NoError(t, err)
mr, err := ModifyRequest.New(optsMr)
ExpectNoError(t, err)
expect.NoError(t, err)
mid := NewMiddlewareChain("test", []*Middleware{mr, realip})
result, err := newMiddlewareTest(mid, nil)
ExpectNoError(t, err)
expect.NoError(t, err)
t.Log(traces)
ExpectEqual(t, result.ResponseStatus, http.StatusOK)
ExpectEqual(t, strings.Split(result.RemoteAddr, ":")[0], testRealIP)
expect.Equal(t, result.ResponseStatus, http.StatusOK)
expect.Equal(t, strings.Split(result.RemoteAddr, ":")[0], testRealIP)
}

View file

@ -5,22 +5,22 @@ import (
"net/url"
"testing"
. "github.com/yusing/go-proxy/internal/utils/testing"
expect "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestRedirectToHTTPs(t *testing.T) {
result, err := newMiddlewareTest(RedirectHTTP, &testArgs{
reqURL: Must(url.Parse("http://example.com")),
reqURL: expect.Must(url.Parse("http://example.com")),
})
ExpectNoError(t, err)
ExpectEqual(t, result.ResponseStatus, http.StatusPermanentRedirect)
ExpectEqual(t, result.ResponseHeaders.Get("Location"), "https://example.com")
expect.NoError(t, err)
expect.Equal(t, result.ResponseStatus, http.StatusPermanentRedirect)
expect.Equal(t, result.ResponseHeaders.Get("Location"), "https://example.com")
}
func TestNoRedirect(t *testing.T) {
result, err := newMiddlewareTest(RedirectHTTP, &testArgs{
reqURL: Must(url.Parse("https://example.com")),
reqURL: expect.Must(url.Parse("https://example.com")),
})
ExpectNoError(t, err)
ExpectEqual(t, result.ResponseStatus, http.StatusOK)
expect.NoError(t, err)
expect.Equal(t, result.ResponseStatus, http.StatusOK)
}

View file

@ -8,11 +8,11 @@ import (
"net/http/httptest"
"net/url"
"github.com/yusing/go-proxy/pkg/json"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy"
. "github.com/yusing/go-proxy/internal/utils/testing"
expect "github.com/yusing/go-proxy/internal/utils/testing"
"github.com/yusing/go-proxy/pkg/json"
)
//go:embed test_data/sample_headers.json
@ -96,13 +96,13 @@ type testArgs struct {
func (args *testArgs) setDefaults() {
if args.reqURL == nil {
args.reqURL = Must(url.Parse("https://example.com"))
args.reqURL = expect.Must(url.Parse("https://example.com"))
}
if args.reqMethod == "" {
args.reqMethod = http.MethodGet
}
if args.upstreamURL == nil {
args.upstreamURL = Must(url.Parse("https://10.0.0.1:8443")) // dummy url, no actual effect
args.upstreamURL = expect.Must(url.Parse("https://10.0.0.1:8443")) // dummy url, no actual effect
}
if args.respHeaders == nil {
args.respHeaders = http.Header{}

View file

@ -7,7 +7,7 @@ import (
"os"
"testing"
"github.com/stretchr/testify/require"
expect "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestPing(t *testing.T) {
@ -17,7 +17,7 @@ func TestPing(t *testing.T) {
if errors.Is(err, os.ErrPermission) {
t.Skip("permission denied")
}
require.NoError(t, err)
require.True(t, ok)
expect.NoError(t, err)
expect.True(t, ok)
})
}

View file

@ -5,7 +5,7 @@ import (
"testing"
"github.com/yusing/go-proxy/internal/utils"
. "github.com/yusing/go-proxy/internal/utils/testing"
expect "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestNotificationConfig(t *testing.T) {
@ -152,11 +152,11 @@ func TestNotificationConfig(t *testing.T) {
provider := tt.cfg["provider"]
err := utils.MapUnmarshalValidate(tt.cfg, &cfg)
if tt.wantErr {
ExpectHasError(t, err)
expect.HasError(t, err)
} else {
ExpectNoError(t, err)
ExpectEqual(t, provider.(string), cfg.ProviderName)
ExpectEqual(t, cfg.Provider, tt.expected)
expect.NoError(t, err)
expect.Equal(t, provider.(string), cfg.ProviderName)
expect.Equal(t, cfg.Provider, tt.expected)
}
})
}

View file

@ -10,7 +10,7 @@ import (
"path/filepath"
"testing"
. "github.com/yusing/go-proxy/internal/utils/testing"
expect "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestPathTraversalAttack(t *testing.T) {
@ -108,7 +108,7 @@ func TestPathTraversalAttack(t *testing.T) {
t.Errorf("Expected status 404 or 400, got %d in url %s", resp.StatusCode, u.String())
}
u = Must(url.Parse(ts.URL + "/" + p))
u = expect.Must(url.Parse(ts.URL + "/" + p))
resp, err = http.DefaultClient.Do(&http.Request{
Method: http.MethodGet,
URL: u,

View file

@ -3,10 +3,10 @@ package provider
import (
"testing"
"github.com/docker/docker/api/types"
"github.com/docker/docker/api/types/container"
"github.com/goccy/go-yaml"
"github.com/yusing/go-proxy/internal/docker"
. "github.com/yusing/go-proxy/internal/utils/testing"
expect "github.com/yusing/go-proxy/internal/utils/testing"
_ "embed"
)
@ -18,19 +18,19 @@ func TestParseDockerLabels(t *testing.T) {
var provider DockerProvider
labels := make(map[string]string)
ExpectNoError(t, yaml.Unmarshal(testDockerLabelsYAML, &labels))
expect.NoError(t, yaml.Unmarshal(testDockerLabelsYAML, &labels))
routes, err := provider.routesFromContainerLabels(
docker.FromDocker(&types.Container{
docker.FromDocker(&container.Summary{
Names: []string{"container"},
Labels: labels,
State: "running",
Ports: []types.Port{
Ports: []container.Port{
{Type: "tcp", PrivatePort: 1234, PublicPort: 1234},
},
}, "/var/run/docker.sock"),
)
ExpectNoError(t, err)
ExpectTrue(t, routes.Contains("app"))
ExpectTrue(t, routes.Contains("app1"))
expect.NoError(t, err)
expect.True(t, routes.Contains("app"))
expect.True(t, routes.Contains("app1"))
}

View file

@ -10,7 +10,7 @@ import (
D "github.com/yusing/go-proxy/internal/docker"
"github.com/yusing/go-proxy/internal/route"
T "github.com/yusing/go-proxy/internal/route/types"
. "github.com/yusing/go-proxy/internal/utils/testing"
expect "github.com/yusing/go-proxy/internal/utils/testing"
)
var dummyNames = []string{"/a"}
@ -30,7 +30,7 @@ func makeRoutes(cont *container.Summary, dockerHostIP ...string) route.Routes {
}
cont.ID = "test"
p.name = "test"
routes := Must(p.routesFromContainerLabels(D.FromDocker(cont, host)))
routes := expect.Must(p.routesFromContainerLabels(D.FromDocker(cont, host)))
for _, r := range routes {
r.Finalize()
}
@ -39,7 +39,7 @@ func makeRoutes(cont *container.Summary, dockerHostIP ...string) route.Routes {
func TestExplicitOnly(t *testing.T) {
p := NewDockerProvider("a!", "")
ExpectTrue(t, p.IsExplicitOnly())
expect.True(t, p.IsExplicitOnly())
}
func TestApplyLabel(t *testing.T) {
@ -87,46 +87,46 @@ func TestApplyLabel(t *testing.T) {
})
a, ok := entries["a"]
ExpectTrue(t, ok)
expect.True(t, ok)
b, ok := entries["b"]
ExpectTrue(t, ok)
expect.True(t, ok)
ExpectEqual(t, a.Scheme, "https")
ExpectEqual(t, b.Scheme, "https")
expect.Equal(t, a.Scheme, "https")
expect.Equal(t, b.Scheme, "https")
ExpectEqual(t, a.Host, "app")
ExpectEqual(t, b.Host, "app")
expect.Equal(t, a.Host, "app")
expect.Equal(t, b.Host, "app")
ExpectEqual(t, a.Port.Proxy, 4567)
ExpectEqual(t, b.Port.Proxy, 4567)
expect.Equal(t, a.Port.Proxy, 4567)
expect.Equal(t, b.Port.Proxy, 4567)
ExpectTrue(t, a.NoTLSVerify)
ExpectTrue(t, b.NoTLSVerify)
expect.True(t, a.NoTLSVerify)
expect.True(t, b.NoTLSVerify)
ExpectEqual(t, a.PathPatterns, pathPatternsExpect)
ExpectEqual(t, len(b.PathPatterns), 0)
expect.Equal(t, a.PathPatterns, pathPatternsExpect)
expect.Equal(t, len(b.PathPatterns), 0)
ExpectEqual(t, a.Middlewares, middlewaresExpect)
ExpectEqual(t, len(b.Middlewares), 0)
expect.Equal(t, a.Middlewares, middlewaresExpect)
expect.Equal(t, len(b.Middlewares), 0)
ExpectEqual(t, a.Container.IdlewatcherConfig.IdleTimeout, 0)
ExpectEqual(t, b.Container.IdlewatcherConfig.IdleTimeout, 0)
ExpectEqual(t, a.Container.IdlewatcherConfig.StopTimeout, time.Hour)
ExpectEqual(t, b.Container.IdlewatcherConfig.StopTimeout, time.Hour)
ExpectEqual(t, a.Container.IdlewatcherConfig.StopMethod, "stop")
ExpectEqual(t, b.Container.IdlewatcherConfig.StopMethod, "stop")
ExpectEqual(t, a.Container.IdlewatcherConfig.WakeTimeout, 10*time.Second)
ExpectEqual(t, b.Container.IdlewatcherConfig.WakeTimeout, 10*time.Second)
ExpectEqual(t, a.Container.IdlewatcherConfig.StopSignal, "SIGTERM")
ExpectEqual(t, b.Container.IdlewatcherConfig.StopSignal, "SIGTERM")
expect.Equal(t, a.Container.IdlewatcherConfig.IdleTimeout, 0)
expect.Equal(t, b.Container.IdlewatcherConfig.IdleTimeout, 0)
expect.Equal(t, a.Container.IdlewatcherConfig.StopTimeout, time.Hour)
expect.Equal(t, b.Container.IdlewatcherConfig.StopTimeout, time.Hour)
expect.Equal(t, a.Container.IdlewatcherConfig.StopMethod, "stop")
expect.Equal(t, b.Container.IdlewatcherConfig.StopMethod, "stop")
expect.Equal(t, a.Container.IdlewatcherConfig.WakeTimeout, 10*time.Second)
expect.Equal(t, b.Container.IdlewatcherConfig.WakeTimeout, 10*time.Second)
expect.Equal(t, a.Container.IdlewatcherConfig.StopSignal, "SIGTERM")
expect.Equal(t, b.Container.IdlewatcherConfig.StopSignal, "SIGTERM")
ExpectEqual(t, a.Homepage.Show, true)
ExpectEqual(t, a.Homepage.Icon.Value, "png/adguard-home.png")
ExpectEqual(t, a.Homepage.Icon.Extra.FileType, "png")
ExpectEqual(t, a.Homepage.Icon.Extra.Name, "adguard-home")
expect.Equal(t, a.Homepage.Show, true)
expect.Equal(t, a.Homepage.Icon.Value, "png/adguard-home.png")
expect.Equal(t, a.Homepage.Icon.Extra.FileType, "png")
expect.Equal(t, a.Homepage.Icon.Extra.Name, "adguard-home")
ExpectEqual(t, a.HealthCheck.Path, "/ping")
ExpectEqual(t, a.HealthCheck.Interval, 10*time.Second)
expect.Equal(t, a.HealthCheck.Path, "/ping")
expect.Equal(t, a.HealthCheck.Interval, 10*time.Second)
}
func TestApplyLabelWithAlias(t *testing.T) {
@ -142,18 +142,18 @@ func TestApplyLabelWithAlias(t *testing.T) {
},
})
a, ok := entries["a"]
ExpectTrue(t, ok)
expect.True(t, ok)
b, ok := entries["b"]
ExpectTrue(t, ok)
expect.True(t, ok)
c, ok := entries["c"]
ExpectTrue(t, ok)
expect.True(t, ok)
ExpectEqual(t, a.Scheme, "http")
ExpectEqual(t, a.Port.Proxy, 3333)
ExpectEqual(t, a.NoTLSVerify, true)
ExpectEqual(t, b.Scheme, "http")
ExpectEqual(t, b.Port.Proxy, 1234)
ExpectEqual(t, c.Scheme, "https")
expect.Equal(t, a.Scheme, "http")
expect.Equal(t, a.Port.Proxy, 3333)
expect.Equal(t, a.NoTLSVerify, true)
expect.Equal(t, b.Scheme, "http")
expect.Equal(t, b.Port.Proxy, 1234)
expect.Equal(t, c.Scheme, "https")
}
func TestApplyLabelWithRef(t *testing.T) {
@ -170,18 +170,18 @@ func TestApplyLabelWithRef(t *testing.T) {
},
})
a, ok := entries["a"]
ExpectTrue(t, ok)
expect.True(t, ok)
b, ok := entries["b"]
ExpectTrue(t, ok)
expect.True(t, ok)
c, ok := entries["c"]
ExpectTrue(t, ok)
expect.True(t, ok)
ExpectEqual(t, a.Scheme, "http")
ExpectEqual(t, a.Host, "localhost")
ExpectEqual(t, a.Port.Proxy, 4444)
ExpectEqual(t, b.Port.Proxy, 9999)
ExpectEqual(t, c.Scheme, "https")
ExpectEqual(t, c.Port.Proxy, 1111)
expect.Equal(t, a.Scheme, "http")
expect.Equal(t, a.Host, "localhost")
expect.Equal(t, a.Port.Proxy, 4444)
expect.Equal(t, b.Port.Proxy, 9999)
expect.Equal(t, c.Scheme, "https")
expect.Equal(t, c.Port.Proxy, 1111)
}
func TestApplyLabelWithRefIndexError(t *testing.T) {
@ -197,7 +197,7 @@ func TestApplyLabelWithRefIndexError(t *testing.T) {
}, "")
var p DockerProvider
_, err := p.routesFromContainerLabels(c)
ExpectError(t, ErrAliasRefIndexOutOfRange, err)
expect.ErrorIs(t, ErrAliasRefIndexOutOfRange, err)
c = D.FromDocker(&container.Summary{
Names: dummyNames,
@ -208,7 +208,7 @@ func TestApplyLabelWithRefIndexError(t *testing.T) {
},
}, "")
_, err = p.routesFromContainerLabels(c)
ExpectError(t, ErrAliasRefIndexOutOfRange, err)
expect.ErrorIs(t, ErrAliasRefIndexOutOfRange, err)
}
func TestDynamicAliases(t *testing.T) {
@ -224,14 +224,14 @@ func TestDynamicAliases(t *testing.T) {
entries := makeRoutes(c)
r, ok := entries["app1"]
ExpectTrue(t, ok)
ExpectEqual(t, r.Scheme, "http")
ExpectEqual(t, r.Port.Proxy, 1234)
expect.True(t, ok)
expect.Equal(t, r.Scheme, "http")
expect.Equal(t, r.Port.Proxy, 1234)
r, ok = entries["app1_backend"]
ExpectTrue(t, ok)
ExpectEqual(t, r.Scheme, "http")
ExpectEqual(t, r.Port.Proxy, 5678)
expect.True(t, ok)
expect.Equal(t, r.Scheme, "http")
expect.Equal(t, r.Port.Proxy, 5678)
}
func TestDisableHealthCheck(t *testing.T) {
@ -244,24 +244,24 @@ func TestDisableHealthCheck(t *testing.T) {
},
}
r, ok := makeRoutes(c)["a"]
ExpectTrue(t, ok)
ExpectFalse(t, r.UseHealthCheck())
expect.True(t, ok)
expect.False(t, r.UseHealthCheck())
}
func TestPublicIPLocalhost(t *testing.T) {
c := &container.Summary{Names: dummyNames, State: "running"}
r, ok := makeRoutes(c)["a"]
ExpectTrue(t, ok)
ExpectEqual(t, r.Container.PublicHostname, "127.0.0.1")
ExpectEqual(t, r.Host, r.Container.PublicHostname)
expect.True(t, ok)
expect.Equal(t, r.Container.PublicHostname, "127.0.0.1")
expect.Equal(t, r.Host, r.Container.PublicHostname)
}
func TestPublicIPRemote(t *testing.T) {
c := &container.Summary{Names: dummyNames, State: "running"}
raw, ok := makeRoutes(c, testIP)["a"]
ExpectTrue(t, ok)
ExpectEqual(t, raw.Container.PublicHostname, testIP)
ExpectEqual(t, raw.Host, raw.Container.PublicHostname)
expect.True(t, ok)
expect.Equal(t, raw.Container.PublicHostname, testIP)
expect.Equal(t, raw.Host, raw.Container.PublicHostname)
}
func TestPrivateIPLocalhost(t *testing.T) {
@ -276,9 +276,9 @@ func TestPrivateIPLocalhost(t *testing.T) {
},
}
r, ok := makeRoutes(c)["a"]
ExpectTrue(t, ok)
ExpectEqual(t, r.Container.PrivateHostname, testDockerIP)
ExpectEqual(t, r.Host, r.Container.PrivateHostname)
expect.True(t, ok)
expect.Equal(t, r.Container.PrivateHostname, testDockerIP)
expect.Equal(t, r.Host, r.Container.PrivateHostname)
}
func TestPrivateIPRemote(t *testing.T) {
@ -294,10 +294,10 @@ func TestPrivateIPRemote(t *testing.T) {
},
}
r, ok := makeRoutes(c, testIP)["a"]
ExpectTrue(t, ok)
ExpectEqual(t, r.Container.PrivateHostname, "")
ExpectEqual(t, r.Container.PublicHostname, testIP)
ExpectEqual(t, r.Host, r.Container.PublicHostname)
expect.True(t, ok)
expect.Equal(t, r.Container.PrivateHostname, "")
expect.Equal(t, r.Container.PublicHostname, testIP)
expect.Equal(t, r.Host, r.Container.PublicHostname)
}
func TestStreamDefaultValues(t *testing.T) {
@ -321,22 +321,22 @@ func TestStreamDefaultValues(t *testing.T) {
t.Run("local", func(t *testing.T) {
r, ok := makeRoutes(cont)["a"]
ExpectTrue(t, ok)
ExpectNoError(t, r.Validate())
ExpectEqual(t, r.Scheme, T.Scheme("udp"))
ExpectEqual(t, r.TargetURL().Hostname(), privIP)
ExpectEqual(t, r.Port.Listening, 0)
ExpectEqual(t, r.Port.Proxy, int(privPort))
expect.True(t, ok)
expect.NoError(t, r.Validate())
expect.Equal(t, r.Scheme, T.Scheme("udp"))
expect.Equal(t, r.TargetURL().Hostname(), privIP)
expect.Equal(t, r.Port.Listening, 0)
expect.Equal(t, r.Port.Proxy, int(privPort))
})
t.Run("remote", func(t *testing.T) {
r, ok := makeRoutes(cont, testIP)["a"]
ExpectTrue(t, ok)
ExpectNoError(t, r.Validate())
ExpectEqual(t, r.Scheme, T.Scheme("udp"))
ExpectEqual(t, r.TargetURL().Hostname(), testIP)
ExpectEqual(t, r.Port.Listening, 0)
ExpectEqual(t, r.Port.Proxy, int(pubPort))
expect.True(t, ok)
expect.NoError(t, r.Validate())
expect.Equal(t, r.Scheme, T.Scheme("udp"))
expect.Equal(t, r.TargetURL().Hostname(), testIP)
expect.Equal(t, r.Port.Listening, 0)
expect.Equal(t, r.Port.Proxy, int(pubPort))
})
}
@ -349,8 +349,8 @@ func TestExplicitExclude(t *testing.T) {
"proxy.a.no_tls_verify": "true",
},
}, "")["a"]
ExpectTrue(t, ok)
ExpectTrue(t, r.ShouldExclude())
expect.True(t, ok)
expect.True(t, r.ShouldExclude())
}
func TestImplicitExcludeDatabase(t *testing.T) {
@ -361,8 +361,8 @@ func TestImplicitExcludeDatabase(t *testing.T) {
{Source: "/data", Destination: "/var/lib/postgresql/data"},
},
})["a"]
ExpectTrue(t, ok)
ExpectTrue(t, r.ShouldExclude())
expect.True(t, ok)
expect.True(t, r.ShouldExclude())
})
t.Run("exposed port detection", func(t *testing.T) {
r, ok := makeRoutes(&container.Summary{
@ -371,7 +371,7 @@ func TestImplicitExcludeDatabase(t *testing.T) {
{Type: "tcp", PrivatePort: 5432, PublicPort: 5432},
},
})["a"]
ExpectTrue(t, ok)
ExpectTrue(t, r.ShouldExclude())
expect.True(t, ok)
expect.True(t, r.ShouldExclude())
})
}

View file

@ -5,7 +5,7 @@ import (
_ "embed"
. "github.com/yusing/go-proxy/internal/utils/testing"
expect "github.com/yusing/go-proxy/internal/utils/testing"
)
//go:embed all_fields.yaml
@ -13,5 +13,5 @@ var testAllFieldsYAML []byte
func TestFile(t *testing.T) {
_, err := validate(testAllFieldsYAML)
ExpectNoError(t, err)
expect.NoError(t, err)
}

View file

@ -3,11 +3,11 @@ package route
import (
"testing"
"github.com/stretchr/testify/require"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/docker"
loadbalance "github.com/yusing/go-proxy/internal/net/gphttp/loadbalancer/types"
route "github.com/yusing/go-proxy/internal/route/types"
expect "github.com/yusing/go-proxy/internal/utils/testing"
"github.com/yusing/go-proxy/internal/watcher/health"
)
@ -20,8 +20,8 @@ func TestRouteValidate(t *testing.T) {
Port: route.Port{Proxy: common.ProxyHTTPPort},
}
err := r.Validate()
require.Error(t, err, "Validate should return error for localhost with reserved port")
require.Contains(t, err.Error(), "reserved for godoxy")
expect.HasError(t, err, "Validate should return error for localhost with reserved port")
expect.ErrorContains(t, err, "reserved for godoxy")
})
t.Run("ListeningPortWithHTTP", func(t *testing.T) {
@ -32,8 +32,8 @@ func TestRouteValidate(t *testing.T) {
Port: route.Port{Proxy: 80, Listening: 1234},
}
err := r.Validate()
require.Error(t, err, "Validate should return error for HTTP scheme with listening port")
require.Contains(t, err.Error(), "unexpected listening port")
expect.HasError(t, err, "Validate should return error for HTTP scheme with listening port")
expect.ErrorContains(t, err, "unexpected listening port")
})
t.Run("DisabledHealthCheckWithLoadBalancer", func(t *testing.T) {
@ -50,8 +50,8 @@ func TestRouteValidate(t *testing.T) {
}, // Minimal LoadBalance config with non-empty Link will be checked by UseLoadBalance
}
err := r.Validate()
require.Error(t, err, "Validate should return error for disabled healthcheck with loadbalancer")
require.Contains(t, err.Error(), "cannot disable healthcheck")
expect.HasError(t, err, "Validate should return error for disabled healthcheck with loadbalancer")
expect.ErrorContains(t, err, "cannot disable healthcheck")
})
t.Run("FileServerScheme", func(t *testing.T) {
@ -63,8 +63,8 @@ func TestRouteValidate(t *testing.T) {
Root: "/tmp", // Root is required for file server
}
err := r.Validate()
require.NoError(t, err, "Validate should not return error for valid file server route")
require.NotNil(t, r.impl, "Impl should be initialized")
expect.NoError(t, err, "Validate should not return error for valid file server route")
expect.NotNil(t, r.impl, "Impl should be initialized")
})
t.Run("HTTPScheme", func(t *testing.T) {
@ -75,8 +75,8 @@ func TestRouteValidate(t *testing.T) {
Port: route.Port{Proxy: 80},
}
err := r.Validate()
require.NoError(t, err, "Validate should not return error for valid HTTP route")
require.NotNil(t, r.impl, "Impl should be initialized")
expect.NoError(t, err, "Validate should not return error for valid HTTP route")
expect.NotNil(t, r.impl, "Impl should be initialized")
})
t.Run("TCPScheme", func(t *testing.T) {
@ -87,8 +87,8 @@ func TestRouteValidate(t *testing.T) {
Port: route.Port{Proxy: 80, Listening: 8080},
}
err := r.Validate()
require.NoError(t, err, "Validate should not return error for valid TCP route")
require.NotNil(t, r.impl, "Impl should be initialized")
expect.NoError(t, err, "Validate should not return error for valid TCP route")
expect.NotNil(t, r.impl, "Impl should be initialized")
})
t.Run("DockerContainer", func(t *testing.T) {
@ -107,8 +107,8 @@ func TestRouteValidate(t *testing.T) {
},
}
err := r.Validate()
require.NoError(t, err, "Validate should not return error for valid docker container route")
require.NotNil(t, r.ProxyURL, "ProxyURL should be set")
expect.NoError(t, err, "Validate should not return error for valid docker container route")
expect.NotNil(t, r.ProxyURL, "ProxyURL should be set")
})
t.Run("InvalidScheme", func(t *testing.T) {
@ -118,7 +118,7 @@ func TestRouteValidate(t *testing.T) {
Host: "example.com",
Port: route.Port{Proxy: 80},
}
require.Panics(t, func() {
expect.Panics(t, func() {
_ = r.Validate()
}, "Validate should panic for invalid scheme")
})
@ -131,8 +131,8 @@ func TestRouteValidate(t *testing.T) {
Port: route.Port{Proxy: 80},
}
err := r.Validate()
require.NoError(t, err)
require.NotNil(t, r.ProxyURL)
require.NotNil(t, r.HealthCheck)
expect.NoError(t, err)
expect.NotNil(t, r.ProxyURL)
expect.NotNil(t, r.HealthCheck)
})
}

View file

@ -3,7 +3,7 @@ package rules
import (
"testing"
. "github.com/yusing/go-proxy/internal/utils/testing"
expect "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestParseCommands(t *testing.T) {
@ -126,9 +126,9 @@ func TestParseCommands(t *testing.T) {
cmd := Command{}
err := cmd.Parse(tt.input)
if tt.wantErr != nil {
ExpectError(t, tt.wantErr, err)
expect.ErrorIs(t, tt.wantErr, err)
} else {
ExpectNoError(t, err)
expect.NoError(t, err)
}
})
}

View file

@ -8,7 +8,7 @@ import (
"testing"
"github.com/yusing/go-proxy/internal/gperr"
. "github.com/yusing/go-proxy/internal/utils/testing"
expect "github.com/yusing/go-proxy/internal/utils/testing"
"golang.org/x/crypto/bcrypt"
)
@ -120,9 +120,9 @@ func TestParseOn(t *testing.T) {
on := &RuleOn{}
err := on.Parse(tt.input)
if tt.wantErr != nil {
ExpectError(t, tt.wantErr, err)
expect.ErrorIs(t, tt.wantErr, err)
} else {
ExpectNoError(t, err)
expect.NoError(t, err)
}
})
}
@ -212,7 +212,7 @@ func TestOnCorrectness(t *testing.T) {
},
{
name: "basic_auth_correct",
checker: "basic_auth user " + string(Must(bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost))),
checker: "basic_auth user " + string(expect.Must(bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost))),
input: &http.Request{
Header: http.Header{
"Authorization": {"Basic " + base64.StdEncoding.EncodeToString([]byte("user:password"))}, // "user:password"
@ -222,7 +222,7 @@ func TestOnCorrectness(t *testing.T) {
},
{
name: "basic_auth_incorrect",
checker: "basic_auth user " + string(Must(bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost))),
checker: "basic_auth user " + string(expect.Must(bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost))),
input: &http.Request{
Header: http.Header{
"Authorization": {"Basic " + base64.StdEncoding.EncodeToString([]byte("user:incorrect"))}, // "user:wrong"
@ -269,7 +269,7 @@ func TestOnCorrectness(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
on, err := parseOn(tt.checker)
ExpectNoError(t, err)
expect.NoError(t, err)
got := on.Check(Cache{}, tt.input)
if tt.want != got {
t.Errorf("want %v, got %v", tt.want, got)

View file

@ -5,7 +5,7 @@ import (
"testing"
"github.com/yusing/go-proxy/internal/gperr"
. "github.com/yusing/go-proxy/internal/utils/testing"
expect "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestParser(t *testing.T) {
@ -68,15 +68,15 @@ func TestParser(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
subject, args, err := parse(tt.input)
if tt.wantErr != nil {
ExpectError(t, tt.wantErr, err)
expect.ErrorIs(t, tt.wantErr, err)
return
}
// t.Log(subject, args, err)
ExpectNoError(t, err)
ExpectEqual(t, subject, tt.subject)
ExpectEqual(t, len(args), len(tt.args))
expect.NoError(t, err)
expect.Equal(t, subject, tt.subject)
expect.Equal(t, len(args), len(tt.args))
for i, arg := range args {
ExpectEqual(t, arg, tt.args[i])
expect.Equal(t, arg, tt.args[i])
}
})
}
@ -89,7 +89,7 @@ func TestParser(t *testing.T) {
for i, test := range tests {
t.Run(strconv.Itoa(i), func(t *testing.T) {
_, _, err := parse(test)
ExpectError(t, ErrUnterminatedQuotes, err)
expect.ErrorIs(t, ErrUnterminatedQuotes, err)
})
}
})

View file

@ -4,7 +4,7 @@ import (
"testing"
"github.com/yusing/go-proxy/internal/utils"
. "github.com/yusing/go-proxy/internal/utils/testing"
expect "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestParseRule(t *testing.T) {
@ -29,18 +29,18 @@ func TestParseRule(t *testing.T) {
Rules Rules
}
err := utils.MapUnmarshalValidate(utils.SerializedObject{"rules": test}, &rules)
ExpectNoError(t, err)
ExpectEqual(t, len(rules.Rules), len(test))
ExpectEqual(t, rules.Rules[0].Name, "test")
ExpectEqual(t, rules.Rules[0].On.String(), "method POST")
ExpectEqual(t, rules.Rules[0].Do.String(), "error 403 Forbidden")
expect.NoError(t, err)
expect.Equal(t, len(rules.Rules), len(test))
expect.Equal(t, rules.Rules[0].Name, "test")
expect.Equal(t, rules.Rules[0].On.String(), "method POST")
expect.Equal(t, rules.Rules[0].Do.String(), "error 403 Forbidden")
ExpectEqual(t, rules.Rules[1].Name, "auth")
ExpectEqual(t, rules.Rules[1].On.String(), `basic_auth "username" "password" | basic_auth username2 "password2" | basic_auth "username3" "password3"`)
ExpectEqual(t, rules.Rules[1].Do.String(), "bypass")
expect.Equal(t, rules.Rules[1].Name, "auth")
expect.Equal(t, rules.Rules[1].On.String(), `basic_auth "username" "password" | basic_auth username2 "password2" | basic_auth "username3" "password3"`)
expect.Equal(t, rules.Rules[1].Do.String(), "bypass")
ExpectEqual(t, rules.Rules[2].Name, "default")
ExpectEqual(t, rules.Rules[2].Do.String(), "require_basic_auth any_realm")
expect.Equal(t, rules.Rules[2].Name, "default")
expect.Equal(t, rules.Rules[2].Do.String(), "require_basic_auth any_realm")
}
// TODO: real tests.

View file

@ -7,7 +7,7 @@ import (
. "github.com/yusing/go-proxy/internal/route"
route "github.com/yusing/go-proxy/internal/route/types"
"github.com/yusing/go-proxy/internal/utils"
. "github.com/yusing/go-proxy/internal/utils/testing"
expect "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestHTTPConfigDeserialize(t *testing.T) {
@ -42,9 +42,9 @@ func TestHTTPConfigDeserialize(t *testing.T) {
tt.input["host"] = "internal"
err := utils.MapUnmarshalValidate(tt.input, &cfg)
if err != nil {
ExpectNoError(t, err)
expect.NoError(t, err)
}
ExpectEqual(t, cfg.HTTPConfig, tt.expected)
expect.Equal(t, cfg.HTTPConfig, tt.expected)
})
}
}

View file

@ -6,7 +6,7 @@ import (
"testing"
"time"
. "github.com/yusing/go-proxy/internal/utils/testing"
expect "github.com/yusing/go-proxy/internal/utils/testing"
)
func testTask() *Task {
@ -35,7 +35,7 @@ func TestChildTaskCancellation(t *testing.T) {
select {
case <-child.Context().Done():
ExpectError(t, context.Canceled, child.Context().Err())
expect.ErrorIs(t, context.Canceled, child.Context().Err())
default:
t.Fatal("subTask context was not canceled as expected")
}
@ -55,10 +55,10 @@ func TestTaskOnCancelOnFinished(t *testing.T) {
shouldTrueOnFinish = true
})
ExpectFalse(t, shouldTrueOnFinish)
expect.False(t, shouldTrueOnFinish)
task.Finish(nil)
ExpectTrue(t, shouldTrueOnCancel)
ExpectTrue(t, shouldTrueOnFinish)
expect.True(t, shouldTrueOnCancel)
expect.True(t, shouldTrueOnFinish)
}
func TestCommonFlowWithGracefulShutdown(t *testing.T) {
@ -83,19 +83,19 @@ func TestCommonFlowWithGracefulShutdown(t *testing.T) {
}
}()
ExpectNoError(t, GracefulShutdown(1*time.Second))
ExpectTrue(t, finished)
expect.NoError(t, GracefulShutdown(1*time.Second))
expect.True(t, finished)
<-root.finished
ExpectError(t, context.Canceled, task.Context().Err())
ExpectError(t, ErrProgramExiting, task.FinishCause())
expect.ErrorIs(t, context.Canceled, task.Context().Err())
expect.ErrorIs(t, ErrProgramExiting, task.FinishCause())
}
func TestTimeoutOnGracefulShutdown(t *testing.T) {
t.Cleanup(testCleanup)
_ = testTask()
ExpectError(t, context.DeadlineExceeded, GracefulShutdown(time.Millisecond))
expect.ErrorIs(t, context.DeadlineExceeded, GracefulShutdown(time.Millisecond))
}
func TestFinishMultipleCalls(t *testing.T) {

View file

@ -33,3 +33,12 @@ func (a *Value[T]) Swap(v T) T {
func (a *Value[T]) MarshalJSONTo(buf []byte) []byte {
return json.MarshalTo(a.Load(), buf)
}
func (a *Value[T]) UnmarshalJSON(data []byte) error {
var v T
err := json.Unmarshal(data, &v)
if err == nil {
a.Store(v)
}
return err
}

View file

@ -4,7 +4,7 @@ import (
"testing"
. "github.com/yusing/go-proxy/internal/utils/functional"
. "github.com/yusing/go-proxy/internal/utils/testing"
expect "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestNewMapFrom(t *testing.T) {
@ -13,8 +13,8 @@ func TestNewMapFrom(t *testing.T) {
"b": 2,
"c": 3,
})
ExpectEqual(t, m.Size(), 3)
ExpectTrue(t, m.Has("a"))
ExpectTrue(t, m.Has("b"))
ExpectTrue(t, m.Has("c"))
expect.Equal(t, m.Size(), 3)
expect.True(t, m.Has("a"))
expect.True(t, m.Has("b"))
expect.True(t, m.Has("c"))
}

View file

@ -5,7 +5,7 @@ import (
"testing"
"time"
. "github.com/yusing/go-proxy/internal/utils/testing"
expect "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestRefCounterAddSub(t *testing.T) {
@ -23,7 +23,7 @@ func TestRefCounterAddSub(t *testing.T) {
}
wg.Wait()
ExpectEqual(t, int(rc.refCount), 0)
expect.Equal(t, int(rc.refCount), 0)
select {
case <-rc.Zero():
@ -48,7 +48,7 @@ func TestRefCounterMultipleAddSub(t *testing.T) {
}()
}
wg.Wait()
ExpectEqual(t, int(rc.refCount), numAdds+1)
expect.Equal(t, int(rc.refCount), numAdds+1)
wg.Add(numSubs)
for range numSubs {
@ -58,7 +58,7 @@ func TestRefCounterMultipleAddSub(t *testing.T) {
}()
}
wg.Wait()
ExpectEqual(t, int(rc.refCount), numAdds+1-numSubs)
expect.Equal(t, int(rc.refCount), numAdds+1-numSubs)
rc.Sub()
select {

View file

@ -9,7 +9,7 @@ import (
"testing"
"github.com/goccy/go-yaml"
. "github.com/yusing/go-proxy/internal/utils/testing"
expect "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestUnmarshal(t *testing.T) {
@ -44,8 +44,8 @@ func TestUnmarshal(t *testing.T) {
t.Run("unmarshal", func(t *testing.T) {
var s2 S
err := MapUnmarshalValidate(testStructSerialized, &s2)
ExpectNoError(t, err)
ExpectEqualValues(t, s2, testStruct)
expect.NoError(t, err)
expect.Values(t, s2, testStruct)
})
}
@ -64,16 +64,16 @@ func TestUnmarshalAnonymousField(t *testing.T) {
// all, anon := extractFields(reflect.TypeOf(s2))
// t.Fatalf("anon %v, all %v", anon, all)
err := MapUnmarshalValidate(map[string]any{"a": 1, "b": 2, "c": 3}, &s)
ExpectNoError(t, err)
ExpectEqualValues(t, s.A, 1)
ExpectEqualValues(t, s.B, 2)
ExpectEqualValues(t, s.C, 3)
expect.NoError(t, err)
expect.Values(t, s.A, 1)
expect.Values(t, s.B, 2)
expect.Values(t, s.C, 3)
err = MapUnmarshalValidate(map[string]any{"a": 1, "b": 2, "c": 3}, &s2)
ExpectNoError(t, err)
ExpectEqualValues(t, s2.A, 1)
ExpectEqualValues(t, s2.B, 2)
ExpectEqualValues(t, s2.C, 3)
expect.NoError(t, err)
expect.Values(t, s2.A, 1)
expect.Values(t, s2.B, 2)
expect.Values(t, s2.C, 3)
}
func TestStringIntConvert(t *testing.T) {
@ -93,13 +93,13 @@ func TestStringIntConvert(t *testing.T) {
field := refl.Elem().Field(i)
t.Run(fmt.Sprintf("field_%s", field.Type().Name()), func(t *testing.T) {
ok, err := ConvertString("127", field)
ExpectTrue(t, ok)
ExpectNoError(t, err)
ExpectEqualValues(t, field.Interface(), 127)
expect.True(t, ok)
expect.NoError(t, err)
expect.Values(t, field.Interface(), 127)
err = Convert(reflect.ValueOf(uint8(64)), field)
ExpectNoError(t, err)
ExpectEqualValues(t, field.Interface(), 64)
expect.NoError(t, err)
expect.Values(t, field.Interface(), 64)
})
}
}
@ -123,26 +123,26 @@ func (c *testType) Parse(v string) (err error) {
func TestConvertor(t *testing.T) {
t.Run("valid", func(t *testing.T) {
m := new(testModel)
ExpectNoError(t, MapUnmarshalValidate(map[string]any{"Test": "123"}, m))
expect.NoError(t, MapUnmarshalValidate(map[string]any{"Test": "123"}, m))
ExpectEqualValues(t, m.Test.foo, 123)
ExpectEqualValues(t, m.Test.bar, "123")
expect.Values(t, m.Test.foo, 123)
expect.Values(t, m.Test.bar, "123")
})
t.Run("int_to_string", func(t *testing.T) {
m := new(testModel)
ExpectNoError(t, MapUnmarshalValidate(map[string]any{"Test": "123"}, m))
expect.NoError(t, MapUnmarshalValidate(map[string]any{"Test": "123"}, m))
ExpectEqualValues(t, m.Test.foo, 123)
ExpectEqualValues(t, m.Test.bar, "123")
expect.Values(t, m.Test.foo, 123)
expect.Values(t, m.Test.bar, "123")
ExpectNoError(t, MapUnmarshalValidate(map[string]any{"Baz": 456}, m))
ExpectEqualValues(t, m.Baz, "456")
expect.NoError(t, MapUnmarshalValidate(map[string]any{"Baz": 456}, m))
expect.Values(t, m.Baz, "456")
})
t.Run("invalid", func(t *testing.T) {
m := new(testModel)
ExpectError(t, ErrUnsupportedConversion, MapUnmarshalValidate(map[string]any{"Test": struct{}{}}, m))
expect.ErrorIs(t, ErrUnsupportedConversion, MapUnmarshalValidate(map[string]any{"Test": struct{}{}}, m))
})
}
@ -150,23 +150,23 @@ func TestStringToSlice(t *testing.T) {
t.Run("comma_separated", func(t *testing.T) {
dst := make([]string, 0)
convertible, err := ConvertString("a,b,c", reflect.ValueOf(&dst))
ExpectTrue(t, convertible)
ExpectNoError(t, err)
ExpectEqualValues(t, dst, []string{"a", "b", "c"})
expect.True(t, convertible)
expect.NoError(t, err)
expect.Values(t, dst, []string{"a", "b", "c"})
})
t.Run("yaml-like", func(t *testing.T) {
dst := make([]string, 0)
convertible, err := ConvertString("- a\n- b\n- c", reflect.ValueOf(&dst))
ExpectTrue(t, convertible)
ExpectNoError(t, err)
ExpectEqualValues(t, dst, []string{"a", "b", "c"})
expect.True(t, convertible)
expect.NoError(t, err)
expect.Values(t, dst, []string{"a", "b", "c"})
})
t.Run("single-line-yaml-like", func(t *testing.T) {
dst := make([]string, 0)
convertible, err := ConvertString("- a", reflect.ValueOf(&dst))
ExpectTrue(t, convertible)
ExpectNoError(t, err)
ExpectEqualValues(t, dst, []string{"a"})
expect.True(t, convertible)
expect.NoError(t, err)
expect.Values(t, dst, []string{"a"})
})
}
@ -188,9 +188,9 @@ func TestStringToMap(t *testing.T) {
t.Run("yaml-like", func(t *testing.T) {
dst := make(map[string]string)
convertible, err := ConvertString(" a: b\n c: d", reflect.ValueOf(&dst))
ExpectTrue(t, convertible)
ExpectNoError(t, err)
ExpectEqualValues(t, dst, map[string]string{"a": "b", "c": "d"})
expect.True(t, convertible)
expect.NoError(t, err)
expect.Values(t, dst, map[string]string{"a": "b", "c": "d"})
})
}
@ -216,10 +216,10 @@ func TestStringToStruct(t *testing.T) {
t.Run("yaml-like simple", func(t *testing.T) {
var dst T
convertible, err := ConvertString(" A: a\n B: 123", reflect.ValueOf(&dst))
ExpectTrue(t, convertible)
ExpectNoError(t, err)
ExpectEqualValues(t, dst.A, "a")
ExpectEqualValues(t, dst.B, 123)
expect.True(t, convertible)
expect.NoError(t, err)
expect.Values(t, dst.A, "a")
expect.Values(t, dst.B, 123)
})
type T2 struct {
@ -229,8 +229,8 @@ func TestStringToStruct(t *testing.T) {
t.Run("yaml-like complex", func(t *testing.T) {
var dst T2
convertible, err := ConvertString(" URL: http://example.com\n CIDR: 1.2.3.0/24", reflect.ValueOf(&dst))
ExpectTrue(t, convertible)
ExpectNoError(t, err)
expect.True(t, convertible)
expect.NoError(t, err)
})
}

View file

@ -5,7 +5,7 @@ import (
"strings"
"testing"
. "github.com/yusing/go-proxy/internal/utils/testing"
expect "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestIntersect(t *testing.T) {
@ -19,7 +19,7 @@ func TestIntersect(t *testing.T) {
result := Intersect(slice1, slice2)
slices.Sort(result)
slices.Sort(want)
ExpectEqual(t, result, want)
expect.Equal(t, result, want)
})
t.Run("intersection", func(t *testing.T) {
var (
@ -30,7 +30,7 @@ func TestIntersect(t *testing.T) {
result := Intersect(slice1, slice2)
slices.Sort(result)
slices.Sort(want)
ExpectEqual(t, result, want)
expect.Equal(t, result, want)
})
})
t.Run("ints", func(t *testing.T) {
@ -43,7 +43,7 @@ func TestIntersect(t *testing.T) {
result := Intersect(slice1, slice2)
slices.Sort(result)
slices.Sort(want)
ExpectEqual(t, result, want)
expect.Equal(t, result, want)
})
t.Run("intersection", func(t *testing.T) {
var (
@ -54,7 +54,7 @@ func TestIntersect(t *testing.T) {
result := Intersect(slice1, slice2)
slices.Sort(result)
slices.Sort(want)
ExpectEqual(t, result, want)
expect.Equal(t, result, want)
})
})
t.Run("complex", func(t *testing.T) {
@ -75,7 +75,7 @@ func TestIntersect(t *testing.T) {
slices.SortFunc(want, func(i T, j T) int {
return strings.Compare(i.A, j.A)
})
ExpectEqual(t, result, want)
expect.Equal(t, result, want)
})
t.Run("intersection", func(t *testing.T) {
var (
@ -90,7 +90,7 @@ func TestIntersect(t *testing.T) {
slices.SortFunc(want, func(i T, j T) int {
return strings.Compare(i.A, j.A)
})
ExpectEqual(t, result, want)
expect.Equal(t, result, want)
})
})
}

View file

@ -5,11 +5,11 @@ import (
"time"
. "github.com/yusing/go-proxy/internal/utils/strutils"
. "github.com/yusing/go-proxy/internal/utils/testing"
expect "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestFormatTime(t *testing.T) {
now := Must(time.Parse(time.RFC3339, "2021-06-15T12:30:30Z"))
now := expect.Must(time.Parse(time.RFC3339, "2021-06-15T12:30:30Z"))
tests := []struct {
name string
@ -84,9 +84,9 @@ func TestFormatTime(t *testing.T) {
result := FormatTimeWithReference(tt.time, now)
if tt.expectedLength > 0 {
ExpectEqual(t, len(result), tt.expectedLength, result)
expect.Equal(t, len(result), tt.expectedLength, result)
} else {
ExpectEqual(t, result, tt.expected)
expect.Equal(t, result, tt.expected)
}
})
}
@ -163,7 +163,7 @@ func TestFormatDuration(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := FormatDuration(tt.duration)
ExpectEqual(t, result, tt.expected)
expect.Equal(t, result, tt.expected)
})
}
}
@ -193,7 +193,7 @@ func TestFormatLastSeen(t *testing.T) {
result := FormatLastSeen(tt.time)
if tt.name == "zero time" {
ExpectEqual(t, result, tt.expected)
expect.Equal(t, result, tt.expected)
} else {
// Just make sure it's not "never", the actual formatting is tested in TestFormatTime
if result == "never" {

View file

@ -5,7 +5,7 @@ import (
"testing"
. "github.com/yusing/go-proxy/internal/utils/strutils"
. "github.com/yusing/go-proxy/internal/utils/testing"
expect "github.com/yusing/go-proxy/internal/utils/testing"
)
var alphaNumeric = func() string {
@ -31,8 +31,8 @@ func TestSplit(t *testing.T) {
for sep, rsep := range tests {
t.Run(sep, func(t *testing.T) {
expected := strings.Split(alphaNumeric, sep)
ExpectEqual(t, SplitRune(alphaNumeric, rsep), expected)
ExpectEqual(t, JoinRune(expected, rsep), alphaNumeric)
expect.Equal(t, SplitRune(alphaNumeric, rsep), expected)
expect.Equal(t, JoinRune(expected, rsep), alphaNumeric)
})
}
}

View file

@ -3,7 +3,7 @@ package strutils
import (
"testing"
"github.com/stretchr/testify/require"
expect "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestSanitizeURI(t *testing.T) {
@ -57,7 +57,7 @@ func TestSanitizeURI(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := SanitizeURI(tt.input)
require.Equal(t, tt.expected, result)
expect.Equal(t, result, tt.expected)
})
}
}

View file

@ -0,0 +1,71 @@
package expect
import (
"os"
"testing"
"github.com/stretchr/testify/require"
"github.com/yusing/go-proxy/internal/common"
)
func init() {
if common.IsTest {
os.Args = append([]string{os.Args[0], "-test.v"}, os.Args[1:]...)
}
}
func Must[Result any](r Result, err error) Result {
if err != nil {
panic(err)
}
return r
}
var (
NoError = require.NoError
HasError = require.Error
True = require.True
False = require.False
Nil = require.Nil
NotNil = require.NotNil
ErrorContains = require.ErrorContains
Panics = require.Panics
)
func ErrorIs(t *testing.T, expected error, err error, msgAndArgs ...any) {
t.Helper()
require.ErrorIs(t, err, expected, msgAndArgs...)
}
func ErrorT[T error](t *testing.T, err error, msgAndArgs ...any) {
t.Helper()
var errAs T
require.ErrorAs(t, err, &errAs, msgAndArgs...)
}
func Equal[T any](t *testing.T, got T, want T, msgAndArgs ...any) {
t.Helper()
require.Equal(t, want, got, msgAndArgs...)
}
func NotEqual[T any](t *testing.T, got T, want T, msgAndArgs ...any) {
t.Helper()
require.NotEqual(t, want, got, msgAndArgs...)
}
func Values(t *testing.T, got any, want any, msgAndArgs ...any) {
t.Helper()
require.EqualValues(t, want, got, msgAndArgs...)
}
func Contains[T any](t *testing.T, got T, wants []T, msgAndArgs ...any) {
t.Helper()
require.Contains(t, wants, got, msgAndArgs...)
}
func Type[T any](t *testing.T, got any, msgAndArgs ...any) (_ T) {
t.Helper()
_, ok := got.(T)
require.True(t, ok, msgAndArgs...)
return got.(T)
}

View file

@ -1,75 +0,0 @@
package utils
import (
"os"
"testing"
"github.com/stretchr/testify/require"
"github.com/yusing/go-proxy/internal/common"
)
func init() {
if common.IsTest {
os.Args = append([]string{os.Args[0], "-test.v"}, os.Args[1:]...)
}
}
func Must[Result any](r Result, err error) Result {
if err != nil {
panic(err)
}
return r
}
func ExpectNoError(t *testing.T, err error, msgAndArgs ...any) {
t.Helper()
require.NoError(t, err, msgAndArgs...)
}
func ExpectHasError(t *testing.T, err error, msgAndArgs ...any) {
t.Helper()
require.Error(t, err, msgAndArgs...)
}
func ExpectError(t *testing.T, expected error, err error, msgAndArgs ...any) {
t.Helper()
require.ErrorIs(t, err, expected, msgAndArgs...)
}
func ExpectErrorT[T error](t *testing.T, err error, msgAndArgs ...any) {
t.Helper()
var errAs T
require.ErrorAs(t, err, &errAs, msgAndArgs...)
}
func ExpectEqual[T any](t *testing.T, got T, want T, msgAndArgs ...any) {
t.Helper()
require.Equal(t, want, got, msgAndArgs...)
}
func ExpectEqualValues(t *testing.T, got any, want any, msgAndArgs ...any) {
t.Helper()
require.EqualValues(t, want, got, msgAndArgs...)
}
func ExpectContains[T any](t *testing.T, got T, wants []T, msgAndArgs ...any) {
t.Helper()
require.Contains(t, wants, got, msgAndArgs...)
}
func ExpectTrue(t *testing.T, got bool, msgAndArgs ...any) {
t.Helper()
require.True(t, got, msgAndArgs...)
}
func ExpectFalse(t *testing.T, got bool, msgAndArgs ...any) {
t.Helper()
require.False(t, got, msgAndArgs...)
}
func ExpectType[T any](t *testing.T, got any, msgAndArgs ...any) (_ T) {
t.Helper()
_, ok := got.(T)
require.True(t, ok, msgAndArgs...)
return got.(T)
}

View file

@ -8,9 +8,8 @@ import (
"testing"
"github.com/bytedance/sonic"
"github.com/stretchr/testify/require"
"github.com/yusing/go-proxy/internal/utils/strutils"
. "github.com/yusing/go-proxy/internal/utils/testing"
expect "github.com/yusing/go-proxy/internal/utils/testing"
. "github.com/yusing/go-proxy/pkg/json"
)
@ -350,7 +349,7 @@ func TestMarshal(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, _ := Marshal(tt.input)
require.Equal(t, tt.expected, string(result))
expect.Equal(t, string(result), tt.expected)
})
}
@ -468,7 +467,7 @@ func TestWithTestStruct(t *testing.T) {
t.Fatalf("Unmarshal error: %v", err)
}
ExpectEqual(t, unmarshalCustom, unmarshalStdlib)
expect.Equal(t, unmarshalCustom, unmarshalStdlib)
}
func BenchmarkMarshalSimpleStdLib(b *testing.B) {