Skip to content

Commit

Permalink
Feature: Add DoRedirects, DoTimeout and DoDeadline to Proxy middleware (
Browse files Browse the repository at this point in the history
#2332)

* Add support for DoRedirects

Signed-off-by: Juan Calderon-Perez <jgcalderonperez@protonmail.com>

* Fix linter issues

Signed-off-by: Juan Calderon-Perez <jgcalderonperez@protonmail.com>

* Add example to README

* Add support for DoDeadline and DoTimeout. Expand unit-tests

* Fix linter errors

Signed-off-by: Juan Calderon-Perez <jgcalderonperez@protonmail.com>

* Add examples for Proxy Middleware

---------

Signed-off-by: Juan Calderon-Perez <jgcalderonperez@protonmail.com>
  • Loading branch information
gaby committed Feb 24, 2023
1 parent b634ba0 commit dc038d8
Show file tree
Hide file tree
Showing 3 changed files with 246 additions and 16 deletions.
36 changes: 36 additions & 0 deletions middleware/proxy/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@ func Balancer(config Config) fiber.Handler
func Forward(addr string, clients ...*fasthttp.Client) fiber.Handler
// Do performs the given http request and fills the given http response.
func Do(c *fiber.Ctx, addr string, clients ...*fasthttp.Client) error
// DoRedirects performs the given http request and fills the given http response while following up to maxRedirectsCount redirects.
func DoRedirects(c *fiber.Ctx, addr string, maxRedirectsCount int, clients ...*fasthttp.Client) error
// DoDeadline performs the given request and waits for response until the given deadline.
func DoDeadline(c *fiber.Ctx, addr string, deadline time.Time, clients ...*fasthttp.Client) error
// DoTimeout performs the given request and waits for response during the given timeout duration.
func DoTimeout(c *fiber.Ctx, addr string, timeout time.Duration, clients ...*fasthttp.Client) error
// DomainForward the given http request based on the given domain and fills the given http response
func DomainForward(hostname string, addr string, clients ...*fasthttp.Client) fiber.Handler
// BalancerForward performs the given http request based round robin balancer and fills the given http response
Expand Down Expand Up @@ -73,6 +79,36 @@ app.Get("/:id", func(c *fiber.Ctx) error {
return nil
})

// Make proxy requests while following redirects
app.Get("/proxy", func(c *fiber.Ctx) error {
if err := proxy.DoRedirects(c, "http://google.com", 3); err != nil {
return err
}
// Remove Server header from response
c.Response().Header.Del(fiber.HeaderServer)
return nil
})

// Make proxy requests and wait up to 5 seconds before timing out
app.Get("/proxy", func(c *fiber.Ctx) error {
if err := proxy.DoTimeout(c, "http://localhost:3000", time.Second * 5); err != nil {
return err
}
// Remove Server header from response
c.Response().Header.Del(fiber.HeaderServer)
return nil
})

// Make proxy requests, timeout a minute from now
app.Get("/proxy", func(c *fiber.Ctx) error {
if err := DoDeadline(c, "http://localhost", time.Now().Add(time.Minute)); err != nil {
return err
}
// Remove Server header from response
c.Response().Header.Del(fiber.HeaderServer)
return nil
})

// Minimal round robin balancer
app.Use(proxy.Balancer(proxy.Config{
Servers: []string{
Expand Down
47 changes: 42 additions & 5 deletions middleware/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"net/url"
"strings"
"sync"
"time"

"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/utils"
Expand Down Expand Up @@ -139,16 +140,53 @@ func Forward(addr string, clients ...*fasthttp.Client) fiber.Handler {
// Do performs the given http request and fills the given http response.
// This method can be used within a fiber.Handler
func Do(c *fiber.Ctx, addr string, clients ...*fasthttp.Client) error {
return doAction(c, addr, func(cli *fasthttp.Client, req *fasthttp.Request, resp *fasthttp.Response) error {
return cli.Do(req, resp)
}, clients...)
}

// DoRedirects performs the given http request and fills the given http response, following up to maxRedirectsCount redirects.
// When the redirect count exceeds maxRedirectsCount, ErrTooManyRedirects is returned.
// This method can be used within a fiber.Handler
func DoRedirects(c *fiber.Ctx, addr string, maxRedirectsCount int, clients ...*fasthttp.Client) error {
return doAction(c, addr, func(cli *fasthttp.Client, req *fasthttp.Request, resp *fasthttp.Response) error {
return cli.DoRedirects(req, resp, maxRedirectsCount)
}, clients...)
}

// DoDeadline performs the given request and waits for response until the given deadline.
// This method can be used within a fiber.Handler
func DoDeadline(c *fiber.Ctx, addr string, deadline time.Time, clients ...*fasthttp.Client) error {
return doAction(c, addr, func(cli *fasthttp.Client, req *fasthttp.Request, resp *fasthttp.Response) error {
return cli.DoDeadline(req, resp, deadline)
}, clients...)
}

// DoTimeout performs the given request and waits for response during the given timeout duration.
// This method can be used within a fiber.Handler
func DoTimeout(c *fiber.Ctx, addr string, timeout time.Duration, clients ...*fasthttp.Client) error {
return doAction(c, addr, func(cli *fasthttp.Client, req *fasthttp.Request, resp *fasthttp.Response) error {
return cli.DoTimeout(req, resp, timeout)
}, clients...)
}

func doAction(
c *fiber.Ctx,
addr string,
action func(cli *fasthttp.Client, req *fasthttp.Request, resp *fasthttp.Response) error,
clients ...*fasthttp.Client,
) error {
var cli *fasthttp.Client

// set local or global client
if len(clients) != 0 {
// Set local client
cli = clients[0]
} else {
// Set global client
lock.RLock()
cli = client
lock.RUnlock()
}

req := c.Request()
res := c.Response()
originalURL := utils.CopyString(c.OriginalURL())
Expand All @@ -157,14 +195,13 @@ func Do(c *fiber.Ctx, addr string, clients ...*fasthttp.Client) error {
copiedURL := utils.CopyString(addr)
req.SetRequestURI(copiedURL)
// NOTE: if req.isTLS is true, SetRequestURI keeps the scheme as https.
// issue reference:
// https://github.com/gofiber/fiber/issues/1762
// Reference: https://github.com/gofiber/fiber/issues/1762
if scheme := getScheme(utils.UnsafeBytes(copiedURL)); len(scheme) > 0 {
req.URI().SetSchemeBytes(scheme)
}

req.Header.Del(fiber.HeaderConnection)
if err := cli.Do(req, res); err != nil {
if err := action(cli, req, res); err != nil {
return err
}
res.Header.Del(fiber.HeaderConnection)
Expand Down
179 changes: 168 additions & 11 deletions middleware/proxy/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package proxy

import (
"crypto/tls"
"errors"
"io"
"net"
"net/http/httptest"
Expand Down Expand Up @@ -48,6 +49,19 @@ func Test_Proxy_Empty_Upstream_Servers(t *testing.T) {
app.Use(Balancer(Config{Servers: []string{}}))
}

// go test -run Test_Proxy_Empty_Config
func Test_Proxy_Empty_Config(t *testing.T) {
t.Parallel()

defer func() {
if r := recover(); r != nil {
utils.AssertEqual(t, "Servers cannot be empty", r)
}
}()
app := fiber.New()
app.Use(New(Config{}))
}

// go test -run Test_Proxy_Next
func Test_Proxy_Next(t *testing.T) {
t.Parallel()
Expand Down Expand Up @@ -345,24 +359,167 @@ func Test_Proxy_Buffer_Size_Response(t *testing.T) {
// go test -race -run Test_Proxy_Do_RestoreOriginalURL
func Test_Proxy_Do_RestoreOriginalURL(t *testing.T) {
t.Parallel()
_, addr := createProxyTestServer(t, func(c *fiber.Ctx) error {
return c.SendString("proxied")
})

app := fiber.New()
app.Get("/proxy", func(c *fiber.Ctx) error {
return c.SendString("ok")
app.Get("/test", func(c *fiber.Ctx) error {
return Do(c, "http://"+addr)
})
resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil))
utils.AssertEqual(t, nil, err1)
utils.AssertEqual(t, "/test", resp.Request.URL.String())
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
body, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "proxied", string(body))
}

// go test -race -run Test_Proxy_Do_WithRealURL
func Test_Proxy_Do_WithRealURL(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Get("/test", func(c *fiber.Ctx) error {
originalURL := utils.CopyString(c.OriginalURL())
if err := Do(c, "/proxy"); err != nil {
return err
}
utils.AssertEqual(t, originalURL, c.OriginalURL())
return c.SendString("ok")
return Do(c, "https://www.google.com")
})

resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil))
utils.AssertEqual(t, nil, err1)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
utils.AssertEqual(t, "/test", resp.Request.URL.String())
body, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, true, strings.Contains(string(body), "https://www.google.com/"))
}

// go test -race -run Test_Proxy_Do_WithRedirect
func Test_Proxy_Do_WithRedirect(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Get("/test", func(c *fiber.Ctx) error {
return Do(c, "https://google.com")
})

resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil))
utils.AssertEqual(t, nil, err1)
body, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, true, strings.Contains(string(body), "https://www.google.com/"))
utils.AssertEqual(t, 301, resp.StatusCode)
}

// go test -race -run Test_Proxy_DoRedirects_RestoreOriginalURL
func Test_Proxy_DoRedirects_RestoreOriginalURL(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Get("/test", func(c *fiber.Ctx) error {
return DoRedirects(c, "http://google.com", 1)
})

resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil))
utils.AssertEqual(t, nil, err1)
_, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
utils.AssertEqual(t, "/test", resp.Request.URL.String())
}

// go test -race -run Test_Proxy_DoRedirects_TooManyRedirects
func Test_Proxy_DoRedirects_TooManyRedirects(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Get("/test", func(c *fiber.Ctx) error {
return DoRedirects(c, "http://google.com", 0)
})

resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil))
utils.AssertEqual(t, nil, err1)
body, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "too many redirects detected when doing the request", string(body))
utils.AssertEqual(t, fiber.StatusInternalServerError, resp.StatusCode)
utils.AssertEqual(t, "/test", resp.Request.URL.String())
}

// go test -race -run Test_Proxy_DoTimeout_RestoreOriginalURL
func Test_Proxy_DoTimeout_RestoreOriginalURL(t *testing.T) {
t.Parallel()

_, addr := createProxyTestServer(t, func(c *fiber.Ctx) error {
return c.SendString("proxied")
})

app := fiber.New()
app.Get("/test", func(c *fiber.Ctx) error {
return DoTimeout(c, "http://"+addr, time.Second)
})

resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil))
utils.AssertEqual(t, nil, err1)
body, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "proxied", string(body))
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
utils.AssertEqual(t, "/test", resp.Request.URL.String())
}

// go test -race -run Test_Proxy_DoTimeout_Timeout
func Test_Proxy_DoTimeout_Timeout(t *testing.T) {
t.Parallel()

_, addr := createProxyTestServer(t, func(c *fiber.Ctx) error {
time.Sleep(time.Second * 5)
return c.SendString("proxied")
})

app := fiber.New()
app.Get("/test", func(c *fiber.Ctx) error {
return DoTimeout(c, "http://"+addr, time.Second)
})

_, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil))
// This test requires multiple requests due to zero allocation used in fiber
_, err2 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil))
utils.AssertEqual(t, errors.New("test: timeout error 1000ms"), err1)
}

// go test -race -run Test_Proxy_DoDeadline_RestoreOriginalURL
func Test_Proxy_DoDeadline_RestoreOriginalURL(t *testing.T) {
t.Parallel()

_, addr := createProxyTestServer(t, func(c *fiber.Ctx) error {
return c.SendString("proxied")
})

app := fiber.New()
app.Get("/test", func(c *fiber.Ctx) error {
return DoDeadline(c, "http://"+addr, time.Now().Add(time.Second))
})

resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil))
utils.AssertEqual(t, nil, err1)
utils.AssertEqual(t, nil, err2)
body, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "proxied", string(body))
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
utils.AssertEqual(t, "/test", resp.Request.URL.String())
}

// go test -race -run Test_Proxy_DoDeadline_PastDeadline
func Test_Proxy_DoDeadline_PastDeadline(t *testing.T) {
t.Parallel()

_, addr := createProxyTestServer(t, func(c *fiber.Ctx) error {
time.Sleep(time.Second * 5)
return c.SendString("proxied")
})

app := fiber.New()
app.Get("/test", func(c *fiber.Ctx) error {
return DoDeadline(c, "http://"+addr, time.Now().Add(time.Second))
})

_, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil))
utils.AssertEqual(t, errors.New("test: timeout error 1000ms"), err1)
}

// go test -race -run Test_Proxy_Do_HTTP_Prefix_URL
Expand Down

0 comments on commit dc038d8

Please sign in to comment.