Skip to content

Commit

Permalink
feat: add non-native hint with native inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
ivokub committed Mar 11, 2024
1 parent 2c49a0f commit c9cf735
Showing 1 changed file with 63 additions and 19 deletions.
82 changes: 63 additions & 19 deletions std/math/emulated/field_hint.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,13 @@ func (f *Field[T]) wrapHint(nonnativeInputs ...*Element[T]) []frontend.Variable
return res
}

func (f *Field[T]) wrapHintNatives(nativeInputs ...frontend.Variable) []frontend.Variable {
res := []frontend.Variable{f.fParams.BitsPerLimb(), f.fParams.NbLimbs()}
res = append(res, f.Modulus().Limbs...)
res = append(res, nativeInputs...)
return res
}

// UnwrapHint unwraps the native inputs into nonnative inputs. Then it calls
// nonnativeHint function with nonnative inputs. After nonnativeHint returns, it
// decomposes the outputs into limbs.
Expand All @@ -33,6 +40,13 @@ func UnwrapHintWithNativeOutput(nativeInputs, nativeOutputs []*big.Int, nonnativ
return unwrapHint(true, false, nativeInputs, nativeOutputs, nonnativeHint)
}

// UnwrapHintWithNativeInput unwraps the native inputs into native inputs. Then
// it calls nonnativeHint function with native inputs. After nonnativeHint
// returns, it decomposes the outputs into limbs.
func UnwrapHintWithNativeInput(nativeInputs, nativeOutputs []*big.Int, nonnativeHint solver.Hint) error {
return unwrapHint(false, true, nativeInputs, nativeOutputs, nonnativeHint)
}

func unwrapHint(isEmulatedInput, isEmulatedOutput bool, nativeInputs, nativeOutputs []*big.Int, nonnativeHint solver.Hint) error {
if len(nativeInputs) < 2 {
return fmt.Errorf("hint wrapper header is 2 elements")
Expand All @@ -49,28 +63,38 @@ func unwrapHint(isEmulatedInput, isEmulatedOutput bool, nativeInputs, nativeOutp
if err := recompose(nativeInputs[2:2+nbLimbs], uint(nbBits), nonnativeMod); err != nil {
return fmt.Errorf("cannot recover nonnative mod: %w", err)
}
if !nativeInputs[2+nbLimbs].IsInt64() {
return fmt.Errorf("number of nonnative elements must be castable to int64")
}
nbInputs := int(nativeInputs[2+nbLimbs].Int64())
nonnativeInputs := make([]*big.Int, nbInputs)
readPtr := 3 + nbLimbs
for i := 0; i < nbInputs; i++ {
if len(nativeInputs) < readPtr+1 {
return fmt.Errorf("can not read %d-th native input", i)
}
if !nativeInputs[readPtr].IsInt64() {
return fmt.Errorf("corrupted %d-th native input", i)
var nonnativeInputs []*big.Int
if isEmulatedInput {
if !nativeInputs[2+nbLimbs].IsInt64() {
return fmt.Errorf("number of nonnative elements must be castable to int64")
}
currentInputLen := int(nativeInputs[readPtr].Int64())
if len(nativeInputs) < (readPtr + 1 + currentInputLen) {
return fmt.Errorf("cannot read %d-th nonnative element", i)
nbInputs := int(nativeInputs[2+nbLimbs].Int64())
readPtr := 3 + nbLimbs
nonnativeInputs = make([]*big.Int, nbInputs)
for i := 0; i < nbInputs; i++ {
if len(nativeInputs) < readPtr+1 {
return fmt.Errorf("can not read %d-th native input", i)
}
if !nativeInputs[readPtr].IsInt64() {
return fmt.Errorf("corrupted %d-th native input", i)
}
currentInputLen := int(nativeInputs[readPtr].Int64())
if len(nativeInputs) < (readPtr + 1 + currentInputLen) {
return fmt.Errorf("cannot read %d-th nonnative element", i)
}
nonnativeInputs[i] = new(big.Int)
if err := recompose(nativeInputs[readPtr+1:readPtr+1+currentInputLen], uint(nbBits), nonnativeInputs[i]); err != nil {
return fmt.Errorf("recompose %d-th element: %w", i, err)
}
readPtr += 1 + currentInputLen
}
nonnativeInputs[i] = new(big.Int)
if err := recompose(nativeInputs[readPtr+1:readPtr+1+currentInputLen], uint(nbBits), nonnativeInputs[i]); err != nil {
return fmt.Errorf("recompose %d-th element: %w", i, err)
} else {
nbInputs := len(nativeInputs[2+nbLimbs:])
readPtr := 2 + nbLimbs
nonnativeInputs = make([]*big.Int, nbInputs)
for i := 0; i < nbInputs; i++ {
nonnativeInputs[i] = new(big.Int).Set(nativeInputs[readPtr+i])
}
readPtr += 1 + currentInputLen
}

var nonnativeOutputs []*big.Int
Expand Down Expand Up @@ -150,3 +174,23 @@ func (f *Field[T]) NewHintWithNativeOutput(hf solver.Hint, nbOutputs int, inputs
}
return nativeOutputs, nil
}

// NewHintWithNativeInput allows to call the emulation hint function hf on
// native inputs, expecting nbOutputs results. This function passes the native
// inputs to the hint function directly and reconstructs the outputs into
// non-native elements. There is [UnwrapHintWithNativeInput] function which
// performs corresponding recomposition of limbs into integer values (and vice
// verse for output).
func (f *Field[T]) NewHintWithNativeInput(hf solver.Hint, nbOutputs int, inputs ...frontend.Variable) ([]*Element[T], error) {
nativeInputs := f.wrapHintNatives(inputs...)
nbNativeOutputs := int(f.fParams.NbLimbs()) * nbOutputs
nativeOutputs, err := f.api.Compiler().NewHint(hf, nbNativeOutputs, nativeInputs...)
if err != nil {
return nil, fmt.Errorf("call hint: %w", err)
}
outputs := make([]*Element[T], nbOutputs)
for i := 0; i < nbOutputs; i++ {
outputs[i] = f.packLimbs(nativeOutputs[i*int(f.fParams.NbLimbs()):(i+1)*int(f.fParams.NbLimbs())], true)
}
return outputs, nil
}

0 comments on commit c9cf735

Please sign in to comment.