feat: hCaptcha middleware

This commit is contained in:
yusing 2025-05-04 17:21:12 +08:00
parent e275ee634c
commit f9a8aede20
15 changed files with 793 additions and 10 deletions

View file

@ -131,7 +131,7 @@ func (auth *OIDCProvider) setSessionTokenCookie(w http.ResponseWriter, r *http.R
logging.Err(err).Msg("failed to sign session token")
return
}
setTokenCookie(w, r, CookieOauthSessionToken, signed, common.APIJWTTokenTTL)
SetTokenCookie(w, r, CookieOauthSessionToken, signed, common.APIJWTTokenTTL)
}
func (auth *OIDCProvider) parseSessionJWT(sessionJWT string) (claims *sessionClaims, valid bool, err error) {

View file

@ -176,7 +176,7 @@ func (auth *OIDCProvider) LoginHandler(w http.ResponseWriter, r *http.Request) {
}
state := generateState()
setTokenCookie(w, r, CookieOauthState, state, 300*time.Second)
SetTokenCookie(w, r, CookieOauthState, state, 300*time.Second)
// redirect user to Idp
http.Redirect(w, r, auth.oauthConfig.AuthCodeURL(state, optRedirectPostAuth(r)), http.StatusFound)
}
@ -301,12 +301,12 @@ func (auth *OIDCProvider) LogoutHandler(w http.ResponseWriter, r *http.Request)
}
func (auth *OIDCProvider) setIDTokenCookie(w http.ResponseWriter, r *http.Request, jwt string, ttl time.Duration) {
setTokenCookie(w, r, CookieOauthToken, jwt, ttl)
SetTokenCookie(w, r, CookieOauthToken, jwt, ttl)
}
func (auth *OIDCProvider) clearCookie(w http.ResponseWriter, r *http.Request) {
clearTokenCookie(w, r, CookieOauthToken)
clearTokenCookie(w, r, CookieOauthSessionToken)
ClearTokenCookie(w, r, CookieOauthToken)
ClearTokenCookie(w, r, CookieOauthSessionToken)
}
// handleTestCallback handles OIDC callback in test environment.
@ -323,7 +323,7 @@ func (auth *OIDCProvider) handleTestCallback(w http.ResponseWriter, r *http.Requ
}
// Create test JWT token
setTokenCookie(w, r, CookieOauthToken, "test", time.Hour)
SetTokenCookie(w, r, CookieOauthToken, "test", time.Hour)
http.Redirect(w, r, "/", http.StatusFound)
}

View file

@ -119,7 +119,7 @@ func (auth *UserPassAuth) PostAuthCallbackHandler(w http.ResponseWriter, r *http
gphttp.ServerError(w, r, err)
return
}
setTokenCookie(w, r, auth.TokenCookieName(), token, auth.tokenTTL)
SetTokenCookie(w, r, auth.TokenCookieName(), token, auth.tokenTTL)
w.WriteHeader(http.StatusOK)
}
@ -128,7 +128,7 @@ func (auth *UserPassAuth) LoginHandler(w http.ResponseWriter, r *http.Request) {
}
func (auth *UserPassAuth) LogoutHandler(w http.ResponseWriter, r *http.Request) {
clearTokenCookie(w, r, auth.TokenCookieName())
ClearTokenCookie(w, r, auth.TokenCookieName())
http.Redirect(w, r, "/", http.StatusFound)
}

View file

@ -44,7 +44,7 @@ func cookieDomain(r *http.Request) string {
return strutils.JoinRune(parts, '.')
}
func setTokenCookie(w http.ResponseWriter, r *http.Request, name, value string, ttl time.Duration) {
func SetTokenCookie(w http.ResponseWriter, r *http.Request, name, value string, ttl time.Duration) {
http.SetCookie(w, &http.Cookie{
Name: name,
Value: value,
@ -57,7 +57,7 @@ func setTokenCookie(w http.ResponseWriter, r *http.Request, name, value string,
})
}
func clearTokenCookie(w http.ResponseWriter, r *http.Request, name string) {
func ClearTokenCookie(w http.ResponseWriter, r *http.Request, name string) {
http.SetCookie(w, &http.Cookie{
Name: name,
Value: "",

View file

@ -85,6 +85,14 @@ func Join(errors ...error) Error {
return &nestedError{Extras: errs}
}
func JoinLines(main error, errors ...string) Error {
errs := make([]error, len(errors))
for i, err := range errors {
errs[i] = newError(err)
}
return &nestedError{Err: main, Extras: errs}
}
func Collect[T any, Err error, Arg any, Func func(Arg) (T, Err)](eb *Builder, fn Func, arg Arg) T {
result, err := fn(arg)
eb.Add(err)

View file

@ -0,0 +1,69 @@
package httpheaders
import (
"net/http"
"strings"
)
// AppendCSP appends a CSP header to specific directives in the response writer.
//
// Directives other than the ones in cspDirectives will be kept as is.
//
// It will replace 'none' with the sources.
//
// It will append 'self' to the sources if it's not already present.
func AppendCSP(w http.ResponseWriter, r *http.Request, cspDirectives []string, sources []string) {
csp := make(map[string]string)
cspValues := r.Header.Values("Content-Security-Policy")
if len(cspValues) == 1 {
cspValues = strings.Split(cspValues[0], ";")
for i, cspString := range cspValues {
cspValues[i] = strings.TrimSpace(cspString)
}
}
for _, cspString := range cspValues {
parts := strings.SplitN(cspString, " ", 2)
if len(parts) == 2 {
csp[parts[0]] = parts[1]
}
}
for _, directive := range cspDirectives {
value, ok := csp[directive]
if !ok {
value = "'self'"
}
switch value {
case "'self'":
csp[directive] = value + " " + strings.Join(sources, " ")
case "'none'":
csp[directive] = strings.Join(sources, " ")
default:
for _, source := range sources {
if !strings.Contains(value, source) {
value += " " + source
}
}
if !strings.Contains(value, "'self'") {
value = "'self' " + value
}
csp[directive] = value
}
}
values := make([]string, 0, len(csp))
for directive, value := range csp {
values = append(values, directive+" "+value)
}
// Remove existing CSP header, case insensitive
for k := range w.Header() {
if strings.EqualFold(k, "Content-Security-Policy") {
delete(w.Header(), k)
}
}
// Set new CSP header
w.Header()["Content-Security-Policy"] = values
}

View file

@ -0,0 +1,168 @@
package httpheaders
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func TestAppendCSP(t *testing.T) {
tests := []struct {
name string
initialHeaders map[string][]string
sources []string
directives []string
expectedCSP map[string]string
}{
{
name: "No CSP header",
initialHeaders: map[string][]string{},
sources: []string{},
directives: []string{"default-src", "script-src", "frame-src", "style-src", "connect-src"},
expectedCSP: map[string]string{"default-src": "'self'", "script-src": "'self'", "frame-src": "'self'", "style-src": "'self'", "connect-src": "'self'"},
},
{
name: "No CSP header with sources",
initialHeaders: map[string][]string{},
sources: []string{"https://example.com"},
directives: []string{"default-src", "script-src", "frame-src", "style-src", "connect-src"},
expectedCSP: map[string]string{"default-src": "'self' https://example.com", "script-src": "'self' https://example.com", "frame-src": "'self' https://example.com", "style-src": "'self' https://example.com", "connect-src": "'self' https://example.com"},
},
{
name: "replace 'none' with sources",
initialHeaders: map[string][]string{
"Content-Security-Policy": {"default-src 'none'"},
},
sources: []string{"https://example.com"},
directives: []string{"default-src"},
expectedCSP: map[string]string{"default-src": "https://example.com"},
},
{
name: "CSP header with some directives",
initialHeaders: map[string][]string{
"Content-Security-Policy": {"default-src 'none'", "script-src 'unsafe-inline'"},
},
sources: []string{"https://example.com"},
directives: []string{"script-src"},
expectedCSP: map[string]string{
"default-src": "'none",
"script-src": "'unsafe-inline' https://example.com",
},
},
{
name: "CSP header with some directives with self",
initialHeaders: map[string][]string{
"Content-Security-Policy": {"default-src 'self'", "connect-src 'self'"},
},
sources: []string{"https://api.example.com"},
directives: []string{"default-src", "connect-src"},
expectedCSP: map[string]string{
"default-src": "'self' https://api.example.com",
"connect-src": "'self' https://api.example.com",
},
},
{
name: "AppendCSP sources conflict with existing CSP header",
initialHeaders: map[string][]string{
"Content-Security-Policy": {"default-src 'self' https://cdn.example.com", "script-src 'unsafe-inline'"},
},
sources: []string{"https://cdn.example.com", "https://api.example.com"},
directives: []string{"default-src", "script-src"},
expectedCSP: map[string]string{
"default-src": "'self' https://cdn.example.com https://api.example.com",
"script-src": "'unsafe-inline' https://cdn.example.com https://api.example.com",
},
},
{
name: "Non-standard CSP directive",
initialHeaders: map[string][]string{
"Content-Security-Policy": {
"default-src 'self'",
"script-src 'unsafe-inline'",
"img-src 'self'", // img-src is not in cspDirectives list
},
},
sources: []string{"https://example.com"},
directives: []string{"default-src", "script-src"},
expectedCSP: map[string]string{
"default-src": "'self' https://example.com",
"script-src": "'unsafe-inline' https://example.com",
// img-src should not be present in response as it's not in cspDirectives
},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// Create a test request with initial headers
req := httptest.NewRequest(http.MethodGet, "/", nil)
for header, values := range tc.initialHeaders {
req.Header[header] = values
}
// Create a test response recorder
w := httptest.NewRecorder()
// Call the function under test
AppendCSP(w, req, tc.directives, tc.sources)
// Check the resulting CSP headers
respHeaders := w.Header()
cspValues, exists := respHeaders["Content-Security-Policy"]
// If we expect no CSP headers, verify none exist
if len(tc.expectedCSP) == 0 {
if exists && len(cspValues) > 0 {
t.Errorf("Expected no CSP header, but got %v", cspValues)
}
return
}
// Verify CSP headers exist when expected
if !exists || len(cspValues) == 0 {
t.Errorf("Expected CSP header to be set, but it was not")
return
}
// Parse the CSP response and verify each directive
foundDirectives := make(map[string]string)
for _, cspValue := range cspValues {
parts := strings.Split(cspValue, ";")
for _, part := range parts {
part = strings.TrimSpace(part)
if part == "" {
continue
}
directiveParts := strings.SplitN(part, " ", 2)
if len(directiveParts) != 2 {
t.Errorf("Invalid CSP directive format: %s", part)
continue
}
directive := directiveParts[0]
value := directiveParts[1]
foundDirectives[directive] = value
}
}
// Verify expected directives
for directive, expectedValue := range tc.expectedCSP {
actualValue, ok := foundDirectives[directive]
if !ok {
t.Errorf("Expected directive %s not found in response", directive)
continue
}
// Check if all expected sources are in the actual value
expectedSources := strings.SplitSeq(expectedValue, " ")
for source := range expectedSources {
if !strings.Contains(actualValue, source) {
t.Errorf("Directive %s missing expected source %s. Got: %s", directive, source, actualValue)
}
}
}
})
}
}

View file

@ -0,0 +1,17 @@
package middleware
import (
"net/http"
"github.com/yusing/go-proxy/internal/net/gphttp/middleware/captcha"
)
type hCaptcha struct {
captcha.HcaptchaProvider
}
func (h *hCaptcha) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
return captcha.PreRequest(h, w, r)
}
var HCaptcha = NewMiddleware[hCaptcha]()

View file

@ -0,0 +1,293 @@
<!doctype html>
<html lang="en">
<head>
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Verification Required</title>
{{.ScriptHTML}}
<script>
function updateTheme() {
const theme = window.matchMedia("(prefers-color-scheme: dark)").matches
? "dark"
: "light";
document
.querySelector("#verification-form > :first-child")
.setAttribute("data-theme", theme);
}
window.addEventListener("load", updateTheme);
</script>
<style>
:root {
/* Light mode colors */
--background-light: #f8f9fa;
--text-light: #2d3748;
--container-bg-light: #ffffff;
--shadow-light: rgba(0, 0, 0, 0.08);
--heading-light: #3d4852;
--button-bg-light: #4f46e5;
--button-hover-light: #4338ca;
--button-disabled-bg-light: #e9ecef;
--button-disabled-text-light: #a0aec0;
--accent-light: #6366f1;
/* Dark mode colors */
--background-dark: #111827;
--text-dark: #e5e7eb;
--container-bg-dark: #1f2937;
--shadow-dark: rgba(0, 0, 0, 0.3);
--heading-dark: #f3f4f6;
--button-bg-dark: #6366f1;
--button-hover-dark: #4f46e5;
--button-disabled-bg-dark: #374151;
--button-disabled-text-dark: #9ca3af;
--accent-dark: #818cf8;
}
@media (prefers-color-scheme: light) {
body {
background: linear-gradient(135deg, var(--background-light), #f0f4f8);
color: var(--text-light);
}
.container {
background-color: var(--container-bg-light);
box-shadow: 0 10px 25px var(--shadow-light);
border: 1px solid rgba(0, 0, 0, 0.04);
}
h1 {
color: var(--heading-light);
}
button {
background: linear-gradient(
to right,
var(--button-bg-light),
var(--accent-light)
);
}
button:hover:not(:disabled) {
background: linear-gradient(
to right,
var(--button-hover-light),
var(--button-bg-light)
);
}
button:disabled {
background: var(--button-disabled-bg-light);
color: var(--button-disabled-text-light);
}
.container::before {
background: linear-gradient(
135deg,
rgba(99, 102, 241, 0.1),
rgba(79, 70, 229, 0.05)
);
}
}
@media (prefers-color-scheme: dark) {
body {
background: linear-gradient(135deg, var(--background-dark), #0f172a);
color: var(--text-dark);
}
.container {
background-color: var(--container-bg-dark);
box-shadow: 0 10px 25px var(--shadow-dark);
border: 1px solid rgba(255, 255, 255, 0.05);
}
h1 {
color: var(--heading-dark);
}
button {
background: linear-gradient(
to right,
var(--button-bg-dark),
var(--accent-dark)
);
}
button:hover:not(:disabled) {
background: linear-gradient(
to right,
var(--button-hover-dark),
var(--button-bg-dark)
);
}
button:disabled {
background: var(--button-disabled-bg-dark);
color: var(--button-disabled-text-dark);
}
.container::before {
background: linear-gradient(
135deg,
rgba(99, 102, 241, 0.1),
rgba(129, 140, 248, 0.05)
);
}
}
body {
font-family:
"Inter",
system-ui,
-apple-system,
BlinkMacSystemFont,
"Segoe UI",
Roboto,
Oxygen,
Ubuntu,
Cantarell,
"Open Sans",
"Helvetica Neue",
sans-serif;
display: flex;
justify-content: center;
align-items: center;
min-height: 100vh;
margin: 0;
transition:
background-color 0.5s ease,
color 0.3s ease;
line-height: 1.6;
}
.container {
position: relative;
padding: 48px 42px;
border-radius: 16px;
text-align: center;
max-width: 420px;
width: 90%;
transition:
background-color 0.3s ease,
box-shadow 0.3s ease,
transform 0.3s ease;
overflow: hidden;
animation: fadeIn 0.5s ease-out;
}
.container::before {
content: "";
position: absolute;
top: -10%;
left: -10%;
width: 120%;
height: 120%;
border-radius: 30%;
opacity: 0.5;
z-index: 0;
transform: rotate(-8deg);
}
.content {
position: relative;
z-index: 1;
}
h1 {
font-size: 1.75em;
font-weight: 700;
margin-bottom: 28px;
transition: color 0.3s ease;
letter-spacing: -0.02em;
}
button {
color: white;
border: none;
padding: 13px 30px;
border-radius: 10px;
cursor: pointer;
font-size: 1rem;
font-weight: 600;
letter-spacing: 0.01em;
transition:
all 0.25s ease,
transform 0.15s ease;
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.15);
position: relative;
overflow: hidden;
}
button:hover:not(:disabled) {
transform: translateY(-2px);
box-shadow: 0 6px 15px rgba(0, 0, 0, 0.2);
}
button:active:not(:disabled) {
transform: translateY(0);
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.15);
}
button:focus {
outline: none;
box-shadow:
0 0 0 2px rgba(99, 102, 241, 0.5),
0 4px 12px rgba(0, 0, 0, 0.15);
}
button:disabled {
cursor: not-allowed;
box-shadow: none;
}
#verification-form {
margin-top: 30px;
display: flex;
flex-direction: column;
align-items: center;
gap: 22px;
position: relative;
z-index: 1;
}
#verification-form > :first-child {
margin-left: auto;
margin-right: auto;
}
@keyframes fadeIn {
from {
opacity: 0;
transform: translateY(10px);
}
to {
opacity: 1;
transform: translateY(0);
}
}
.description {
color: var(--text-light);
opacity: 0.85;
font-size: 0.95rem;
margin-bottom: 20px;
max-width: 90%;
margin-left: auto;
margin-right: auto;
}
@media (prefers-color-scheme: dark) {
.description {
color: var(--text-dark);
opacity: 0.75;
}
}
</style>
</head>
<body>
<script>
function onDataCallback() {
document.getElementById("verification-form").submit();
}
</script>
<div class="container">
<div class="content">
<h1>Human Verification</h1>
<p class="description">
Please complete the verification below to continue.
</p>
<form id="verification-form" method="POST" action="">
{{.FormHTML}}
</form>
</div>
</div>
</body>
</html>

View file

@ -0,0 +1,96 @@
package captcha
import (
"bytes"
"context"
"encoding/json"
"errors"
"net"
"net/http"
"net/url"
"time"
_ "embed"
"github.com/yusing/go-proxy/internal/gperr"
)
type HcaptchaProvider struct {
ProviderBase
SiteKey string `json:"site_key" validate:"required"`
Secret string `json:"secret" validate:"required"`
}
// https://docs.hcaptcha.com/#content-security-policy-settings
func (p *HcaptchaProvider) CSPDirectives() []string {
return []string{"script-src", "frame-src", "style-src", "connect-src"}
}
// https://docs.hcaptcha.com/#content-security-policy-settings
func (p *HcaptchaProvider) CSPSources() []string {
return []string{
"https://hcaptcha.com",
"https://*.hcaptcha.com",
}
}
func (p *HcaptchaProvider) Verify(r *http.Request) error {
response := r.PostFormValue("h-captcha-response")
if response == "" {
return errors.New("h-captcha-response is missing")
}
remoteIP := r.RemoteAddr
if ip, _, err := net.SplitHostPort(remoteIP); err == nil {
remoteIP = ip
}
ctx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
defer cancel()
formData := url.Values{}
formData.Set("secret", p.Secret)
formData.Set("response", response)
formData.Set("remoteip", remoteIP)
formData.Set("sitekey", p.SiteKey)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://api.hcaptcha.com/siteverify", bytes.NewBufferString(formData.Encode()))
if err != nil {
return err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := http.DefaultClient.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
var respData struct {
Success bool `json:"success"`
Error []string `json:"error-codes"`
}
if err := json.NewDecoder(resp.Body).Decode(&respData); err != nil {
return err
}
if !respData.Success {
return gperr.JoinLines(ErrCaptchaVerificationFailed, respData.Error...)
}
return nil
}
func (p *HcaptchaProvider) ScriptHTML() string {
return `
<script src="https://js.hcaptcha.com/1/api.js" async defer></script>`
}
func (p *HcaptchaProvider) FormHTML() string {
return `
<div
class="h-captcha"
data-sitekey="` + p.SiteKey + `"
data-callback="onDataCallback"
/>`
}

View file

@ -0,0 +1,61 @@
package captcha
import (
"net/http"
"text/template"
"github.com/yusing/go-proxy/internal/auth"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/net/gphttp"
_ "embed"
)
const cookieName = "godoxy_captcha_session"
//go:embed captcha.html
var captchaPageHTML string
var captchaPage = template.Must(template.New("captcha").Parse(captchaPageHTML))
func PreRequest(p Provider, w http.ResponseWriter, r *http.Request) (proceed bool) {
// check session
sessionID, err := r.Cookie(cookieName)
if err == nil {
session, ok := CaptchaSessions.Load(sessionID.Value)
if ok {
if session.expired() {
CaptchaSessions.Delete(sessionID.Value)
} else {
return true
}
}
}
if !gphttp.GetAccept(r.Header).AcceptHTML() {
gphttp.Forbidden(w, "Captcha is required")
return false
}
if r.Method == http.MethodPost {
err := p.Verify(r)
if err == nil {
session := newCaptchaSession(p)
CaptchaSessions.Store(session.ID, session)
auth.SetTokenCookie(w, r, cookieName, session.ID, p.SessionExpiry())
http.Redirect(w, r, r.URL.Path, http.StatusFound)
return false
}
gphttp.Unauthorized(w, err.Error())
return false
}
// captcha challenge
err = captchaPage.Execute(w, map[string]any{
"ScriptHTML": p.ScriptHTML(),
"FormHTML": p.FormHTML(),
})
if err != nil {
logging.Error().Err(err).Msg("failed to execute captcha page")
}
return false
}

View file

@ -0,0 +1,21 @@
package captcha
import (
"net/http"
"time"
"github.com/yusing/go-proxy/internal/gperr"
)
type Provider interface {
CSPDirectives() []string
CSPSources() []string
Verify(r *http.Request) error
SessionExpiry() time.Duration
ScriptHTML() string
FormHTML() string
}
var (
ErrCaptchaVerificationFailed = gperr.New("captcha verification failed")
)

View file

@ -0,0 +1,14 @@
package captcha
import "time"
type ProviderBase struct {
Expiry time.Duration `json:"session_expiry"`
}
func (p *ProviderBase) SessionExpiry() time.Duration {
if p.Expiry == 0 {
p.Expiry = 24 * time.Hour
}
return p.Expiry
}

View file

@ -0,0 +1,34 @@
package captcha
import (
"crypto/rand"
"encoding/hex"
"time"
_ "embed"
"github.com/yusing/go-proxy/internal/jsonstore"
"github.com/yusing/go-proxy/internal/utils"
)
type CaptchaSession struct {
ID string `json:"id"`
Expiry time.Time `json:"expiry"`
}
var CaptchaSessions = jsonstore.Store[*CaptchaSession]("captcha_sessions")
func newCaptchaSession(p Provider) *CaptchaSession {
buf := make([]byte, 32)
_, _ = rand.Read(buf)
now := utils.TimeNow()
return &CaptchaSession{
ID: hex.EncodeToString(buf),
Expiry: now.Add(p.SessionExpiry()),
}
}
func (s *CaptchaSession) expired() bool {
return utils.TimeNow().After(s.Expiry)
}

View file

@ -32,6 +32,8 @@ var allMiddlewares = map[string]*Middleware{
"cidrwhitelist": CIDRWhiteList,
"ratelimit": RateLimiter,
"hcaptcha": HCaptcha,
}
var (