Skip to content

Commit

Permalink
perf(2-chain): optimize varScalarMul
Browse files Browse the repository at this point in the history
  • Loading branch information
yelhousni committed Mar 15, 2024
1 parent ce0186e commit 9bc2788
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 75 deletions.
2 changes: 1 addition & 1 deletion std/algebra/emulated/sw_emulated/point.go
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,7 @@ func (c *Curve[B, S]) scalarMulGLV(Q *AffinePoint[B], s *emulated.Element[S], op
// B1 = Q+Φ(Q)
// B2 = -Q-Φ(Q)
// B3 = Q-Φ(Q)
// B4 = -(Q)
// B4 = -Q+Φ(Q)
//
// If we extend this by merging two iterations, we need to look up P and P'
// both from {B1, B2, B3, B4} and compute:
Expand Down
120 changes: 61 additions & 59 deletions std/algebra/native/sw_bls12377/g1.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,29 +186,30 @@ func (P *G1Affine) ScalarMul(api frontend.API, Q G1Affine, s interface{}, opts .
}
}

// varScalarMul sets P = [s] Q and returns P.
// varScalarMul sets P = [s]Q and returns P. It doesn't modify Q nor s.
// It implements an optimized version based on algorithm 1 of [Halo] (see Section 6.2 and appendix C).
//
// ⚠️ The scalar s must be nonzero and the point Q different from (0,0) unless [algopts.WithCompleteArithmetic] is set.
// (0,0) is not on the curve but we conventionally take it as the
// neutral/infinity point as per the [EVM].
//
// [Halo]: https://eprint.iacr.org/2019/1021.pdf
// [EVM]: https://ethereum.github.io/yellowpaper/paper.pdf
func (P *G1Affine) varScalarMul(api frontend.API, Q G1Affine, s frontend.Variable, opts ...algopts.AlgebraOption) *G1Affine {
cfg, err := algopts.NewConfig(opts...)
if err != nil {
panic(err)
}
// This method computes [s] Q. We use several methods to reduce the number
// of added constraints - first, instead of classical double-and-add, we use
// the optimized version from https://github.com/zcash/zcash/issues/3924
// which allows to omit computation of several intermediate values.
// Secondly, we use the GLV scalar multiplication to reduce the number
// iterations in the main loop. There is a small difference though - as
// two-bit select takes three constraints, then it takes as many constraints
// to compute ± Q ± Φ(Q) every iteration instead of selecting the value
// from a precomputed table. However, precomputing the table adds 12
// additional constraints and thus table-version is more expensive than
// addition-version.
var selector frontend.Variable
if cfg.CompleteArithmetic {
// if Q=(0,0) we assign a dummy (1,1) to Q and continue
selector = api.And(api.IsZero(Q.X), api.IsZero(Q.Y))
Q.Select(api, selector, G1Affine{X: 1, Y: 1}, Q)
}

// We use the endomorphism à la GLV to compute [s]Q as
// [s1]Q + [s2]Φ(Q)
//
// The context we are working is based on the `outer` curve. However, the
// points and the operations on the points are performed on the `inner`
// curve of the outer curve. We require some parameters from the inner
Expand All @@ -218,77 +219,73 @@ func (P *G1Affine) varScalarMul(api frontend.API, Q G1Affine, s frontend.Variabl
// the hints allow to decompose the scalar s into s1 and s2 such that
// s1 + λ * s2 == s mod r,
// where λ is third root of one in 𝔽_r.
sd, err := api.Compiler().NewHint(decomposeScalarG1, 3, s)
sd, err := api.Compiler().NewHint(decomposeScalarG1, 2, s)
if err != nil {
// err is non-nil only for invalid number of inputs
panic(err)
}
s1, s2 := sd[0], sd[1]

// when we split scalar, then s1, s2 < lambda by default. However, to have
// the high 1-2 bits of s1, s2 set, the hint functions compute the
// decomposition for
// s + k*r (for some k)
// instead and omits the last reduction. Thus, to constrain s1 and s2, we
// have to assert that
// s1 + λ * s2 == s + k*r
api.AssertIsEqual(api.Add(s1, api.Mul(s2, cc.lambda)), api.Add(s, api.Mul(cc.fr, sd[2])))

// As the decomposed scalars are not fully reduced, then in addition of
// having the high bit set, an overflow bit may also be set. Thus, the total
// number of bits may be one more than the bitlength of λ.
nbits := cc.lambda.BitLen() + 1
// s1 + λ * s2 == s
api.AssertIsEqual(
api.Add(s1, api.Mul(s2, cc.lambda)),
s,
)

// For BLS12 λ bitsize is 127 equal to half of r bitsize
nbits := cc.lambda.BitLen()
s1bits := api.ToBinary(s1, nbits)
s2bits := api.ToBinary(s2, nbits)

var Acc /*accumulator*/, B, B2 /*tmp vars*/ G1Affine
// precompute -Q, -Φ(Q), Φ(Q)
var tableQ, tablePhiQ [2]G1Affine
tableQ[1] = Q
tableQ[0].Neg(api, Q)
cc.phi1(api, &tablePhiQ[1], &Q)
tablePhiQ[0].Neg(api, tablePhiQ[1])

// We now initialize the accumulator. Due to the way the scalar is
// decomposed, either the high bits of s1 or s2 are set and we can use the
// incomplete addition laws.

// Acc = Q + Φ(Q) = -Φ²(Q)
// we suppose that the first bits of the sub-scalars are 1 and set:
// Acc = Q + Φ(Q) = -Φ²(Q)
var Acc, B G1Affine
cc.phi2Neg(api, &Acc, &Q)

// However, we can not directly add step value conditionally as we may get
// to incomplete path of the addition formula. We either add or subtract
// step value from [2] Acc (instead of conditionally adding step value to
// Acc):
// Acc = [2] (Q + Φ(Q)) ± Q ± Φ(Q)
// only y coordinate differs for negation, select on that instead.
// first bit
B.X = tableQ[0].X
B.Y = api.Select(s1bits[nbits-1], tableQ[1].Y, tableQ[0].Y)
Acc.DoubleAndAdd(api, &Acc, &B)
B.X = tablePhiQ[0].X
B.Y = api.Select(s2bits[nbits-1], tablePhiQ[1].Y, tablePhiQ[0].Y)
Acc.AddAssign(api, B)

// second bit
B.X = tableQ[0].X
B.Y = api.Select(s1bits[nbits-2], tableQ[1].Y, tableQ[0].Y)
Acc.DoubleAndAdd(api, &Acc, &B)
B.X = tablePhiQ[0].X
B.Y = api.Select(s2bits[nbits-2], tablePhiQ[1].Y, tablePhiQ[0].Y)
Acc.AddAssign(api, B)
// At each iteration we need to compute:
// [2]Acc ± Q ± Φ(Q).
// We can compute [2]Acc and look up the (precomputed) point B from:
// B1 = +Q + Φ(Q)
B1 := Acc
// B2 = -Q - Φ(Q)
B2 := G1Affine{}
B2.Neg(api, B1)
// B3 = +Q - Φ(Q)
B3 := tableQ[1]
B3.AddAssign(api, tablePhiQ[0])
// B4 = -Q + Φ(Q)
B4 := G1Affine{}
B4.Neg(api, B3)
//
// Note that half the points are negatives of the other half,
// hence have the same X coordinates.

// However when doing doubleAndAdd(Acc, B) as (Acc+B)+Acc it might happen
// that Acc==B or -B. So we add the base point G to it to avoid incomplete
// additions in the loop by forcing Acc to be different than the stored B.
// However now we need at the end to subtract [2^nbits]G (harcoded) from
// the result.
//
// Acc = Q + Φ(Q) + G
points := getCurvePoints()
Acc.AddAssign(api, G1Affine{X: points.G1x, Y: points.G1y})

B2.X = tablePhiQ[0].X
for i := nbits - 3; i > 0; i-- {
B.X = Q.X
B.Y = api.Select(s1bits[i], tableQ[1].Y, tableQ[0].Y)
B2.Y = api.Select(s2bits[i], tablePhiQ[1].Y, tablePhiQ[0].Y)
B.AddAssign(api, B2)
for i := nbits - 1; i > 0; i-- {
B.X = api.Select(api.Xor(s1bits[i], s2bits[i]), B3.X, B2.X)
B.Y = api.Lookup2(s1bits[i], s2bits[i], B2.Y, B3.Y, B4.Y, B1.Y)
// Acc = [2]Acc + B
Acc.DoubleAndAdd(api, &Acc, &B)
}

// i = 0
// subtract the Q, R, Φ(Q), Φ(R) if the first bits are 0.
// When cfg.CompleteArithmetic is set, we use AddUnified instead of Add. This means
// when s=0 then Acc=(0,0) because AddUnified(Q, -Q) = (0,0).
if cfg.CompleteArithmetic {
Expand All @@ -304,6 +301,11 @@ func (P *G1Affine) varScalarMul(api frontend.API, Q G1Affine, s frontend.Variabl
Acc.Select(api, s2bits[0], Acc, tablePhiQ[0])
}

// subtract [2^nbits]G since we added G at the beginning
B.X = points.G1m[nbits-1][0]
B.Y = api.Neg(points.G1m[nbits-1][1])
Acc.AddAssign(api, B)

P.X = Acc.X
P.Y = Acc.Y

Expand Down
25 changes: 10 additions & 15 deletions std/algebra/native/sw_bls12377/hints.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package sw_bls12377

import (
"fmt"
"math/big"

"github.com/consensys/gnark-crypto/ecc"
Expand All @@ -18,23 +19,17 @@ func init() {
solver.RegisterHint(GetHints()...)
}

func decomposeScalarG1(scalarField *big.Int, inputs []*big.Int, res []*big.Int) error {
func decomposeScalarG1(scalarField *big.Int, inputs []*big.Int, outputs []*big.Int) error {
if len(inputs) != 1 {
return fmt.Errorf("expecting one input")
}
if len(outputs) != 2 {
return fmt.Errorf("expecting two outputs")
}
cc := getInnerCurveConfig(scalarField)
sp := ecc.SplitScalar(inputs[0], cc.glvBasis)
res[0].Set(&(sp[0]))
res[1].Set(&(sp[1]))
one := big.NewInt(1)
// add (lambda+1, lambda) until scalar compostion is over Fr to ensure that
// the high bits are set in decomposition.
for res[0].Cmp(cc.lambda) < 1 && res[1].Cmp(cc.lambda) < 1 {
res[0].Add(res[0], cc.lambda)
res[0].Add(res[0], one)
res[1].Add(res[1], cc.lambda)
}
// figure out how many times we have overflowed
res[2].Mul(res[1], cc.lambda).Add(res[2], res[0])
res[2].Sub(res[2], inputs[0])
res[2].Div(res[2], cc.fr)
outputs[0].Set(&(sp[0]))
outputs[1].Set(&(sp[1]))

return nil
}
Expand Down

0 comments on commit 9bc2788

Please sign in to comment.