From 370990865ffe2d431e9ff5af02add01c95e59a15 Mon Sep 17 00:00:00 2001 From: Shramee Srivastav Date: Thu, 22 Aug 2024 02:41:22 -0400 Subject: [PATCH] feat: circuit output #7 --- std/algebra/emulated/sw_bn254/cairo_test.go | 17 ++---- std/math/emulated/cairo.go | 68 +++++++++++++-------- 2 files changed, 48 insertions(+), 37 deletions(-) diff --git a/std/algebra/emulated/sw_bn254/cairo_test.go b/std/algebra/emulated/sw_bn254/cairo_test.go index bd202fbf3..1778cfb39 100644 --- a/std/algebra/emulated/sw_bn254/cairo_test.go +++ b/std/algebra/emulated/sw_bn254/cairo_test.go @@ -13,18 +13,16 @@ import ( ) type e2Mul struct { - A, B, C f.E2 + A, B f.E2 } func (circuit *e2Mul) Define(api frontend.API) error { - em.ParseCircuitFields[em.BN254Fp](api, circuit) + em.ParseCircuitInputs[em.BN254Fp](api, circuit) e := f.NewExt2(api) expected := e.Mul(&circuit.A, &circuit.B) - e.AssertIsEqual(expected, &circuit.C) - - em.FinishCairoCircuit[em.BN254Fp](api) + em.FinishCairoCircuit[em.BN254Fp](api, expected) return nil } @@ -39,7 +37,6 @@ func TestCairoE2Mul(t *testing.T) { witness := e2Mul{ A: f.FromE2(&a), B: f.FromE2(&b), - C: f.FromE2(&c), } err := test.IsSolved(&e2Mul{}, &witness, ecc.BN254.ScalarField()) @@ -48,15 +45,14 @@ func TestCairoE2Mul(t *testing.T) { } type e12Mul struct { - A, B, C f.E12 + A, B f.E12 } func (circuit *e12Mul) Define(api frontend.API) error { - em.ParseCircuitFields[em.BN254Fp](api, circuit) + em.ParseCircuitInputs[em.BN254Fp](api, circuit) e := f.NewExt12(api) expected := e.Mul(&circuit.A, &circuit.B) - e.AssertIsEqual(expected, &circuit.C) - em.FinishCairoCircuit[em.BN254Fp](api) + em.FinishCairoCircuit[em.BN254Fp](api, expected) return nil } @@ -72,7 +68,6 @@ func TestCairoE12Mul(t *testing.T) { witness := e12Mul{ A: f.FromE12(&a), B: f.FromE12(&b), - C: f.FromE12(&c), } err := test.IsSolved(&e12Mul{}, &witness, ecc.BN254.ScalarField()) diff --git a/std/math/emulated/cairo.go b/std/math/emulated/cairo.go index cc6ae338c..b6dfc88c3 100644 --- a/std/math/emulated/cairo.go +++ b/std/math/emulated/cairo.go @@ -63,11 +63,11 @@ type CairoManager[T FieldParams] struct { func cairoCodeFromTypeAndIndex(ty string, i int) string { switch ty { case "op_res": - return fmt.Sprintf("op_%d", i) + return fmt.Sprintf("t%d", i) case "input": - return fmt.Sprintf("in_%d", i) + return fmt.Sprintf("i%d", i) case "cnst": - return fmt.Sprintf("cn_%d", i) + return fmt.Sprintf("c%d", i) default: panic("Cairo circuit manager: Unknown element type") } @@ -121,12 +121,10 @@ func moduleCode(modulo string) string { }; type CI = CircuitElement::>; -const PRIME: [u96; 4] = %s; - pub fn circuit(mut inputs: Array) -> Array { %%[2]s - let circuit = (op_8, op_9); - let modulus = TryInto::<_, CircuitModulus>::try_into(PRIME).unwrap(); + let modulus = TryInto::<_, CircuitModulus>::try_into(%[1]s).unwrap(); + let circuit = (%%[5]s); let mut circuit = circuit.new_inputs(); assert( inputs.len() == %%[1]d, 'input required %%[1]d' ); for input in inputs { @@ -134,10 +132,7 @@ pub fn circuit(mut inputs: Array) -> Array { }; %%[3]s let circuit = circuit.done().eval(modulus).unwrap(); - let mut outputs = array![]; - outputs.append(circuit.get_output(op_8)); - outputs.append(circuit.get_output(op_9)); - outputs + array![%%[4]s] } `, modulo) } @@ -209,16 +204,23 @@ func (cm *CairoManager[T]) FinishCircuit(fieldSizeEl *Element[T]) string { cm.CodeCircuit += cm.CodeInputEl(el.CairoVarName(), i+inputsCount, "") indentedConsts += fmt.Sprintf("circuit = circuit.next(%s);\n", el.ToU384()) } + indentedConsts = strings.ReplaceAll(indentedConsts, "\n", "\n ") + + indentedOutputs := "" + circuitOutputs := "" + for _, el := range cm.Outputs { + indentedOutputs += fmt.Sprintf("\ncircuit.get_output(%s),", el.CairoVarName()) + circuitOutputs += el.CairoVarName() + ", " + } + indentedOutputs = strings.ReplaceAll(indentedOutputs, "\n", "\n ") + "\n " // Generation operations code from CairoManager.Ops for _, op := range cm.Ops { cm.CodeCircuit += cm.CodeOperation(op) } - indentedCode := strings.ReplaceAll(cm.CodeCircuit, "\n", "\n ") - indentedConsts = strings.ReplaceAll(indentedConsts, "\n", "\n ") - code := fmt.Sprintf(cm.CodeModule, inputsCount, indentedCode, indentedConsts) + code := fmt.Sprintf(cm.CodeModule, inputsCount, indentedCode, indentedConsts, indentedOutputs, circuitOutputs) // println(" ", indentedCode) cm.Reset() return code @@ -247,8 +249,9 @@ func GetField[T FieldParams](api frontend.API) *Field[T] { } // FinishCairoCircuit runs FinishCircuit to generate cairo circuit code -func FinishCairoCircuit[T FieldParams](api frontend.API) { +func FinishCairoCircuit[T FieldParams, CT any](api frontend.API, output *CT) { // NewField returns a cached copy of the field if available + ParseCircuitOutputs[T](api, output) field := GetField[T](api) code := field.Cairo.FinishCircuit(field.Modulus()) println(code) @@ -258,15 +261,22 @@ func FinishCairoCircuit[T FieldParams](api frontend.API) { // #region Circuit registration calls -// Generic method to register inputs from any type -func ParseCircuitFields[T FieldParams, CT any](api frontend.API, circuit *CT) { +// ParseCircuitInputs registers inputs from Elements in any type +func ParseCircuitInputs[T FieldParams, CT any](api frontend.API, circuit *CT) { + field := GetField[T](api) + val := reflect.ValueOf(circuit).Elem() + field.Cairo.ParseCircuitFields(val, "", field.Cairo.MaybeAddInput) +} + +// ParseCircuitOutputs registers outputs from Elements in any type +func ParseCircuitOutputs[T FieldParams, CT any](api frontend.API, circuit *CT) { field := GetField[T](api) val := reflect.ValueOf(circuit).Elem() - field.Cairo.ParseCircuitFields(val, "") + field.Cairo.ParseCircuitFields(val, "", func(el *Element[T], _ string) *Element[T] { return field.Cairo.AddOutput(el) }) } // Helper function to register each field in a struct recursively -func (cm *CairoManager[T]) ParseCircuitFields(v reflect.Value, path string) { +func (cm *CairoManager[T]) ParseCircuitFields(v reflect.Value, path string, op func(*Element[T], string) *Element[T]) { v = reflect.Indirect(v) // Dereference pointers typ := v.Type().Name() @@ -275,7 +285,7 @@ func (cm *CairoManager[T]) ParseCircuitFields(v reflect.Value, path string) { el := v.Addr().Interface().(*Element[T]) // .(*Element[T]) // el := v.Interface().(Element[T]) - cm.MaybeAddInput(el, path) + op(el, path) // Here you would register the element with your witness object } else if v.Kind() == reflect.Struct { if path == "" { @@ -285,7 +295,7 @@ func (cm *CairoManager[T]) ParseCircuitFields(v reflect.Value, path string) { } for i := 0; i < v.NumField(); i++ { field := v.Field(i) - cm.ParseCircuitFields(field, fmt.Sprintf("%s[%d]", path, i)) + cm.ParseCircuitFields(field, fmt.Sprintf("%s[%d]", path, i), op) } } else { println(errors.New("cairo circuit manager: Unknown field type")) @@ -309,6 +319,14 @@ func (cm *CairoManager[T]) MaybeAddInput(el *Element[T], immediateComment string return el } +// AddInput adds an input to Cairo circuit manager +// Value for the constant input is provided within the circuit +func (cm *CairoManager[T]) AddOutput(el *Element[T]) *Element[T] { + cEl := el.CairoEl() + cm.Outputs = append(cm.Outputs, cEl) + return el +} + func (cm *CairoManager[T]) CodeComment(comment string) { cm.CodeCircuit += "//" + comment + "\n" } @@ -322,9 +340,7 @@ func (cm *CairoManager[T]) AddConstantInput(el *Element[T]) *Element[T] { return el } -// RegisterOperation registers an operation for (upto) two elements -// Returned res element has appropriate values for calling CairoVarName -// Adds correct ElType and ElIndex for the result element +// AssertZero adds an assertion for a variable to be zero func (cm *CairoManager[T]) AssertZero(a, res *Element[T]) *Element[T] { cm.Assertions = append(cm.Assertions, a.CairoVarName()) return res @@ -348,8 +364,8 @@ func (cm *CairoManager[T]) RegisterOperation(a, b, res *Element[T], fnName strin res.ElIndex = len(cm.Ops) res.ElType = "op_res" // println("res", res.Limbs[0], res.ElIndex, res.internal, res.ElType) - fmt.Printf("// OP<%d>::%s", res.ElIndex, fnName) - fmt.Printf("(%s, %s);\n", cairoA.CairoVarName(), cairoB.CairoVarName()) + // fmt.Printf("// OP<%d>::%s", res.ElIndex, fnName) + // fmt.Printf("(%s, %s);\n", cairoA.CairoVarName(), cairoB.CairoVarName()) op := Op{ A: cairoA,