Skip to content

Commit

Permalink
perf: simplify the glv decomposition hint
Browse files Browse the repository at this point in the history
  • Loading branch information
yelhousni committed Mar 7, 2024
1 parent 2f2fadc commit c35311d
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 26 deletions.
29 changes: 12 additions & 17 deletions std/algebra/emulated/sw_emulated/hints.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,40 +22,35 @@ func decomposeScalarG1(mod *big.Int, inputs []*big.Int, outputs []*big.Int) erro
if len(inputs) != 3 {
return fmt.Errorf("expecting three inputs")
}
if len(outputs) != 7 {
return fmt.Errorf("expecting seven outputs")
if len(outputs) != 6 {
return fmt.Errorf("expecting six outputs")
}
glvBasis := new(ecc.Lattice)
ecc.PrecomputeLattice(inputs[2], inputs[1], glvBasis)
sp := ecc.SplitScalar(inputs[0], glvBasis)
outputs[0].Set(&(sp[0]))
outputs[1].Set(&(sp[1]))
// figure out how many times we have overflowed
outputs[2].Mul(outputs[1], inputs[1]).Add(outputs[2], outputs[0])
outputs[2].Sub(outputs[2], inputs[0])
outputs[2].Div(outputs[2], inputs[2])

// we need the negative values for to check that s0+λ*s1 == s mod r
// output5 = s0 mod r
// output6 = s1 mod r
outputs[5].Set(outputs[0])
outputs[6].Set(outputs[1])
// output4 = s0 mod r
// output5 = s1 mod r
outputs[4].Set(outputs[0])
outputs[5].Set(outputs[1])
// we need the absolute values for the in-circuit computations,
// otherwise the negative values will be reduced modulo the SNARK scalar
// field and not the emulated field.
// output0 = |s0| mod r
// output1 = |s1| mod r
// output3 = 1 if s0 is positive, 0 if s0 is negative
// output4 = 1 if s1 is positive, 0 if s0 is negative
outputs[3].SetUint64(1)
// output2 = 1 if s0 is positive, 0 if s0 is negative
// output3 = 1 if s1 is positive, 0 if s0 is negative
outputs[2].SetUint64(1)
if outputs[0].Sign() == -1 {
outputs[0].Neg(outputs[0])
outputs[3].SetUint64(0)
outputs[2].SetUint64(0)
}
outputs[4].SetUint64(1)
outputs[3].SetUint64(1)
if outputs[1].Sign() == -1 {
outputs[1].Neg(outputs[1])
outputs[4].SetUint64(0)
outputs[3].SetUint64(0)
}

return nil
Expand Down
18 changes: 9 additions & 9 deletions std/algebra/emulated/sw_emulated/point.go
Original file line number Diff line number Diff line change
Expand Up @@ -522,15 +522,15 @@ func (c *Curve[B, S]) scalarMulGLV(Q *AffinePoint[B], s *emulated.Element[S], op
one := c.baseApi.One()
Q = c.Select(selector, &AffinePoint[B]{X: *one, Y: *one}, Q)
}
sd, err := c.scalarApi.NewHint(decomposeScalarG1, 7, s, c.eigenvalue, frModulus)
sd, err := c.scalarApi.NewHint(decomposeScalarG1, 6, s, c.eigenvalue, frModulus)
if err != nil {
panic(fmt.Sprintf("compute GLV decomposition: %v", err))
}
s1, s2, s3, s4, s5, s6 := sd[0], sd[1], sd[3], sd[4], sd[5], sd[6]
s1, s2, s3, s4, s5, s6 := sd[0], sd[1], sd[2], sd[3], sd[4], sd[5]

c.scalarApi.AssertIsEqual(
c.scalarApi.Add(s5, c.scalarApi.Mul(s6, c.eigenvalue)),
c.scalarApi.Add(s, c.scalarApi.Mul(frModulus, sd[2])),
s,
)

// s1, s2 can be negative (bigints) in the hint, but will be reduced
Expand Down Expand Up @@ -858,27 +858,27 @@ func (c *Curve[B, S]) jointScalarMulGLV(p1, p2 *AffinePoint[B], s1, s2 *emulated
func (c *Curve[B, S]) jointScalarMulGLVUnsafe(Q, R *AffinePoint[B], s, t *emulated.Element[S]) *AffinePoint[B] {
var st S
frModulus := c.scalarApi.Modulus()
sd, err := c.scalarApi.NewHint(decomposeScalarG1, 7, s, c.eigenvalue, frModulus)
sd, err := c.scalarApi.NewHint(decomposeScalarG1, 6, s, c.eigenvalue, frModulus)
if err != nil {
// err is non-nil only for invalid number of inputs
panic(err)
}
s1, s2, s3, s4, s5, s6 := sd[0], sd[1], sd[3], sd[4], sd[5], sd[6]
s1, s2, s3, s4, s5, s6 := sd[0], sd[1], sd[2], sd[3], sd[4], sd[5]

td, err := c.scalarApi.NewHint(decomposeScalarG1, 7, t, c.eigenvalue, frModulus)
td, err := c.scalarApi.NewHint(decomposeScalarG1, 6, t, c.eigenvalue, frModulus)
if err != nil {
// err is non-nil only for invalid number of inputs
panic(err)
}
t1, t2, t3, t4, t5, t6 := td[0], td[1], td[3], td[4], td[5], td[6]
t1, t2, t3, t4, t5, t6 := td[0], td[1], td[2], td[3], td[4], td[5]

c.scalarApi.AssertIsEqual(
c.scalarApi.Add(s5, c.scalarApi.Mul(s6, c.eigenvalue)),
c.scalarApi.Add(s, c.scalarApi.Mul(frModulus, sd[2])),
s,
)
c.scalarApi.AssertIsEqual(
c.scalarApi.Add(t5, c.scalarApi.Mul(t6, c.eigenvalue)),
c.scalarApi.Add(t, c.scalarApi.Mul(frModulus, td[2])),
t,
)

// s1, s2 can be negative (bigints) in the hint, but will be reduced
Expand Down

0 comments on commit c35311d

Please sign in to comment.