diff --git a/constantine/math/pairings/gt_multiexp.nim b/constantine/math/pairings/gt_multiexp.nim index 635e77708..2fc5ac3f7 100644 --- a/constantine/math/pairings/gt_multiexp.nim +++ b/constantine/math/pairings/gt_multiexp.nim @@ -460,6 +460,7 @@ template withTorus[exponentsBits: static int, GT]( var r_torus {.noInit.}: T2Prj[F] multiExpProc(r_torus, elemsTorus, expos, len, c) r.fromTorus2_vartime(r_torus) + freeHeap(elemsTorus) # Combined accel # ----------------------------------------------------------------------------------------------------------------------- diff --git a/constantine/math/pairings/gt_multiexp_parallel.nim b/constantine/math/pairings/gt_multiexp_parallel.nim index 9203f9bfa..c4e7a9a5f 100644 --- a/constantine/math/pairings/gt_multiexp_parallel.nim +++ b/constantine/math/pairings/gt_multiexp_parallel.nim @@ -12,7 +12,7 @@ import constantine/named/algebras, constantine/math/arithmetic, constantine/named/zoo_endomorphisms, constantine/platforms/abstractions, - ./cyclotomic_subgroups, + ./cyclotomic_subgroups, ./gt_prj, constantine/threadpool import ./gt_multiexp {.all.} @@ -27,21 +27,21 @@ import ./gt_multiexp {.all.} # # # ########################################################### # -proc bucketAccumReduce_withInit[bits: static int, GT]( - windowProd: ptr GT, - buckets: ptr GT or ptr UncheckedArray[GT], +proc bucketAccumReduce_withInit[bits: static int, GtAcc, GtElt]( + windowProd: ptr GtAcc, + buckets: ptr GtAcc or ptr UncheckedArray[GtAcc], bitIndex: int, miniMultiExpKind: static MiniMultiExpKind, c: static int, - elems: ptr UncheckedArray[GT], expos: ptr UncheckedArray[BigInt[bits]], N: int) = + elems: ptr UncheckedArray[GtElt], expos: ptr UncheckedArray[BigInt[bits]], N: int) = const numBuckets = 1 shl (c-1) - let buckets = cast[ptr UncheckedArray[GT]](buckets) + let buckets = cast[ptr UncheckedArray[GtAcc]](buckets) for i in 0 ..< numBuckets: buckets[i].setNeutral() bucketAccumReduce(windowProd[], buckets, bitIndex, miniMultiExpKind, c, elems, expos, N) -proc multiexpImpl_vartime_parallel[bits: static int, GT]( +proc multiexpImpl_vartime_parallel[bits: static int, GtAcc, GtElt]( tp: Threadpool, - r: ptr GT, - elems: ptr UncheckedArray[GT], expos: ptr UncheckedArray[BigInt[bits]], + r: ptr GtAcc, + elems: ptr UncheckedArray[GtElt], expos: ptr UncheckedArray[BigInt[bits]], N: int, c: static int) = # Prologue @@ -53,10 +53,10 @@ proc multiexpImpl_vartime_parallel[bits: static int, GT]( # Instead of storing the result in futures, risking them being scattered in memory # we store them in a contiguous array, and the synchronizing future just returns a bool. # top window is done on this thread - let miniMultiExpsResults = allocHeapArray(GT, numFullWindows) + let miniMultiExpsResults = allocHeapArray(GtAcc, numFullWindows) let miniMultiExpsReady = allocStackArray(FlowVar[bool], numFullWindows) - let bucketsMatrix = allocHeapArray(GT, numBuckets*numWindows) + let bucketsMatrix = allocHeapArray(GtAcc, numBuckets*numWindows) # Algorithm # --------- @@ -78,32 +78,22 @@ proc multiexpImpl_vartime_parallel[bits: static int, GT]( # Last window is done sync on this thread, directly initializing r const excess = bits mod c const top = bits-excess - - when top != 0: - when excess != 0: - bucketAccumReduce_withInit( - r, - bucketsMatrix[numFullWindows*numBuckets].addr, - bitIndex = top, kTopWindow, c, - elems, expos, N) - else: - r[].setNeutral() - - # 3. Final reduction, r initialized to what would be miniMSMsReady[numWindows-1] - when excess != 0: - for w in countdown(numWindows-2, 0): - for _ in 0 ..< c: - r[].cyclotomic_square() - discard sync miniMultiExpsReady[w] - r[] ~*= miniMultiExpsResults[w] - elif numWindows >= 2: - discard sync miniMultiExpsReady[numWindows-2] - r[] = miniMultiExpsResults[numWindows-2] - for w in countdown(numWindows-3, 0): - for _ in 0 ..< c: - r[].cyclotomic_square() - discard sync miniMultiExpsReady[w] - r[] ~*= miniMultiExpsResults[w] + const msmKind = if top == 0: kBottomWindow + elif excess == 0: kFullWindow + else: kTopWindow + + bucketAccumReduce_withInit( + r, + bucketsMatrix[numFullWindows*numBuckets].addr, + bitIndex = top, msmKind, c, + elems, expos, N) + + # 3. Final reduction + for w in countdown(numFullWindows-1, 0): + for _ in 0 ..< c: + r[].cyclotomic_square() + discard sync miniMultiExpsReady[w] + r[] ~*= miniMultiExpsResults[w] # Cleanup # ------- @@ -170,6 +160,110 @@ template withEndo[exponentsBits: static int, GT]( else: multiExpProc(tp, r, elems, expos, N, c) +# Torus acceleration +# ----------------------------------------------------------------------------------------------------------------------- + +template withTorus[exponentsBits: static int, GT]( + multiExpProc: untyped, + tp: Threadpool, + r: ptr GT, + elems: ptr UncheckedArray[GT], + expos: ptr UncheckedArray[BigInt[exponentsBits]], + len: int, c: static int) = + static: doAssert Gt is QuadraticExt, "GT was: " & $Gt + type F = typeof(elems[0].c0) + var elemsTorus = allocHeapArrayAligned(T2Aff[F], len, alignment = 64) + # TODO: macro symbol resolution bug + # syncScope: + # tp.parallelFor i in 0 ..< N: + # captures: {elems, elemsTorus} + # # TODO: Parallel batch conversion + # elemsTorus.fromGT_vartime(elems[i]) + elemsTorus.toOpenArray(0, len-1).batchFromGT_vartime( + elems.toOpenArray(0, len-1) + ) + var r_torus {.noInit.}: T2Prj[F] + multiExpProc(tp, r_torus.addr, elemsTorus, expos, len, c) + r[].fromTorus2_vartime(r_torus) + freeHeap(elemsTorus) + +# Combined accel +# ----------------------------------------------------------------------------------------------------------------------- + +# Endomorphism acceleration on a torus can be implemented through either of the following approaches: +# - First convert to Torus then apply endomorphism acceleration +# - or apply endomorphism acceleration then convert to Torus +# +# The first approach minimizes memory as we use a compressed torus representation and is easier to compose (no withEndoTorus) +# the second approach reuses Constantine's Frobenius implementation. +# It's unsure which one is more efficient, but difference is dwarfed by the rest of the compute. + +proc applyEndoTorus_parallel[bits: static int, GT]( + tp: Threadpool, + elems: ptr UncheckedArray[GT], + expos: ptr UncheckedArray[BigInt[bits]], + N: int): auto = + ## Decompose (elems, expos) into mini-scalars + ## and apply Torus conversion + ## Returns a new triplet (endoTorusElems, endoTorusExpos, N) + ## endoTorusElems and endoTorusExpos MUST be freed afterwards + + const M = when Gt.Name.getEmbeddingDegree() == 6: 2 + elif Gt.Name.getEmbeddingDegree() == 12: 4 + else: {.error: "Unconfigured".} + + const L = Fr[Gt.Name].bits().computeEndoRecodedLength(M) + let splitExpos = allocHeapArray(array[M, BigInt[L]], N) + let endoBasis = allocHeapArray(array[M, GT], N) + + type F = typeof(elems[0].c0) + let endoTorusBasis = allocHeapArray(array[M, T2Aff[F]], N) + + syncScope: + tp.parallelFor i in 0 ..< N: + captures: {elems, expos, splitExpos, endoBasis, endoTorusBasis} + + var negateElems {.noinit.}: array[M, SecretBool] + splitExpos[i].decomposeEndo(negateElems, expos[i], Fr[Gt.Name].bits(), Gt.Name, G2) # š”¾ā‚œ has same decomposition as š”¾ā‚‚ + if negateElems[0].bool: + endoBasis[i][0].cyclotomic_inv(elems[i]) + else: + endoBasis[i][0] = elems[i] + + cast[ptr array[M-1, GT]](endoBasis[i][1].addr)[].computeEndomorphisms(elems[i]) + for m in 1 ..< M: + if negateElems[m].bool: + endoBasis[i][m].cyclotomic_inv() + + # TODO: we batch-torus convert M by M + # but we could parallel batch convert over the whole range + endoTorusBasis[i].batchFromGT_vartime(endoBasis[i]) + + let endoTorusElems = cast[ptr UncheckedArray[GT]](endoTorusBasis) + let endoExpos = cast[ptr UncheckedArray[BigInt[L]]](splitExpos) + freeHeapAligned(endoBasis) + + return (endoTorusElems, endoExpos, M*N) + +template withEndoTorus[exponentsBits: static int, GT]( + multiExpProc: untyped, + tp: Threadpool, + r: ptr GT, + elems: ptr UncheckedArray[GT], + expos: ptr UncheckedArray[BigInt[exponentsBits]], + N: int, c: static int) = + when Gt.Name.hasEndomorphismAcceleration() and + EndomorphismThreshold <= exponentsBits and + exponentsBits <= Fr[Gt.Name].bits(): + let (endoTorusElems, endoExpos, endoN) = applyEndoTorus_parallel(tp, elems, expos, N) + # Given that bits and N changed, we are able to use a bigger `c` + # TODO: bench + multiExpProc(tp, r, endoTorusElems, endoExpos, endoN, c) + freeHeap(endoTorusElems) + freeHeap(endoExpos) + else: + withTorus(multiExpProc, r, elems, expos, N, c) + # Algorithm selection # ----------------------------------------------------------------------------------------------------------------------- @@ -177,7 +271,8 @@ proc multiexp_dispatch_vartime_parallel[bits: static int, GT]( tp: Threadpool, r: ptr GT, elems: ptr UncheckedArray[GT], - expos: ptr UncheckedArray[BigInt[bits]], N: int) = + expos: ptr UncheckedArray[BigInt[bits]], N: int, + useTorus: static bool) = ## Multiexponentiation: ## r <- gā‚€^aā‚€ + gā‚^aā‚ + ... + gā‚™^aā‚™ let c = bestBucketBitSize(N, bits, useSignedBuckets = true, useManualTuning = true) @@ -186,53 +281,77 @@ proc multiexp_dispatch_vartime_parallel[bits: static int, GT]( # we are able to use a bigger `c` # TODO: benchmark - case c - of 2: withEndo(multiExpImpl_vartime_parallel, tp, r, elems, expos, N, c = 2) - of 3: withEndo(multiExpImpl_vartime_parallel, tp, r, elems, expos, N, c = 3) - of 4: withEndo(multiExpImpl_vartime_parallel, tp, r, elems, expos, N, c = 4) - of 5: withEndo(multiExpImpl_vartime_parallel, tp, r, elems, expos, N, c = 5) - of 6: withEndo(multiExpImpl_vartime_parallel, tp, r, elems, expos, N, c = 6) - of 7: withEndo(multiExpImpl_vartime_parallel, tp, r, elems, expos, N, c = 7) - of 8: withEndo(multiExpImpl_vartime_parallel, tp, r, elems, expos, N, c = 8) - of 9: withEndo(multiExpImpl_vartime_parallel, tp, r, elems, expos, N, c = 9) - of 10: withEndo(multiExpImpl_vartime_parallel, tp, r, elems, expos, N, c = 10) - of 11: withEndo(multiExpImpl_vartime_parallel, tp, r, elems, expos, N, c = 11) - of 12: withEndo(multiExpImpl_vartime_parallel, tp, r, elems, expos, N, c = 12) - of 13: withEndo(multiExpImpl_vartime_parallel, tp, r, elems, expos, N, c = 13) - of 14: multiExpImpl_vartime_parallel(tp, r, elems, expos, N, c = 14) - of 15: multiExpImpl_vartime_parallel(tp, r, elems, expos, N, c = 15) - - of 16..17: multiExpImpl_vartime_parallel(tp, r, elems, expos, N, c = 16) + when useTorus: + case c + of 2: withTorus(multiExpImpl_vartime_parallel, tp, r, elems, expos, N, c = 2) + of 3: withTorus(multiExpImpl_vartime_parallel, tp, r, elems, expos, N, c = 3) + of 4: withTorus(multiExpImpl_vartime_parallel, tp, r, elems, expos, N, c = 4) + of 5: withTorus(multiExpImpl_vartime_parallel, tp, r, elems, expos, N, c = 5) + of 6: withTorus(multiExpImpl_vartime_parallel, tp, r, elems, expos, N, c = 6) + of 7: withTorus(multiExpImpl_vartime_parallel, tp, r, elems, expos, N, c = 7) + of 8: withTorus(multiExpImpl_vartime_parallel, tp, r, elems, expos, N, c = 8) + of 9: withTorus(multiExpImpl_vartime_parallel, tp, r, elems, expos, N, c = 9) + of 10: withTorus(multiExpImpl_vartime_parallel, tp, r, elems, expos, N, c = 10) + of 11: withTorus(multiExpImpl_vartime_parallel, tp, r, elems, expos, N, c = 11) + of 12: withTorus(multiExpImpl_vartime_parallel, tp, r, elems, expos, N, c = 12) + of 13: withTorus(multiExpImpl_vartime_parallel, tp, r, elems, expos, N, c = 13) + of 14: withTorus(multiExpImpl_vartime_parallel, tp, r, elems, expos, N, c = 14) + of 15: withTorus(multiExpImpl_vartime_parallel, tp, r, elems, expos, N, c = 15) + + of 16..17: withTorus(multiExpImpl_vartime_parallel, tp, r, elems, expos, N, c = 16) + else: + unreachable() else: - unreachable() + case c + of 2: withEndo(multiExpImpl_vartime_parallel, tp, r, elems, expos, N, c = 2) + of 3: withEndo(multiExpImpl_vartime_parallel, tp, r, elems, expos, N, c = 3) + of 4: withEndo(multiExpImpl_vartime_parallel, tp, r, elems, expos, N, c = 4) + of 5: withEndo(multiExpImpl_vartime_parallel, tp, r, elems, expos, N, c = 5) + of 6: withEndo(multiExpImpl_vartime_parallel, tp, r, elems, expos, N, c = 6) + of 7: withEndo(multiExpImpl_vartime_parallel, tp, r, elems, expos, N, c = 7) + of 8: withEndo(multiExpImpl_vartime_parallel, tp, r, elems, expos, N, c = 8) + of 9: withEndo(multiExpImpl_vartime_parallel, tp, r, elems, expos, N, c = 9) + of 10: withEndo(multiExpImpl_vartime_parallel, tp, r, elems, expos, N, c = 10) + of 11: withEndo(multiExpImpl_vartime_parallel, tp, r, elems, expos, N, c = 11) + of 12: withEndo(multiExpImpl_vartime_parallel, tp, r, elems, expos, N, c = 12) + of 13: withEndo(multiExpImpl_vartime_parallel, tp, r, elems, expos, N, c = 13) + of 14: multiExpImpl_vartime_parallel(tp, r, elems, expos, N, c = 14) + of 15: multiExpImpl_vartime_parallel(tp, r, elems, expos, N, c = 15) + + of 16..17: multiExpImpl_vartime_parallel(tp, r, elems, expos, N, c = 16) + else: + unreachable() proc multiExp_vartime_parallel*[bits: static int, GT]( tp: Threadpool, r: ptr GT, elems: ptr UncheckedArray[GT], expos: ptr UncheckedArray[BigInt[bits]], - len: int) {.meter, inline.} = + len: int, + useTorus: static bool = false) {.meter, inline.} = ## Multiexponentiation: ## r <- gā‚€^aā‚€ + gā‚^aā‚ + ... + gā‚™^aā‚™ - tp.multiExp_dispatch_vartime_parallel(r, elems, expos, len) + tp.multiExp_dispatch_vartime_parallel(r, elems, expos, len, useTorus) proc multiExp_vartime_parallel*[bits: static int, GT]( tp: Threadpool, r: var GT, elems: openArray[GT], - expos: openArray[BigInt[bits]]) {.meter, inline.} = + expos: openArray[BigInt[bits]], + useTorus: static bool = false) {.meter, inline.} = ## Multiexponentiation: ## r <- gā‚€^aā‚€ + gā‚^aā‚ + ... + gā‚™^aā‚™ debug: doAssert elems.len == expos.len let N = elems.len - tp.multiExp_dispatch_vartime_parallel(r.addr, elems.asUnchecked(), expos.asUnchecked(), N) + tp.multiExp_dispatch_vartime_parallel(r.addr, elems.asUnchecked(), expos.asUnchecked(), N, useTorus) proc multiExp_vartime_parallel*[F, GT]( tp: Threadpool, r: ptr GT, elems: ptr UncheckedArray[GT], expos: ptr UncheckedArray[F], - len: int) {.meter.} = + len: int, + useTorus: static bool = false) {.meter.} = ## Multiexponentiation: ## r <- gā‚€^aā‚€ + gā‚^aā‚ + ... + gā‚™^aā‚™ let n = cast[int](len) @@ -242,7 +361,7 @@ proc multiExp_vartime_parallel*[F, GT]( tp.parallelFor i in 0 ..< n: captures: {expos, expos_big} expos_big[i].fromField(expos[i]) - tp.multiExp_vartime_parallel(r, elems, expos_big, n) + tp.multiExp_vartime_parallel(r, elems, expos_big, n, useTorus) freeHeapAligned(expos_big) @@ -250,9 +369,10 @@ proc multiExp_vartime_parallel*[GT]( tp: Threadpool, r: var GT, elems: openArray[GT], - expos: openArray[Fr]) {.meter, inline.} = + expos: openArray[Fr], + useTorus: static bool = false) {.meter, inline.} = ## Multiexponentiation: ## r <- gā‚€^aā‚€ + gā‚^aā‚ + ... + gā‚™^aā‚™ debug: doAssert elems.len == expos.len let N = elems.len - tp.multiExp_vartime_parallel(r.addr, elems.asUnchecked(), expos.asUnchecked(), N) + tp.multiExp_vartime_parallel(r.addr, elems.asUnchecked(), expos.asUnchecked(), N, useTorus) diff --git a/constantine/math/pairings/gt_prj.nim b/constantine/math/pairings/gt_prj.nim index 338a9de50..c1d8cc290 100644 --- a/constantine/math/pairings/gt_prj.nim +++ b/constantine/math/pairings/gt_prj.nim @@ -464,6 +464,7 @@ proc batchFromGT_vartime*[F](dst: var openArray[T2Aff[F]], ## so this is about a ~25% speedup # TODO: handle neutral element + # TODO: Parallel batch inversion debug: doAssert dst.len == src.len diff --git a/tests/math_pairings/t_pairing_bls12_381_gt_multiexp.nim b/tests/math_pairings/t_pairing_bls12_381_gt_multiexp.nim index 4379547a4..6eb371fa8 100644 --- a/tests/math_pairings/t_pairing_bls12_381_gt_multiexp.nim +++ b/tests/math_pairings/t_pairing_bls12_381_gt_multiexp.nim @@ -10,7 +10,7 @@ import # Test utilities ./t_pairing_template -const numPoints = [1, 2, 8, 16, 128, 256, 1024] +const numPoints = [1, 2, 3, 4, 5, 6, 7, 8, 16, 128, 256, 1024] runGTmultiexpTests( # Torus-based cryptography requires quadratic extension diff --git a/tests/parallel/t_pairing_bls12_381_gt_multiexp_parallel.nim b/tests/parallel/t_pairing_bls12_381_gt_multiexp_parallel.nim index 98fd820c4..af84faa01 100644 --- a/tests/parallel/t_pairing_bls12_381_gt_multiexp_parallel.nim +++ b/tests/parallel/t_pairing_bls12_381_gt_multiexp_parallel.nim @@ -10,9 +10,12 @@ import # Test utilities ./t_pairing_template_parallel -const numPoints = [1, 2, 8, 16, 128, 256, 1024] +const numPoints = [1, 2, 3, 4, 5, 6, 7, 8, 16, 128, 256, 1024] runGTmultiexp_parallel_Tests( - GT = Fp12[BLS12_381], + # Torus-based cryptography requires quadratic extension + # but by default cubic extensions are faster + # GT = Fp12[BLS12_381], + GT = QuadraticExt[Fp6[BLS12_381]], numPoints, Iters = 4) diff --git a/tests/parallel/t_pairing_template_parallel.nim b/tests/parallel/t_pairing_template_parallel.nim index cfa6db6a4..aedd7a4cf 100644 --- a/tests/parallel/t_pairing_template_parallel.nim +++ b/tests/parallel/t_pairing_template_parallel.nim @@ -74,10 +74,12 @@ proc runGTmultiexp_parallel_Tests*[N: static int](GT: typedesc, num_points: arra t.gtExp_vartime(elems[i], exponents[i]) naive *= t - var mexp: GT - tp.multiExp_vartime_parallel(mexp, elems, exponents) + var mexp, mexp_torus: GT + tp.multiExp_vartime_parallel(mexp, elems, exponents, useTorus = false) + tp.multiExp_vartime_parallel(mexp_torus, elems, exponents, useTorus = true) doAssert bool(naive == mexp) + doAssert bool(naive == mexp_torus) stdout.write '.'