Skip to content

Commit

Permalink
perf(emulated): huge optim scalarMulGLV
Browse files Browse the repository at this point in the history
  • Loading branch information
yelhousni committed Feb 21, 2024
1 parent 33b31c5 commit 16990de
Showing 1 changed file with 112 additions and 20 deletions.
132 changes: 112 additions & 20 deletions std/algebra/emulated/sw_emulated/point.go
Original file line number Diff line number Diff line change
Expand Up @@ -527,9 +527,13 @@ func (c *Curve[B, S]) scalarMulGLV(Q *AffinePoint[B], s *emulated.Element[S], op
selector1 := c.api.IsZero(s3.Limbs[0])
selector2 := c.api.IsZero(s4.Limbs[0])

s1bits := c.scalarApi.ToBits(s1)
s2bits := c.scalarApi.ToBits(s2)
nbits := st.Modulus().BitLen()>>1 + 2

var Acc *AffinePoint[B]
// precompute -Q, -Φ(Q), Φ(Q)
var tableQ, tablePhiQ [2]*AffinePoint[B]
var tableQ, tablePhiQ [3]*AffinePoint[B]
tableQ[1] = &AffinePoint[B]{
X: Q.X,
Y: *c.baseApi.Select(selector1, c.baseApi.Neg(&Q.Y), &Q.Y),
Expand All @@ -540,35 +544,123 @@ func (c *Curve[B, S]) scalarMulGLV(Q *AffinePoint[B], s *emulated.Element[S], op
Y: *c.baseApi.Select(selector2, c.baseApi.Neg(&Q.Y), &Q.Y),
}
tablePhiQ[0] = c.Neg(tablePhiQ[1])
tableQ[2] = c.triple(tableQ[1])
tablePhiQ[2] = &AffinePoint[B]{
X: *c.baseApi.Mul(&tableQ[2].X, c.thirdRootOne),
Y: *c.baseApi.Select(selector2, c.baseApi.Neg(&tableQ[2].Y), &tableQ[2].Y),
}

// Acc = Q + Φ(Q)
// Starting with Acc = Q + Φ(Q)
Acc = c.Add(tableQ[1], tablePhiQ[1])

s1bits := c.scalarApi.ToBits(s1)
s2bits := c.scalarApi.ToBits(s2)

nbits := st.Modulus().BitLen()>>1 + 2

// At each iteration look up the point P from:
// At each iteration we need to compute:
// [2]Acc ± Q ± Φ(Q).
// We can compute [2]Acc and look up the (precomputed) point P from:
// B1 = Q+Φ(Q)
// B2 = -Q-Φ(Q)
// B3 = -Φ(Q)+Q
// B4 = Φ(Q)-Q
// and compute [2]Acc+P. We don't use doubleAndAdd as it involves edge
// cases when first bits are 00 (P==-Acc) or 11 (P==Acc).
B1 := Acc
B2 := c.Neg(B1)
B3 := c.Add(tablePhiQ[0], tableQ[1])
B4 := c.Neg(B3)
var P *AffinePoint[B]
for i := nbits - 1; i > 0; i-- {
//
// If we extend this by merging two iterations, we need to look up P and P'
// both from {B1, B2, B3, B4} and compute:
// [2]([2]Acc+P)+P' = [4]Acc + T
// where T = [2]P+P'. So at each (merged) iteration, we can compute [4]Acc
// and look up T from the precomputed list of points:
//
// T = [3](Q + Φ(Q))
// P = B1 and P' = B1
T1 := c.Add(tableQ[2], tablePhiQ[2])
// T = Q + Φ(Q)
// P = B1 and P' = B2
T2 := Acc
// T = [3]Q + Φ(Q)
// P = B1 and P' = B3
T3 := c.Add(tableQ[2], tablePhiQ[1])
// T = Q + [3]Φ(Q)
// P = B1 and P' = B4
T4 := c.Add(tableQ[1], tablePhiQ[2])
// T = -Q - Φ(Q)
// P = B2 and P' = B1
T5 := c.Neg(T2)
// T = -[3](Q + Φ(Q))
// P = B2 and P' = B2
T6 := c.Neg(T1)
// T = -Q - [3]Φ(Q)
// P = B2 and P' = B3
T7 := c.Neg(T4)
// T = -[3]Q - Φ(Q)
// P = B2 and P' = B4
T8 := c.Neg(T3)
// T = [3]Q - Φ(Q)
// P = B3 and P' = B1
T9 := c.Add(tableQ[2], tablePhiQ[0])
// T = Q - [3]Φ(Q)
// P = B3 and P' = B2
T11 := c.Neg(tablePhiQ[2])
T10 := c.Add(tableQ[1], T11)
// T = [3](Q - Φ(Q))
// P = B3 and P' = B3
T11 = c.Add(tableQ[2], T11)
// T = -Φ(Q) + Q
// P = B3 and P' = B4
T12 := c.Add(tablePhiQ[0], tableQ[1])
// T = [3]Φ(Q) - Q
// P = B4 and P' = B1
T13 := c.Add(tablePhiQ[2], tableQ[0])
// T = Φ(Q) - [3]Q
// P = B4 and P' = B2
T14 := c.Neg(T9)
// T = Φ(Q) - Q
// P = B4 and P' = B3
T15 := c.Neg(T12)
// T = [3](Φ(Q) - Q)
// P = B4 and P' = B4
T16 := c.Neg(T11)

// when nbits is odd, we need to handle the first iteration separately
if nbits%2 == 0 {
// Acc = [2]Acc ± Q ± R ± Φ(Q) ± Φ(R)
P = &AffinePoint[B]{
X: *c.baseApi.Select(c.api.Xor(s1bits[i], s2bits[i]), &B3.X, &B2.X),
Y: *c.baseApi.Lookup2(s1bits[i], s2bits[i], &B2.Y, &B3.Y, &B4.Y, &B1.Y),
T := &AffinePoint[B]{
X: *c.baseApi.Select(c.api.Xor(s1bits[nbits-1], s2bits[nbits-1]), &T12.X, &T5.X),
Y: *c.baseApi.Lookup2(s1bits[nbits-1], s2bits[nbits-1], &T5.Y, &T12.Y, &T15.Y, &T2.Y),
}
// We don't use doubleAndAdd as it would involve edge cases
// when bits are 00 (T==-Acc) or 11 (T==Acc).
Acc = c.double(Acc)
Acc = c.add(Acc, P)
Acc = c.add(Acc, T)
} else {
// when nbits is even we start the main loop at normally nbits - 1
nbits++
}
for i := nbits - 2; i > 0; i -= 2 {
// selectorY takes values in [0,15]
selectorY := c.api.Add(
s1bits[i],
c.api.Mul(s2bits[i], 2),
c.api.Mul(s1bits[i-1], 4),
c.api.Mul(s2bits[i-1], 8),
)
// selectorX takes values in [0,7] with:
// - when selectorY < 8: selectorX = selectorY
// - when selectorY >= 8: selectorX = 15 - selectorY
selectorX := c.api.Add(
c.api.Mul(selectorY, c.api.Sub(1, c.api.Mul(s2bits[i-1], 2))),
c.api.Mul(s2bits[i-1], 15),
)
// Bi.Y are distints so we need a 16-to-1 multiplexer,
// but only half of the Bi.X are distinct so we need a 8-to-1.
T := &AffinePoint[B]{
X: *c.baseApi.Mux(selectorX,
&T6.X, &T10.X, &T14.X, &T2.X, &T7.X, &T11.X, &T15.X, &T3.X,
),
Y: *c.baseApi.Mux(selectorY,
&T6.Y, &T10.Y, &T14.Y, &T2.Y, &T7.Y, &T11.Y, &T15.Y, &T3.Y,
&T8.Y, &T12.Y, &T16.Y, &T4.Y, &T5.Y, &T9.Y, &T13.Y, &T1.Y,
),
}
// Acc = [4]Acc + T
Acc = c.double(Acc)
Acc = c.doubleAndAdd(Acc, T)
}

// i = 0
Expand Down

0 comments on commit 16990de

Please sign in to comment.