diff --git a/frontend/builder.go b/frontend/builder.go index 7dd1b6c488..cd52e71169 100644 --- a/frontend/builder.go +++ b/frontend/builder.go @@ -78,3 +78,12 @@ type Committer interface { // Commit commits to the variables and returns the commitment. Commit(toCommit ...Variable) (commitment Variable, err error) } + +// Rangechecker allows to externally range-check the variables to be of +// specified width. Not all compilers implement this interface. Users should +// instead use [github.com/consensys/gnark/std/rangecheck] package which +// automatically chooses most optimal method for range checking the variables. +type Rangechecker interface { + // Check checks that the given variable v has bit-length bits. + Check(v Variable, bits int) +} diff --git a/frontend/cs/r1cs/builder.go b/frontend/cs/r1cs/builder.go index 87ac36d348..6d085e41fc 100644 --- a/frontend/cs/r1cs/builder.go +++ b/frontend/cs/r1cs/builder.go @@ -29,6 +29,7 @@ import ( "github.com/consensys/gnark/frontend/internal/expr" "github.com/consensys/gnark/frontend/schema" "github.com/consensys/gnark/internal/circuitdefer" + "github.com/consensys/gnark/internal/frontendtype" "github.com/consensys/gnark/internal/kvstore" "github.com/consensys/gnark/internal/tinyfield" "github.com/consensys/gnark/internal/utils" @@ -452,3 +453,7 @@ func (builder *builder) compress(le expr.LinearExpression) expr.LinearExpression func (builder *builder) Defer(cb func(frontend.API) error) { circuitdefer.Put(builder, cb) } + +func (*builder) FrontendType() frontendtype.Type { + return frontendtype.R1CS +} diff --git a/frontend/cs/scs/api.go b/frontend/cs/scs/api.go index f9f379fcf5..fb1ecbf2d8 100644 --- a/frontend/cs/scs/api.go +++ b/frontend/cs/scs/api.go @@ -28,6 +28,7 @@ import ( "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/internal/expr" "github.com/consensys/gnark/frontend/schema" + "github.com/consensys/gnark/internal/frontendtype" "github.com/consensys/gnark/std/math/bits" ) @@ -557,3 +558,7 @@ func (builder *builder) printArg(log *constraint.LogEntry, sbb *strings.Builder, func (builder *builder) Compiler() frontend.Compiler { return builder } + +func (*builder) FrontendType() frontendtype.Type { + return frontendtype.SCS +} diff --git a/internal/frontendtype/frontendtype.go b/internal/frontendtype/frontendtype.go new file mode 100644 index 0000000000..040af53c5b --- /dev/null +++ b/internal/frontendtype/frontendtype.go @@ -0,0 +1,13 @@ +// Package frontendtype allows to assert frontend type. +package frontendtype + +type Type int + +const ( + R1CS Type = iota + SCS +) + +type FrontendTyper interface { + FrontendType() Type +} diff --git a/internal/stats/latest.stats b/internal/stats/latest.stats index 3d6d18dabd..0948135ddd 100644 Binary files a/internal/stats/latest.stats and b/internal/stats/latest.stats differ diff --git a/std/hints.go b/std/hints.go index f9b85d3266..c566ccaad8 100644 --- a/std/hints.go +++ b/std/hints.go @@ -8,6 +8,7 @@ import ( "github.com/consensys/gnark/std/algebra/native/sw_bls24315" "github.com/consensys/gnark/std/math/bits" "github.com/consensys/gnark/std/math/emulated" + "github.com/consensys/gnark/std/rangecheck" "github.com/consensys/gnark/std/selector" ) @@ -34,4 +35,5 @@ func registerHints() { solver.RegisterHint(selector.MuxIndicators) solver.RegisterHint(selector.MapIndicators) solver.RegisterHint(emulated.GetHints()...) + solver.RegisterHint(rangecheck.CountHint, rangecheck.DecomposeHint) } diff --git a/std/math/emulated/field.go b/std/math/emulated/field.go index 87653a430a..3d47bf74be 100644 --- a/std/math/emulated/field.go +++ b/std/math/emulated/field.go @@ -8,6 +8,7 @@ import ( "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/internal/utils" "github.com/consensys/gnark/logger" + "github.com/consensys/gnark/std/rangecheck" "github.com/rs/zerolog" "golang.org/x/exp/constraints" ) @@ -38,6 +39,7 @@ type Field[T FieldParams] struct { log zerolog.Logger constrainedLimbs map[uint64]struct{} + checker frontend.Rangechecker } // NewField returns an object to be used in-circuit to perform emulated @@ -53,6 +55,7 @@ func NewField[T FieldParams](native frontend.API) (*Field[T], error) { api: native, log: logger.Logger(), constrainedLimbs: make(map[uint64]struct{}), + checker: rangecheck.New(native), } // ensure prime is correctly set diff --git a/std/math/emulated/field_assert.go b/std/math/emulated/field_assert.go index c8daae2bf5..af16d4e065 100644 --- a/std/math/emulated/field_assert.go +++ b/std/math/emulated/field_assert.go @@ -5,13 +5,12 @@ import ( "math/big" "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/std/math/bits" ) // assertLimbsEqualitySlow is the main routine in the package. It asserts that the // two slices of limbs represent the same integer value. This is also the most // costly operation in the package as it does bit decomposition of the limbs. -func assertLimbsEqualitySlow(api frontend.API, l, r []frontend.Variable, nbBits, nbCarryBits uint) { +func (f *Field[T]) assertLimbsEqualitySlow(api frontend.API, l, r []frontend.Variable, nbBits, nbCarryBits uint) { nbLimbs := max(len(l), len(r)) maxValue := new(big.Int).Lsh(big.NewInt(1), nbBits+nbCarryBits) @@ -33,52 +32,29 @@ func assertLimbsEqualitySlow(api frontend.API, l, r []frontend.Variable, nbBits, // carry is stored in the highest bits of diff[nbBits:nbBits+nbCarryBits+1] // we know that diff[:nbBits] are 0 bits, but still need to constrain them. // to do both; we do a "clean" right shift and only need to boolean constrain the carry part - carry = rsh(api, diff, int(nbBits), int(nbBits+nbCarryBits+1)) + carry = f.rsh(diff, int(nbBits), int(nbBits+nbCarryBits+1)) } api.AssertIsEqual(carry, maxValueShift) } -// rsh right shifts a variable endDigit-startDigit bits and returns it. -func rsh(api frontend.API, v frontend.Variable, startDigit, endDigit int) frontend.Variable { +func (f *Field[T]) rsh(v frontend.Variable, startDigit, endDigit int) frontend.Variable { // if v is a constant, work with the big int value. - if c, ok := api.Compiler().ConstantValue(v); ok { + if c, ok := f.api.Compiler().ConstantValue(v); ok { bits := make([]frontend.Variable, endDigit-startDigit) for i := 0; i < len(bits); i++ { bits[i] = c.Bit(i + startDigit) } return bits } - - bits, err := api.Compiler().NewHint(NBitsShifted, endDigit-startDigit, v, startDigit) + shifted, err := f.api.Compiler().NewHint(RightShift, 1, startDigit, v) if err != nil { - panic(err) - } - - // we compute 2 sums; - // Σbi ensures that "ignoring" the lowest bits (< startDigit) still is a valid bit decomposition. - // that is, it ensures that bits from startDigit to endDigit * corresponding coefficients (powers of 2 shifted) - // are equal to the input variable - // ΣbiRShift computes the actual result; that is, the Σ (2**i * b[i]) - Σbi := frontend.Variable(0) - ΣbiRShift := frontend.Variable(0) - - cRShift := big.NewInt(1) - c := big.NewInt(1) - c.Lsh(c, uint(startDigit)) - - for i := 0; i < len(bits); i++ { - Σbi = api.MulAcc(Σbi, bits[i], c) - ΣbiRShift = api.MulAcc(ΣbiRShift, bits[i], cRShift) - - c.Lsh(c, 1) - cRShift.Lsh(cRShift, 1) - api.AssertIsBoolean(bits[i]) + panic(fmt.Sprintf("right shift: %v", err)) } - - // constraint Σ (2**i_shift * b[i]) == v - api.AssertIsEqual(Σbi, v) - return ΣbiRShift - + f.checker.Check(shifted[0], endDigit-startDigit) + shift := new(big.Int).Lsh(big.NewInt(1), uint(startDigit)) + composed := f.api.Mul(shifted[0], shift) + f.api.AssertIsEqual(composed, v) + return shifted[0] } // AssertLimbsEquality asserts that the limbs represent a same integer value. @@ -107,9 +83,9 @@ func (f *Field[T]) AssertLimbsEquality(a, b *Element[T]) { // TODO: we previously assumed that one side was "larger" than the other // side, but I think this assumption is not valid anymore if a.overflow > b.overflow { - assertLimbsEqualitySlow(f.api, ca, cb, bitsPerLimb, a.overflow) + f.assertLimbsEqualitySlow(f.api, ca, cb, bitsPerLimb, a.overflow) } else { - assertLimbsEqualitySlow(f.api, cb, ca, bitsPerLimb, b.overflow) + f.assertLimbsEqualitySlow(f.api, cb, ca, bitsPerLimb, b.overflow) } } @@ -133,10 +109,7 @@ func (f *Field[T]) enforceWidth(a *Element[T], modWidth bool) { // take only required bits from the most significant limb limbNbBits = ((f.fParams.Modulus().BitLen() - 1) % int(f.fParams.BitsPerLimb())) + 1 } - // bits.ToBinary restricts the least significant NbDigits to be equal to - // the limb value. This is sufficient to restrict for the bitlength and - // we can discard the bits themselves. - bits.ToBinary(f.api, a.Limbs[i], bits.WithNbDigits(limbNbBits)) + f.checker.Check(a.Limbs[i], limbNbBits) } } diff --git a/std/math/emulated/hints.go b/std/math/emulated/hints.go index 9901aab49f..00ecd2d8f6 100644 --- a/std/math/emulated/hints.go +++ b/std/math/emulated/hints.go @@ -23,7 +23,7 @@ func GetHints() []solver.Hint { InverseHint, MultiplicationHint, RemHint, - NBitsShifted, + RightShift, } } @@ -287,13 +287,20 @@ func parseHintDivInputs(inputs []*big.Int) (uint, int, *big.Int, *big.Int, error return nbBits, nbLimbs, x, y, nil } -// NBitsShifted returns the first bits of the input, with a shift. The number of returned bits is -// defined by the length of the results slice. -func NBitsShifted(_ *big.Int, inputs []*big.Int, results []*big.Int) error { - n := inputs[0] - shift := inputs[1].Uint64() // TODO @gbotrel validate input vs perf in large circuits. - for i := 0; i < len(results); i++ { - results[i].SetUint64(uint64(n.Bit(i + int(shift)))) - } +// RightShift shifts input by the given number of bits. Expects two inputs: +// - first input is the shift, will be represented as uint64; +// - second input is the value to be shifted. +// +// Returns a single output which is the value shifted. Errors if number of +// inputs is not 2 and number of outputs is not 1. +func RightShift(_ *big.Int, inputs []*big.Int, outputs []*big.Int) error { + if len(inputs) != 2 { + return fmt.Errorf("expecting two inputs") + } + if len(outputs) != 1 { + return fmt.Errorf("expecting single output") + } + shift := inputs[0].Uint64() + outputs[0].Rsh(inputs[1], uint(shift)) return nil } diff --git a/std/rangecheck/rangecheck.go b/std/rangecheck/rangecheck.go new file mode 100644 index 0000000000..b3bf870690 --- /dev/null +++ b/std/rangecheck/rangecheck.go @@ -0,0 +1,30 @@ +// Package rangecheck implements range checking gadget +// +// This package chooses the most optimal path for performing range checks: +// - if the backend supports native range checking and the frontend exports the variables in the proprietary format by implementing [frontend.Rangechecker], then use it directly; +// - if the backend supports creating a commitment of variables by implementing [frontend.Committer], then we use the product argument as in [BCG+18]. [r1cs.NewBuilder] returns a builder which implements this interface; +// - lacking these, we perform binary decomposition of variable into bits. +// +// [BCG+18]: https://eprint.iacr.org/2018/380 +package rangecheck + +import ( + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" +) + +// only for documentation purposes. If we import the package then godoc knows +// how to refer to package r1cs and we get nice links in godoc. We import the +// package anyway in test. +var _ = r1cs.NewBuilder + +// New returns a new range checker depending on the frontend capabilities. +func New(api frontend.API) frontend.Rangechecker { + if rc, ok := api.(frontend.Rangechecker); ok { + return rc + } + if _, ok := api.(frontend.Committer); ok { + return newCommitRangechecker(api) + } + return plainChecker{api: api} +} diff --git a/std/rangecheck/rangecheck_commit.go b/std/rangecheck/rangecheck_commit.go new file mode 100644 index 0000000000..f5b6ddb30c --- /dev/null +++ b/std/rangecheck/rangecheck_commit.go @@ -0,0 +1,238 @@ +package rangecheck + +import ( + "fmt" + "math" + "math/big" + stdbits "math/bits" + + "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/internal/frontendtype" + "github.com/consensys/gnark/internal/kvstore" + "github.com/consensys/gnark/std/math/bits" +) + +type ctxCheckerKey struct{} + +func init() { + solver.RegisterHint(DecomposeHint, CountHint) +} + +type checkedVariable struct { + v frontend.Variable + bits int +} + +type commitChecker struct { + collected []checkedVariable + closed bool +} + +func newCommitRangechecker(api frontend.API) *commitChecker { + kv, ok := api.Compiler().(kvstore.Store) + if !ok { + panic("builder should implement key-value store") + } + ch := kv.GetKeyValue(ctxCheckerKey{}) + if ch != nil { + if cht, ok := ch.(*commitChecker); ok { + return cht + } else { + panic("stored rangechecker is not valid") + } + } + cht := &commitChecker{} + kv.SetKeyValue(ctxCheckerKey{}, cht) + api.Compiler().Defer(cht.commit) + return cht +} + +func (c *commitChecker) Check(in frontend.Variable, bits int) { + if c.closed { + panic("checker already closed") + } + c.collected = append(c.collected, checkedVariable{v: in, bits: bits}) +} + +func (c *commitChecker) commit(api frontend.API) error { + if c.closed { + return nil + } + defer func() { c.closed = true }() + if len(c.collected) == 0 { + return nil + } + committer, ok := api.(frontend.Committer) + if !ok { + panic("expected committer API") + } + baseLength := c.getOptimalBasewidth(api) + // decompose into smaller limbs + decomposed := make([]frontend.Variable, 0, len(c.collected)) + collected := make([]frontend.Variable, len(c.collected)) + base := new(big.Int).Lsh(big.NewInt(1), uint(baseLength)) + for i := range c.collected { + // collect all vars for commitment input + collected[i] = c.collected[i].v + // decompose value into limbs + nbLimbs := decompSize(c.collected[i].bits, baseLength) + limbs, err := api.Compiler().NewHint(DecomposeHint, int(nbLimbs), c.collected[i].bits, baseLength, c.collected[i].v) + if err != nil { + panic(fmt.Sprintf("decompose %v", err)) + } + // store all limbs for counting + decomposed = append(decomposed, limbs...) + // check that limbs are correct. We check the sizes of the limbs later + var composed frontend.Variable = 0 + for j := range limbs { + composed = api.Add(composed, api.Mul(limbs[j], new(big.Int).Exp(base, big.NewInt(int64(j)), nil))) + } + api.AssertIsEqual(composed, c.collected[i].v) + } + nbTable := 1 << baseLength + // compute the counts for every value in the range + exps, err := api.Compiler().NewHint(CountHint, nbTable, decomposed...) + if err != nil { + panic(fmt.Sprintf("count %v", err)) + } + // compute the poly \pi (X - s_i)^{e_i} + commitment, err := committer.Commit(collected...) + if err != nil { + panic(fmt.Sprintf("commit %v", err)) + } + logn := stdbits.Len(uint(len(decomposed))) + var lp frontend.Variable = 1 + for i := 0; i < nbTable; i++ { + expbits := bits.ToBinary(api, exps[i], bits.WithNbDigits(logn)) + var acc frontend.Variable = 1 + tmp := api.Sub(commitment, i) + for j := 0; j < logn; j++ { + curr := api.Select(expbits[j], tmp, 1) + acc = api.Mul(acc, curr) + tmp = api.Mul(tmp, tmp) + } + lp = api.Mul(lp, acc) + } + // compute the poly \pi (X - f_i) + var rp frontend.Variable = 1 + for i := range decomposed { + val := api.Sub(commitment, decomposed[i]) + rp = api.Mul(rp, val) + } + api.AssertIsEqual(lp, rp) + return nil +} + +func decompSize(varSize int, limbSize int) int { + return (varSize + limbSize - 1) / limbSize +} + +// DecomposeHint is a hint used for range checking with commitment. It +// decomposes large variables into chunks which can be individually range-check +// in the native range. +func DecomposeHint(m *big.Int, inputs []*big.Int, outputs []*big.Int) error { + if len(inputs) != 3 { + return fmt.Errorf("input must be 3 elements") + } + if !inputs[0].IsUint64() || !inputs[1].IsUint64() { + return fmt.Errorf("first two inputs have to be uint64") + } + varSize := int(inputs[0].Int64()) + limbSize := int(inputs[1].Int64()) + val := inputs[2] + nbLimbs := decompSize(varSize, limbSize) + if len(outputs) != nbLimbs { + return fmt.Errorf("need %d outputs instead to decompose", nbLimbs) + } + base := new(big.Int).Lsh(big.NewInt(1), uint(limbSize)) + tmp := new(big.Int).Set(val) + for i := 0; i < len(outputs); i++ { + outputs[i].Mod(tmp, base) + tmp.Rsh(tmp, uint(limbSize)) + } + return nil +} + +// CountHint is a hint function which is used in range checking using +// commitment. It counts the occurences of checked variables in the range and +// returns the counts. +func CountHint(m *big.Int, inputs []*big.Int, outputs []*big.Int) error { + nbVals := len(outputs) + if len(outputs) != nbVals { + return fmt.Errorf("output size %d does not match range size %d", len(outputs), nbVals) + } + counts := make(map[uint64]uint64, nbVals) + for i := 0; i < len(inputs); i++ { + if !inputs[i].IsUint64() { + return fmt.Errorf("input %d not uint64", i) + } + c := inputs[i].Uint64() + counts[c]++ + } + for i := 0; i < nbVals; i++ { + outputs[i].SetUint64(counts[uint64(i)]) + } + return nil +} + +func (c *commitChecker) getOptimalBasewidth(api frontend.API) int { + if ft, ok := api.(frontendtype.FrontendTyper); ok { + switch ft.FrontendType() { + case frontendtype.R1CS: + return optimalWidth(nbR1CSConstraints, c.collected) + case frontendtype.SCS: + return optimalWidth(nbPLONKConstraints, c.collected) + } + } + return optimalWidth(nbR1CSConstraints, c.collected) +} + +func optimalWidth(countFn func(baseLength int, collected []checkedVariable) int, collected []checkedVariable) int { + min := math.MaxInt64 + minVal := 0 + for j := 2; j < 18; j++ { + current := countFn(j, collected) + if current < min { + min = current + minVal = j + } + } + return minVal +} + +func nbR1CSConstraints(baseLength int, collected []checkedVariable) int { + nbDecomposed := 0 + for i := range collected { + nbDecomposed += int(decompSize(collected[i].bits, baseLength)) + } + eqs := len(collected) // single composition check per collected + logn := stdbits.Len(uint(nbDecomposed)) + nbTable := 1 << baseLength + nbLeft := nbTable * + (logn + // tobinary + logn + // select per exponent bit + logn + // mul per exponent bit + logn + // mul per exponent bit + 1) // final mul + nbRight := nbDecomposed // mul all decomposed + return nbLeft + nbRight + eqs + 1 // single for final equality +} + +func nbPLONKConstraints(baseLength int, collected []checkedVariable) int { + nbDecomposed := 0 + for i := range collected { + nbDecomposed += int(decompSize(collected[i].bits, baseLength)) + } + eqs := nbDecomposed // check correctness of every decomposition. this is nbDecomp adds + eq cost per collected + logn := stdbits.Len(uint(nbDecomposed)) + nbTable := 1 << baseLength + nbLeft := nbTable * + (3*logn + // tobinary. decomposition check + binary check + 2*logn + // select per exponent bit + logn + // mul per exponent bit + logn + // mul per exponent bit + 1) // final mul + nbRight := 2 * nbDecomposed // per decomposed sub and mul + return nbLeft + nbRight + eqs + 1 // single for final equality +} diff --git a/std/rangecheck/rangecheck_plain.go b/std/rangecheck/rangecheck_plain.go new file mode 100644 index 0000000000..6f20418f2d --- /dev/null +++ b/std/rangecheck/rangecheck_plain.go @@ -0,0 +1,14 @@ +package rangecheck + +import ( + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/math/bits" +) + +type plainChecker struct { + api frontend.API +} + +func (pl plainChecker) Check(v frontend.Variable, nbBits int) { + bits.ToBinary(pl.api, v, bits.WithNbDigits(nbBits)) +} diff --git a/std/rangecheck/rangecheck_test.go b/std/rangecheck/rangecheck_test.go new file mode 100644 index 0000000000..e135376afb --- /dev/null +++ b/std/rangecheck/rangecheck_test.go @@ -0,0 +1,50 @@ +package rangecheck + +import ( + "crypto/rand" + "math/big" + "testing" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/field/goldilocks" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/test" +) + +type CheckCircuit struct { + Vals []frontend.Variable + bits int + base int +} + +func (c *CheckCircuit) Define(api frontend.API) error { + r := newCommitRangechecker(api) + for i := range c.Vals { + r.Check(c.Vals[i], c.bits) + } + return nil +} + +func TestCheck(t *testing.T) { + assert := test.NewAssert(t) + var err error + bits := 64 + base := 11 + nbVals := 100000 + bound := new(big.Int).Lsh(big.NewInt(1), uint(bits)) + vals := make([]frontend.Variable, nbVals) + for i := range vals { + vals[i], err = rand.Int(rand.Reader, bound) + if err != nil { + t.Fatal(err) + } + } + witness := CheckCircuit{Vals: vals, bits: bits, base: base} + circuit := CheckCircuit{Vals: make([]frontend.Variable, len(vals)), bits: bits, base: base} + err = test.IsSolved(&circuit, &witness, goldilocks.Modulus()) + assert.NoError(err) + ccs, err := frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &circuit) + assert.NoError(err) + t.Log(ccs.GetNbConstraints()) +}