Skip to content

Commit

Permalink
Pass the application context to the request and response modifiers. A…
Browse files Browse the repository at this point in the history
…lso extend the structs passed to the modifier functions so it can retrieve the execution Context. Reponse modifers now have access to the Request. Updated examples to show the new features.

Signed-off-by: Daniel Ortiz <dortiz@krakend.io>
  • Loading branch information
taik0 committed Feb 23, 2024
1 parent 7bca413 commit 5995f0d
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 22 deletions.
43 changes: 28 additions & 15 deletions proxy/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,14 @@ func newPluginMiddleware(logger logging.Logger, tag, pattern string, cfg map[str
return resp, err
}

return executeResponseModifiers(respModifiers, resp)
return executeResponseModifiers(ctx, respModifiers, resp, NewRequestWrapper(ctx, r))
}
}

if totRespModifiers == 0 {
return func(ctx context.Context, r *Request) (*Response, error) {
var err error
r, err = executeRequestModifiers(reqModifiers, r)
r, err = executeRequestModifiers(ctx, reqModifiers, r)
if err != nil {
return nil, err
}
Expand All @@ -119,7 +119,7 @@ func newPluginMiddleware(logger logging.Logger, tag, pattern string, cfg map[str

return func(ctx context.Context, r *Request) (*Response, error) {
var err error
r, err = executeRequestModifiers(reqModifiers, r)
r, err = executeRequestModifiers(ctx, reqModifiers, r)
if err != nil {
return nil, err
}
Expand All @@ -129,22 +129,14 @@ func newPluginMiddleware(logger logging.Logger, tag, pattern string, cfg map[str
return resp, err
}

return executeResponseModifiers(respModifiers, resp)
return executeResponseModifiers(ctx, respModifiers, resp, NewRequestWrapper(ctx, r))
}
}
}

func executeRequestModifiers(reqModifiers []func(interface{}) (interface{}, error), r *Request) (*Request, error) {
func executeRequestModifiers(ctx context.Context, reqModifiers []func(interface{}) (interface{}, error), r *Request) (*Request, error) {
var tmp RequestWrapper
tmp = &requestWrapper{
method: r.Method,
url: r.URL,
query: r.Query,
path: r.Path,
body: r.Body,
params: r.Params,
headers: r.Headers,
}
tmp = NewRequestWrapper(ctx, r)

for _, f := range reqModifiers {
res, err := f(tmp)
Expand All @@ -169,9 +161,11 @@ func executeRequestModifiers(reqModifiers []func(interface{}) (interface{}, erro
return r, nil
}

func executeResponseModifiers(respModifiers []func(interface{}) (interface{}, error), r *Response) (*Response, error) {
func executeResponseModifiers(ctx context.Context, respModifiers []func(interface{}) (interface{}, error), r *Response, req RequestWrapper) (*Response, error) {
var tmp ResponseWrapper
tmp = responseWrapper{
ctx: ctx,
request: req,
data: r.Data,
isComplete: r.IsComplete,
metadata: metadataWrapper{
Expand Down Expand Up @@ -222,7 +216,21 @@ type ResponseWrapper interface {
StatusCode() int
}

func NewRequestWrapper(ctx context.Context, r *Request) *requestWrapper {
return &requestWrapper{
ctx: ctx,
method: r.Method,
url: r.URL,
query: r.Query,
path: r.Path,
body: r.Body,
params: r.Params,
headers: r.Headers,
}
}

type requestWrapper struct {
ctx context.Context
method string
url *url.URL
query url.Values
Expand All @@ -232,6 +240,7 @@ type requestWrapper struct {
headers map[string][]string
}

func (r *requestWrapper) Context() context.Context { return r.ctx }
func (r *requestWrapper) Method() string { return r.method }
func (r *requestWrapper) URL() *url.URL { return r.url }
func (r *requestWrapper) Query() url.Values { return r.query }
Expand All @@ -249,12 +258,16 @@ func (m metadataWrapper) Headers() map[string][]string { return m.headers }
func (m metadataWrapper) StatusCode() int { return m.statusCode }

type responseWrapper struct {
ctx context.Context
request interface{}
data map[string]interface{}
isComplete bool
metadata metadataWrapper
io io.Reader
}

func (r responseWrapper) Context() context.Context { return r.ctx }
func (r responseWrapper) Request() interface{} { return r.request }
func (r responseWrapper) Data() map[string]interface{} { return r.data }
func (r responseWrapper) IsComplete() bool { return r.isComplete }
func (r responseWrapper) Io() io.Reader { return r.io }
Expand Down
23 changes: 18 additions & 5 deletions proxy/plugin/modifier.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
// SPDX-License-Identifier: Apache-2.0

/*
Package plugin provides tools for loading and registering proxy plugins
Package plugin provides tools for loading and registering proxy plugins
*/
package plugin

import (
"context"
"fmt"
"plugin"
"strings"
Expand Down Expand Up @@ -84,6 +85,10 @@ type LoggerRegisterer interface {
RegisterLogger(interface{})
}

type ContextRegisterer interface {
RegisterContext(context.Context)
}

// RegisterModifierFunc type is the function passed to the loaded Registerers
type RegisterModifierFunc func(
name string,
Expand All @@ -99,18 +104,22 @@ func Load(path, pattern string, rmf RegisterModifierFunc) (int, error) {

// LoadWithLogger scans the given path using the pattern and registers all the found modifier plugins into the rmf
func LoadWithLogger(path, pattern string, rmf RegisterModifierFunc, logger logging.Logger) (int, error) {
return LoadWithLoggerAndContext(context.Background(), path, pattern, rmf, logger)
}

func LoadWithLoggerAndContext(ctx context.Context, path, pattern string, rmf RegisterModifierFunc, logger logging.Logger) (int, error) {
plugins, err := luraplugin.Scan(path, pattern)
if err != nil {
return 0, err
}
return load(plugins, rmf, logger)
return load(ctx, plugins, rmf, logger)
}

func load(plugins []string, rmf RegisterModifierFunc, logger logging.Logger) (int, error) {
func load(ctx context.Context, plugins []string, rmf RegisterModifierFunc, logger logging.Logger) (int, error) {
errors := []error{}
loadedPlugins := 0
for k, pluginName := range plugins {
if err := open(pluginName, rmf, logger); err != nil {
if err := open(ctx, pluginName, rmf, logger); err != nil {
errors = append(errors, fmt.Errorf("plugin #%d (%s): %s", k, pluginName, err.Error()))
continue
}
Expand All @@ -123,7 +132,7 @@ func load(plugins []string, rmf RegisterModifierFunc, logger logging.Logger) (in
return loadedPlugins, nil
}

func open(pluginName string, rmf RegisterModifierFunc, logger logging.Logger) (err error) {
func open(ctx context.Context, pluginName string, rmf RegisterModifierFunc, logger logging.Logger) (err error) {
defer func() {
if r := recover(); r != nil {
var ok bool
Expand Down Expand Up @@ -155,6 +164,10 @@ func open(pluginName string, rmf RegisterModifierFunc, logger logging.Logger) (e
}
}

if lr, ok := r.(ContextRegisterer); ok {
lr.RegisterContext(ctx)
}

registerer.RegisterModifiers(rmf)
return
}
Expand Down
80 changes: 78 additions & 2 deletions proxy/plugin/modifier_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,87 @@
package plugin

import (
"bytes"
"context"
"fmt"
"io"
"net/url"
"strings"
"testing"

"github.com/luraproject/lura/v2/logging"
)

func ExampleLoadWithLoggerAndContext() {
data := []byte{}
buf := bytes.NewBuffer(data)
logger, err := logging.NewLogger("DEBUG", buf, "")
if err != nil {
fmt.Println(err.Error())
return
}
total, err := LoadWithLoggerAndContext(context.Background(), "./tests", ".so", RegisterModifier, logger)
if err != nil {
fmt.Println(err.Error())
return
}
if total != 2 {
fmt.Printf("unexpected number of loaded plugins!. have %d, want 2\n", total)
return
}

modFactory, ok := GetRequestModifier("lura-request-modifier-example-request")
if !ok {
fmt.Println("modifier factory not found in the register")
return
}

modifier := modFactory(map[string]interface{}{})

input := requestWrapper{
ctx: context.WithValue(context.Background(), "myCtxKey", "some"),
path: "/bar",
method: "GET",
}

tmp, err := modifier(input)
if err != nil {
fmt.Println(err.Error())
return
}

output, ok := tmp.(RequestWrapper)
if !ok {
fmt.Println("unexpected result type")
return
}

if res := output.Path(); res != "/bar/fooo" {
fmt.Printf("unexpected result path. have %s, want /bar/fooo\n", res)
return
}

lines := strings.Split(buf.String(), "\n")
for i := range lines[:len(lines)-1] {
fmt.Println(lines[i][21:])
}

// output:
// DEBUG: [PLUGIN: lura-error-example] Logger loaded
// DEBUG: [PLUGIN: lura-request-modifier-example] Logger loaded
// DEBUG: [PLUGIN: lura-request-modifier-example] Context loaded
// DEBUG: [PLUGIN: lura-request-modifier-example] Request modifier injected
// DEBUG: context key: some
// DEBUG: params: map[]
// DEBUG: headers: map[]
// DEBUG: method: GET
// DEBUG: url: <nil>
// DEBUG: query: map[]
// DEBUG: path: /bar/fooo
}

func TestLoad(t *testing.T) {
total, err := Load("./tests", ".so", RegisterModifier)
total, err := LoadWithLogger("./tests", ".so", RegisterModifier, logging.NoOp)
if err != nil {
t.Error(err.Error())
t.Fail()
Expand All @@ -29,7 +103,7 @@ func TestLoad(t *testing.T) {

modifier := modFactory(map[string]interface{}{})

input := requestWrapper{path: "/bar"}
input := requestWrapper{ctx: context.WithValue(context.Background(), "myCtxKey", "some"), path: "/bar"}

tmp, err := modifier(input)
if err != nil {
Expand Down Expand Up @@ -59,6 +133,7 @@ type RequestWrapper interface {
}

type requestWrapper struct {
ctx context.Context
method string
url *url.URL
query url.Values
Expand All @@ -68,6 +143,7 @@ type requestWrapper struct {
headers map[string][]string
}

func (r requestWrapper) Context() context.Context { return r.ctx }
func (r requestWrapper) Method() string { return r.method }
func (r requestWrapper) URL() *url.URL { return r.url }
func (r requestWrapper) Query() url.Values { return r.query }
Expand Down
Loading

0 comments on commit 5995f0d

Please sign in to comment.