Skip to content

Commit

Permalink
fix: edge cases in SM and JSM were inverted + comments
Browse files Browse the repository at this point in the history
  • Loading branch information
yelhousni committed Mar 8, 2024
1 parent 9fa2c4c commit a2f0bdc
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 84 deletions.
12 changes: 6 additions & 6 deletions std/algebra/emulated/sw_emulated/hints.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ func GetHints() []solver.Hint {
func decomposeScalarG1Subscalars(mod *big.Int, inputs []*big.Int, outputs []*big.Int) error {
return emulated.UnwrapHint(inputs, outputs, func(field *big.Int, inputs, outputs []*big.Int) error {
if len(inputs) != 2 {
return fmt.Errorf("expecting three inputs")
return fmt.Errorf("expecting two inputs")
}
if len(outputs) != 2 {
return fmt.Errorf("expecting six outputs")
return fmt.Errorf("expecting two outputs")
}
glvBasis := new(ecc.Lattice)
ecc.PrecomputeLattice(field, inputs[1], glvBasis)
Expand All @@ -49,10 +49,10 @@ func decomposeScalarG1Subscalars(mod *big.Int, inputs []*big.Int, outputs []*big
func decomposeScalarG1Signs(mod *big.Int, inputs []*big.Int, outputs []*big.Int) error {
return emulated.UnwrapHintWithNativeOutput(inputs, outputs, func(field *big.Int, inputs, outputs []*big.Int) error {
if len(inputs) != 2 {
return fmt.Errorf("expecting three inputs")
return fmt.Errorf("expecting two inputs")
}
if len(outputs) != 2 {
return fmt.Errorf("expecting six outputs")
return fmt.Errorf("expecting two outputs")
}
glvBasis := new(ecc.Lattice)
ecc.PrecomputeLattice(field, inputs[1], glvBasis)
Expand All @@ -73,10 +73,10 @@ func decomposeScalarG1Signs(mod *big.Int, inputs []*big.Int, outputs []*big.Int)
func decomposeScalarG1(mod *big.Int, inputs []*big.Int, outputs []*big.Int) error {
return emulated.UnwrapHint(inputs, outputs, func(field *big.Int, inputs, outputs []*big.Int) error {
if len(inputs) != 3 {
return fmt.Errorf("expecting three inputs")
return fmt.Errorf("expecting two inputs")
}
if len(outputs) != 6 {
return fmt.Errorf("expecting six outputs")
return fmt.Errorf("expecting two outputs")
}
glvBasis := new(ecc.Lattice)
ecc.PrecomputeLattice(inputs[2], inputs[1], glvBasis)
Expand Down
157 changes: 82 additions & 75 deletions std/algebra/emulated/sw_emulated/point.go
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,6 @@ func (c *Curve[B, S]) scalarMulGLV(Q *AffinePoint[B], s *emulated.Element[S], op
if err != nil {
panic(err)
}
var st S
addFn := c.Add
var selector frontend.Variable
if cfg.CompleteArithmetic {
Expand All @@ -521,36 +520,39 @@ func (c *Curve[B, S]) scalarMulGLV(Q *AffinePoint[B], s *emulated.Element[S], op
one := c.baseApi.One()
Q = c.Select(selector, &AffinePoint[B]{X: *one, Y: *one}, Q)
}

// We use the endomorphism à la GLV to compute [s]Q as
// [s1]Q + [s2]Φ(Q)
// the sub-scalars s1, s2 can be negative (bigints) in the hint. If so,
// they will be reduced in-circuit modulo the SNARK scalar field and not
// the emulated field. So we return in the hint |s1|, |s2| and boolean
// flags sdBits to negate the points Q, Φ(Q) instead of the corresponding
// sub-scalars.

// decompose s into s1 and s2
sd, err := c.scalarApi.NewHint(decomposeScalarG1Subscalars, 2, s, c.eigenvalue)
if err != nil {
panic(fmt.Sprintf("compute GLV decomposition: %v", err))
}
s1, s2 := sd[0], sd[1]

// s1, s2 can be negative (bigints) in the hint, but will be reduced
// in-circuit modulo the SNARK scalar field and not the emulated field. So
// we return in the hint both |s1|, |s2| and the flags s5=0/1, s6=0/1 to
// negate the point instead of the corresponding scalar.
sdBits, err := c.scalarApi.NewHintWithNativeOutput(decomposeScalarG1Signs, 2, s, c.eigenvalue)
if err != nil {
panic(fmt.Sprintf("compute GLV decomposition bits: %v", err))
}
selector1, selector2 := sdBits[0], sdBits[1]

s5 := c.scalarApi.Select(selector1, c.scalarApi.Neg(s1), s1)
s6 := c.scalarApi.Select(selector2, c.scalarApi.Neg(s2), s2)

// s == s5 + lambda*s6
s3 := c.scalarApi.Select(selector1, c.scalarApi.Neg(s1), s1)
s4 := c.scalarApi.Select(selector2, c.scalarApi.Neg(s2), s2)
// s == s3 + [λ]s4
c.scalarApi.AssertIsEqual(
c.scalarApi.Add(s5, c.scalarApi.Mul(s6, c.eigenvalue)),
c.scalarApi.Add(s3, c.scalarApi.Mul(s4, c.eigenvalue)),
s,
)

s1bits := c.scalarApi.ToBits(s1)
s2bits := c.scalarApi.ToBits(s2)
var st S
nbits := st.Modulus().BitLen()>>1 + 2

var Acc *AffinePoint[B]
// precompute -Q, -Φ(Q), Φ(Q)
var tableQ, tablePhiQ [3]*AffinePoint[B]
negQY := c.baseApi.Neg(&Q.Y)
Expand All @@ -570,8 +572,9 @@ func (c *Curve[B, S]) scalarMulGLV(Q *AffinePoint[B], s *emulated.Element[S], op
Y: *c.baseApi.Select(selector2, c.baseApi.Neg(&tableQ[2].Y), &tableQ[2].Y),
}

// Starting with Acc = Q + Φ(Q)
Acc = c.Add(tableQ[1], tablePhiQ[1])
// we suppose that the first bits of the sub-scalars are 1 and set:
// Acc = Q + Φ(Q)
Acc := c.Add(tableQ[1], tablePhiQ[1])

// At each iteration we need to compute:
// [2]Acc ± Q ± Φ(Q).
Expand Down Expand Up @@ -636,6 +639,8 @@ func (c *Curve[B, S]) scalarMulGLV(Q *AffinePoint[B], s *emulated.Element[S], op
// T = [3](Φ(Q) - Q)
// P = B4 and P' = B4
T16 := c.Neg(T11)
// note that half the points are negatives of the other half,
// hence have the same X coordinates.

// when nbits is odd, we need to handle the first iteration separately
if nbits%2 == 0 {
Expand All @@ -644,7 +649,7 @@ func (c *Curve[B, S]) scalarMulGLV(Q *AffinePoint[B], s *emulated.Element[S], op
X: *c.baseApi.Select(c.api.Xor(s1bits[nbits-1], s2bits[nbits-1]), &T12.X, &T5.X),
Y: *c.baseApi.Lookup2(s1bits[nbits-1], s2bits[nbits-1], &T5.Y, &T12.Y, &T15.Y, &T2.Y),
}
// We don't use doubleAndAdd as it would involve edge cases
// We don't use doubleAndAdd here as it would involve edge cases
// when bits are 00 (T==-Acc) or 11 (T==Acc).
Acc = c.double(Acc)
Acc = c.add(Acc, T)
Expand Down Expand Up @@ -684,8 +689,9 @@ func (c *Curve[B, S]) scalarMulGLV(Q *AffinePoint[B], s *emulated.Element[S], op
}

// i = 0
// When cfg.CompleteArithmetic is set, we use AddUnified instead of Add. This means
// when s=0 then Acc=(0,0) because AddUnified(Q, -Q) = (0,0).
// subtract the Q, Φ(Q) if the first bits are 0.
// When cfg.CompleteArithmetic is set, we use AddUnified instead of Add.
// This means when s=0 then Acc=(0,0) because AddUnified(Q, -Q) = (0,0).
tableQ[0] = addFn(tableQ[0], Acc)
Acc = c.Select(s1bits[0], Acc, tableQ[0])
tablePhiQ[0] = addFn(tablePhiQ[0], Acc)
Expand Down Expand Up @@ -769,7 +775,7 @@ func (c *Curve[B, S]) scalarMulGeneric(p *AffinePoint[B], s *emulated.Element[S]
return R0
}

// jointScalarMul computes s1 * p1 + s2 * p2 and returns it. It doesn't modify the inputs.
// jointScalarMul computes [s1]p1 + [s2]p2 and returns it. It doesn't modify the inputs.
// This function doesn't check that the p1 and p2 are on the curve. See AssertIsOnCurve.
//
// jointScalarMul calls jointScalarMulGeneric or jointScalarMulGLV depending on whether an efficient endomorphism is available.
Expand Down Expand Up @@ -861,54 +867,51 @@ func (c *Curve[B, S]) jointScalarMulGLV(p1, p2 *AffinePoint[B], s1, s2 *emulated
// jointScalarMulGLVUnsafe computes [s]Q + [t]R using Shamir's trick with an efficient endomorphism and returns it. It doesn't modify Q, R nor s, t.
// ⚠️ The scalars must be nonzero and the points different from (0,0).
func (c *Curve[B, S]) jointScalarMulGLVUnsafe(Q, R *AffinePoint[B], s, t *emulated.Element[S]) *AffinePoint[B] {
var st S
// We use the endomorphism à la GLV to compute [s]Q + [t]R as
// [s1]Q + [s2]Φ(Q) + [t1]R + [t2]Φ(R)
// the sub-scalars s1, s2, t1, t2 can be negative (bigints) in the hint. If
// so, they will be reduced in-circuit modulo the SNARK scalar field and
// not the emulated field. So we return in the hint |s1|, |s2|, |t1|, |t2|
// and boolean flags sdBits and tdBits to negate the points Q, Φ(Q), R and
// Φ(R) instead of the corresponding sub-scalars.

// decompose s into s1 and s2
sd, err := c.scalarApi.NewHint(decomposeScalarG1Subscalars, 2, s, c.eigenvalue)
if err != nil {
// err is non-nil only for invalid number of inputs
panic(err)
}
s1, s2 := sd[0], sd[1]

td, err := c.scalarApi.NewHint(decomposeScalarG1Subscalars, 2, t, c.eigenvalue)
if err != nil {
// err is non-nil only for invalid number of inputs
panic(err)
}
t1, t2 := td[0], td[1]

// s1, s2 can be negative (bigints) in the hint, but will be reduced
// in-circuit modulo the SNARK scalar field and not the emulated field. So
// we return in the hint both |s1|, |s2| and the flags s3=0/1, s4=0/1 to
// negate the point instead of the corresponding scalar. Since s3, s4 are
// either 0 or 1, we only need to check the first limb is zero or not.
// Same goes for t1, t2.
sdBits, err := c.scalarApi.NewHintWithNativeOutput(decomposeScalarG1Signs, 2, s, c.eigenvalue)
if err != nil {
panic(fmt.Sprintf("compute s GLV decomposition bits: %v", err))
}
selector1, selector2 := sdBits[0], sdBits[1]

s5 := c.scalarApi.Select(selector1, c.scalarApi.Neg(s1), s1)
s6 := c.scalarApi.Select(selector2, c.scalarApi.Neg(s2), s2)

// s == s5 + lambda*s6
s3 := c.scalarApi.Select(selector1, c.scalarApi.Neg(s1), s1)
s4 := c.scalarApi.Select(selector2, c.scalarApi.Neg(s2), s2)
// s == s3 + [λ]s4
c.scalarApi.AssertIsEqual(
c.scalarApi.Add(s5, c.scalarApi.Mul(s6, c.eigenvalue)),
c.scalarApi.Add(s3, c.scalarApi.Mul(s4, c.eigenvalue)),
s,
)

// decompose t into t1 and t2
td, err := c.scalarApi.NewHint(decomposeScalarG1Subscalars, 2, t, c.eigenvalue)
if err != nil {
// err is non-nil only for invalid number of inputs
panic(err)
}
t1, t2 := td[0], td[1]
tdBits, err := c.scalarApi.NewHintWithNativeOutput(decomposeScalarG1Signs, 2, t, c.eigenvalue)
if err != nil {
panic(fmt.Sprintf("compute t GLV decomposition bits: %v", err))
}
selector3, selector4 := tdBits[0], tdBits[1]

t5 := c.scalarApi.Select(selector3, c.scalarApi.Neg(t1), t1)
t6 := c.scalarApi.Select(selector4, c.scalarApi.Neg(t2), t2)

// t == t5 + lambda*t6
t3 := c.scalarApi.Select(selector3, c.scalarApi.Neg(t1), t1)
t4 := c.scalarApi.Select(selector4, c.scalarApi.Neg(t2), t2)
// t == t3 + [λ]t4
c.scalarApi.AssertIsEqual(
c.scalarApi.Add(t5, c.scalarApi.Mul(t6, c.eigenvalue)),
c.scalarApi.Add(t3, c.scalarApi.Mul(t4, c.eigenvalue)),
t,
)

Expand All @@ -925,6 +928,7 @@ func (c *Curve[B, S]) jointScalarMulGLVUnsafe(Q, R *AffinePoint[B], s, t *emulat
Y: *c.baseApi.Select(selector2, negQY, &Q.Y),
}
tablePhiQ[0] = c.Neg(tablePhiQ[1])

// precompute -R, -Φ(R), Φ(R)
var tableR, tablePhiR [2]*AffinePoint[B]
negRY := c.baseApi.Neg(&R.Y)
Expand All @@ -938,6 +942,7 @@ func (c *Curve[B, S]) jointScalarMulGLVUnsafe(Q, R *AffinePoint[B], s, t *emulat
Y: *c.baseApi.Select(selector4, negRY, &R.Y),
}
tablePhiR[0] = c.Neg(tablePhiR[1])

// precompute Q+R, -Q-R, Q-R, -Q+R, Φ(Q)+Φ(R), -Φ(Q)-Φ(R), Φ(Q)-Φ(R), -Φ(Q)+Φ(R)
var tableS, tablePhiS [4]*AffinePoint[B]
tableS[0] = tableQ[0]
Expand All @@ -959,42 +964,41 @@ func (c *Curve[B, S]) jointScalarMulGLVUnsafe(Q, R *AffinePoint[B], s, t *emulat
}
tablePhiS[3] = c.Neg(tablePhiS[2])

// suppose first bits are 1 and set:
// Acc = Q + R + Φ(Q) + Φ(R)
// we suppose that the first bits of the sub-scalars are 1 and set:
// Acc = Q + R + Φ(Q) + Φ(R)
Acc := c.Add(tableS[1], tablePhiS[1])
B1 := Acc
// then conditionally add to Acc either Φ²(G) or G (if Acc==-Φ²(G)) where G is
// the base point to avoid incomplete additions in the loop, because when
// doing doubleAndAdd(Acc, Bi) as (Acc+Bi)+Acc it might happen that Acc==Bi
// or Acc==-Bi. But now we force Acc to be different than the stored Bi.
// However we need at the end to subtract [2^nbits]Φ²(G) (or conditionally
// [2^nbits]G) from the result.
// then we conditionally add to Acc either G (the base point) or
// conditionally Φ²(G) (if Acc==-G) to avoid incomplete additions in the
// loop, because when doing doubleAndAdd(Acc, Bi) as (Acc+Bi)+Acc it might
// happen that Acc==Bi or Acc==-Bi. But now we force Acc to be different
// than the stored Bi. However we need at the end to subtract [2^nbits]G
// or conditionally [2^nbits]Φ²(G) from the result.
//
// Acc = Q + R + Φ(Q) + Φ(R) + Φ²(G) or Q + R + Φ(Q) + Φ(R) + G ( = -Φ²(G)+G )
// g0 = G
g0 := c.Generator()
// g1 = Φ²(G)
g1 := &AffinePoint[B]{
X: *c.baseApi.Mul(
c.baseApi.Mul(&g0.X, c.thirdRootOne), c.thirdRootOne),
Y: g0.Y,
}
selector0 := c.baseApi.IsZero(
c.baseApi.Add(&Acc.Y, &c.Generator().Y),
)
g := c.Select(
selector0,
// G
g0,
// Φ²(G)
&AffinePoint[B]{
X: *c.baseApi.Mul(
c.baseApi.Mul(&g0.X, c.thirdRootOne), c.thirdRootOne),
Y: g0.Y,
},
c.baseApi.Add(&Acc.Y, &g0.Y),
)
g := c.Select(selector0, g1, g0)
// Acc = Q + R + Φ(Q) + Φ(R) + G or
// Q + R + Φ(Q) + Φ(R) + Φ²(G) ( = -G+Φ²(G) = -2G-Φ(G) )
Acc = c.Add(Acc, g)

s1bits := c.scalarApi.ToBits(s1)
s2bits := c.scalarApi.ToBits(s2)
t1bits := c.scalarApi.ToBits(t1)
t2bits := c.scalarApi.ToBits(t2)
var st S
nbits := st.Modulus().BitLen()>>1 + 2

// At each iteration look up the point P from:
// At each iteration we look up the point Bi from:
// B1 = +Q + R + Φ(Q) + Φ(R)
// B2 = +Q + R + Φ(Q) - Φ(R)
B2 := c.Add(tableS[1], tablePhiS[2])
Expand Down Expand Up @@ -1026,8 +1030,10 @@ func (c *Curve[B, S]) jointScalarMulGLVUnsafe(Q, R *AffinePoint[B], s, t *emulat
B15 := c.Neg(B2)
// B16 = -Q - R - Φ(Q) - Φ(R)
B16 := c.Neg(B1)
// note that half the points are negatives of the other half,
// hence have the same X coordinates.

var P *AffinePoint[B]
var Bi *AffinePoint[B]
for i := nbits - 1; i > 0; i-- {
// selectorY takes values in [0,15]
selectorY := c.api.Add(
Expand All @@ -1045,7 +1051,7 @@ func (c *Curve[B, S]) jointScalarMulGLVUnsafe(Q, R *AffinePoint[B], s, t *emulat
)
// Bi.Y are distints so we need a 16-to-1 multiplexer,
// but only half of the Bi.X are distinct so we need a 8-to-1.
P = &AffinePoint[B]{
Bi = &AffinePoint[B]{
X: *c.baseApi.Mux(selectorX,
&B16.X, &B8.X, &B14.X, &B6.X, &B12.X, &B4.X, &B10.X, &B2.X,
),
Expand All @@ -1054,11 +1060,12 @@ func (c *Curve[B, S]) jointScalarMulGLVUnsafe(Q, R *AffinePoint[B], s, t *emulat
&B15.Y, &B7.Y, &B13.Y, &B5.Y, &B11.Y, &B3.Y, &B9.Y, &B1.Y,
),
}
Acc = c.doubleAndAdd(Acc, P)
// Acc = [2]Acc + Bi
Acc = c.doubleAndAdd(Acc, Bi)
}

// i = 0
// subtract the initial points from the accumulator when first bits are 0
// subtract the Q, R, Φ(Q), Φ(R) if the first bits are 0
tableQ[0] = c.Add(tableQ[0], Acc)
Acc = c.Select(s1bits[0], Acc, tableQ[0])
tablePhiQ[0] = c.Add(tablePhiQ[0], Acc)
Expand All @@ -1068,18 +1075,18 @@ func (c *Curve[B, S]) jointScalarMulGLVUnsafe(Q, R *AffinePoint[B], s, t *emulat
tablePhiR[0] = c.Add(tablePhiR[0], Acc)
Acc = c.Select(t2bits[0], Acc, tablePhiR[0])

// subtract [2^nbits]Φ²(G) (or conditionally [2^nbits]G)
// subtract [2^nbits]G or conditionally [2^nbits]Φ²(G)
gm := c.GeneratorMultiples()[nbits-1]
g = c.Select(
selector0,
// [2^nbits]G
&gm,
// [2^nbits]Φ²(G)
&AffinePoint[B]{
X: *c.baseApi.Mul(
c.baseApi.Mul(&gm.X, c.thirdRootOne), c.thirdRootOne),
Y: gm.Y,
},
// [2^nbits]G
&gm,
)
Acc = c.Add(Acc, c.Neg(g))

Expand Down
11 changes: 8 additions & 3 deletions std/algebra/emulated/sw_emulated/point_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1532,8 +1532,9 @@ func TestJointScalarMul4(t *testing.T) {
assert.NoError(err)
}

// We explicitly choose here P1 and P2 s.t. P1+P2 = G (the base point). This
// triggers the edge case where Q + R + Φ(Q) + Φ(R) + Φ²(G) == inf
// We explicitly choose here P1 and P2 s.t. P1+P2 = Φ(G) (G the base point).
// This should sometimes (when the sub-scalars are positive in the hint)
// triggers the edge case Q + R + Φ(Q) + Φ(R) + G == inf
func TestJointScalarMulSpecial6(t *testing.T) {
assert := test.NewAssert(t)
var r1, r2 fr_bw6761.Element
Expand All @@ -1546,9 +1547,13 @@ func TestJointScalarMulSpecial6(t *testing.T) {
var res, tmp, p1, p2 bw6761.G1Affine
// P1
p1.ScalarMultiplicationBase(s1)
// P2 = G-P1
// P2 = Φ(G)-P1
_, _, g, _ := bw6761.Generators()
var lambdaGLV big.Int
lambdaGLV.SetString("80949648264912719408558363140637477264845294720710499478137287262712535938301461879813459410945", 10) // (x⁵-3x⁴+3x³-x+1)
g.ScalarMultiplication(&g, &lambdaGLV)
p2.Sub(&g, &p1)
// res = [s1]P+[s2]P
tmp.ScalarMultiplication(&p1, s1)
res.ScalarMultiplication(&p2, s2)
res.Add(&res, &tmp)
Expand Down

0 comments on commit a2f0bdc

Please sign in to comment.