Skip to content

Commit

Permalink
feat: circuit output Consensys#7
Browse files Browse the repository at this point in the history
  • Loading branch information
shramee committed Aug 22, 2024
1 parent 5d2a4b8 commit 3709908
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 37 deletions.
17 changes: 6 additions & 11 deletions std/algebra/emulated/sw_bn254/cairo_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,16 @@ import (
)

type e2Mul struct {
A, B, C f.E2
A, B f.E2
}

func (circuit *e2Mul) Define(api frontend.API) error {

em.ParseCircuitFields[em.BN254Fp](api, circuit)
em.ParseCircuitInputs[em.BN254Fp](api, circuit)

e := f.NewExt2(api)
expected := e.Mul(&circuit.A, &circuit.B)
e.AssertIsEqual(expected, &circuit.C)

em.FinishCairoCircuit[em.BN254Fp](api)
em.FinishCairoCircuit[em.BN254Fp](api, expected)
return nil
}

Expand All @@ -39,7 +37,6 @@ func TestCairoE2Mul(t *testing.T) {
witness := e2Mul{
A: f.FromE2(&a),
B: f.FromE2(&b),
C: f.FromE2(&c),
}

err := test.IsSolved(&e2Mul{}, &witness, ecc.BN254.ScalarField())
Expand All @@ -48,15 +45,14 @@ func TestCairoE2Mul(t *testing.T) {
}

type e12Mul struct {
A, B, C f.E12
A, B f.E12
}

func (circuit *e12Mul) Define(api frontend.API) error {
em.ParseCircuitFields[em.BN254Fp](api, circuit)
em.ParseCircuitInputs[em.BN254Fp](api, circuit)
e := f.NewExt12(api)
expected := e.Mul(&circuit.A, &circuit.B)
e.AssertIsEqual(expected, &circuit.C)
em.FinishCairoCircuit[em.BN254Fp](api)
em.FinishCairoCircuit[em.BN254Fp](api, expected)
return nil
}

Expand All @@ -72,7 +68,6 @@ func TestCairoE12Mul(t *testing.T) {
witness := e12Mul{
A: f.FromE12(&a),
B: f.FromE12(&b),
C: f.FromE12(&c),
}

err := test.IsSolved(&e12Mul{}, &witness, ecc.BN254.ScalarField())
Expand Down
68 changes: 42 additions & 26 deletions std/math/emulated/cairo.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,11 @@ type CairoManager[T FieldParams] struct {
func cairoCodeFromTypeAndIndex(ty string, i int) string {
switch ty {
case "op_res":
return fmt.Sprintf("op_%d", i)
return fmt.Sprintf("t%d", i)
case "input":
return fmt.Sprintf("in_%d", i)
return fmt.Sprintf("i%d", i)
case "cnst":
return fmt.Sprintf("cn_%d", i)
return fmt.Sprintf("c%d", i)
default:
panic("Cairo circuit manager: Unknown element type")
}
Expand Down Expand Up @@ -121,23 +121,18 @@ func moduleCode(modulo string) string {
};
type CI<const T: usize> = CircuitElement::<CircuitInput<T>>;
const PRIME: [u96; 4] = %s;
pub fn circuit(mut inputs: Array<u384>) -> Array<u384> {
%%[2]s
let circuit = (op_8, op_9);
let modulus = TryInto::<_, CircuitModulus>::try_into(PRIME).unwrap();
let modulus = TryInto::<_, CircuitModulus>::try_into(%[1]s).unwrap();
let circuit = (%%[5]s);
let mut circuit = circuit.new_inputs();
assert( inputs.len() == %%[1]d, 'input required %%[1]d' );
for input in inputs {
circuit = circuit.next(input);
};
%%[3]s
let circuit = circuit.done().eval(modulus).unwrap();
let mut outputs = array![];
outputs.append(circuit.get_output(op_8));
outputs.append(circuit.get_output(op_9));
outputs
array![%%[4]s]
}
`, modulo)
}
Expand Down Expand Up @@ -209,16 +204,23 @@ func (cm *CairoManager[T]) FinishCircuit(fieldSizeEl *Element[T]) string {
cm.CodeCircuit += cm.CodeInputEl(el.CairoVarName(), i+inputsCount, "")
indentedConsts += fmt.Sprintf("circuit = circuit.next(%s);\n", el.ToU384())
}
indentedConsts = strings.ReplaceAll(indentedConsts, "\n", "\n ")

indentedOutputs := ""
circuitOutputs := ""
for _, el := range cm.Outputs {
indentedOutputs += fmt.Sprintf("\ncircuit.get_output(%s),", el.CairoVarName())
circuitOutputs += el.CairoVarName() + ", "
}
indentedOutputs = strings.ReplaceAll(indentedOutputs, "\n", "\n ") + "\n "

// Generation operations code from CairoManager.Ops
for _, op := range cm.Ops {
cm.CodeCircuit += cm.CodeOperation(op)
}

indentedCode := strings.ReplaceAll(cm.CodeCircuit, "\n", "\n ")
indentedConsts = strings.ReplaceAll(indentedConsts, "\n", "\n ")

code := fmt.Sprintf(cm.CodeModule, inputsCount, indentedCode, indentedConsts)
code := fmt.Sprintf(cm.CodeModule, inputsCount, indentedCode, indentedConsts, indentedOutputs, circuitOutputs)
// println(" ", indentedCode)
cm.Reset()
return code
Expand Down Expand Up @@ -247,8 +249,9 @@ func GetField[T FieldParams](api frontend.API) *Field[T] {
}

// FinishCairoCircuit runs FinishCircuit to generate cairo circuit code
func FinishCairoCircuit[T FieldParams](api frontend.API) {
func FinishCairoCircuit[T FieldParams, CT any](api frontend.API, output *CT) {
// NewField returns a cached copy of the field if available
ParseCircuitOutputs[T](api, output)
field := GetField[T](api)
code := field.Cairo.FinishCircuit(field.Modulus())
println(code)
Expand All @@ -258,15 +261,22 @@ func FinishCairoCircuit[T FieldParams](api frontend.API) {

// #region Circuit registration calls

// Generic method to register inputs from any type
func ParseCircuitFields[T FieldParams, CT any](api frontend.API, circuit *CT) {
// ParseCircuitInputs registers inputs from Elements in any type
func ParseCircuitInputs[T FieldParams, CT any](api frontend.API, circuit *CT) {
field := GetField[T](api)
val := reflect.ValueOf(circuit).Elem()
field.Cairo.ParseCircuitFields(val, "", field.Cairo.MaybeAddInput)
}

// ParseCircuitOutputs registers outputs from Elements in any type
func ParseCircuitOutputs[T FieldParams, CT any](api frontend.API, circuit *CT) {
field := GetField[T](api)
val := reflect.ValueOf(circuit).Elem()
field.Cairo.ParseCircuitFields(val, "")
field.Cairo.ParseCircuitFields(val, "", func(el *Element[T], _ string) *Element[T] { return field.Cairo.AddOutput(el) })
}

// Helper function to register each field in a struct recursively
func (cm *CairoManager[T]) ParseCircuitFields(v reflect.Value, path string) {
func (cm *CairoManager[T]) ParseCircuitFields(v reflect.Value, path string, op func(*Element[T], string) *Element[T]) {
v = reflect.Indirect(v) // Dereference pointers
typ := v.Type().Name()

Expand All @@ -275,7 +285,7 @@ func (cm *CairoManager[T]) ParseCircuitFields(v reflect.Value, path string) {
el := v.Addr().Interface().(*Element[T])
// .(*Element[T])
// el := v.Interface().(Element[T])
cm.MaybeAddInput(el, path)
op(el, path)
// Here you would register the element with your witness object
} else if v.Kind() == reflect.Struct {
if path == "" {
Expand All @@ -285,7 +295,7 @@ func (cm *CairoManager[T]) ParseCircuitFields(v reflect.Value, path string) {
}
for i := 0; i < v.NumField(); i++ {
field := v.Field(i)
cm.ParseCircuitFields(field, fmt.Sprintf("%s[%d]", path, i))
cm.ParseCircuitFields(field, fmt.Sprintf("%s[%d]", path, i), op)
}
} else {
println(errors.New("cairo circuit manager: Unknown field type"))
Expand All @@ -309,6 +319,14 @@ func (cm *CairoManager[T]) MaybeAddInput(el *Element[T], immediateComment string
return el
}

// AddInput adds an input to Cairo circuit manager
// Value for the constant input is provided within the circuit
func (cm *CairoManager[T]) AddOutput(el *Element[T]) *Element[T] {
cEl := el.CairoEl()
cm.Outputs = append(cm.Outputs, cEl)
return el
}

func (cm *CairoManager[T]) CodeComment(comment string) {
cm.CodeCircuit += "//" + comment + "\n"
}
Expand All @@ -322,9 +340,7 @@ func (cm *CairoManager[T]) AddConstantInput(el *Element[T]) *Element[T] {
return el
}

// RegisterOperation registers an operation for (upto) two elements
// Returned res element has appropriate values for calling CairoVarName
// Adds correct ElType and ElIndex for the result element
// AssertZero adds an assertion for a variable to be zero
func (cm *CairoManager[T]) AssertZero(a, res *Element[T]) *Element[T] {
cm.Assertions = append(cm.Assertions, a.CairoVarName())
return res
Expand All @@ -348,8 +364,8 @@ func (cm *CairoManager[T]) RegisterOperation(a, b, res *Element[T], fnName strin
res.ElIndex = len(cm.Ops)
res.ElType = "op_res"
// println("res", res.Limbs[0], res.ElIndex, res.internal, res.ElType)
fmt.Printf("// OP<%d>::%s", res.ElIndex, fnName)
fmt.Printf("(%s, %s);\n", cairoA.CairoVarName(), cairoB.CairoVarName())
// fmt.Printf("// OP<%d>::%s", res.ElIndex, fnName)
// fmt.Printf("(%s, %s);\n", cairoA.CairoVarName(), cairoB.CairoVarName())

op := Op{
A: cairoA,
Expand Down

0 comments on commit 3709908

Please sign in to comment.