Skip to content

Commit

Permalink
feat: range check gadget (#472)
Browse files Browse the repository at this point in the history
* feat: add external range checker interface

* feat: add optional FrontendType method to the builders

In range checking gadget we try to estimate the number of constraints given
different parameters. But for estimating we need to know the costs of
operations. And the costs of the operations depend on the way we arithmetize
the circuit.

Added an internal interface which allows to query the arithmetization method
and implement this in existing builders.

* feat: implement range checking

* feat: use range checking in field emulation

* test: update circuit statistics

* test: update stats

* test: update stats
  • Loading branch information
ivokub authored Mar 9, 2023
1 parent bd39e9f commit a8af4f3
Show file tree
Hide file tree
Showing 13 changed files with 399 additions and 50 deletions.
9 changes: 9 additions & 0 deletions frontend/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,12 @@ type Committer interface {
// Commit commits to the variables and returns the commitment.
Commit(toCommit ...Variable) (commitment Variable, err error)
}

// Rangechecker allows to externally range-check the variables to be of
// specified width. Not all compilers implement this interface. Users should
// instead use [github.com/consensys/gnark/std/rangecheck] package which
// automatically chooses most optimal method for range checking the variables.
type Rangechecker interface {
// Check checks that the given variable v has bit-length bits.
Check(v Variable, bits int)
}
5 changes: 5 additions & 0 deletions frontend/cs/r1cs/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
"github.com/consensys/gnark/frontend/internal/expr"
"github.com/consensys/gnark/frontend/schema"
"github.com/consensys/gnark/internal/circuitdefer"
"github.com/consensys/gnark/internal/frontendtype"
"github.com/consensys/gnark/internal/kvstore"
"github.com/consensys/gnark/internal/tinyfield"
"github.com/consensys/gnark/internal/utils"
Expand Down Expand Up @@ -452,3 +453,7 @@ func (builder *builder) compress(le expr.LinearExpression) expr.LinearExpression
func (builder *builder) Defer(cb func(frontend.API) error) {
circuitdefer.Put(builder, cb)
}

func (*builder) FrontendType() frontendtype.Type {
return frontendtype.R1CS
}
5 changes: 5 additions & 0 deletions frontend/cs/scs/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/frontend/internal/expr"
"github.com/consensys/gnark/frontend/schema"
"github.com/consensys/gnark/internal/frontendtype"
"github.com/consensys/gnark/std/math/bits"
)

Expand Down Expand Up @@ -557,3 +558,7 @@ func (builder *builder) printArg(log *constraint.LogEntry, sbb *strings.Builder,
func (builder *builder) Compiler() frontend.Compiler {
return builder
}

func (*builder) FrontendType() frontendtype.Type {
return frontendtype.SCS
}
13 changes: 13 additions & 0 deletions internal/frontendtype/frontendtype.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// Package frontendtype allows to assert frontend type.
package frontendtype

type Type int

const (
R1CS Type = iota
SCS
)

type FrontendTyper interface {
FrontendType() Type
}
Binary file modified internal/stats/latest.stats
Binary file not shown.
2 changes: 2 additions & 0 deletions std/hints.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/consensys/gnark/std/algebra/native/sw_bls24315"
"github.com/consensys/gnark/std/math/bits"
"github.com/consensys/gnark/std/math/emulated"
"github.com/consensys/gnark/std/rangecheck"
"github.com/consensys/gnark/std/selector"
)

Expand All @@ -34,4 +35,5 @@ func registerHints() {
solver.RegisterHint(selector.MuxIndicators)
solver.RegisterHint(selector.MapIndicators)
solver.RegisterHint(emulated.GetHints()...)
solver.RegisterHint(rangecheck.CountHint, rangecheck.DecomposeHint)
}
3 changes: 3 additions & 0 deletions std/math/emulated/field.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/internal/utils"
"github.com/consensys/gnark/logger"
"github.com/consensys/gnark/std/rangecheck"
"github.com/rs/zerolog"
"golang.org/x/exp/constraints"
)
Expand Down Expand Up @@ -38,6 +39,7 @@ type Field[T FieldParams] struct {
log zerolog.Logger

constrainedLimbs map[uint64]struct{}
checker frontend.Rangechecker
}

// NewField returns an object to be used in-circuit to perform emulated
Expand All @@ -53,6 +55,7 @@ func NewField[T FieldParams](native frontend.API) (*Field[T], error) {
api: native,
log: logger.Logger(),
constrainedLimbs: make(map[uint64]struct{}),
checker: rangecheck.New(native),
}

// ensure prime is correctly set
Expand Down
55 changes: 14 additions & 41 deletions std/math/emulated/field_assert.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,12 @@ import (
"math/big"

"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/std/math/bits"
)

// assertLimbsEqualitySlow is the main routine in the package. It asserts that the
// two slices of limbs represent the same integer value. This is also the most
// costly operation in the package as it does bit decomposition of the limbs.
func assertLimbsEqualitySlow(api frontend.API, l, r []frontend.Variable, nbBits, nbCarryBits uint) {
func (f *Field[T]) assertLimbsEqualitySlow(api frontend.API, l, r []frontend.Variable, nbBits, nbCarryBits uint) {

nbLimbs := max(len(l), len(r))
maxValue := new(big.Int).Lsh(big.NewInt(1), nbBits+nbCarryBits)
Expand All @@ -33,52 +32,29 @@ func assertLimbsEqualitySlow(api frontend.API, l, r []frontend.Variable, nbBits,
// carry is stored in the highest bits of diff[nbBits:nbBits+nbCarryBits+1]
// we know that diff[:nbBits] are 0 bits, but still need to constrain them.
// to do both; we do a "clean" right shift and only need to boolean constrain the carry part
carry = rsh(api, diff, int(nbBits), int(nbBits+nbCarryBits+1))
carry = f.rsh(diff, int(nbBits), int(nbBits+nbCarryBits+1))
}
api.AssertIsEqual(carry, maxValueShift)
}

// rsh right shifts a variable endDigit-startDigit bits and returns it.
func rsh(api frontend.API, v frontend.Variable, startDigit, endDigit int) frontend.Variable {
func (f *Field[T]) rsh(v frontend.Variable, startDigit, endDigit int) frontend.Variable {
// if v is a constant, work with the big int value.
if c, ok := api.Compiler().ConstantValue(v); ok {
if c, ok := f.api.Compiler().ConstantValue(v); ok {
bits := make([]frontend.Variable, endDigit-startDigit)
for i := 0; i < len(bits); i++ {
bits[i] = c.Bit(i + startDigit)
}
return bits
}

bits, err := api.Compiler().NewHint(NBitsShifted, endDigit-startDigit, v, startDigit)
shifted, err := f.api.Compiler().NewHint(RightShift, 1, startDigit, v)
if err != nil {
panic(err)
}

// we compute 2 sums;
// Σbi ensures that "ignoring" the lowest bits (< startDigit) still is a valid bit decomposition.
// that is, it ensures that bits from startDigit to endDigit * corresponding coefficients (powers of 2 shifted)
// are equal to the input variable
// ΣbiRShift computes the actual result; that is, the Σ (2**i * b[i])
Σbi := frontend.Variable(0)
ΣbiRShift := frontend.Variable(0)

cRShift := big.NewInt(1)
c := big.NewInt(1)
c.Lsh(c, uint(startDigit))

for i := 0; i < len(bits); i++ {
Σbi = api.MulAcc(Σbi, bits[i], c)
ΣbiRShift = api.MulAcc(ΣbiRShift, bits[i], cRShift)

c.Lsh(c, 1)
cRShift.Lsh(cRShift, 1)
api.AssertIsBoolean(bits[i])
panic(fmt.Sprintf("right shift: %v", err))
}

// constraint Σ (2**i_shift * b[i]) == v
api.AssertIsEqual(Σbi, v)
return ΣbiRShift

f.checker.Check(shifted[0], endDigit-startDigit)
shift := new(big.Int).Lsh(big.NewInt(1), uint(startDigit))
composed := f.api.Mul(shifted[0], shift)
f.api.AssertIsEqual(composed, v)
return shifted[0]
}

// AssertLimbsEquality asserts that the limbs represent a same integer value.
Expand Down Expand Up @@ -107,9 +83,9 @@ func (f *Field[T]) AssertLimbsEquality(a, b *Element[T]) {
// TODO: we previously assumed that one side was "larger" than the other
// side, but I think this assumption is not valid anymore
if a.overflow > b.overflow {
assertLimbsEqualitySlow(f.api, ca, cb, bitsPerLimb, a.overflow)
f.assertLimbsEqualitySlow(f.api, ca, cb, bitsPerLimb, a.overflow)
} else {
assertLimbsEqualitySlow(f.api, cb, ca, bitsPerLimb, b.overflow)
f.assertLimbsEqualitySlow(f.api, cb, ca, bitsPerLimb, b.overflow)
}
}

Expand All @@ -133,10 +109,7 @@ func (f *Field[T]) enforceWidth(a *Element[T], modWidth bool) {
// take only required bits from the most significant limb
limbNbBits = ((f.fParams.Modulus().BitLen() - 1) % int(f.fParams.BitsPerLimb())) + 1
}
// bits.ToBinary restricts the least significant NbDigits to be equal to
// the limb value. This is sufficient to restrict for the bitlength and
// we can discard the bits themselves.
bits.ToBinary(f.api, a.Limbs[i], bits.WithNbDigits(limbNbBits))
f.checker.Check(a.Limbs[i], limbNbBits)
}
}

Expand Down
25 changes: 16 additions & 9 deletions std/math/emulated/hints.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func GetHints() []solver.Hint {
InverseHint,
MultiplicationHint,
RemHint,
NBitsShifted,
RightShift,
}
}

Expand Down Expand Up @@ -287,13 +287,20 @@ func parseHintDivInputs(inputs []*big.Int) (uint, int, *big.Int, *big.Int, error
return nbBits, nbLimbs, x, y, nil
}

// NBitsShifted returns the first bits of the input, with a shift. The number of returned bits is
// defined by the length of the results slice.
func NBitsShifted(_ *big.Int, inputs []*big.Int, results []*big.Int) error {
n := inputs[0]
shift := inputs[1].Uint64() // TODO @gbotrel validate input vs perf in large circuits.
for i := 0; i < len(results); i++ {
results[i].SetUint64(uint64(n.Bit(i + int(shift))))
}
// RightShift shifts input by the given number of bits. Expects two inputs:
// - first input is the shift, will be represented as uint64;
// - second input is the value to be shifted.
//
// Returns a single output which is the value shifted. Errors if number of
// inputs is not 2 and number of outputs is not 1.
func RightShift(_ *big.Int, inputs []*big.Int, outputs []*big.Int) error {
if len(inputs) != 2 {
return fmt.Errorf("expecting two inputs")
}
if len(outputs) != 1 {
return fmt.Errorf("expecting single output")
}
shift := inputs[0].Uint64()
outputs[0].Rsh(inputs[1], uint(shift))
return nil
}
30 changes: 30 additions & 0 deletions std/rangecheck/rangecheck.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Package rangecheck implements range checking gadget
//
// This package chooses the most optimal path for performing range checks:
// - if the backend supports native range checking and the frontend exports the variables in the proprietary format by implementing [frontend.Rangechecker], then use it directly;
// - if the backend supports creating a commitment of variables by implementing [frontend.Committer], then we use the product argument as in [BCG+18]. [r1cs.NewBuilder] returns a builder which implements this interface;
// - lacking these, we perform binary decomposition of variable into bits.
//
// [BCG+18]: https://eprint.iacr.org/2018/380
package rangecheck

import (
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/frontend/cs/r1cs"
)

// only for documentation purposes. If we import the package then godoc knows
// how to refer to package r1cs and we get nice links in godoc. We import the
// package anyway in test.
var _ = r1cs.NewBuilder

// New returns a new range checker depending on the frontend capabilities.
func New(api frontend.API) frontend.Rangechecker {
if rc, ok := api.(frontend.Rangechecker); ok {
return rc
}
if _, ok := api.(frontend.Committer); ok {
return newCommitRangechecker(api)
}
return plainChecker{api: api}
}
Loading

0 comments on commit a8af4f3

Please sign in to comment.