From 2c49a0fc8f48477fc46899c165ea5b9b139d7f01 Mon Sep 17 00:00:00 2001 From: Ivo Kubjas Date: Fri, 8 Mar 2024 13:20:15 +0000 Subject: [PATCH 1/5] feat: add non-native hint with native output --- std/math/emulated/field_hint.go | 57 +++++++++++++++++++++++++++++---- 1 file changed, 50 insertions(+), 7 deletions(-) diff --git a/std/math/emulated/field_hint.go b/std/math/emulated/field_hint.go index 52b8e9ddfc..c264136b79 100644 --- a/std/math/emulated/field_hint.go +++ b/std/math/emulated/field_hint.go @@ -23,6 +23,17 @@ func (f *Field[T]) wrapHint(nonnativeInputs ...*Element[T]) []frontend.Variable // nonnativeHint function with nonnative inputs. After nonnativeHint returns, it // decomposes the outputs into limbs. func UnwrapHint(nativeInputs, nativeOutputs []*big.Int, nonnativeHint solver.Hint) error { + return unwrapHint(true, true, nativeInputs, nativeOutputs, nonnativeHint) +} + +// UnwrapHintWithNativeOutput unwraps the native inputs into nonnative inputs. Then +// it calls nonnativeHint function with nonnative inputs. After nonnativeHint +// returns, it returns native outputs as-is. +func UnwrapHintWithNativeOutput(nativeInputs, nativeOutputs []*big.Int, nonnativeHint solver.Hint) error { + return unwrapHint(true, false, nativeInputs, nativeOutputs, nonnativeHint) +} + +func unwrapHint(isEmulatedInput, isEmulatedOutput bool, nativeInputs, nativeOutputs []*big.Int, nonnativeHint solver.Hint) error { if len(nativeInputs) < 2 { return fmt.Errorf("hint wrapper header is 2 elements") } @@ -61,20 +72,32 @@ func UnwrapHint(nativeInputs, nativeOutputs []*big.Int, nonnativeHint solver.Hin } readPtr += 1 + currentInputLen } - if len(nativeOutputs)%nbLimbs != 0 { - return fmt.Errorf("output count doesn't divide limb count") + + var nonnativeOutputs []*big.Int + if isEmulatedOutput { + if len(nativeOutputs)%nbLimbs != 0 { + return fmt.Errorf("output count doesn't divide limb count") + } + nonnativeOutputs = make([]*big.Int, len(nativeOutputs)/nbLimbs) + } else { + nonnativeOutputs = make([]*big.Int, len(nativeOutputs)) } - nonnativeOutputs := make([]*big.Int, len(nativeOutputs)/nbLimbs) for i := range nonnativeOutputs { nonnativeOutputs[i] = new(big.Int) } if err := nonnativeHint(nonnativeMod, nonnativeInputs, nonnativeOutputs); err != nil { return fmt.Errorf("nonnative hint: %w", err) } - for i := range nonnativeOutputs { - nonnativeOutputs[i].Mod(nonnativeOutputs[i], nonnativeMod) - if err := decompose(nonnativeOutputs[i], uint(nbBits), nativeOutputs[i*nbLimbs:(i+1)*nbLimbs]); err != nil { - return fmt.Errorf("decompose %d-th element: %w", i, err) + if isEmulatedOutput { + for i := range nonnativeOutputs { + nonnativeOutputs[i].Mod(nonnativeOutputs[i], nonnativeMod) + if err := decompose(nonnativeOutputs[i], uint(nbBits), nativeOutputs[i*nbLimbs:(i+1)*nbLimbs]); err != nil { + return fmt.Errorf("decompose %d-th element: %w", i, err) + } + } + } else { + for i := range nonnativeOutputs { + nativeOutputs[i].Set(nonnativeOutputs[i]) } } return nil @@ -107,3 +130,23 @@ func (f *Field[T]) NewHint(hf solver.Hint, nbOutputs int, inputs ...*Element[T]) } return outputs, nil } + +// NewHintWithNativeOutput allows to call the emulation hint function hf on +// nonnative inputs, expecting nbOutputs results. This function splits +// internally the emulated element into limbs and passes these to the hint +// function. There is [UnwrapHint] function which performs corresponding +// recomposition of limbs into integer values (and vice verse for output). +// +// This method is an alternation of [NewHint] method, which allows to pass +// nonnative inputs to the hint function and returns native outputs. This is +// useful when the outputs do not necessarily have to be emulated elements (e.g. +// bits) as it skips enforcing range checks on the outputs. +func (f *Field[T]) NewHintWithNativeOutput(hf solver.Hint, nbOutputs int, inputs ...*Element[T]) ([]frontend.Variable, error) { + nativeInputs := f.wrapHint(inputs...) + nbNativeOutputs := nbOutputs + nativeOutputs, err := f.api.Compiler().NewHint(hf, nbNativeOutputs, nativeInputs...) + if err != nil { + return nil, fmt.Errorf("call hint: %w", err) + } + return nativeOutputs, nil +} From c9cf73514f3eab3d0fc790204149b8126c72c43d Mon Sep 17 00:00:00 2001 From: Ivo Kubjas Date: Mon, 11 Mar 2024 13:44:37 +0000 Subject: [PATCH 2/5] feat: add non-native hint with native inputs --- std/math/emulated/field_hint.go | 82 +++++++++++++++++++++++++-------- 1 file changed, 63 insertions(+), 19 deletions(-) diff --git a/std/math/emulated/field_hint.go b/std/math/emulated/field_hint.go index c264136b79..b4b00c01dc 100644 --- a/std/math/emulated/field_hint.go +++ b/std/math/emulated/field_hint.go @@ -19,6 +19,13 @@ func (f *Field[T]) wrapHint(nonnativeInputs ...*Element[T]) []frontend.Variable return res } +func (f *Field[T]) wrapHintNatives(nativeInputs ...frontend.Variable) []frontend.Variable { + res := []frontend.Variable{f.fParams.BitsPerLimb(), f.fParams.NbLimbs()} + res = append(res, f.Modulus().Limbs...) + res = append(res, nativeInputs...) + return res +} + // UnwrapHint unwraps the native inputs into nonnative inputs. Then it calls // nonnativeHint function with nonnative inputs. After nonnativeHint returns, it // decomposes the outputs into limbs. @@ -33,6 +40,13 @@ func UnwrapHintWithNativeOutput(nativeInputs, nativeOutputs []*big.Int, nonnativ return unwrapHint(true, false, nativeInputs, nativeOutputs, nonnativeHint) } +// UnwrapHintWithNativeInput unwraps the native inputs into native inputs. Then +// it calls nonnativeHint function with native inputs. After nonnativeHint +// returns, it decomposes the outputs into limbs. +func UnwrapHintWithNativeInput(nativeInputs, nativeOutputs []*big.Int, nonnativeHint solver.Hint) error { + return unwrapHint(false, true, nativeInputs, nativeOutputs, nonnativeHint) +} + func unwrapHint(isEmulatedInput, isEmulatedOutput bool, nativeInputs, nativeOutputs []*big.Int, nonnativeHint solver.Hint) error { if len(nativeInputs) < 2 { return fmt.Errorf("hint wrapper header is 2 elements") @@ -49,28 +63,38 @@ func unwrapHint(isEmulatedInput, isEmulatedOutput bool, nativeInputs, nativeOutp if err := recompose(nativeInputs[2:2+nbLimbs], uint(nbBits), nonnativeMod); err != nil { return fmt.Errorf("cannot recover nonnative mod: %w", err) } - if !nativeInputs[2+nbLimbs].IsInt64() { - return fmt.Errorf("number of nonnative elements must be castable to int64") - } - nbInputs := int(nativeInputs[2+nbLimbs].Int64()) - nonnativeInputs := make([]*big.Int, nbInputs) - readPtr := 3 + nbLimbs - for i := 0; i < nbInputs; i++ { - if len(nativeInputs) < readPtr+1 { - return fmt.Errorf("can not read %d-th native input", i) - } - if !nativeInputs[readPtr].IsInt64() { - return fmt.Errorf("corrupted %d-th native input", i) + var nonnativeInputs []*big.Int + if isEmulatedInput { + if !nativeInputs[2+nbLimbs].IsInt64() { + return fmt.Errorf("number of nonnative elements must be castable to int64") } - currentInputLen := int(nativeInputs[readPtr].Int64()) - if len(nativeInputs) < (readPtr + 1 + currentInputLen) { - return fmt.Errorf("cannot read %d-th nonnative element", i) + nbInputs := int(nativeInputs[2+nbLimbs].Int64()) + readPtr := 3 + nbLimbs + nonnativeInputs = make([]*big.Int, nbInputs) + for i := 0; i < nbInputs; i++ { + if len(nativeInputs) < readPtr+1 { + return fmt.Errorf("can not read %d-th native input", i) + } + if !nativeInputs[readPtr].IsInt64() { + return fmt.Errorf("corrupted %d-th native input", i) + } + currentInputLen := int(nativeInputs[readPtr].Int64()) + if len(nativeInputs) < (readPtr + 1 + currentInputLen) { + return fmt.Errorf("cannot read %d-th nonnative element", i) + } + nonnativeInputs[i] = new(big.Int) + if err := recompose(nativeInputs[readPtr+1:readPtr+1+currentInputLen], uint(nbBits), nonnativeInputs[i]); err != nil { + return fmt.Errorf("recompose %d-th element: %w", i, err) + } + readPtr += 1 + currentInputLen } - nonnativeInputs[i] = new(big.Int) - if err := recompose(nativeInputs[readPtr+1:readPtr+1+currentInputLen], uint(nbBits), nonnativeInputs[i]); err != nil { - return fmt.Errorf("recompose %d-th element: %w", i, err) + } else { + nbInputs := len(nativeInputs[2+nbLimbs:]) + readPtr := 2 + nbLimbs + nonnativeInputs = make([]*big.Int, nbInputs) + for i := 0; i < nbInputs; i++ { + nonnativeInputs[i] = new(big.Int).Set(nativeInputs[readPtr+i]) } - readPtr += 1 + currentInputLen } var nonnativeOutputs []*big.Int @@ -150,3 +174,23 @@ func (f *Field[T]) NewHintWithNativeOutput(hf solver.Hint, nbOutputs int, inputs } return nativeOutputs, nil } + +// NewHintWithNativeInput allows to call the emulation hint function hf on +// native inputs, expecting nbOutputs results. This function passes the native +// inputs to the hint function directly and reconstructs the outputs into +// non-native elements. There is [UnwrapHintWithNativeInput] function which +// performs corresponding recomposition of limbs into integer values (and vice +// verse for output). +func (f *Field[T]) NewHintWithNativeInput(hf solver.Hint, nbOutputs int, inputs ...frontend.Variable) ([]*Element[T], error) { + nativeInputs := f.wrapHintNatives(inputs...) + nbNativeOutputs := int(f.fParams.NbLimbs()) * nbOutputs + nativeOutputs, err := f.api.Compiler().NewHint(hf, nbNativeOutputs, nativeInputs...) + if err != nil { + return nil, fmt.Errorf("call hint: %w", err) + } + outputs := make([]*Element[T], nbOutputs) + for i := 0; i < nbOutputs; i++ { + outputs[i] = f.packLimbs(nativeOutputs[i*int(f.fParams.NbLimbs()):(i+1)*int(f.fParams.NbLimbs())], true) + } + return outputs, nil +} From df7cc97161c3c853a2d24cba33f4b382d8b75dfe Mon Sep 17 00:00:00 2001 From: Ivo Kubjas Date: Mon, 11 Mar 2024 13:45:27 +0000 Subject: [PATCH 3/5] docs: method doc native output --- std/math/emulated/field_hint.go | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/std/math/emulated/field_hint.go b/std/math/emulated/field_hint.go index b4b00c01dc..dfef4f9187 100644 --- a/std/math/emulated/field_hint.go +++ b/std/math/emulated/field_hint.go @@ -158,13 +158,21 @@ func (f *Field[T]) NewHint(hf solver.Hint, nbOutputs int, inputs ...*Element[T]) // NewHintWithNativeOutput allows to call the emulation hint function hf on // nonnative inputs, expecting nbOutputs results. This function splits // internally the emulated element into limbs and passes these to the hint -// function. There is [UnwrapHint] function which performs corresponding -// recomposition of limbs into integer values (and vice verse for output). +// function. There is [UnwrapHintWithNativeOutput] function which performs +// corresponding recomposition of limbs into integer values (and vice verse for +// output). // // This method is an alternation of [NewHint] method, which allows to pass // nonnative inputs to the hint function and returns native outputs. This is // useful when the outputs do not necessarily have to be emulated elements (e.g. // bits) as it skips enforcing range checks on the outputs. +// +// The hint function for this method is defined as: +// +// func HintFn(nativeMod *big.Int, inputs, outputs []*big.Int) error { +// return emulated.UnwrapHintWithNativeOutput(inputs, outputs, func(emulatedMod *big.Int, inputs, outputs []*big.Int) error { +// // in the function we have access to both native and nonantive modulus +// })} func (f *Field[T]) NewHintWithNativeOutput(hf solver.Hint, nbOutputs int, inputs ...*Element[T]) ([]frontend.Variable, error) { nativeInputs := f.wrapHint(inputs...) nbNativeOutputs := nbOutputs From ebf932622fb4cc2cb43219be2fed044c50163b65 Mon Sep 17 00:00:00 2001 From: Ivo Kubjas Date: Mon, 11 Mar 2024 13:47:26 +0000 Subject: [PATCH 4/5] docs: add hint definition for native inputs --- std/math/emulated/field_hint.go | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/std/math/emulated/field_hint.go b/std/math/emulated/field_hint.go index dfef4f9187..2c613c5397 100644 --- a/std/math/emulated/field_hint.go +++ b/std/math/emulated/field_hint.go @@ -189,6 +189,18 @@ func (f *Field[T]) NewHintWithNativeOutput(hf solver.Hint, nbOutputs int, inputs // non-native elements. There is [UnwrapHintWithNativeInput] function which // performs corresponding recomposition of limbs into integer values (and vice // verse for output). +// +// This method is an alternation of [NewHint] method, which allows to pass +// native inputs to the hint function and returns nonnative outputs. This is +// useful when the inputs do not necessarily have to be emulated elements (e.g. +// indices) and allows to work between different fields. +// +// The hint function for this method is defined as: +// +// func HintFn(nativeMod *big.Int, nativeInputs, nativeOutputs []*big.Int) error { +// return emulated.UnwrapHintWithNativeInput(nativeInputs, nativeOutputs, func(emulatedMod *big.Int, inputs, outputs []*big.Int) error { +// // in the function we have access to both native and nonantive modulus +// })} func (f *Field[T]) NewHintWithNativeInput(hf solver.Hint, nbOutputs int, inputs ...frontend.Variable) ([]*Element[T], error) { nativeInputs := f.wrapHintNatives(inputs...) nbNativeOutputs := int(f.fParams.NbLimbs()) * nbOutputs From 4ed9999f7c6f01eff1b62d57ab0d81295639a527 Mon Sep 17 00:00:00 2001 From: Ivo Kubjas Date: Mon, 11 Mar 2024 13:47:41 +0000 Subject: [PATCH 5/5] test: add tests for all types of hints --- std/math/emulated/field_test.go | 176 ++++++++++++++++++++++++++++++++ 1 file changed, 176 insertions(+) diff --git a/std/math/emulated/field_test.go b/std/math/emulated/field_test.go index 175131827f..760eef182f 100644 --- a/std/math/emulated/field_test.go +++ b/std/math/emulated/field_test.go @@ -1,8 +1,10 @@ package emulated import ( + "crypto/rand" "errors" "fmt" + "math/big" "testing" "github.com/consensys/gnark/frontend" @@ -131,3 +133,177 @@ func TestSubConstantCircuit(t *testing.T) { _, err = frontend.Compile(testCurve.ScalarField(), r1cs.NewBuilder, &circuit, frontend.IgnoreUnconstrainedInputs()) assert.NoError(err) } + +func nnaHint(nativeMod *big.Int, nativeInputs, nativeOutputs []*big.Int) error { + return UnwrapHint(nativeInputs, nativeOutputs, func(mod *big.Int, inputs, outputs []*big.Int) error { + nominator := inputs[0] + denominator := inputs[1] + res := new(big.Int).ModInverse(denominator, mod) + if res == nil { + return fmt.Errorf("no modular inverse") + } + res.Mul(res, nominator) + res.Mod(res, mod) + outputs[0].Set(res) + return nil + }) +} + +type hintCircuit[T FieldParams] struct { + Nominator Element[T] + Denominator Element[T] + Expected Element[T] +} + +func (c *hintCircuit[T]) Define(api frontend.API) error { + field, err := NewField[T](api) + if err != nil { + return err + } + res, err := field.NewHint(nnaHint, 1, &c.Nominator, &c.Denominator) + if err != nil { + return err + } + field.AssertIsEqual(res[0], &c.Expected) + return nil +} + +func testHint[T FieldParams](t *testing.T) { + var fr T + assert := test.NewAssert(t) + a, _ := rand.Int(rand.Reader, fr.Modulus()) + b, _ := rand.Int(rand.Reader, fr.Modulus()) + c := new(big.Int).ModInverse(b, fr.Modulus()) + c.Mul(c, a) + c.Mod(c, fr.Modulus()) + + circuit := hintCircuit[T]{} + witness := hintCircuit[T]{ + Nominator: ValueOf[T](a), + Denominator: ValueOf[T](b), + Expected: ValueOf[T](c), + } + assert.CheckCircuit(&circuit, test.WithValidAssignment(&witness)) +} + +func TestHint(t *testing.T) { + testHint[Goldilocks](t) + testHint[Secp256k1Fp](t) + testHint[BN254Fp](t) +} + +func nativeInputHint(nativeMod *big.Int, nativeInputs, nativeOutputs []*big.Int) error { + return UnwrapHintWithNativeInput(nativeInputs, nativeOutputs, func(nonnativeMod *big.Int, inputs, outputs []*big.Int) error { + nominator := inputs[0] + denominator := inputs[1] + res := new(big.Int).ModInverse(denominator, nonnativeMod) + if res == nil { + return fmt.Errorf("no modular inverse") + } + res.Mul(res, nominator) + res.Mod(res, nonnativeMod) + outputs[0].Set(res) + return nil + }) +} + +type hintNativeInputCircuit[T FieldParams] struct { + Nominator frontend.Variable + Denominator frontend.Variable + Expected Element[T] +} + +func (c *hintNativeInputCircuit[T]) Define(api frontend.API) error { + field, err := NewField[T](api) + if err != nil { + return err + } + res, err := field.NewHintWithNativeInput(nativeInputHint, 1, c.Nominator, c.Denominator) + if err != nil { + return err + } + field.AssertIsEqual(res[0], &c.Expected) + return nil +} + +func testHintNativeInput[T FieldParams](t *testing.T) { + var fr T + assert := test.NewAssert(t) + a, _ := rand.Int(rand.Reader, testCurve.ScalarField()) + b, _ := rand.Int(rand.Reader, testCurve.ScalarField()) + c := new(big.Int).ModInverse(b, fr.Modulus()) + c.Mul(c, a) + c.Mod(c, fr.Modulus()) + + circuit := hintNativeInputCircuit[T]{} + witness := hintNativeInputCircuit[T]{ + Nominator: a, + Denominator: b, + Expected: ValueOf[T](c), + } + assert.CheckCircuit(&circuit, test.WithValidAssignment(&witness), test.WithCurves(testCurve)) +} + +func TestHintNativeInput(t *testing.T) { + testHintNativeInput[Goldilocks](t) + testHintNativeInput[Secp256k1Fp](t) + testHintNativeInput[BN254Fp](t) +} + +func nativeOutputHint(nativeMod *big.Int, nativeInputs, nativeOutputs []*big.Int) error { + return UnwrapHintWithNativeOutput(nativeInputs, nativeOutputs, func(nonnativeMod *big.Int, inputs, outputs []*big.Int) error { + nominator := inputs[0] + denominator := inputs[1] + res := new(big.Int).ModInverse(denominator, nativeMod) + if res == nil { + return fmt.Errorf("no modular inverse") + } + res.Mul(res, nominator) + res.Mod(res, nativeMod) + outputs[0].Set(res) + return nil + }) +} + +type hintNativeOutputCircuit[T FieldParams] struct { + Nominator Element[T] + Denominator Element[T] + Expected frontend.Variable +} + +func (c *hintNativeOutputCircuit[T]) Define(api frontend.API) error { + field, err := NewField[T](api) + if err != nil { + return err + } + res, err := field.NewHintWithNativeOutput(nativeOutputHint, 1, &c.Nominator, &c.Denominator) + if err != nil { + return err + } + api.AssertIsEqual(res[0], c.Expected) + return nil +} + +func testHintNativeOutput[T FieldParams](t *testing.T) { + var fr T + assert := test.NewAssert(t) + a, _ := rand.Int(rand.Reader, fr.Modulus()) + b, _ := rand.Int(rand.Reader, fr.Modulus()) + c := new(big.Int).ModInverse(b, testCurve.ScalarField()) + c.Mul(c, a) + c.Mod(c, testCurve.ScalarField()) + + circuit := hintNativeOutputCircuit[T]{} + witness := hintNativeOutputCircuit[T]{ + Nominator: ValueOf[T](a), + Denominator: ValueOf[T](b), + Expected: c, + } + assert.CheckCircuit(&circuit, test.WithValidAssignment(&witness), test.WithCurves(testCurve)) +} + +func TestHintNativeOutput(t *testing.T) { + testHintNativeOutput[Goldilocks](t) + testHintNativeOutput[Secp256k1Fp](t) + testHintNativeOutput[BN254Fp](t) +}