From ce0186ef32c1a447c0c35d5598ea04b0a5696feb Mon Sep 17 00:00:00 2001 From: Ivo Kubjas Date: Tue, 12 Mar 2024 18:56:55 +0100 Subject: [PATCH] feat: add MulNoReduce and Sum methods in field emulation (#1072) * feat: implement mulnoreduce * test: mulnoreduce test * docs: add method doc * feat: add AddMany * refactor: rename AddMany to Sum * feat: if only single input then return as is * test: non-native sum --- std/math/emulated/element_test.go | 89 +++++++++++++++++++++++++++++++ std/math/emulated/field_mul.go | 20 +++++++ std/math/emulated/field_ops.go | 32 +++++++++++ 3 files changed, 141 insertions(+) diff --git a/std/math/emulated/element_test.go b/std/math/emulated/element_test.go index c8e5c817b5..8954fc4d69 100644 --- a/std/math/emulated/element_test.go +++ b/std/math/emulated/element_test.go @@ -3,6 +3,7 @@ package emulated import ( "crypto/rand" "fmt" + "math" "math/big" "reflect" "testing" @@ -970,3 +971,91 @@ func testSqrt[T FieldParams](t *testing.T) { assert.ProverSucceeded(&SqrtCircuit[T]{}, &SqrtCircuit[T]{X: ValueOf[T](X), Expected: ValueOf[T](exp)}, test.WithCurves(testCurve), test.NoSerializationChecks(), test.WithBackends(backend.GROTH16, backend.PLONK)) }, testName[T]()) } + +type MulNoReduceCircuit[T FieldParams] struct { + A, B, C Element[T] + expectedOverflow uint + expectedNbLimbs int +} + +func (c *MulNoReduceCircuit[T]) Define(api frontend.API) error { + f, err := NewField[T](api) + if err != nil { + return err + } + res := f.MulNoReduce(&c.A, &c.B) + f.AssertIsEqual(res, &c.C) + if res.overflow != c.expectedOverflow { + return fmt.Errorf("unexpected overflow: got %d, expected %d", res.overflow, c.expectedOverflow) + } + if len(res.Limbs) != c.expectedNbLimbs { + return fmt.Errorf("unexpected number of limbs: got %d, expected %d", len(res.Limbs), c.expectedNbLimbs) + } + return nil +} + +func TestMulNoReduce(t *testing.T) { + testMulNoReduce[Goldilocks](t) + testMulNoReduce[Secp256k1Fp](t) + testMulNoReduce[BN254Fp](t) +} + +func testMulNoReduce[T FieldParams](t *testing.T) { + var fp T + assert := test.NewAssert(t) + assert.Run(func(assert *test.Assert) { + A, _ := rand.Int(rand.Reader, fp.Modulus()) + B, _ := rand.Int(rand.Reader, fp.Modulus()) + C := new(big.Int).Mul(A, B) + C.Mod(C, fp.Modulus()) + expectedLimbs := 2*fp.NbLimbs() - 1 + expectedOverFlow := math.Ceil(math.Log2(float64(expectedLimbs+1))) + float64(fp.BitsPerLimb()) + circuit := &MulNoReduceCircuit[T]{expectedOverflow: uint(expectedOverFlow), expectedNbLimbs: int(expectedLimbs)} + assignment := &MulNoReduceCircuit[T]{A: ValueOf[T](A), B: ValueOf[T](B), C: ValueOf[T](C)} + assert.CheckCircuit(circuit, test.WithValidAssignment(assignment)) + }, testName[T]()) +} + +type SumCircuit[T FieldParams] struct { + Inputs []Element[T] + Expected Element[T] +} + +func (c *SumCircuit[T]) Define(api frontend.API) error { + f, err := NewField[T](api) + if err != nil { + return err + } + inputs := make([]*Element[T], len(c.Inputs)) + for i := range inputs { + inputs[i] = &c.Inputs[i] + } + res := f.Sum(inputs...) + f.AssertIsEqual(res, &c.Expected) + return nil +} + +func TestSum(t *testing.T) { + testSum[Goldilocks](t) + testSum[Secp256k1Fp](t) + testSum[BN254Fp](t) +} + +func testSum[T FieldParams](t *testing.T) { + var fp T + nbInputs := 1024 + assert := test.NewAssert(t) + assert.Run(func(assert *test.Assert) { + circuit := &SumCircuit[T]{Inputs: make([]Element[T], nbInputs)} + inputs := make([]Element[T], nbInputs) + result := new(big.Int) + for i := range inputs { + val, _ := rand.Int(rand.Reader, fp.Modulus()) + result.Add(result, val) + inputs[i] = ValueOf[T](val) + } + result.Mod(result, fp.Modulus()) + witness := &SumCircuit[T]{Inputs: inputs, Expected: ValueOf[T](result)} + assert.CheckCircuit(circuit, test.WithValidAssignment(witness)) + }, testName[T]()) +} diff --git a/std/math/emulated/field_mul.go b/std/math/emulated/field_mul.go index 3b3235e5cb..9a2671d08a 100644 --- a/std/math/emulated/field_mul.go +++ b/std/math/emulated/field_mul.go @@ -443,3 +443,23 @@ func (f *Field[T]) mulPreCond(a, b *Element[T]) (nextOverflow uint, err error) { } return } + +// MulNoReduce computes a*b and returns the result without reducing it modulo +// the field order. The number of limbs of the returned element depends on the +// number of limbs of the inputs. +func (f *Field[T]) MulNoReduce(a, b *Element[T]) *Element[T] { + return f.reduceAndOp(f.mulNoReduce, f.mulPreCond, a, b) +} + +func (f *Field[T]) mulNoReduce(a, b *Element[T], nextoverflow uint) *Element[T] { + resLimbs := make([]frontend.Variable, nbMultiplicationResLimbs(len(a.Limbs), len(b.Limbs))) + for i := range resLimbs { + resLimbs[i] = 0 + } + for i := range a.Limbs { + for j := range b.Limbs { + resLimbs[i+j] = f.api.MulAcc(resLimbs[i+j], a.Limbs[i], b.Limbs[j]) + } + } + return f.newInternalElement(resLimbs, nextoverflow) +} diff --git a/std/math/emulated/field_ops.go b/std/math/emulated/field_ops.go index 4115089a8c..aeaf2c3059 100644 --- a/std/math/emulated/field_ops.go +++ b/std/math/emulated/field_ops.go @@ -3,6 +3,7 @@ package emulated import ( "errors" "fmt" + "math/bits" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/std/selector" @@ -132,6 +133,37 @@ func (f *Field[T]) add(a, b *Element[T], nextOverflow uint) *Element[T] { return f.newInternalElement(limbs, nextOverflow) } +func (f *Field[T]) Sum(inputs ...*Element[T]) *Element[T] { + if len(inputs) == 0 { + return f.Zero() + } + if len(inputs) == 1 { + return inputs[0] + } + overflow := uint(0) + nbLimbs := 0 + for i := range inputs { + f.enforceWidthConditional(inputs[i]) + if inputs[i].overflow > overflow { + overflow = inputs[i].overflow + } + if len(inputs[i].Limbs) > nbLimbs { + nbLimbs = len(inputs[i].Limbs) + } + } + addOverflow := bits.Len(uint(len(inputs))) + limbs := make([]frontend.Variable, nbLimbs) + for i := range limbs { + limbs[i] = 0 + } + for i := range inputs { + for j := range inputs[i].Limbs { + limbs[j] = f.api.Add(limbs[j], inputs[i].Limbs[j]) + } + } + return f.newInternalElement(limbs, overflow+uint(addOverflow)) +} + // Reduce reduces a modulo the field order and returns it. func (f *Field[T]) Reduce(a *Element[T]) *Element[T] { f.enforceWidthConditional(a)