Skip to content

Commit

Permalink
MulConst support Consensys#5
Browse files Browse the repository at this point in the history
  • Loading branch information
shramee committed Aug 22, 2024
1 parent b6e343e commit 5d2a4b8
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 38 deletions.
99 changes: 63 additions & 36 deletions std/math/emulated/cairo.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ type Op struct {
}

// Tracks entire Cairo circuit
// Contains constant inputs, operations, input and output indices
// Contains constant inputs, operations, constants and output indices
type CairoManager[T FieldParams] struct {
// Field paramaeters for reference during code gen
FParams T
Expand Down Expand Up @@ -62,7 +62,7 @@ type CairoManager[T FieldParams] struct {
// CairoVarName returns code for the Element in Cairo circuit manager
func cairoCodeFromTypeAndIndex(ty string, i int) string {
switch ty {
case "output":
case "op_res":
return fmt.Sprintf("op_%d", i)
case "input":
return fmt.Sprintf("in_%d", i)
Expand All @@ -82,7 +82,7 @@ func (el *Element[T]) CairoVarName() string {
}

// CairoEl returns an El from Element[T].
func CairoEl[T FieldParams](el *Element[T]) El {
func (el *Element[T]) CairoEl() El {
return El{
Limbs: el.Limbs,
Index: el.ElIndex,
Expand Down Expand Up @@ -124,13 +124,15 @@ type CI<const T: usize> = CircuitElement::<CircuitInput<T>>;
const PRIME: [u96; 4] = %s;
pub fn circuit(mut inputs: Array<u384>) -> Array<u384> {
%%s
%%[2]s
let circuit = (op_8, op_9);
let modulus = TryInto::<_, CircuitModulus>::try_into(PRIME).unwrap();
let mut circuit = circuit.new_inputs();
assert( inputs.len() == %%[1]d, 'input required %%[1]d' );
for input in inputs {
circuit = circuit.next(input);
};
%%[3]s
let circuit = circuit.done().eval(modulus).unwrap();
let mut outputs = array![];
outputs.append(circuit.get_output(op_8));
Expand All @@ -141,10 +143,10 @@ pub fn circuit(mut inputs: Array<u384>) -> Array<u384> {
}

// GetValue returns the big.Int value of an Element[T].
func (el *Element[T]) GetValue() *big.Int {
func GetValue(limbs []frontend.Variable) *big.Int {
result := new(big.Int)
shift := uint(0)
for _, limb := range el.Limbs {
for _, limb := range limbs {
part := utils.FromInterface(limb)
part.Lsh(&part, shift) // Shift the current part by the appropriate number of bits.
result.Or(result, &part) // Combine it with the result using a bitwise OR.
Expand All @@ -155,7 +157,11 @@ func (el *Element[T]) GetValue() *big.Int {
}

func (el *Element[T]) ToU384() string {
return ToU384(el.GetValue())
return ToU384(GetValue(el.Limbs))
}

func (el *El) ToU384() string {
return ToU384(GetValue(el.Limbs))
}

// SplitBigIntTo96BitArray splits a big.Int into an array of 96-bit big.Ints, from low to high.
Expand Down Expand Up @@ -197,16 +203,23 @@ func (cm *CairoManager[T]) FinishCircuit(fieldSizeEl *Element[T]) string {
cm.CodeCircuit += cm.CodeInputEl(el.CairoVarName(), i, "")
}

inputsCount := len(cm.Inputs)
indentedConsts := ""
for i, el := range cm.ConstVals {
cm.CodeCircuit += cm.CodeInputEl(el.CairoVarName(), i+inputsCount, "")
indentedConsts += fmt.Sprintf("circuit = circuit.next(%s);\n", el.ToU384())
}

// Generation operations code from CairoManager.Ops
for _, op := range cm.Ops {
cm.CodeCircuit += cm.CodeOperation(op)
}

indentedCode := strings.ReplaceAll(cm.CodeCircuit, "\n", "\n ")
indentedConsts = strings.ReplaceAll(indentedConsts, "\n", "\n ")

code := fmt.Sprintf(cm.CodeModule, indentedCode)
code := fmt.Sprintf(cm.CodeModule, inputsCount, indentedCode, indentedConsts)
// println(" ", indentedCode)
println(code)
cm.Reset()
return code
}
Expand Down Expand Up @@ -237,7 +250,8 @@ func GetField[T FieldParams](api frontend.API) *Field[T] {
func FinishCairoCircuit[T FieldParams](api frontend.API) {
// NewField returns a cached copy of the field if available
field := GetField[T](api)
field.Cairo.FinishCircuit(field.Modulus())
code := field.Cairo.FinishCircuit(field.Modulus())
println(code)
}

// #endregion Circuit Utility functions
Expand Down Expand Up @@ -284,7 +298,7 @@ func (cm *CairoManager[T]) MaybeAddInput(el *Element[T], immediateComment string
if el.ElType == "" {
el.ElIndex = len(cm.Inputs)
el.ElType = "input"
cEl := CairoEl[T](el)
cEl := el.CairoEl()
immediateCode := immediateComment != ""
cEl.Skip = immediateCode
if immediateCode {
Expand All @@ -302,10 +316,9 @@ func (cm *CairoManager[T]) CodeComment(comment string) {
// AddConstantInput adds a constant input to Cairo circuit manager
// Value for the constant inputs is provided within the circuit code
func (cm *CairoManager[T]) AddConstantInput(el *Element[T]) *Element[T] {
e := CairoEl[T](el)
e.Type = "cnst"
e.Index = len(cm.ConstVals)
cm.ConstVals = append(cm.ConstVals, e)
el.ElType = "cnst"
el.ElIndex = len(cm.ConstVals)
cm.ConstVals = append(cm.ConstVals, el.CairoEl())
return el
}

Expand All @@ -320,33 +333,28 @@ func (cm *CairoManager[T]) AssertZero(a, res *Element[T]) *Element[T] {
// RegisterOperation registers an operation for (upto) two elements
// Returned res element has appropriate values for calling CairoVarName
// Adds correct ElType and ElIndex for the result element
func (cm *CairoManager[T]) RegisterOperation(a, b, res *Element[T]) *Element[T] {
var fnName string
pc, _, _, ok := runtime.Caller(2)
details := runtime.FuncForPC(pc)
if ok && details != nil {
fnName = strings.Replace(details.Name(), "github.com/consensys/gnark/std/math/emulated.(*Field[...]).", "", 1)
} else {
panic("Cairo circuit manager: Failed to get runtime details")
}

a = cm.MaybeAddInput(a, "")
func (cm *CairoManager[T]) RegisterOperation(a, b, res *Element[T], fnName string) *Element[T] {
var cairoB El
if b != nil {
b = cm.MaybeAddInput(b, "")
cairoB = b.CairoEl()
}
// println("a", a.Limbs[0], a.ElIndex, a.internal, a.ElType)
// println("b", b.Limbs[0], b.ElIndex, b.internal, b.ElType)
a = cm.MaybeAddInput(a, "")
cairoA := a.CairoEl()

if res.ElType != "" {
panic("Cairo circuit manager: Res element already registered")
}
res.ElIndex = len(cm.Ops)
res.ElType = "output"
res.ElType = "op_res"
// println("res", res.Limbs[0], res.ElIndex, res.internal, res.ElType)
// fmt.Printf("let %s = %s(%s, %s);\n", res.CairoVarName(), fnName, a.CairoVarName(), b.CairoVarName())
fmt.Printf("// OP<%d>::%s", res.ElIndex, fnName)
fmt.Printf("(%s, %s);\n", cairoA.CairoVarName(), cairoB.CairoVarName())

op := Op{
A: CairoEl(a),
B: CairoEl(b),
Out: CairoEl[T](res),
A: cairoA,
B: cairoB,
Out: res.CairoEl(),
Type: fnName,
}

Expand All @@ -355,6 +363,26 @@ func (cm *CairoManager[T]) RegisterOperation(a, b, res *Element[T]) *Element[T]
return res
}

// RegisterOperation registers an operation for (upto) two elements
// Returned res element has appropriate values for calling CairoVarName
// Adds correct ElType and ElIndex for the result element
func (cm *CairoManager[T]) RegisterReduceAndOp(a, b, res *Element[T]) *Element[T] {
var fnName string
pc, _, _, ok := runtime.Caller(2)
details := runtime.FuncForPC(pc)
if ok && details != nil {
fnName = strings.Replace(details.Name(), "github.com/consensys/gnark/std/math/emulated.(*Field[...]).", "", 1)
} else {
panic("Cairo circuit manager: Failed to get runtime details")
}

if fnName == "MulConst" {
return res
}

return cm.RegisterOperation(a, b, res, fnName)
}

// #endregion Circuit registration calls

// #region Circuit operations
Expand All @@ -371,11 +399,10 @@ func (cm *CairoManager[T]) FnNameToOperation(fnName string) string {
return "circuit_sub"
case "Inverse":
return "circuit_inverse"
case "Mul":
case "MulMod":
case "Mul", "MulMod", "MulConst":
return "circuit_mul"
}
panic("Cairo circuit manager: Unknown operation")
panic("Cairo circuit manager: Unknown operation " + fnName)
}

// #endregion Circuit operations
4 changes: 3 additions & 1 deletion std/math/emulated/field_mul.go
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,9 @@ func (f *Field[T]) MulConst(a *Element[T], c *big.Int) *Element[T] {
for i := range a.Limbs {
limbs[i] = f.api.Mul(a.Limbs[i], c)
}
return f.newInternalElement(limbs, a.overflow+cbl)
res := f.newInternalElement(limbs, a.overflow+cbl)
b := f.Cairo.AddConstantInput(newConstElement[T](c))
return f.Cairo.RegisterOperation(a, b, res, "MulConst")
},
func(a, _ *Element[T]) (nextOverflow uint, err error) {
nextOverflow = a.overflow + uint(cbl)
Expand Down
2 changes: 1 addition & 1 deletion std/math/emulated/field_ops.go
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ func (f *Field[T]) reduceAndOp(op func(*Element[T], *Element[T], uint) *Element[
}
}

return f.Cairo.RegisterOperation(a, b, op(a, b, nextOverflow))
return f.Cairo.RegisterReduceAndOp(a, b, op(a, b, nextOverflow))
}

type overflowError struct {
Expand Down

0 comments on commit 5d2a4b8

Please sign in to comment.