diff --git a/packages/state-transition/src/util/seed.ts b/packages/state-transition/src/util/seed.ts index 70604ac21fcc..c09d1205aa56 100644 --- a/packages/state-transition/src/util/seed.ts +++ b/packages/state-transition/src/util/seed.ts @@ -35,6 +35,7 @@ export function computeProposers( fork, effectiveBalanceIncrements, shuffling.activeIndices, + // TODO: if we use hashTree, we can precompute the roots for the next n loops digest(Buffer.concat([epochSeed, intToBytes(slot, 8)])) ) ); @@ -44,10 +45,11 @@ export function computeProposers( /** * Return from ``indices`` a random index sampled by effective balance. + * This is just to make sure lodestar follows the spec, this is not for production. * * SLOW CODE - 🐢 */ -export function computeProposerIndex( +export function naiveComputeProposerIndex( fork: ForkSeq, effectiveBalanceIncrements: EffectiveBalanceIncrements, indices: ArrayLike, @@ -95,7 +97,93 @@ export function computeProposerIndex( } /** - * TODO: NAIVE + * Optimized version of `naiveComputeProposerIndex`. + * It shows > 3x speedup according to the perf test. + */ +export function computeProposerIndex( + fork: ForkSeq, + effectiveBalanceIncrements: EffectiveBalanceIncrements, + indices: ArrayLike, + seed: Uint8Array +): ValidatorIndex { + if (indices.length === 0) { + throw Error("Validator indices must not be empty"); + } + + if (fork >= ForkSeq.electra) { + // electra, see inline comments for the optimization + const MAX_RANDOM_VALUE = 2 ** 16 - 1; + const MAX_EFFECTIVE_BALANCE_INCREMENT = MAX_EFFECTIVE_BALANCE_ELECTRA / EFFECTIVE_BALANCE_INCREMENT; + + const shuffledIndexFn = getComputeShuffledIndexFn(indices.length, seed); + // this simple cache makes sure we don't have to recompute the shuffled index for the next round of activeValidatorCount + const shuffledResult = new Map(); + + let i = 0; + const cachedHashInput = Buffer.allocUnsafe(32 + 8); + cachedHashInput.set(seed, 0); + cachedHashInput.writeUint32LE(0, 32 + 4); + let cachedHash: Uint8Array | null = null; + while (true) { + // an optimized version of the below naive code + // const candidateIndex = indices[computeShuffledIndex(i % indices.length, indices.length, seed)]; + const index = i % indices.length; + let shuffledIndex = shuffledResult.get(index); + if (shuffledIndex == null) { + shuffledIndex = shuffledIndexFn(index); + shuffledResult.set(index, shuffledIndex); + } + const candidateIndex = indices[shuffledIndex]; + + // compute a new hash every 16 iterations + if (i % 16 === 0) { + cachedHashInput.writeUint32LE(Math.floor(i / 16), 32); + cachedHash = digest(cachedHashInput); + } + + if (cachedHash == null) { + // there is always a cachedHash, handle this to make the compiler happy + throw new Error("cachedHash should not be null"); + } + + const randomBytes = cachedHash; + const offset = (i % 16) * 2; + // this is equivalent to bytesToInt(randomBytes.subarray(offset, offset + 2)); + // but it does not get through BigInt + const lowByte = randomBytes[offset]; + const highByte = randomBytes[offset + 1]; + const randomValue = lowByte + highByte * 256; + + const effectiveBalanceIncrement = effectiveBalanceIncrements[candidateIndex]; + if (effectiveBalanceIncrement * MAX_RANDOM_VALUE >= MAX_EFFECTIVE_BALANCE_INCREMENT * randomValue) { + return candidateIndex; + } + + i += 1; + } + } else { + // preelectra, this function is the same to the naive version + const MAX_RANDOM_BYTE = 2 ** 8 - 1; + const MAX_EFFECTIVE_BALANCE_INCREMENT = MAX_EFFECTIVE_BALANCE / EFFECTIVE_BALANCE_INCREMENT; + + let i = 0; + while (true) { + const candidateIndex = indices[computeShuffledIndex(i % indices.length, indices.length, seed)]; + const randomByte = digest(Buffer.concat([seed, intToBytes(Math.floor(i / 32), 8, "le")]))[i % 32]; + + const effectiveBalanceIncrement = effectiveBalanceIncrements[candidateIndex]; + if (effectiveBalanceIncrement * MAX_RANDOM_BYTE >= MAX_EFFECTIVE_BALANCE_INCREMENT * randomByte) { + return candidateIndex; + } + + i += 1; + } + } +} + +/** + * Naive version, this is not supposed to be used in production. + * See `computeProposerIndex` for the optimized version. * * Return the sync committee indices for a given state and epoch. * Aligns `epoch` to `baseEpoch` so the result is the same with any `epoch` within a sync period. @@ -104,7 +192,7 @@ export function computeProposerIndex( * * SLOW CODE - 🐢 */ -export function getNextSyncCommitteeIndices( +export function naiveGetNextSyncCommitteeIndices( fork: ForkSeq, state: BeaconStateAllForks, activeValidatorIndices: ArrayLike, @@ -161,6 +249,101 @@ export function getNextSyncCommitteeIndices( return syncCommitteeIndices; } +/** + * Optmized version of `naiveGetNextSyncCommitteeIndices`. + * + * In the worse case scenario, this could be >1000x speedup according to the perf test. + */ +export function getNextSyncCommitteeIndices( + fork: ForkSeq, + state: BeaconStateAllForks, + activeValidatorIndices: ArrayLike, + effectiveBalanceIncrements: EffectiveBalanceIncrements +): ValidatorIndex[] { + const syncCommitteeIndices = []; + + if (fork >= ForkSeq.electra) { + // electra, see inline comments for the optimization + const MAX_RANDOM_VALUE = 2 ** 16 - 1; + const MAX_EFFECTIVE_BALANCE_INCREMENT = MAX_EFFECTIVE_BALANCE_ELECTRA / EFFECTIVE_BALANCE_INCREMENT; + + const epoch = computeEpochAtSlot(state.slot) + 1; + const activeValidatorCount = activeValidatorIndices.length; + const seed = getSeed(state, epoch, DOMAIN_SYNC_COMMITTEE); + const shuffledIndexFn = getComputeShuffledIndexFn(activeValidatorCount, seed); + + let i = 0; + let cachedHash: Uint8Array | null = null; + const cachedHashInput = Buffer.allocUnsafe(32 + 8); + cachedHashInput.set(seed, 0); + cachedHashInput.writeUInt32LE(0, 32 + 4); + // this simple cache makes sure we don't have to recompute the shuffled index for the next round of activeValidatorCount + const shuffledResult = new Map(); + while (syncCommitteeIndices.length < SYNC_COMMITTEE_SIZE) { + // optimized version of the below naive code + // const shuffledIndex = shuffledIndexFn(i % activeValidatorCount); + const index = i % activeValidatorCount; + let shuffledIndex = shuffledResult.get(index); + if (shuffledIndex == null) { + shuffledIndex = shuffledIndexFn(index); + shuffledResult.set(index, shuffledIndex); + } + const candidateIndex = activeValidatorIndices[shuffledIndex]; + + // compute a new hash every 16 iterations + if (i % 16 === 0) { + cachedHashInput.writeUint32LE(Math.floor(i / 16), 32); + cachedHash = digest(cachedHashInput); + } + + if (cachedHash == null) { + // there is always a cachedHash, handle this to make the compiler happy + throw new Error("cachedHash should not be null"); + } + + const randomBytes = cachedHash; + const offset = (i % 16) * 2; + + // this is equivalent to bytesToInt(randomBytes.subarray(offset, offset + 2)); + // but it does not get through BigInt + const lowByte = randomBytes[offset]; + const highByte = randomBytes[offset + 1]; + const randomValue = lowByte + highByte * 256; + + const effectiveBalanceIncrement = effectiveBalanceIncrements[candidateIndex]; + if (effectiveBalanceIncrement * MAX_RANDOM_VALUE >= MAX_EFFECTIVE_BALANCE_INCREMENT * randomValue) { + syncCommitteeIndices.push(candidateIndex); + } + + i += 1; + } + } else { + // pre-electra, keep the same naive version + const MAX_RANDOM_BYTE = 2 ** 8 - 1; + const MAX_EFFECTIVE_BALANCE_INCREMENT = MAX_EFFECTIVE_BALANCE / EFFECTIVE_BALANCE_INCREMENT; + + const epoch = computeEpochAtSlot(state.slot) + 1; + const activeValidatorCount = activeValidatorIndices.length; + const seed = getSeed(state, epoch, DOMAIN_SYNC_COMMITTEE); + + let i = 0; + while (syncCommitteeIndices.length < SYNC_COMMITTEE_SIZE) { + const shuffledIndex = computeShuffledIndex(i % activeValidatorCount, activeValidatorCount, seed); + const candidateIndex = activeValidatorIndices[shuffledIndex]; + const randomByte = digest(Buffer.concat([seed, intToBytes(Math.floor(i / 32), 8, "le")]))[i % 32]; + + const effectiveBalanceIncrement = effectiveBalanceIncrements[candidateIndex]; + if (effectiveBalanceIncrement * MAX_RANDOM_BYTE >= MAX_EFFECTIVE_BALANCE_INCREMENT * randomByte) { + syncCommitteeIndices.push(candidateIndex); + } + + i += 1; + } + } + + return syncCommitteeIndices; +} + /** * Return the shuffled validator index corresponding to ``seed`` (and ``index_count``). * @@ -168,6 +351,8 @@ export function getNextSyncCommitteeIndices( * https://link.springer.com/content/pdf/10.1007%2F978-3-642-32009-5_1.pdf * * See the 'generalized domain' algorithm on page 3. + * This is the naive implementation just to make sure lodestar follows the spec, this is not for production. + * The optimized version is in `getComputeShuffledIndexFn`. */ export function computeShuffledIndex(index: number, indexCount: number, seed: Bytes32): number { let permuted = index; @@ -188,6 +373,75 @@ export function computeShuffledIndex(index: number, indexCount: number, seed: By return permuted; } +type ComputeShuffledIndexFn = (index: number) => number; + +/** + * An optimized version of `computeShuffledIndex`, this is for production. + */ +export function getComputeShuffledIndexFn(indexCount: number, seed: Bytes32): ComputeShuffledIndexFn { + // there are possibly SHUFFLE_ROUND_COUNT (90 for mainnet) values for this cache + // this cache will always hit after the 1st call + const pivotByIndex: Map = new Map(); + // given 2M active validators, there are 2 M / 256 = 8k possible positionDiv + // it means there are at most 8k different sources for each round + const sourceByPositionDivByIndex: Map> = new Map(); + // 32 bytes seed + 1 byte i + const pivotBuffer = Buffer.alloc(32 + 1); + pivotBuffer.set(seed, 0); + // 32 bytes seed + 1 byte i + 4 bytes positionDiv + const sourceBuffer = Buffer.alloc(32 + 1 + 4); + sourceBuffer.set(seed, 0); + + return (index): number => { + assert.lt(index, indexCount, "indexCount must be less than index"); + assert.lte(indexCount, 2 ** 40, "indexCount too big"); + let permuted = index; + const _seed = seed; + for (let i = 0; i < SHUFFLE_ROUND_COUNT; i++) { + // optimized version of the below naive code + // const pivot = Number( + // bytesToBigInt(digest(Buffer.concat([_seed, intToBytes(i, 1)])).slice(0, 8)) % BigInt(indexCount) + // ); + + let pivot = pivotByIndex.get(i); + if (pivot == null) { + // naive version always creates a new buffer, we can reuse the buffer + // pivot = Number( + // bytesToBigInt(digest(Buffer.concat([_seed, intToBytes(i, 1)])).slice(0, 8)) % BigInt(indexCount) + // ); + pivotBuffer[32] = i % 256; + pivot = Number(bytesToBigInt(digest(pivotBuffer).subarray(0, 8)) % BigInt(indexCount)); + pivotByIndex.set(i, pivot); + } + + const flip = (pivot + indexCount - permuted) % indexCount; + const position = Math.max(permuted, flip); + + // optimized version of the below naive code + // const source = digest(Buffer.concat([_seed, intToBytes(i, 1), intToBytes(Math.floor(position / 256), 4)])); + let sourceByPositionDiv = sourceByPositionDivByIndex.get(i); + if (sourceByPositionDiv == null) { + sourceByPositionDiv = new Map(); + sourceByPositionDivByIndex.set(i, sourceByPositionDiv); + } + const positionDiv256 = Math.floor(position / 256); + let source = sourceByPositionDiv.get(positionDiv256); + if (source == null) { + // naive version always creates a new buffer, we can reuse the buffer + // don't want to go through intToBytes() to avoid BigInt + sourceBuffer[32] = i % 256; + sourceBuffer.writeUint32LE(positionDiv256, 33); + source = digest(sourceBuffer); + sourceByPositionDiv.set(positionDiv256, source); + } + const byte = source[Math.floor((position % 256) / 8)]; + const bit = (byte >> (position % 8)) % 2; + permuted = bit ? flip : permuted; + } + return permuted; + }; +} + /** * Return the randao mix at a recent [[epoch]]. */ diff --git a/packages/state-transition/test/perf/util/seed.test.ts b/packages/state-transition/test/perf/util/seed.test.ts new file mode 100644 index 000000000000..71919ab168eb --- /dev/null +++ b/packages/state-transition/test/perf/util/seed.test.ts @@ -0,0 +1,100 @@ +import {bench, describe} from "@chainsafe/benchmark"; +import {ForkSeq} from "@lodestar/params"; +import {fromHex} from "@lodestar/utils"; +import { + computeProposerIndex, + computeShuffledIndex, + getComputeShuffledIndexFn, + getNextSyncCommitteeIndices, + naiveComputeProposerIndex, + naiveGetNextSyncCommitteeIndices, +} from "../../../src/util/seed.js"; +import {generatePerfTestCachedStateAltair} from "../util.js"; + +// I'm not sure how to populate a good test data for this benchmark +describe("computeProposerIndex", () => { + // it's hard to find a seed that shows differences between naive and optimized version + // this was selected after a couple of time I run and try crytpo.randomBytes() + const seed = fromHex("0x902199936ba358175ec5eca9825fd0d26fc355d5fd4d37d1b10575a29d4bd5a8"); + + const vc = 100_000; + const effectiveBalanceIncrements = new Uint16Array(vc); + for (let i = 0; i < vc; i++) { + // make it the worse case where each validator has 32 ETH effective balance + effectiveBalanceIncrements[i] = 32; + } + + const activeIndices = Array.from({length: vc}, (_, i) => i); + const runsFactor = 100; + bench({ + id: `naive computeProposerIndex ${vc} validators`, + fn: () => { + for (let i = 0; i < runsFactor; i++) { + naiveComputeProposerIndex(ForkSeq.electra, effectiveBalanceIncrements, activeIndices, seed); + } + }, + runsFactor, + }); + + bench({ + id: `computeProposerIndex ${vc} validators`, + fn: () => { + for (let i = 0; i < runsFactor; i++) { + computeProposerIndex(ForkSeq.electra, effectiveBalanceIncrements, activeIndices, seed); + } + }, + runsFactor, + }); +}); + +describe("getNextSyncCommitteeIndices electra", () => { + for (const vc of [1_000, 10_000, 100_000]) { + const state = generatePerfTestCachedStateAltair({vc, goBackOneSlot: false}); + const activeIndices = Array.from({length: state.validators.length}, (_, i) => i); + const effectiveBalanceIncrements = new Uint16Array(state.validators.length); + for (let i = 0; i < state.validators.length; i++) { + // make it the worse case where each validator has 32 ETH effective balance + effectiveBalanceIncrements[i] = 32; + } + + bench({ + id: `naiveGetNextSyncCommitteeIndices ${vc} validators`, + fn: () => { + naiveGetNextSyncCommitteeIndices(ForkSeq.electra, state, activeIndices, effectiveBalanceIncrements); + }, + }); + + bench({ + id: `getNextSyncCommitteeIndices ${vc} validators`, + fn: () => { + getNextSyncCommitteeIndices(ForkSeq.electra, state, activeIndices, effectiveBalanceIncrements); + }, + }); + } +}); + +describe("computeShuffledIndex", () => { + const seed = new Uint8Array(Array.from({length: 32}, (_, i) => i)); + + for (const vc of [100_000, 2_000_000]) { + bench({ + id: `naive computeShuffledIndex ${vc} validators`, + fn: () => { + for (let i = 0; i < vc; i++) { + computeShuffledIndex(i, vc, seed); + } + }, + }); + + const shuffledIndexFn = getComputeShuffledIndexFn(vc, seed); + + bench({ + id: `cached computeShuffledIndex ${vc} validators`, + fn: () => { + for (let i = 0; i < vc; i++) { + shuffledIndexFn(i); + } + }, + }); + } +}); diff --git a/packages/state-transition/test/unit/util/seed.test.ts b/packages/state-transition/test/unit/util/seed.test.ts index e0f9e5d8ae67..baa8a764ae90 100644 --- a/packages/state-transition/test/unit/util/seed.test.ts +++ b/packages/state-transition/test/unit/util/seed.test.ts @@ -1,10 +1,21 @@ +import crypto from "node:crypto"; import {describe, expect, it} from "vitest"; import {toHexString} from "@chainsafe/ssz"; -import {GENESIS_EPOCH, GENESIS_SLOT, SLOTS_PER_EPOCH} from "@lodestar/params"; -import {getRandaoMix} from "../../../src/util/index.js"; +import {ForkSeq, GENESIS_EPOCH, GENESIS_SLOT, SLOTS_PER_EPOCH} from "@lodestar/params"; +import { + computeProposerIndex, + computeShuffledIndex, + getComputeShuffledIndexFn, + getNextSyncCommitteeIndices, + getRandaoMix, + naiveComputeProposerIndex, + naiveGetNextSyncCommitteeIndices, +} from "../../../src/util/index.js"; +import {bytesToInt} from "@lodestar/utils"; import {generateState} from "../../utils/state.js"; +import {generateValidators} from "../../utils/validator.js"; describe("getRandaoMix", () => { const randaoMix1 = Buffer.alloc(32, 1); @@ -28,3 +39,71 @@ describe("getRandaoMix", () => { expect(toHexString(res)).toBe(toHexString(randaoMix2)); }); }); + +describe("computeProposerIndex electra", () => { + const seed = crypto.randomBytes(32); + const vc = 1000; + const activeIndices = Array.from({length: vc}, (_, i) => i); + const effectiveBalanceIncrements = new Uint16Array(vc); + for (let i = 0; i < vc; i++) { + effectiveBalanceIncrements[i] = 32 + 32 * (i % 64); + } + + it("should be the same to the naive version", () => { + const expected = naiveComputeProposerIndex(ForkSeq.electra, effectiveBalanceIncrements, activeIndices, seed); + const result = computeProposerIndex(ForkSeq.electra, effectiveBalanceIncrements, activeIndices, seed); + expect(result).toBe(expected); + }); +}); + +describe("computeShuffledIndex", () => { + const seed = crypto.randomBytes(32); + const vc = 1000; + const shuffledIndexFn = getComputeShuffledIndexFn(vc, seed); + it("should be the same to the naive version", () => { + for (let i = 0; i < vc; i++) { + const expectedIndex = computeShuffledIndex(i, vc, seed); + expect(shuffledIndexFn(i)).toBe(expectedIndex); + } + }); +}); + +describe("electra getNextSyncCommitteeIndices", () => { + const vc = 1000; + const validators = generateValidators(vc); + const state = generateState({validators}); + const activeValidatorIndices = Array.from({length: vc}, (_, i) => i); + const effectiveBalanceIncrements = new Uint16Array(vc); + for (let i = 0; i < vc; i++) { + effectiveBalanceIncrements[i] = 32 + 32 * (i % 64); + } + + it("should return the same result as naiveGetNextSyncCommitteeIndices", () => { + const expected = naiveGetNextSyncCommitteeIndices( + ForkSeq.electra, + state, + activeValidatorIndices, + effectiveBalanceIncrements + ); + const result = getNextSyncCommitteeIndices( + ForkSeq.electra, + state, + activeValidatorIndices, + effectiveBalanceIncrements + ); + expect(result).toEqual(expected); + }); +}); + +describe("number from 2 bytes bytesToInt", () => { + it("should compute numbers manually from 2 bytes", () => { + // this is to be used in getNextSyncCommitteeIndices without getting through BigInt + for (let lowByte = 0; lowByte < 256; lowByte++) { + for (let highByte = 0; highByte < 256; highByte++) { + const bytes = new Uint8Array([lowByte, highByte]); + const n = lowByte + highByte * 256; + expect(n).toBe(bytesToInt(bytes)); + } + } + }); +});