From 5d2a4b85c5663839094eb8d8240092f1e87d96e6 Mon Sep 17 00:00:00 2001 From: Shramee Srivastav Date: Thu, 22 Aug 2024 01:13:28 -0400 Subject: [PATCH] MulConst support #5 --- std/math/emulated/cairo.go | 99 +++++++++++++++++++++------------- std/math/emulated/field_mul.go | 4 +- std/math/emulated/field_ops.go | 2 +- 3 files changed, 67 insertions(+), 38 deletions(-) diff --git a/std/math/emulated/cairo.go b/std/math/emulated/cairo.go index 53528aba9..cc6ae338c 100644 --- a/std/math/emulated/cairo.go +++ b/std/math/emulated/cairo.go @@ -33,7 +33,7 @@ type Op struct { } // Tracks entire Cairo circuit -// Contains constant inputs, operations, input and output indices +// Contains constant inputs, operations, constants and output indices type CairoManager[T FieldParams] struct { // Field paramaeters for reference during code gen FParams T @@ -62,7 +62,7 @@ type CairoManager[T FieldParams] struct { // CairoVarName returns code for the Element in Cairo circuit manager func cairoCodeFromTypeAndIndex(ty string, i int) string { switch ty { - case "output": + case "op_res": return fmt.Sprintf("op_%d", i) case "input": return fmt.Sprintf("in_%d", i) @@ -82,7 +82,7 @@ func (el *Element[T]) CairoVarName() string { } // CairoEl returns an El from Element[T]. -func CairoEl[T FieldParams](el *Element[T]) El { +func (el *Element[T]) CairoEl() El { return El{ Limbs: el.Limbs, Index: el.ElIndex, @@ -124,13 +124,15 @@ type CI = CircuitElement::>; const PRIME: [u96; 4] = %s; pub fn circuit(mut inputs: Array) -> Array { - %%s + %%[2]s let circuit = (op_8, op_9); let modulus = TryInto::<_, CircuitModulus>::try_into(PRIME).unwrap(); let mut circuit = circuit.new_inputs(); + assert( inputs.len() == %%[1]d, 'input required %%[1]d' ); for input in inputs { circuit = circuit.next(input); }; + %%[3]s let circuit = circuit.done().eval(modulus).unwrap(); let mut outputs = array![]; outputs.append(circuit.get_output(op_8)); @@ -141,10 +143,10 @@ pub fn circuit(mut inputs: Array) -> Array { } // GetValue returns the big.Int value of an Element[T]. -func (el *Element[T]) GetValue() *big.Int { +func GetValue(limbs []frontend.Variable) *big.Int { result := new(big.Int) shift := uint(0) - for _, limb := range el.Limbs { + for _, limb := range limbs { part := utils.FromInterface(limb) part.Lsh(&part, shift) // Shift the current part by the appropriate number of bits. result.Or(result, &part) // Combine it with the result using a bitwise OR. @@ -155,7 +157,11 @@ func (el *Element[T]) GetValue() *big.Int { } func (el *Element[T]) ToU384() string { - return ToU384(el.GetValue()) + return ToU384(GetValue(el.Limbs)) +} + +func (el *El) ToU384() string { + return ToU384(GetValue(el.Limbs)) } // SplitBigIntTo96BitArray splits a big.Int into an array of 96-bit big.Ints, from low to high. @@ -197,16 +203,23 @@ func (cm *CairoManager[T]) FinishCircuit(fieldSizeEl *Element[T]) string { cm.CodeCircuit += cm.CodeInputEl(el.CairoVarName(), i, "") } + inputsCount := len(cm.Inputs) + indentedConsts := "" + for i, el := range cm.ConstVals { + cm.CodeCircuit += cm.CodeInputEl(el.CairoVarName(), i+inputsCount, "") + indentedConsts += fmt.Sprintf("circuit = circuit.next(%s);\n", el.ToU384()) + } + // Generation operations code from CairoManager.Ops for _, op := range cm.Ops { cm.CodeCircuit += cm.CodeOperation(op) } indentedCode := strings.ReplaceAll(cm.CodeCircuit, "\n", "\n ") + indentedConsts = strings.ReplaceAll(indentedConsts, "\n", "\n ") - code := fmt.Sprintf(cm.CodeModule, indentedCode) + code := fmt.Sprintf(cm.CodeModule, inputsCount, indentedCode, indentedConsts) // println(" ", indentedCode) - println(code) cm.Reset() return code } @@ -237,7 +250,8 @@ func GetField[T FieldParams](api frontend.API) *Field[T] { func FinishCairoCircuit[T FieldParams](api frontend.API) { // NewField returns a cached copy of the field if available field := GetField[T](api) - field.Cairo.FinishCircuit(field.Modulus()) + code := field.Cairo.FinishCircuit(field.Modulus()) + println(code) } // #endregion Circuit Utility functions @@ -284,7 +298,7 @@ func (cm *CairoManager[T]) MaybeAddInput(el *Element[T], immediateComment string if el.ElType == "" { el.ElIndex = len(cm.Inputs) el.ElType = "input" - cEl := CairoEl[T](el) + cEl := el.CairoEl() immediateCode := immediateComment != "" cEl.Skip = immediateCode if immediateCode { @@ -302,10 +316,9 @@ func (cm *CairoManager[T]) CodeComment(comment string) { // AddConstantInput adds a constant input to Cairo circuit manager // Value for the constant inputs is provided within the circuit code func (cm *CairoManager[T]) AddConstantInput(el *Element[T]) *Element[T] { - e := CairoEl[T](el) - e.Type = "cnst" - e.Index = len(cm.ConstVals) - cm.ConstVals = append(cm.ConstVals, e) + el.ElType = "cnst" + el.ElIndex = len(cm.ConstVals) + cm.ConstVals = append(cm.ConstVals, el.CairoEl()) return el } @@ -320,33 +333,28 @@ func (cm *CairoManager[T]) AssertZero(a, res *Element[T]) *Element[T] { // RegisterOperation registers an operation for (upto) two elements // Returned res element has appropriate values for calling CairoVarName // Adds correct ElType and ElIndex for the result element -func (cm *CairoManager[T]) RegisterOperation(a, b, res *Element[T]) *Element[T] { - var fnName string - pc, _, _, ok := runtime.Caller(2) - details := runtime.FuncForPC(pc) - if ok && details != nil { - fnName = strings.Replace(details.Name(), "github.com/consensys/gnark/std/math/emulated.(*Field[...]).", "", 1) - } else { - panic("Cairo circuit manager: Failed to get runtime details") - } - - a = cm.MaybeAddInput(a, "") +func (cm *CairoManager[T]) RegisterOperation(a, b, res *Element[T], fnName string) *Element[T] { + var cairoB El if b != nil { b = cm.MaybeAddInput(b, "") + cairoB = b.CairoEl() } - // println("a", a.Limbs[0], a.ElIndex, a.internal, a.ElType) - // println("b", b.Limbs[0], b.ElIndex, b.internal, b.ElType) + a = cm.MaybeAddInput(a, "") + cairoA := a.CairoEl() + if res.ElType != "" { panic("Cairo circuit manager: Res element already registered") } res.ElIndex = len(cm.Ops) - res.ElType = "output" + res.ElType = "op_res" // println("res", res.Limbs[0], res.ElIndex, res.internal, res.ElType) - // fmt.Printf("let %s = %s(%s, %s);\n", res.CairoVarName(), fnName, a.CairoVarName(), b.CairoVarName()) + fmt.Printf("// OP<%d>::%s", res.ElIndex, fnName) + fmt.Printf("(%s, %s);\n", cairoA.CairoVarName(), cairoB.CairoVarName()) + op := Op{ - A: CairoEl(a), - B: CairoEl(b), - Out: CairoEl[T](res), + A: cairoA, + B: cairoB, + Out: res.CairoEl(), Type: fnName, } @@ -355,6 +363,26 @@ func (cm *CairoManager[T]) RegisterOperation(a, b, res *Element[T]) *Element[T] return res } +// RegisterOperation registers an operation for (upto) two elements +// Returned res element has appropriate values for calling CairoVarName +// Adds correct ElType and ElIndex for the result element +func (cm *CairoManager[T]) RegisterReduceAndOp(a, b, res *Element[T]) *Element[T] { + var fnName string + pc, _, _, ok := runtime.Caller(2) + details := runtime.FuncForPC(pc) + if ok && details != nil { + fnName = strings.Replace(details.Name(), "github.com/consensys/gnark/std/math/emulated.(*Field[...]).", "", 1) + } else { + panic("Cairo circuit manager: Failed to get runtime details") + } + + if fnName == "MulConst" { + return res + } + + return cm.RegisterOperation(a, b, res, fnName) +} + // #endregion Circuit registration calls // #region Circuit operations @@ -371,11 +399,10 @@ func (cm *CairoManager[T]) FnNameToOperation(fnName string) string { return "circuit_sub" case "Inverse": return "circuit_inverse" - case "Mul": - case "MulMod": + case "Mul", "MulMod", "MulConst": return "circuit_mul" } - panic("Cairo circuit manager: Unknown operation") + panic("Cairo circuit manager: Unknown operation " + fnName) } // #endregion Circuit operations diff --git a/std/math/emulated/field_mul.go b/std/math/emulated/field_mul.go index 278b9a502..6ead4a97f 100644 --- a/std/math/emulated/field_mul.go +++ b/std/math/emulated/field_mul.go @@ -448,7 +448,9 @@ func (f *Field[T]) MulConst(a *Element[T], c *big.Int) *Element[T] { for i := range a.Limbs { limbs[i] = f.api.Mul(a.Limbs[i], c) } - return f.newInternalElement(limbs, a.overflow+cbl) + res := f.newInternalElement(limbs, a.overflow+cbl) + b := f.Cairo.AddConstantInput(newConstElement[T](c)) + return f.Cairo.RegisterOperation(a, b, res, "MulConst") }, func(a, _ *Element[T]) (nextOverflow uint, err error) { nextOverflow = a.overflow + uint(cbl) diff --git a/std/math/emulated/field_ops.go b/std/math/emulated/field_ops.go index 1835bb232..4bd79ef7f 100644 --- a/std/math/emulated/field_ops.go +++ b/std/math/emulated/field_ops.go @@ -356,7 +356,7 @@ func (f *Field[T]) reduceAndOp(op func(*Element[T], *Element[T], uint) *Element[ } } - return f.Cairo.RegisterOperation(a, b, op(a, b, nextOverflow)) + return f.Cairo.RegisterReduceAndOp(a, b, op(a, b, nextOverflow)) } type overflowError struct {