From bc6a8cd4575fe0cc8030242752980debbe71012c Mon Sep 17 00:00:00 2001 From: Marcin Kostrzewa Date: Tue, 3 Dec 2024 17:32:15 +0100 Subject: [PATCH] properly parametrize skyscrapers lookup table size --- example/main.go | 39 ++---------------- hash/skyscraper.go | 88 +++++++++++++++++++++++------------------ hash/skyscraper_test.go | 59 ++++++++++++++++----------- 3 files changed, 88 insertions(+), 98 deletions(-) diff --git a/example/main.go b/example/main.go index ad67b9c..bce79e1 100644 --- a/example/main.go +++ b/example/main.go @@ -6,11 +6,9 @@ import ( "github.com/consensys/gnark/backend/groth16" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" - "github.com/consensys/gnark/std/lookup/logderivlookup" "github.com/consensys/gnark/std/math/uints" gnark_nimue "github.com/reilabs/gnark-nimue" "github.com/reilabs/gnark-nimue/hash" - "math/bits" ) type TestCircuit struct { @@ -192,7 +190,7 @@ type Manhattan struct { } func (c *Manhattan) Define(api frontend.API) error { - s := hash.NewSkyscraper(api) + s := hash.NewSkyscraper(api, 1) a := c.I for range 3000 { a = s.Compress(a, a) @@ -208,41 +206,10 @@ func ExampleManhattan() { fmt.Println(err) return } - pk, vk, _ := groth16.Setup(ccs) - assignment := Manhattan{ - I: 1, - O: 1000, - } - witness, _ := frontend.NewWitness(&assignment, ecc.BN254.ScalarField()) - publicWitness, _ := witness.Public() - - proof, _ := groth16.Prove(ccs, pk, witness) - vErr := groth16.Verify(proof, vk, publicWitness) - fmt.Printf("%v\n", vErr) -} - -type TestLookup struct { - In frontend.Variable -} + fmt.Printf("constraints: %d\n", ccs.GetNbConstraints()) -func (c *TestLookup) Define(api frontend.API) error { - table := logderivlookup.New(api) - for i := range 256 { - table.Insert(bits.RotateLeft8(uint8(i), 3)) - } - c0 := c.In - for range 256 { - c0 = table.Lookup(c0)[0] - } - api.AssertIsEqual(c0, c.In) - return nil } func main() { - ccs, _ := frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &TestLookup{}) - fmt.Printf("constraints: %d\n", ccs.GetNbConstraints()) - - //Example1() - //ExampleWhir() - //ExampleManhattan() + ExampleManhattan() } diff --git a/hash/skyscraper.go b/hash/skyscraper.go index eb8e0c8..cc30dca 100644 --- a/hash/skyscraper.go +++ b/hash/skyscraper.go @@ -11,22 +11,25 @@ import ( "math/bits" ) -func bytesBeHint(field *big.Int, inputs []*big.Int, outputs []*big.Int) error { +func wordsBeHint(field *big.Int, inputs []*big.Int, outputs []*big.Int) error { if field.Cmp(ecc.BN254.ScalarField()) != 0 { return fmt.Errorf("bytesHint: expected BN254 Fr, got %s", field) } - if len(inputs) != 1 { - return fmt.Errorf("bytesHint: expected 1 input, got %d", len(inputs)) + if len(inputs) != 2 { + return fmt.Errorf("bytesHint: expected 2 inputs, got %d", len(inputs)) } - if len(outputs) != 16 { - return fmt.Errorf("bytesHint: expected 32 outputs, got %d", len(outputs)) + wordLen := int(inputs[0].Int64()) + if len(outputs) != 32/wordLen { + return fmt.Errorf("bytesHint: expected %d outputs, got %d", 32/wordLen, len(outputs)) } bytes := make([]byte, 32) - inputs[0].FillBytes(bytes) + inputs[1].FillBytes(bytes) for i, o := range outputs { - o.SetUint64(uint64(bytes[2*i])) - o.Mul(o, big.NewInt(256)) - o.Add(o, big.NewInt(int64(bytes[2*i+1]))) + o.SetUint64(0) + for j := range wordLen { + o.Mul(o, big.NewInt(256)) + o.Add(o, big.NewInt(int64(bytes[wordLen*i+j]))) + } } return nil } @@ -48,16 +51,17 @@ func gtHint(_ *big.Int, inputs []*big.Int, outputs []*big.Int) error { } func init() { - solver.RegisterHint(bytesBeHint) + solver.RegisterHint(wordsBeHint) solver.RegisterHint(gtHint) } type Skyscraper struct { - rc [8]big.Int - sigma big.Int - sboxT *logderivlookup.Table - rchk frontend.Rangechecker - api frontend.API + rc [8]big.Int + sigma big.Int + sboxT *logderivlookup.Table + rchk frontend.Rangechecker + wordSize int + api frontend.API } func sboxByte(b byte) byte { @@ -67,19 +71,22 @@ func sboxByte(b byte) byte { return bits.RotateLeft8(b^(x&y&z), 1) } -func initSbox(api frontend.API) *logderivlookup.Table { +func initSbox(api frontend.API, wordSize int) *logderivlookup.Table { t := logderivlookup.New(api) - for i := range 65536 { - w := uint16(i) - b1 := byte(w & 0xff) - b2 := byte(w >> 8) - r := uint16(sboxByte(b1)) | (uint16(sboxByte(b2)) << 8) + tableSize := 1 << (8 * wordSize) + for i := range tableSize { + r := uint64(0) + for j := range wordSize { + shiftSize := j * 8 + inpByte := byte((i >> shiftSize) & 0xff) + r |= uint64(sboxByte(inpByte)) << shiftSize + } t.Insert(r) } return t } -func NewSkyscraper(api frontend.API) *Skyscraper { +func NewSkyscraper(api frontend.API, wordSize int) *Skyscraper { rc := [8]big.Int{} rc[0].SetString("17829420340877239108687448009732280677191990375576158938221412342251481978692", 10) rc[1].SetString("5852100059362614845584985098022261541909346143980691326489891671321030921585", 10) @@ -95,8 +102,9 @@ func NewSkyscraper(api frontend.API) *Skyscraper { return &Skyscraper{ rc, sigma, - initSbox(api), + initSbox(api, wordSize), rangecheck.New(api), + wordSize, api, } } @@ -109,10 +117,10 @@ func (s *Skyscraper) square(v frontend.Variable) frontend.Variable { return s.api.Mul(s.api.Mul(v, v), s.sigma) } -func (s *Skyscraper) varFromBytesBe(bytes []frontend.Variable) frontend.Variable { +func (s *Skyscraper) varFromWordsBe(words []frontend.Variable) frontend.Variable { result := frontend.Variable(0) - for _, b := range bytes { - result = s.api.Mul(result, 65536) + for _, b := range words { + result = s.api.Mul(result, 1<<(8*s.wordSize)) result = s.api.Add(result, b) } return result @@ -137,26 +145,28 @@ func (s *Skyscraper) assertLessThanModulus(hi, lo frontend.Variable) { } // the result is NOT rangechecked, but if it is in range, it is canonical -func (s *Skyscraper) canonicalDecompose(v frontend.Variable) [16]frontend.Variable { - o, _ := s.api.Compiler().NewHint(bytesBeHint, 16, v) - result := [16]frontend.Variable{} +func (s *Skyscraper) canonicalDecompose(v frontend.Variable) []frontend.Variable { + wordsPerFelt := 32 / s.wordSize + o, _ := s.api.Compiler().NewHint(wordsBeHint, wordsPerFelt, s.wordSize, v) + result := make([]frontend.Variable, wordsPerFelt) copy(result[:], o) - s.api.AssertIsEqual(s.varFromBytesBe(result[:]), v) - s.assertLessThanModulus(s.varFromBytesBe(result[:8]), s.varFromBytesBe(result[8:])) + s.api.AssertIsEqual(s.varFromWordsBe(result[:]), v) + s.assertLessThanModulus(s.varFromWordsBe(result[:wordsPerFelt/2]), s.varFromWordsBe(result[wordsPerFelt/2:])) return result } func (s *Skyscraper) bar(v frontend.Variable) frontend.Variable { - bytes := s.canonicalDecompose(v) - tmp := [8]frontend.Variable{} - copy(tmp[:], bytes[:8]) - copy(bytes[:], bytes[8:]) - copy(bytes[8:], tmp[:]) - for i := range bytes { + words := s.canonicalDecompose(v) + wordsPerFelt := 32 / s.wordSize + tmp := make([]frontend.Variable, wordsPerFelt/2) + copy(tmp[:], words[:wordsPerFelt/2]) + copy(words[:], words[wordsPerFelt/2:]) + copy(words[wordsPerFelt/2:], tmp[:]) + for i := range words { // sbox implicitly rangechecks the input - bytes[i] = s.sbox(bytes[i]) + words[i] = s.sbox(words[i]) } - return s.varFromBytesBe(bytes[:]) + return s.varFromWordsBe(words[:]) } func (s *Skyscraper) Permute(state *[2]frontend.Variable) { diff --git a/hash/skyscraper_test.go b/hash/skyscraper_test.go index 0364a20..7e09c7d 100644 --- a/hash/skyscraper_test.go +++ b/hash/skyscraper_test.go @@ -1,6 +1,7 @@ package hash import ( + "fmt" "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/backend" "github.com/consensys/gnark/frontend" @@ -16,23 +17,27 @@ func bigIntFromString(s string) frontend.Variable { } type TestSboxC struct { - In, Out frontend.Variable + WordSize int + In, Out frontend.Variable } func (c *TestSboxC) Define(api frontend.API) error { - s := NewSkyscraper(api) + s := NewSkyscraper(api, c.WordSize) api.AssertIsEqual(s.sbox(c.In), c.Out) return nil } func TestSbox(t *testing.T) { assert := test.NewAssert(t) - assert.CheckCircuit(&TestSboxC{}, test.WithCurves(ecc.BN254), test.WithBackends(backend.GROTH16), - test.WithValidAssignment(&TestSboxC{0xcd, 0xd3}), - test.WithValidAssignment(&TestSboxC{0x17, 0x0e}), - test.WithInvalidAssignment(&TestSboxC{0x17, 0x0f}), - test.WithInvalidAssignment(&TestSboxC{0x1234, 0x0f})) - + for wordSize := 1; wordSize <= 2; wordSize++ { + assert.CheckCircuit(&TestSboxC{WordSize: wordSize}, test.WithCurves(ecc.BN254), test.WithBackends(backend.GROTH16), + test.WithValidAssignment(&TestSboxC{wordSize, 0xcd, 0xd3}), + test.WithValidAssignment(&TestSboxC{wordSize, 0x17, 0x0e}), + test.WithInvalidAssignment(&TestSboxC{wordSize, 0x17, 0x0f}), + test.WithInvalidAssignment(&TestSboxC{wordSize, 0x1234, 0x0f})) + } + assert.CheckCircuit(&TestSboxC{WordSize: 2}, test.WithCurves(ecc.BN254), test.WithBackends(backend.GROTH16), + test.WithValidAssignment(&TestSboxC{2, 0xcd17, 0xd30e})) } type TestSquareC struct { @@ -40,7 +45,7 @@ type TestSquareC struct { } func (c *TestSquareC) Define(api frontend.API) error { - s := NewSkyscraper(api) + s := NewSkyscraper(api, 1) s.sbox(123) // needed to silence an error about unused lookup tables api.AssertIsEqual(s.square(c.In), c.Out) return nil @@ -57,40 +62,48 @@ func TestSquare(t *testing.T) { } type TestBarC struct { - In, Out frontend.Variable + WordSize int + In, Out frontend.Variable } func (c *TestBarC) Define(api frontend.API) error { - s := NewSkyscraper(api) + s := NewSkyscraper(api, c.WordSize) api.AssertIsEqual(s.bar(c.In), c.Out) return nil } func TestBar(t *testing.T) { assert := test.NewAssert(t) - assert.CheckCircuit(&TestBarC{}, test.WithCurves(ecc.BN254), test.WithBackends(backend.GROTH16), - test.WithValidAssignment(&TestBarC{0, 0}), - test.WithValidAssignment(&TestBarC{1, bigIntFromString("680564733841876926926749214863536422912")}), - test.WithValidAssignment(&TestBarC{2, bigIntFromString("1361129467683753853853498429727072845824")}), - test.WithValidAssignment(&TestBarC{bigIntFromString("4111585712030104139416666328230194227848755236259444667527487224433891325648"), bigIntFromString("18867677047139790809471719918880601980605904427073186248909139907505620573990")})) - + for wordSize := 1; wordSize <= 2; wordSize++ { + fmt.Printf("wordSize: %d\n", wordSize) + assert.CheckCircuit(&TestBarC{WordSize: wordSize}, test.WithCurves(ecc.BN254), test.WithBackends(backend.GROTH16), + test.WithValidAssignment(&TestBarC{wordSize, 0, 0}), + test.WithValidAssignment(&TestBarC{wordSize, 1, bigIntFromString("680564733841876926926749214863536422912")}), + test.WithValidAssignment(&TestBarC{wordSize, 2, bigIntFromString("1361129467683753853853498429727072845824")}), + test.WithValidAssignment(&TestBarC{wordSize, bigIntFromString("4111585712030104139416666328230194227848755236259444667527487224433891325648"), bigIntFromString("18867677047139790809471719918880601980605904427073186248909139907505620573990")})) + + } } type TestCompressC struct { + WordSize int In1, In2, Out frontend.Variable } func (c *TestCompressC) Define(api frontend.API) error { - s := NewSkyscraper(api) + s := NewSkyscraper(api, c.WordSize) api.AssertIsEqual(s.Compress(c.In1, c.In2), c.Out) return nil } func TestCompress(t *testing.T) { assert := test.NewAssert(t) - assert.CheckCircuit(&TestCompressC{}, test.WithCurves(ecc.BN254), test.WithBackends(backend.GROTH16), - test.WithValidAssignment(&TestCompressC{ - bigIntFromString("21614608883591910674239883101354062083890746690626773887530227216615498812963"), - bigIntFromString("9813154100006487150380270585621895148484502414032888228750638800367218873447"), - bigIntFromString("3583228880285179354728993622328037400470978495633822008876840172083178912457")})) + for wordSize := 1; wordSize <= 2; wordSize++ { + assert.CheckCircuit(&TestCompressC{WordSize: wordSize}, test.WithCurves(ecc.BN254), test.WithBackends(backend.GROTH16), + test.WithValidAssignment(&TestCompressC{wordSize, + bigIntFromString("21614608883591910674239883101354062083890746690626773887530227216615498812963"), + bigIntFromString("9813154100006487150380270585621895148484502414032888228750638800367218873447"), + bigIntFromString("3583228880285179354728993622328037400470978495633822008876840172083178912457")})) + } + }