Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make ntt.rs shorter #7

Merged
merged 2 commits into from
Sep 21, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 9 additions & 15 deletions src/biguint/ntt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -237,12 +237,7 @@ impl<const P: u64, const INV: bool> NttKernelImpl<P, INV> {
const U4: u64 = Arith::<P>::mpowmod(Self::ROOTR, Arith::<P>::MAX_NTT_LEN/4);
const U5: u64 = Arith::<P>::mpowmod(Self::ROOTR, Arith::<P>::MAX_NTT_LEN/5);
const U6: u64 = Arith::<P>::mpowmod(Self::ROOTR, Arith::<P>::MAX_NTT_LEN/6);
const C51: u64 = Self::c5().0;
const C52: u64 = Self::c5().1;
const C53: u64 = Self::c5().2;
const C54: u64 = Self::c5().3;
const C55: u64 = Self::c5().4;
const fn c5() -> (u64, u64, u64, u64, u64) {
const C5: (u64, u64, u64, u64, u64, u64) = {
let w = Self::U5;
let w2 = Arith::<P>::mpowmod(w, 2);
let w4 = Arith::<P>::mpowmod(w, 4);
Expand All @@ -253,8 +248,8 @@ impl<const P: u64, const INV: bool> NttKernelImpl<P, INV> {
let c53 = Arith::<P>::mmulmod(inv2, Arith::<P>::submod(w, w4)); // 2^-1 * (w - w^4) mod P
let c54 = Arith::<P>::addmod(Arith::<P>::addmod(w, w2), inv2); // 2^-1 * (2*w + 2*w^2 + 1) mod P
let c55 = Arith::<P>::addmod(Arith::<P>::addmod(w2, w4), inv2); // 2^-1 * (2*w^2 + 2*w^4 + 1) mod P
(c51, c52, c53, c54, c55)
}
(0, c51, c52, c53, c54, c55)
};
}
const fn ntt2_kernel<const P: u64, const INV: bool, const TWIDDLE: bool>(
w1: u64,
Expand Down Expand Up @@ -348,11 +343,11 @@ const fn ntt5_kernel<const P: u64, const INV: bool, const TWIDDLE: bool>(
let t6 = Arith::<P>::submod(t1, t2);
let t7 = Arith::<P>::addmod64(t3, t4);
let m1 = Arith::<P>::addmod(a, t5);
let m2 = Arith::<P>::mmulsubmod(NttKernelImpl::<P, INV>::C51, t5, m1);
let m3 = Arith::<P>::mmulmod(NttKernelImpl::<P, INV>::C52, t6);
let m4 = Arith::<P>::mmulmod(NttKernelImpl::<P, INV>::C53, t7);
let m5 = Arith::<P>::mmulsubmod(NttKernelImpl::<P, INV>::C54, t4, m4);
let m6 = Arith::<P>::mmulsubmod(P.wrapping_sub(NttKernelImpl::<P, INV>::C55), t3, m4);
let m2 = Arith::<P>::mmulsubmod(NttKernelImpl::<P, INV>::C5.1, t5, m1);
let m3 = Arith::<P>::mmulmod(NttKernelImpl::<P, INV>::C5.2, t6);
let m4 = Arith::<P>::mmulmod(NttKernelImpl::<P, INV>::C5.3, t7);
let m5 = Arith::<P>::mmulsubmod(NttKernelImpl::<P, INV>::C5.4, t4, m4);
let m6 = Arith::<P>::mmulsubmod(P.wrapping_sub(NttKernelImpl::<P, INV>::C5.5), t3, m4);
let s1 = Arith::<P>::submod(m3, m2);
let s2 = Arith::<P>::addmod(m2, m3);
let out0 = m1;
Expand Down Expand Up @@ -547,7 +542,6 @@ const P1: u64 = 14_259_017_916_245_606_401; // Max NTT length = 2^22 * 3^21 * 5^
const P2: u64 = 17_984_575_660_032_000_001; // Max NTT length = 2^19 * 3^17 * 5^6 = 1_057_916_215_296_000_000
const P3: u64 = 17_995_154_822_184_960_001; // Max NTT length = 2^17 * 3^22 * 5^4 = 2_570_736_403_169_280_000

const P2P3: u128 = P2 as u128 * P3 as u128;
const P1INV_R_MOD_P2: u64 = Arith::<P2>::mmulmod(Arith::<P2>::R2, arith::invmod(P1, P2));
const P1P2INV_R_MOD_P3: u64 = Arith::<P3>::mmulmod(Arith::<P3>::R2, arith::invmod((P1 as u128 * P2 as u128 % P3 as u128) as u64, P3));
const P1_R_MOD_P3: u64 = Arith::<P3>::mmulmod(Arith::<P3>::R2, P1);
Expand Down Expand Up @@ -618,7 +612,7 @@ fn mac3_two_primes(acc: &mut [u64], b: &[u64], c: &[u64], bits: u64) {
/* extract the convolution result */
let (a, b) = (x[i], y[i]);
let (mut v, overflow) = (a as u128 * P3 as u128 + carry).overflowing_sub(b as u128 * P2 as u128);
if overflow { v = v.wrapping_add(P2P3); }
if overflow { v = v.wrapping_add(P2 as u128 * P3 as u128); }
carry = v >> bits;

/* write to s */
Expand Down