package middleware import ( "bytes" "io" "net/http" "path/filepath" "strconv" "strings" gphttp "github.com/yusing/go-proxy/internal/net/http" "github.com/yusing/go-proxy/internal/net/http/middleware/errorpage" ) type customErrorPage struct{} var CustomErrorPage = NewMiddleware[customErrorPage]() // before implements RequestModifier. func (customErrorPage) before(w http.ResponseWriter, r *http.Request) (proceed bool) { return !ServeStaticErrorPageFile(w, r) } // modifyResponse implements ResponseModifier. func (customErrorPage) modifyResponse(resp *http.Response) error { // only handles non-success status code and html/plain content type contentType := gphttp.GetContentType(resp.Header) if !gphttp.IsSuccess(resp.StatusCode) && (contentType.IsHTML() || contentType.IsPlainText()) { errorPage, ok := errorpage.GetErrorPageByStatus(resp.StatusCode) if ok { logger.Debug().Msgf("error page for status %d loaded", resp.StatusCode) _, _ = io.Copy(io.Discard, resp.Body) // drain the original body resp.Body.Close() resp.Body = io.NopCloser(bytes.NewReader(errorPage)) resp.ContentLength = int64(len(errorPage)) resp.Header.Set(gphttp.HeaderContentLength, strconv.Itoa(len(errorPage))) resp.Header.Set(gphttp.HeaderContentType, "text/html; charset=utf-8") } else { logger.Error().Msgf("unable to load error page for status %d", resp.StatusCode) } return nil } return nil } func ServeStaticErrorPageFile(w http.ResponseWriter, r *http.Request) (served bool) { path := r.URL.Path if path != "" && path[0] != '/' { path = "/" + path } if strings.HasPrefix(path, gphttp.StaticFilePathPrefix) { filename := path[len(gphttp.StaticFilePathPrefix):] file, ok := errorpage.GetStaticFile(filename) if !ok { logger.Error().Msg("unable to load resource " + filename) return false } ext := filepath.Ext(filename) switch ext { case ".html": w.Header().Set(gphttp.HeaderContentType, "text/html; charset=utf-8") case ".js": w.Header().Set(gphttp.HeaderContentType, "application/javascript; charset=utf-8") case ".css": w.Header().Set(gphttp.HeaderContentType, "text/css; charset=utf-8") default: logger.Error().Msgf("unexpected file type %q for %s", ext, filename) } if _, err := w.Write(file); err != nil { logger.Err(err).Msg("unable to write resource " + filename) http.Error(w, "Error page failure", http.StatusInternalServerError) } return true } return false }