Skip to content

Commit

Permalink
perf: optimize test execution time on large repeated data set (#122)
Browse files Browse the repository at this point in the history
* perf: parse configuration once

* feat: notes on future improvements

* perf: implement query cache optimization
  • Loading branch information
bcho authored Oct 2, 2024
1 parent 1dd29ce commit 27f391b
Show file tree
Hide file tree
Showing 12 changed files with 363 additions and 21 deletions.
22 changes: 16 additions & 6 deletions sg/internal/cli/test/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,11 @@ func (s *failSettings) CheckQueryResults(results []result.QueryResults) error {

// cliApp is the CLI cliApplication for the test subcommand.
type cliApp struct {
projectSpecFile string
contextRoot string
outputFormat string
failSettings *failSettings
projectSpecFile string
contextRoot string
outputFormat string
failSettings *failSettings
enableQueryCache bool

stdout io.Writer
}
Expand Down Expand Up @@ -100,9 +101,11 @@ func (cliApp *cliApp) Run() error {
return fmt.Errorf("read project spec: %w", err)
}

queryCache := engine.NewQueryCache()

var queryResultsList []result.QueryResults
for _, target := range projectSpec.Files {
queryResult, err := cliApp.queryFileTarget(ctx, cliApp.contextRoot, target)
queryResult, err := cliApp.queryFileTarget(ctx, cliApp.contextRoot, target, queryCache)
if err != nil {
return fmt.Errorf("run target (%s): %w", target.Name, err)
}
Expand All @@ -127,6 +130,7 @@ func (cliApp *cliApp) BindCLIFlags(fs *pflag.FlagSet) {
&cliApp.outputFormat, "output", "o", cliApp.outputFormat,
fmt.Sprintf("Output format. Available formats: %s", presenter.AvailableFormatsHelp()),
)
fs.BoolVarP(&cliApp.enableQueryCache, "enable-query-cache", "", false, "Enable query cache (experimental).")
cliApp.failSettings.BindCLIFlags(fs)
}

Expand Down Expand Up @@ -163,6 +167,7 @@ func (cliApp *cliApp) queryFileTarget(
ctx context.Context,
contextRoot string,
target project.FileTargetSpec,
queryCache engine.QueryCache,
) ([]result.QueryResults, error) {
resolveToContextRoot := resolveToContextRootFn(contextRoot)

Expand All @@ -176,7 +181,12 @@ func (cliApp *cliApp) queryFileTarget(
return nil, fmt.Errorf("load sources failed: %w", err)
}

queryer, err := engine.QueryWithPolicy(policyPaths).Complete()
qb := engine.QueryWithPolicy(policyPaths)
if cliApp.enableQueryCache {
qb.WithQueueCache(queryCache)
}

queryer, err := qb.Complete()
if err != nil {
return nil, fmt.Errorf("create queryer failed: %w", err)
}
Expand Down
21 changes: 16 additions & 5 deletions sg/internal/engine/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,16 @@ import (

// QueryerBuilder constructs a Queryer.
type QueryerBuilder struct {
packages []policy.Package
err error
packages []policy.Package
queryCache QueryCache
err error
}

// QueryWithPolicy creates a QueryerBuilder with loading packages from the given paths.
func QueryWithPolicy(policyPaths []string) *QueryerBuilder {
qb := &QueryerBuilder{}
qb := &QueryerBuilder{
queryCache: noopQueryCache,
}

qb.packages, qb.err = policy.LoadPackagesFromPaths(policyPaths)
if qb.err != nil {
Expand All @@ -24,24 +27,32 @@ func QueryWithPolicy(policyPaths []string) *QueryerBuilder {
return qb
}

// WithQueueCache sets the query cache for the queryer.
func (qb *QueryerBuilder) WithQueueCache(cache QueryCache) *QueryerBuilder {
qb.queryCache = cache
return qb
}

// Complete constructs the Queryer.
func (qb *QueryerBuilder) Complete() (Queryer, error) {
if qb.err != nil {
return nil, qb.err
}

compiler, err := policy.NewRegoCompiler(qb.packages)
compiler, compilerKey, err := policy.NewRegoCompiler(qb.packages)
if err != nil {
return nil, fmt.Errorf("failed to create compiler from packages: %w", err)
}

rv := &RegoEngine{
policyPackages: qb.packages,
compiler: compiler,
compilerKey: compilerKey,
// NOTE: we limit the actual query by CPU count as policy evaluation is CPU bounded.
// For input actions like reading policy files / source code, we allow them to run unbounded,
// as the actual limiting is done by this limiter.
limiter: newLimiterFromMaxProcs(),
limiter: newLimiterFromMaxProcs(),
queryCache: qb.queryCache,
}
return rv, nil
}
77 changes: 77 additions & 0 deletions sg/internal/engine/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,17 @@ import (
)

func Test_Integration_BrokenPolicy(t *testing.T) {
t.Parallel()

_, err := QueryWithPolicy([]string{
"./testdata/broken-policy",
}).Complete()
assert.Error(t, err)
}

func Test_Integration_Basic(t *testing.T) {
t.Parallel()

queryer, err := QueryWithPolicy([]string{
"./testdata/basic/policy",
}).Complete()
Expand Down Expand Up @@ -77,3 +81,76 @@ func Test_Integration_Basic(t *testing.T) {
assert.Equal(t, failureResult.RuleDocLink, "https://example.com/foo-deny-001-foo")
}
}

func Test_Integration_Basic_QueryCacheEnabled(t *testing.T) {
t.Parallel()

queryCache := NewQueryCache()
queryer, err := QueryWithPolicy([]string{
"./testdata/basic/policy",
}).
WithQueueCache(queryCache).
Complete()
assert.NoError(t, err)
assert.NotNil(t, queryer)

sources, err := source.FromPath([]string{
"./testdata/basic/configurations",
}).Complete()
assert.NoError(t, err)
assert.Len(t, sources, 1)
dataYAMLSource := sources[0]

round := func() {
ctx := context.Background()

queryResult, err := queryer.Query(ctx, dataYAMLSource)
assert.NoError(t, err)
assert.NotNil(t, queryResult)
assert.Equal(t, queryResult.Source, dataYAMLSource)
assert.Equal(t, queryResult.Successes, 2, "one document passes the test")

assert.Len(t, queryResult.Exceptions, 2, "one document emits two exceptions")
{
denyExcResult := queryResult.Exceptions[0]
assert.Equal(t, denyExcResult.Query, `data.main.exception[_][_] == "foo"`)
assert.Equal(t, denyExcResult.Rule.Kind, policy.QueryKindDeny)
assert.Equal(t, denyExcResult.Rule.Name, "foo")
assert.Equal(t, denyExcResult.Rule.Namespace, "main")
assert.Equal(t, denyExcResult.RuleDocLink, "https://example.com/foo-deny-001-foo")

warnExcResult := queryResult.Exceptions[1]
assert.Equal(t, warnExcResult.Query, `data.main.exception[_][_] == "foo"`)
assert.Equal(t, warnExcResult.Rule.Kind, policy.QueryKindWarn)
assert.Equal(t, warnExcResult.Rule.Name, "foo")
assert.Equal(t, warnExcResult.Rule.Namespace, "main")
assert.Equal(t, warnExcResult.RuleDocLink, "https://example.com/foo-warn-001-foo")
}

assert.Len(t, queryResult.Warnings, 1, "one document emits warning")
{
warnResult := queryResult.Warnings[0]
assert.Equal(t, warnResult.Message, "name is foo")
assert.Equal(t, warnResult.Query, "data.main.warn_foo")
assert.Equal(t, warnResult.Rule.Kind, policy.QueryKindWarn)
assert.Equal(t, warnResult.Rule.Name, "foo")
assert.Equal(t, warnResult.Rule.Namespace, "main")
assert.Equal(t, warnResult.RuleDocLink, "https://example.com/foo-warn-001-foo")
}

assert.Len(t, queryResult.Failures, 1, "one document fails the test")
{
failureResult := queryResult.Failures[0]
assert.Equal(t, failureResult.Message, "name cannot be foo")
assert.Equal(t, failureResult.Query, "data.main.deny_foo")
assert.Equal(t, failureResult.Rule.Kind, policy.QueryKindDeny)
assert.Equal(t, failureResult.Rule.Name, "foo")
assert.Equal(t, failureResult.Rule.Namespace, "main")
assert.Equal(t, failureResult.RuleDocLink, "https://example.com/foo-deny-001-foo")
}
}

for i := 0; i < 10; i++ {
round()
}
}
57 changes: 57 additions & 0 deletions sg/internal/engine/query_cache.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package engine

import (
"fmt"
"sync"

"github.com/Azure/ShieldGuard/sg/internal/result"
)

func (k queryCacheKey) cacheKey() string {
return fmt.Sprintf(
"%s:%d:%s",
k.compilerKey,
k.parsedInput.Hash(),
k.query,
)
}

type noopQueryCacheT struct{}

func (noopQueryCacheT) set(_ queryCacheKey, _ []result.Result) {}

func (noopQueryCacheT) get(_ queryCacheKey) ([]result.Result, bool) {
return nil, false
}

var noopQueryCache = noopQueryCacheT{}

type queryCache struct {
mu *sync.RWMutex
items map[string][]result.Result
}

// NewQueryCache creates a new QueryCache.
func NewQueryCache() QueryCache {
return &queryCache{
mu: &sync.RWMutex{},
items: make(map[string][]result.Result),
}
}

var _ QueryCache = (*queryCache)(nil)

func (qc *queryCache) set(key queryCacheKey, value []result.Result) {
qc.mu.Lock()
defer qc.mu.Unlock()

qc.items[key.cacheKey()] = value
}

func (qc *queryCache) get(key queryCacheKey) ([]result.Result, bool) {
qc.mu.RLock()
defer qc.mu.RUnlock()

rv, ok := qc.items[key.cacheKey()]
return rv, ok
}
100 changes: 100 additions & 0 deletions sg/internal/engine/query_cache_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
package engine

import (
"sync"
"testing"

"github.com/Azure/ShieldGuard/sg/internal/result"
"github.com/open-policy-agent/opa/ast"
"github.com/stretchr/testify/assert"
)

func Test_noopQueryCacheT(t *testing.T) {
t.Parallel()

_, ok := noopQueryCache.get(queryCacheKey{})
assert.False(t, ok)
noopQueryCache.set(queryCacheKey{}, nil)
_, ok = noopQueryCache.get(queryCacheKey{})
assert.False(t, ok)
}

func Test_QueryCache(t *testing.T) {
t.Run("basic", func(t *testing.T) {
t.Parallel()

cacheKey1 := queryCacheKey{
compilerKey: "compilerKey1",
parsedInput: ast.String("input1"),
query: "query1",
}
cacheKey2 := queryCacheKey{
compilerKey: "compilerKey2",
parsedInput: ast.String("input2"),
query: "query2",
}
queryResult := []result.Result{
{
Query: "query",
},
}

qc := NewQueryCache()

{
_, ok := qc.get(cacheKey1)
assert.False(t, ok)
_, ok = qc.get(cacheKey2)
assert.False(t, ok)
}

{
qc.set(cacheKey1, queryResult)
cached, ok := qc.get(cacheKey1)
assert.True(t, ok)
assert.Equal(t, queryResult, cached)

_, ok = qc.get(cacheKey2)
assert.False(t, ok)
}
})

t.Run("concurrent access", func(t *testing.T) {
t.Parallel()

cacheKey := queryCacheKey{
compilerKey: "compilerKey",
parsedInput: ast.String("input"),
query: "query",
}
queryResult := []result.Result{
{
Query: "query",
},
}

qc := NewQueryCache()

{
_, ok := qc.get(cacheKey)
assert.False(t, ok)
}

var wg sync.WaitGroup
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
defer wg.Done()

qc.set(cacheKey, queryResult)
}()
}
wg.Wait()

{
cached, ok := qc.get(cacheKey)
assert.True(t, ok)
assert.Equal(t, queryResult, cached)
}
})
}
Loading

0 comments on commit 27f391b

Please sign in to comment.