Skip to content

Commit

Permalink
feat: add MulNoReduce and Sum methods in field emulation (#1072)
Browse files Browse the repository at this point in the history
* feat: implement mulnoreduce

* test: mulnoreduce test

* docs: add method doc

* feat: add AddMany

* refactor: rename AddMany to Sum

* feat: if only single input then return as is

* test: non-native sum
  • Loading branch information
ivokub authored Mar 12, 2024
1 parent 781de03 commit ce0186e
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 0 deletions.
89 changes: 89 additions & 0 deletions std/math/emulated/element_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package emulated
import (
"crypto/rand"
"fmt"
"math"
"math/big"
"reflect"
"testing"
Expand Down Expand Up @@ -970,3 +971,91 @@ func testSqrt[T FieldParams](t *testing.T) {
assert.ProverSucceeded(&SqrtCircuit[T]{}, &SqrtCircuit[T]{X: ValueOf[T](X), Expected: ValueOf[T](exp)}, test.WithCurves(testCurve), test.NoSerializationChecks(), test.WithBackends(backend.GROTH16, backend.PLONK))
}, testName[T]())
}

type MulNoReduceCircuit[T FieldParams] struct {
A, B, C Element[T]
expectedOverflow uint
expectedNbLimbs int
}

func (c *MulNoReduceCircuit[T]) Define(api frontend.API) error {
f, err := NewField[T](api)
if err != nil {
return err
}
res := f.MulNoReduce(&c.A, &c.B)
f.AssertIsEqual(res, &c.C)
if res.overflow != c.expectedOverflow {
return fmt.Errorf("unexpected overflow: got %d, expected %d", res.overflow, c.expectedOverflow)
}
if len(res.Limbs) != c.expectedNbLimbs {
return fmt.Errorf("unexpected number of limbs: got %d, expected %d", len(res.Limbs), c.expectedNbLimbs)
}
return nil
}

func TestMulNoReduce(t *testing.T) {
testMulNoReduce[Goldilocks](t)
testMulNoReduce[Secp256k1Fp](t)
testMulNoReduce[BN254Fp](t)
}

func testMulNoReduce[T FieldParams](t *testing.T) {
var fp T
assert := test.NewAssert(t)
assert.Run(func(assert *test.Assert) {
A, _ := rand.Int(rand.Reader, fp.Modulus())
B, _ := rand.Int(rand.Reader, fp.Modulus())
C := new(big.Int).Mul(A, B)
C.Mod(C, fp.Modulus())
expectedLimbs := 2*fp.NbLimbs() - 1
expectedOverFlow := math.Ceil(math.Log2(float64(expectedLimbs+1))) + float64(fp.BitsPerLimb())
circuit := &MulNoReduceCircuit[T]{expectedOverflow: uint(expectedOverFlow), expectedNbLimbs: int(expectedLimbs)}
assignment := &MulNoReduceCircuit[T]{A: ValueOf[T](A), B: ValueOf[T](B), C: ValueOf[T](C)}
assert.CheckCircuit(circuit, test.WithValidAssignment(assignment))
}, testName[T]())
}

type SumCircuit[T FieldParams] struct {
Inputs []Element[T]
Expected Element[T]
}

func (c *SumCircuit[T]) Define(api frontend.API) error {
f, err := NewField[T](api)
if err != nil {
return err
}
inputs := make([]*Element[T], len(c.Inputs))
for i := range inputs {
inputs[i] = &c.Inputs[i]
}
res := f.Sum(inputs...)
f.AssertIsEqual(res, &c.Expected)
return nil
}

func TestSum(t *testing.T) {
testSum[Goldilocks](t)
testSum[Secp256k1Fp](t)
testSum[BN254Fp](t)
}

func testSum[T FieldParams](t *testing.T) {
var fp T
nbInputs := 1024
assert := test.NewAssert(t)
assert.Run(func(assert *test.Assert) {
circuit := &SumCircuit[T]{Inputs: make([]Element[T], nbInputs)}
inputs := make([]Element[T], nbInputs)
result := new(big.Int)
for i := range inputs {
val, _ := rand.Int(rand.Reader, fp.Modulus())
result.Add(result, val)
inputs[i] = ValueOf[T](val)
}
result.Mod(result, fp.Modulus())
witness := &SumCircuit[T]{Inputs: inputs, Expected: ValueOf[T](result)}
assert.CheckCircuit(circuit, test.WithValidAssignment(witness))
}, testName[T]())
}
20 changes: 20 additions & 0 deletions std/math/emulated/field_mul.go
Original file line number Diff line number Diff line change
Expand Up @@ -443,3 +443,23 @@ func (f *Field[T]) mulPreCond(a, b *Element[T]) (nextOverflow uint, err error) {
}
return
}

// MulNoReduce computes a*b and returns the result without reducing it modulo
// the field order. The number of limbs of the returned element depends on the
// number of limbs of the inputs.
func (f *Field[T]) MulNoReduce(a, b *Element[T]) *Element[T] {
return f.reduceAndOp(f.mulNoReduce, f.mulPreCond, a, b)
}

func (f *Field[T]) mulNoReduce(a, b *Element[T], nextoverflow uint) *Element[T] {
resLimbs := make([]frontend.Variable, nbMultiplicationResLimbs(len(a.Limbs), len(b.Limbs)))
for i := range resLimbs {
resLimbs[i] = 0
}
for i := range a.Limbs {
for j := range b.Limbs {
resLimbs[i+j] = f.api.MulAcc(resLimbs[i+j], a.Limbs[i], b.Limbs[j])
}
}
return f.newInternalElement(resLimbs, nextoverflow)
}
32 changes: 32 additions & 0 deletions std/math/emulated/field_ops.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package emulated
import (
"errors"
"fmt"
"math/bits"

"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/std/selector"
Expand Down Expand Up @@ -132,6 +133,37 @@ func (f *Field[T]) add(a, b *Element[T], nextOverflow uint) *Element[T] {
return f.newInternalElement(limbs, nextOverflow)
}

func (f *Field[T]) Sum(inputs ...*Element[T]) *Element[T] {
if len(inputs) == 0 {
return f.Zero()
}
if len(inputs) == 1 {
return inputs[0]
}
overflow := uint(0)
nbLimbs := 0
for i := range inputs {
f.enforceWidthConditional(inputs[i])
if inputs[i].overflow > overflow {
overflow = inputs[i].overflow
}
if len(inputs[i].Limbs) > nbLimbs {
nbLimbs = len(inputs[i].Limbs)
}
}
addOverflow := bits.Len(uint(len(inputs)))
limbs := make([]frontend.Variable, nbLimbs)
for i := range limbs {
limbs[i] = 0
}
for i := range inputs {
for j := range inputs[i].Limbs {
limbs[j] = f.api.Add(limbs[j], inputs[i].Limbs[j])
}
}
return f.newInternalElement(limbs, overflow+uint(addOverflow))
}

// Reduce reduces a modulo the field order and returns it.
func (f *Field[T]) Reduce(a *Element[T]) *Element[T] {
f.enforceWidthConditional(a)
Expand Down

0 comments on commit ce0186e

Please sign in to comment.