Skip to content

Commit

Permalink
Merge pull request #1080 from Consensys/feat/emulated-nativehint
Browse files Browse the repository at this point in the history
feat: add hint calling with either native inputs or outputs
  • Loading branch information
yelhousni authored Mar 12, 2024
2 parents 732620b + 4ed9999 commit 4ae5707
Show file tree
Hide file tree
Showing 2 changed files with 309 additions and 26 deletions.
159 changes: 133 additions & 26 deletions std/math/emulated/field_hint.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,35 @@ func (f *Field[T]) wrapHint(nonnativeInputs ...*Element[T]) []frontend.Variable
return res
}

func (f *Field[T]) wrapHintNatives(nativeInputs ...frontend.Variable) []frontend.Variable {
res := []frontend.Variable{f.fParams.BitsPerLimb(), f.fParams.NbLimbs()}
res = append(res, f.Modulus().Limbs...)
res = append(res, nativeInputs...)
return res
}

// UnwrapHint unwraps the native inputs into nonnative inputs. Then it calls
// nonnativeHint function with nonnative inputs. After nonnativeHint returns, it
// decomposes the outputs into limbs.
func UnwrapHint(nativeInputs, nativeOutputs []*big.Int, nonnativeHint solver.Hint) error {
return unwrapHint(true, true, nativeInputs, nativeOutputs, nonnativeHint)
}

// UnwrapHintWithNativeOutput unwraps the native inputs into nonnative inputs. Then
// it calls nonnativeHint function with nonnative inputs. After nonnativeHint
// returns, it returns native outputs as-is.
func UnwrapHintWithNativeOutput(nativeInputs, nativeOutputs []*big.Int, nonnativeHint solver.Hint) error {
return unwrapHint(true, false, nativeInputs, nativeOutputs, nonnativeHint)
}

// UnwrapHintWithNativeInput unwraps the native inputs into native inputs. Then
// it calls nonnativeHint function with native inputs. After nonnativeHint
// returns, it decomposes the outputs into limbs.
func UnwrapHintWithNativeInput(nativeInputs, nativeOutputs []*big.Int, nonnativeHint solver.Hint) error {
return unwrapHint(false, true, nativeInputs, nativeOutputs, nonnativeHint)
}

func unwrapHint(isEmulatedInput, isEmulatedOutput bool, nativeInputs, nativeOutputs []*big.Int, nonnativeHint solver.Hint) error {
if len(nativeInputs) < 2 {
return fmt.Errorf("hint wrapper header is 2 elements")
}
Expand All @@ -38,43 +63,65 @@ func UnwrapHint(nativeInputs, nativeOutputs []*big.Int, nonnativeHint solver.Hin
if err := recompose(nativeInputs[2:2+nbLimbs], uint(nbBits), nonnativeMod); err != nil {
return fmt.Errorf("cannot recover nonnative mod: %w", err)
}
if !nativeInputs[2+nbLimbs].IsInt64() {
return fmt.Errorf("number of nonnative elements must be castable to int64")
}
nbInputs := int(nativeInputs[2+nbLimbs].Int64())
nonnativeInputs := make([]*big.Int, nbInputs)
readPtr := 3 + nbLimbs
for i := 0; i < nbInputs; i++ {
if len(nativeInputs) < readPtr+1 {
return fmt.Errorf("can not read %d-th native input", i)
var nonnativeInputs []*big.Int
if isEmulatedInput {
if !nativeInputs[2+nbLimbs].IsInt64() {
return fmt.Errorf("number of nonnative elements must be castable to int64")
}
if !nativeInputs[readPtr].IsInt64() {
return fmt.Errorf("corrupted %d-th native input", i)
nbInputs := int(nativeInputs[2+nbLimbs].Int64())
readPtr := 3 + nbLimbs
nonnativeInputs = make([]*big.Int, nbInputs)
for i := 0; i < nbInputs; i++ {
if len(nativeInputs) < readPtr+1 {
return fmt.Errorf("can not read %d-th native input", i)
}
if !nativeInputs[readPtr].IsInt64() {
return fmt.Errorf("corrupted %d-th native input", i)
}
currentInputLen := int(nativeInputs[readPtr].Int64())
if len(nativeInputs) < (readPtr + 1 + currentInputLen) {
return fmt.Errorf("cannot read %d-th nonnative element", i)
}
nonnativeInputs[i] = new(big.Int)
if err := recompose(nativeInputs[readPtr+1:readPtr+1+currentInputLen], uint(nbBits), nonnativeInputs[i]); err != nil {
return fmt.Errorf("recompose %d-th element: %w", i, err)
}
readPtr += 1 + currentInputLen
}
currentInputLen := int(nativeInputs[readPtr].Int64())
if len(nativeInputs) < (readPtr + 1 + currentInputLen) {
return fmt.Errorf("cannot read %d-th nonnative element", i)
} else {
nbInputs := len(nativeInputs[2+nbLimbs:])
readPtr := 2 + nbLimbs
nonnativeInputs = make([]*big.Int, nbInputs)
for i := 0; i < nbInputs; i++ {
nonnativeInputs[i] = new(big.Int).Set(nativeInputs[readPtr+i])
}
nonnativeInputs[i] = new(big.Int)
if err := recompose(nativeInputs[readPtr+1:readPtr+1+currentInputLen], uint(nbBits), nonnativeInputs[i]); err != nil {
return fmt.Errorf("recompose %d-th element: %w", i, err)
}
readPtr += 1 + currentInputLen
}
if len(nativeOutputs)%nbLimbs != 0 {
return fmt.Errorf("output count doesn't divide limb count")

var nonnativeOutputs []*big.Int
if isEmulatedOutput {
if len(nativeOutputs)%nbLimbs != 0 {
return fmt.Errorf("output count doesn't divide limb count")
}
nonnativeOutputs = make([]*big.Int, len(nativeOutputs)/nbLimbs)
} else {
nonnativeOutputs = make([]*big.Int, len(nativeOutputs))
}
nonnativeOutputs := make([]*big.Int, len(nativeOutputs)/nbLimbs)
for i := range nonnativeOutputs {
nonnativeOutputs[i] = new(big.Int)
}
if err := nonnativeHint(nonnativeMod, nonnativeInputs, nonnativeOutputs); err != nil {
return fmt.Errorf("nonnative hint: %w", err)
}
for i := range nonnativeOutputs {
nonnativeOutputs[i].Mod(nonnativeOutputs[i], nonnativeMod)
if err := decompose(nonnativeOutputs[i], uint(nbBits), nativeOutputs[i*nbLimbs:(i+1)*nbLimbs]); err != nil {
return fmt.Errorf("decompose %d-th element: %w", i, err)
if isEmulatedOutput {
for i := range nonnativeOutputs {
nonnativeOutputs[i].Mod(nonnativeOutputs[i], nonnativeMod)
if err := decompose(nonnativeOutputs[i], uint(nbBits), nativeOutputs[i*nbLimbs:(i+1)*nbLimbs]); err != nil {
return fmt.Errorf("decompose %d-th element: %w", i, err)
}
}
} else {
for i := range nonnativeOutputs {
nativeOutputs[i].Set(nonnativeOutputs[i])
}
}
return nil
Expand Down Expand Up @@ -107,3 +154,63 @@ func (f *Field[T]) NewHint(hf solver.Hint, nbOutputs int, inputs ...*Element[T])
}
return outputs, nil
}

// NewHintWithNativeOutput allows to call the emulation hint function hf on
// nonnative inputs, expecting nbOutputs results. This function splits
// internally the emulated element into limbs and passes these to the hint
// function. There is [UnwrapHintWithNativeOutput] function which performs
// corresponding recomposition of limbs into integer values (and vice verse for
// output).
//
// This method is an alternation of [NewHint] method, which allows to pass
// nonnative inputs to the hint function and returns native outputs. This is
// useful when the outputs do not necessarily have to be emulated elements (e.g.
// bits) as it skips enforcing range checks on the outputs.
//
// The hint function for this method is defined as:
//
// func HintFn(nativeMod *big.Int, inputs, outputs []*big.Int) error {
// return emulated.UnwrapHintWithNativeOutput(inputs, outputs, func(emulatedMod *big.Int, inputs, outputs []*big.Int) error {
// // in the function we have access to both native and nonantive modulus
// })}
func (f *Field[T]) NewHintWithNativeOutput(hf solver.Hint, nbOutputs int, inputs ...*Element[T]) ([]frontend.Variable, error) {
nativeInputs := f.wrapHint(inputs...)
nbNativeOutputs := nbOutputs
nativeOutputs, err := f.api.Compiler().NewHint(hf, nbNativeOutputs, nativeInputs...)
if err != nil {
return nil, fmt.Errorf("call hint: %w", err)
}
return nativeOutputs, nil
}

// NewHintWithNativeInput allows to call the emulation hint function hf on
// native inputs, expecting nbOutputs results. This function passes the native
// inputs to the hint function directly and reconstructs the outputs into
// non-native elements. There is [UnwrapHintWithNativeInput] function which
// performs corresponding recomposition of limbs into integer values (and vice
// verse for output).
//
// This method is an alternation of [NewHint] method, which allows to pass
// native inputs to the hint function and returns nonnative outputs. This is
// useful when the inputs do not necessarily have to be emulated elements (e.g.
// indices) and allows to work between different fields.
//
// The hint function for this method is defined as:
//
// func HintFn(nativeMod *big.Int, nativeInputs, nativeOutputs []*big.Int) error {
// return emulated.UnwrapHintWithNativeInput(nativeInputs, nativeOutputs, func(emulatedMod *big.Int, inputs, outputs []*big.Int) error {
// // in the function we have access to both native and nonantive modulus
// })}
func (f *Field[T]) NewHintWithNativeInput(hf solver.Hint, nbOutputs int, inputs ...frontend.Variable) ([]*Element[T], error) {
nativeInputs := f.wrapHintNatives(inputs...)
nbNativeOutputs := int(f.fParams.NbLimbs()) * nbOutputs
nativeOutputs, err := f.api.Compiler().NewHint(hf, nbNativeOutputs, nativeInputs...)
if err != nil {
return nil, fmt.Errorf("call hint: %w", err)
}
outputs := make([]*Element[T], nbOutputs)
for i := 0; i < nbOutputs; i++ {
outputs[i] = f.packLimbs(nativeOutputs[i*int(f.fParams.NbLimbs()):(i+1)*int(f.fParams.NbLimbs())], true)
}
return outputs, nil
}
176 changes: 176 additions & 0 deletions std/math/emulated/field_test.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package emulated

import (
"crypto/rand"
"errors"
"fmt"
"math/big"
"testing"

"github.com/consensys/gnark/frontend"
Expand Down Expand Up @@ -131,3 +133,177 @@ func TestSubConstantCircuit(t *testing.T) {
_, err = frontend.Compile(testCurve.ScalarField(), r1cs.NewBuilder, &circuit, frontend.IgnoreUnconstrainedInputs())
assert.NoError(err)
}

func nnaHint(nativeMod *big.Int, nativeInputs, nativeOutputs []*big.Int) error {
return UnwrapHint(nativeInputs, nativeOutputs, func(mod *big.Int, inputs, outputs []*big.Int) error {
nominator := inputs[0]
denominator := inputs[1]
res := new(big.Int).ModInverse(denominator, mod)
if res == nil {
return fmt.Errorf("no modular inverse")
}
res.Mul(res, nominator)
res.Mod(res, mod)
outputs[0].Set(res)
return nil
})
}

type hintCircuit[T FieldParams] struct {
Nominator Element[T]
Denominator Element[T]
Expected Element[T]
}

func (c *hintCircuit[T]) Define(api frontend.API) error {
field, err := NewField[T](api)
if err != nil {
return err
}
res, err := field.NewHint(nnaHint, 1, &c.Nominator, &c.Denominator)
if err != nil {
return err
}
field.AssertIsEqual(res[0], &c.Expected)
return nil
}

func testHint[T FieldParams](t *testing.T) {
var fr T
assert := test.NewAssert(t)
a, _ := rand.Int(rand.Reader, fr.Modulus())
b, _ := rand.Int(rand.Reader, fr.Modulus())
c := new(big.Int).ModInverse(b, fr.Modulus())
c.Mul(c, a)
c.Mod(c, fr.Modulus())

circuit := hintCircuit[T]{}
witness := hintCircuit[T]{
Nominator: ValueOf[T](a),
Denominator: ValueOf[T](b),
Expected: ValueOf[T](c),
}
assert.CheckCircuit(&circuit, test.WithValidAssignment(&witness))
}

func TestHint(t *testing.T) {
testHint[Goldilocks](t)
testHint[Secp256k1Fp](t)
testHint[BN254Fp](t)
}

func nativeInputHint(nativeMod *big.Int, nativeInputs, nativeOutputs []*big.Int) error {
return UnwrapHintWithNativeInput(nativeInputs, nativeOutputs, func(nonnativeMod *big.Int, inputs, outputs []*big.Int) error {
nominator := inputs[0]
denominator := inputs[1]
res := new(big.Int).ModInverse(denominator, nonnativeMod)
if res == nil {
return fmt.Errorf("no modular inverse")
}
res.Mul(res, nominator)
res.Mod(res, nonnativeMod)
outputs[0].Set(res)
return nil
})
}

type hintNativeInputCircuit[T FieldParams] struct {
Nominator frontend.Variable
Denominator frontend.Variable
Expected Element[T]
}

func (c *hintNativeInputCircuit[T]) Define(api frontend.API) error {
field, err := NewField[T](api)
if err != nil {
return err
}
res, err := field.NewHintWithNativeInput(nativeInputHint, 1, c.Nominator, c.Denominator)
if err != nil {
return err
}
field.AssertIsEqual(res[0], &c.Expected)
return nil
}

func testHintNativeInput[T FieldParams](t *testing.T) {
var fr T
assert := test.NewAssert(t)
a, _ := rand.Int(rand.Reader, testCurve.ScalarField())
b, _ := rand.Int(rand.Reader, testCurve.ScalarField())
c := new(big.Int).ModInverse(b, fr.Modulus())
c.Mul(c, a)
c.Mod(c, fr.Modulus())

circuit := hintNativeInputCircuit[T]{}
witness := hintNativeInputCircuit[T]{
Nominator: a,
Denominator: b,
Expected: ValueOf[T](c),
}
assert.CheckCircuit(&circuit, test.WithValidAssignment(&witness), test.WithCurves(testCurve))
}

func TestHintNativeInput(t *testing.T) {
testHintNativeInput[Goldilocks](t)
testHintNativeInput[Secp256k1Fp](t)
testHintNativeInput[BN254Fp](t)
}

func nativeOutputHint(nativeMod *big.Int, nativeInputs, nativeOutputs []*big.Int) error {
return UnwrapHintWithNativeOutput(nativeInputs, nativeOutputs, func(nonnativeMod *big.Int, inputs, outputs []*big.Int) error {
nominator := inputs[0]
denominator := inputs[1]
res := new(big.Int).ModInverse(denominator, nativeMod)
if res == nil {
return fmt.Errorf("no modular inverse")
}
res.Mul(res, nominator)
res.Mod(res, nativeMod)
outputs[0].Set(res)
return nil
})
}

type hintNativeOutputCircuit[T FieldParams] struct {
Nominator Element[T]
Denominator Element[T]
Expected frontend.Variable
}

func (c *hintNativeOutputCircuit[T]) Define(api frontend.API) error {
field, err := NewField[T](api)
if err != nil {
return err
}
res, err := field.NewHintWithNativeOutput(nativeOutputHint, 1, &c.Nominator, &c.Denominator)
if err != nil {
return err
}
api.AssertIsEqual(res[0], c.Expected)
return nil
}

func testHintNativeOutput[T FieldParams](t *testing.T) {
var fr T
assert := test.NewAssert(t)
a, _ := rand.Int(rand.Reader, fr.Modulus())
b, _ := rand.Int(rand.Reader, fr.Modulus())
c := new(big.Int).ModInverse(b, testCurve.ScalarField())
c.Mul(c, a)
c.Mod(c, testCurve.ScalarField())

circuit := hintNativeOutputCircuit[T]{}
witness := hintNativeOutputCircuit[T]{
Nominator: ValueOf[T](a),
Denominator: ValueOf[T](b),
Expected: c,
}
assert.CheckCircuit(&circuit, test.WithValidAssignment(&witness), test.WithCurves(testCurve))
}

func TestHintNativeOutput(t *testing.T) {
testHintNativeOutput[Goldilocks](t)
testHintNativeOutput[Secp256k1Fp](t)
testHintNativeOutput[BN254Fp](t)
}

0 comments on commit 4ae5707

Please sign in to comment.