From b39b13fdd7018a15445f7df591c4f1140e4c5e9c Mon Sep 17 00:00:00 2001 From: Arya Tabaie Date: Wed, 26 Jul 2023 03:06:29 -0500 Subject: [PATCH] Perf: Improve MultiLin.Eval number of constraints (#788) * bench: multilin eval constraints number * perf: fewer multilin folding constraints * fix: correct nb constraints * fix: panic if error * perf: sometimes defer scaling of folding results --- std/polynomial/polynomial.go | 65 +++++++++++++++++++++++++------ std/polynomial/polynomial_test.go | 50 ++++++++++++++++++++++++ 2 files changed, 103 insertions(+), 12 deletions(-) diff --git a/std/polynomial/polynomial.go b/std/polynomial/polynomial.go index 72cc329f9f..0953cb3ac7 100644 --- a/std/polynomial/polynomial.go +++ b/std/polynomial/polynomial.go @@ -9,30 +9,71 @@ import ( type Polynomial []frontend.Variable type MultiLin []frontend.Variable +var minFoldScaledLogSize = 16 + // Evaluate assumes len(m) = 1 << len(at) +// it doesn't modify m func (m MultiLin) Evaluate(api frontend.API, at []frontend.Variable) frontend.Variable { - - eqs := make([]frontend.Variable, len(m)) - eqs[0] = 1 - for i, rI := range at { - prevSize := 1 << i - for j := prevSize - 1; j >= 0; j-- { - eqs[2*j+1] = api.Mul(rI, eqs[j]) - eqs[2*j] = api.Sub(eqs[j], eqs[2*j+1]) // eq[2j] == (1 - rI) * eq[j] + _m := m.Clone() + + /*minFoldScaledLogSize := 16 + if api is r1cs { + minFoldScaledLogSize = math.MaxInt64 // no scaling for r1cs + }*/ + + scaleCorrectionFactor := frontend.Variable(1) + // at each iteration fold by at[i] + for len(_m) > 1 { + if len(_m) >= minFoldScaledLogSize { + scaleCorrectionFactor = api.Mul(scaleCorrectionFactor, _m.foldScaled(api, at[0])) + } else { + _m.fold(api, at[0]) } + _m = _m[:len(_m)/2] + at = at[1:] + } + + if len(at) != 0 { + panic("incompatible evaluation vector size") + } + + return api.Mul(_m[0], scaleCorrectionFactor) +} + +// fold fixes the value of m's first variable to at, thus halving m's required bookkeeping table size +// WARNING: The user should halve m themselves after the call +func (m MultiLin) fold(api frontend.API, at frontend.Variable) { + zero := m[:len(m)/2] + one := m[len(m)/2:] + for j := range zero { + diff := api.Sub(one[j], zero[j]) + zero[j] = api.MulAcc(zero[j], diff, at) } +} - evaluation := frontend.Variable(0) - for j := range m { - evaluation = api.MulAcc(evaluation, eqs[j], m[j]) +// foldScaled(m, at) = fold(m, at) / (1 - at) +// it returns 1 - at, for convenience +func (m MultiLin) foldScaled(api frontend.API, at frontend.Variable) (denom frontend.Variable) { + denom = api.Sub(1, at) + coeff := api.Div(at, denom) + zero := m[:len(m)/2] + one := m[len(m)/2:] + for j := range zero { + zero[j] = api.MulAcc(zero[j], one[j], coeff) } - return evaluation + return } func (m MultiLin) NumVars() int { return bits.TrailingZeros(uint(len(m))) } +func (m MultiLin) Clone() MultiLin { + clone := make(MultiLin, len(m)) + copy(clone, m) + return clone +} + func (p Polynomial) Eval(api frontend.API, at frontend.Variable) (pAt frontend.Variable) { pAt = 0 diff --git a/std/polynomial/polynomial_test.go b/std/polynomial/polynomial_test.go index 92656d779a..24b706a3f6 100644 --- a/std/polynomial/polynomial_test.go +++ b/std/polynomial/polynomial_test.go @@ -1,10 +1,13 @@ package polynomial import ( + "errors" "fmt" "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/backend" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/frontend/cs/scs" "github.com/consensys/gnark/test" "testing" ) @@ -70,6 +73,31 @@ func TestEvalDeltasQuadratic(t *testing.T) { testEvalDeltas(t, 3, []int64{1, -3, 3}) } +type foldMultiLinCircuit struct { + M []frontend.Variable + At frontend.Variable + Result []frontend.Variable +} + +func (c *foldMultiLinCircuit) Define(api frontend.API) error { + if len(c.M) != 2*len(c.Result) { + return errors.New("folding size mismatch") + } + m := MultiLin(c.M) + m.fold(api, c.At) + for i := range c.Result { + api.AssertIsEqual(m[i], c.Result[i]) + } + return nil +} + +func TestFoldSmall(t *testing.T) { + test.NewAssert(t).SolvingSucceeded( + &foldMultiLinCircuit{M: make([]frontend.Variable, 4), Result: make([]frontend.Variable, 2)}, + &foldMultiLinCircuit{M: []frontend.Variable{0, 1, 2, 3}, At: 2, Result: []frontend.Variable{4, 5}}, + ) +} + type evalMultiLinCircuit struct { M []frontend.Variable `gnark:",public"` At []frontend.Variable `gnark:",secret"` @@ -204,3 +232,25 @@ func int64SliceToVariableSlice(slice []int64) []frontend.Variable { } return res } + +func ExampleMultiLin_Evaluate() { + const logSize = 20 + const size = 1 << logSize + m := MultiLin(make([]frontend.Variable, size)) + e := MultiLin(make([]frontend.Variable, logSize)) + + cs, err := frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &evalMultiLinCircuit{M: m, At: e, Evaluation: 0}) + if err != nil { + panic(err) + } + fmt.Println("r1cs size:", cs.GetNbConstraints()) + + cs, err = frontend.Compile(ecc.BN254.ScalarField(), scs.NewBuilder, &evalMultiLinCircuit{M: m, At: e, Evaluation: 0}) + if err != nil { + panic(err) + } + fmt.Println("scs size:", cs.GetNbConstraints()) + + // Output: r1cs size: 1048627 + //scs size: 2097226 +}