Skip to content

Commit

Permalink
Add a Constraints helper
Browse files Browse the repository at this point in the history
There are two existing patterns for constructing a gate from a set of
constraints with a common selector:

- Create an iterator of constraints, where each constraint includes the
  selector:
  ```
  vec![
      ("foo", selector.clone() * foo),
      ("bar", selector.clone() * bar),
      ("baz", selector * bar),
  ]
  ```
  This requires the user to write O(n) `selector.clone()` calls.

- Create an iterator of constraints, and then map the selector in:
  ```
  vec![
      ("foo", foo),
      ("bar", bar),
      ("baz", bar),
  ].into_iter().map(move |(name, poly)| (name, selector.clone() * poly))
  ```
  This looks cleaner overall, but the API is not as intuitive, and it
  is messier when the constraints are named.

The `Constraints` struct provides a third, clearer API:
```
Constraints::with_selector(
    selector,
    vec![
        ("foo", foo),
        ("bar", bar),
        ("baz", bar),
    ],
)
```
This focuses on the structure of the constraints, and handles the
selector application for the user.
  • Loading branch information
str4d committed Dec 3, 2021
1 parent 0295dc7 commit 34912cf
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 27 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ and this project adheres to Rust's notion of
[Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased]
### Added
- `halo2::plonk::Constraints` helper, for constructing a gate from a set of
constraints with a common selector.

### Changed
- `halo2::plonk::Error` has been overhauled:
- `Error` now implements `std::fmt::Display` and `std::error::Error`.
Expand Down
77 changes: 50 additions & 27 deletions examples/sha256/table16/compression/compression_gates.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
use super::super::{util::*, Gate};
use halo2::{arithmetic::FieldExt, plonk::Expression};
use halo2::{
arithmetic::FieldExt,
plonk::{Constraints, Expression},
};
use std::{array, marker::PhantomData};

pub struct CompressionGate<F: FieldExt>(PhantomData<F>);
Expand Down Expand Up @@ -32,7 +35,11 @@ impl<F: FieldExt> CompressionGate<F> {
spread_word_lo: Expression<F>,
word_hi: Expression<F>,
spread_word_hi: Expression<F>,
) -> impl Iterator<Item = (&'static str, Expression<F>)> {
) -> Constraints<
F,
(&'static str, Expression<F>),
impl Iterator<Item = (&'static str, Expression<F>)>,
> {
let check_spread_and_range =
Gate::three_bit_spread_and_range(c_lo.clone(), spread_c_lo.clone())
.chain(Gate::three_bit_spread_and_range(
Expand Down Expand Up @@ -63,12 +70,14 @@ impl<F: FieldExt> CompressionGate<F> {
+ spread_word_lo * (-F::one())
+ spread_word_hi * F::from_u64(1 << 32) * (-F::one());

check_spread_and_range
.chain(Some(("range_check_tag_b", range_check_tag_b)))
.chain(Some(("range_check_tag_d", range_check_tag_d)))
.chain(Some(("dense_check", dense_check)))
.chain(Some(("spread_check", spread_check)))
.map(move |(name, poly)| (name, s_decompose_abcd.clone() * poly))
Constraints::with_selector(
s_decompose_abcd,
check_spread_and_range
.chain(Some(("range_check_tag_b", range_check_tag_b)))
.chain(Some(("range_check_tag_d", range_check_tag_d)))
.chain(Some(("dense_check", dense_check)))
.chain(Some(("spread_check", spread_check))),
)
}

// Decompose `E,F,G,H` words
Expand All @@ -94,7 +103,11 @@ impl<F: FieldExt> CompressionGate<F> {
spread_word_lo: Expression<F>,
word_hi: Expression<F>,
spread_word_hi: Expression<F>,
) -> impl Iterator<Item = (&'static str, Expression<F>)> {
) -> Constraints<
F,
(&'static str, Expression<F>),
impl Iterator<Item = (&'static str, Expression<F>)>,
> {
let check_spread_and_range =
Gate::three_bit_spread_and_range(a_lo.clone(), spread_a_lo.clone())
.chain(Gate::three_bit_spread_and_range(
Expand Down Expand Up @@ -128,12 +141,14 @@ impl<F: FieldExt> CompressionGate<F> {
+ spread_word_lo * (-F::one())
+ spread_word_hi * F::from_u64(1 << 32) * (-F::one());

check_spread_and_range
.chain(Some(("range_check_tag_c", range_check_tag_c)))
.chain(Some(("range_check_tag_d", range_check_tag_d)))
.chain(Some(("dense_check", dense_check)))
.chain(Some(("spread_check", spread_check)))
.map(move |(name, poly)| (name, s_decompose_efgh.clone() * poly))
Constraints::with_selector(
s_decompose_efgh,
check_spread_and_range
.chain(Some(("range_check_tag_c", range_check_tag_c)))
.chain(Some(("range_check_tag_d", range_check_tag_d)))
.chain(Some(("dense_check", dense_check)))
.chain(Some(("spread_check", spread_check))),
)
}

// s_upper_sigma_0 on abcd words
Expand Down Expand Up @@ -263,7 +278,11 @@ impl<F: FieldExt> CompressionGate<F> {
spread_e_neg_hi: Expression<F>,
spread_g_lo: Expression<F>,
spread_g_hi: Expression<F>,
) -> impl Iterator<Item = (&'static str, Expression<F>)> {
) -> Constraints<
F,
(&'static str, Expression<F>),
impl Iterator<Item = (&'static str, Expression<F>)>,
> {
let neg_check = {
let evens = Self::ones() * F::from_u64(MASK_EVEN_32 as u64);
// evens - spread_e_lo = spread_e_neg_lo
Expand All @@ -284,9 +303,7 @@ impl<F: FieldExt> CompressionGate<F> {
let rhs_odd = spread_q0_odd + spread_q1_odd * F::from_u64(1 << 32);
let rhs = rhs_even + rhs_odd * F::from_u64(2);

neg_check
.chain(Some(("s_ch_neg", lhs - rhs)))
.map(move |(name, poly)| (name, s_ch_neg.clone() * poly))
Constraints::with_selector(s_ch_neg, neg_check.chain(Some(("s_ch_neg", lhs - rhs))))
}

// Majority gate on (A, B, C)
Expand Down Expand Up @@ -409,17 +426,23 @@ impl<F: FieldExt> CompressionGate<F> {
lo_3: Expression<F>,
hi_3: Expression<F>,
word_3: Expression<F>,
) -> impl Iterator<Item = (&'static str, Expression<F>)> {
) -> Constraints<
F,
(&'static str, Expression<F>),
impl Iterator<Item = (&'static str, Expression<F>)>,
> {
let check_lo_hi = |lo: Expression<F>, hi: Expression<F>, word: Expression<F>| {
lo + hi * F::from_u64(1 << 16) - word
};

array::IntoIter::new([
("check_lo_hi_0", check_lo_hi(lo_0, hi_0, word_0)),
("check_lo_hi_1", check_lo_hi(lo_1, hi_1, word_1)),
("check_lo_hi_2", check_lo_hi(lo_2, hi_2, word_2)),
("check_lo_hi_3", check_lo_hi(lo_3, hi_3, word_3)),
])
.map(move |(name, poly)| (name, s_digest.clone() * poly))
Constraints::with_selector(
s_digest,
array::IntoIter::new([
("check_lo_hi_0", check_lo_hi(lo_0, hi_0, word_0)),
("check_lo_hi_1", check_lo_hi(lo_1, hi_1, word_1)),
("check_lo_hi_2", check_lo_hi(lo_2, hi_2, word_2)),
("check_lo_hi_3", check_lo_hi(lo_3, hi_3, word_3)),
]),
)
}
}
77 changes: 77 additions & 0 deletions src/plonk/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -917,6 +917,83 @@ impl<F: Field> From<Expression<F>> for Vec<Constraint<F>> {
}
}

/// A set of polynomial constraints with a common selector.
///
/// ```
/// use halo2::{pasta::Fp, plonk::{Constraints, Expression}, poly::Rotation};
/// # use halo2::plonk::ConstraintSystem;
///
/// # let mut meta = ConstraintSystem::<Fp>::default();
/// let a = meta.advice_column();
/// let b = meta.advice_column();
/// let c = meta.advice_column();
/// let s = meta.selector();
///
/// meta.create_gate("foo", |meta| {
/// let next = meta.query_advice(a, Rotation::next());
/// let a = meta.query_advice(a, Rotation::cur());
/// let b = meta.query_advice(b, Rotation::cur());
/// let c = meta.query_advice(c, Rotation::cur());
/// let s_ternary = meta.query_selector(s);
///
/// let one_minus_a = Expression::Constant(Fp::one()) - a.clone();
///
/// Constraints::with_selector(
/// s_ternary,
/// std::array::IntoIter::new([
/// ("a is boolean", a.clone() * one_minus_a.clone()),
/// ("next == a ? b : c", next - (a * b + one_minus_a * c)),
/// ]),
/// )
/// });
/// ```
#[derive(Debug)]
pub struct Constraints<F: Field, C: Into<Constraint<F>>, Iter: IntoIterator<Item = C>> {
selector: Expression<F>,
constraints: Iter,
}

impl<F: Field, C: Into<Constraint<F>>, Iter: IntoIterator<Item = C>> Constraints<F, C, Iter> {
/// Constructs a set of constraints that are controlled by the given selector.
///
/// Each constraint `c` in `iterator` will be converted into the constraint
/// `selector * c`.
pub fn with_selector(selector: Expression<F>, constraints: Iter) -> Self {
Constraints {
selector,
constraints,
}
}
}

fn apply_selector_to_constraint<F: Field, C: Into<Constraint<F>>>(
(selector, c): (Expression<F>, C),
) -> Constraint<F> {
let constraint: Constraint<F> = c.into();
Constraint {
name: constraint.name,
poly: selector * constraint.poly,
}
}

type ApplySelectorToConstraint<F, C> = fn((Expression<F>, C)) -> Constraint<F>;

impl<F: Field, C: Into<Constraint<F>>, Iter: IntoIterator<Item = C>> IntoIterator
for Constraints<F, C, Iter>
{
type Item = Constraint<F>;
type IntoIter = std::iter::Map<
std::iter::Zip<std::iter::Repeat<Expression<F>>, Iter::IntoIter>,
ApplySelectorToConstraint<F, C>,
>;

fn into_iter(self) -> Self::IntoIter {
std::iter::repeat(self.selector)
.zip(self.constraints.into_iter())
.map(apply_selector_to_constraint)
}
}

#[derive(Clone, Debug)]
pub(crate) struct Gate<F: Field> {
name: &'static str,
Expand Down

0 comments on commit 34912cf

Please sign in to comment.