Skip to content

Commit 0489f79

Browse files
committed
add remainder() function to MontgomeryForm
1 parent f3f98f0 commit 0489f79

File tree

4 files changed

+64
-1
lines changed

4 files changed

+64
-1
lines changed

montgomery_arithmetic/include/hurchalla/montgomery_arithmetic/MontgomeryForm.h

+11
Original file line numberDiff line numberDiff line change
@@ -486,6 +486,17 @@ class MontgomeryForm final {
486486
{
487487
return static_cast<T>(impl.gcd_with_modulus(x, gcd_functor));
488488
}
489+
490+
491+
// Returns a % modulus. A convenience function for better performance.
492+
// If you have already instantiated this MontgomeryForm, then calling
493+
// remainder() should be faster than directly computing a % modulus,
494+
// even if your CPU has extremely fast division (like many new CPUs).
495+
T remainder(T a) const
496+
{
497+
HPBC_PRECONDITION(a >= 0);
498+
return static_cast<T>(impl.remainder(static_cast<U>(a)));
499+
}
489500
};
490501

491502

montgomery_arithmetic/include/hurchalla/montgomery_arithmetic/detail/MontyCommonBase.h

+15
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,21 @@ class MontyCommonBase {
9898
return result;
9999
}
100100

101+
HURCHALLA_FORCE_INLINE T remainder(T a) const
102+
{
103+
HPBC_INVARIANT2(r_mod_n_ < n_);
104+
namespace hc = ::hurchalla;
105+
T u_lo;
106+
T u_hi = hc::unsigned_multiply_to_hilo_product(u_lo, a, r_mod_n_);
107+
// Since a is type T, 0 <= a < R. And since r_mod_n is type T and
108+
// r_mod_n < n, we know 0 <= r_mod_n < n. Therefore,
109+
// 0 <= u == a * r_mod_n < R*n, which will satisfy REDC's precondition.
110+
T result = hc::REDC_standard(u_hi, u_lo, n_, inv_n_, LowlatencyTag());
111+
112+
HPBC_POSTCONDITION2(result < n_);
113+
return result;
114+
}
115+
101116
HURCHALLA_FORCE_INLINE C getUnityValue() const
102117
{
103118
// as noted in constructor, unityValue == (1*R)%n_ == r_mod_n_

montgomery_arithmetic/include/hurchalla/montgomery_arithmetic/detail/MontyWrappedStandardMath.h

+5
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,11 @@ class MontyWrappedStandardMath final {
105105
return ret;
106106
}
107107

108+
HURCHALLA_FORCE_INLINE T remainder(T a) const
109+
{
110+
return static_cast<T>(a % modulus_);
111+
}
112+
108113
HURCHALLA_FORCE_INLINE C getCanonicalValue(V x) const
109114
{
110115
HPBC_PRECONDITION2(isCanonical(x));

test/montgomery_arithmetic/test_MontgomeryForm.h

+33-1
Original file line numberDiff line numberDiff line change
@@ -224,9 +224,29 @@ void test_square_variants(const M& mf, typename M::MontgomeryValue x,
224224
fusedSquareAdd<hc::LowuopsTag>(x,zc)) == answer);
225225
}
226226

227+
template <typename M>
228+
void test_remainder(const M& mf)
229+
{
230+
using T = typename M::IntegerType;
231+
namespace hc = ::hurchalla;
232+
233+
T max = hc::ut_numeric_limits<T>::max();
234+
T mid = static_cast<T>(max/2);
235+
T modulus = mf.getModulus();
236+
237+
EXPECT_TRUE(mf.remainder(0) == (0 % modulus));
238+
EXPECT_TRUE(mf.remainder(1) == (1 % modulus));
239+
EXPECT_TRUE(mf.remainder(2) == (2 % modulus));
240+
EXPECT_TRUE(mf.remainder(static_cast<T>(max-0)) == ((max-0) % modulus));
241+
EXPECT_TRUE(mf.remainder(static_cast<T>(max-1)) == ((max-1) % modulus));
242+
EXPECT_TRUE(mf.remainder(static_cast<T>(max-2)) == ((max-2) % modulus));
243+
EXPECT_TRUE(mf.remainder(static_cast<T>(mid-1)) == ((mid-1) % modulus));
244+
EXPECT_TRUE(mf.remainder(static_cast<T>(mid-0)) == ((mid-0) % modulus));
245+
EXPECT_TRUE(mf.remainder(static_cast<T>(mid+1)) == ((mid+1) % modulus));
246+
}
227247

228248
template <typename M>
229-
void test_mf_general_checks(M& mf, typename M::IntegerType a,
249+
void test_mf_general_checks(const M& mf, typename M::IntegerType a,
230250
typename M::IntegerType b, typename M::IntegerType c)
231251
{
232252
namespace hc = ::hurchalla;
@@ -590,6 +610,18 @@ void test_MontgomeryForm()
590610
M mf(modulus);
591611
EXPECT_TRUE(mf.gcd_with_modulus(mf.convertIn(12), GcdFunctor()) == 3);
592612
}
613+
614+
// test remainder()
615+
{
616+
T max = M::max_modulus();
617+
T mid = static_cast<T>(max/2);
618+
mid = (mid % 2 == 0) ? static_cast<T>(mid + 1) : mid;
619+
test_remainder(M(3)); // smallest possible modulus
620+
test_remainder(M(max)); // largest possible modulus
621+
if (121 <= max)
622+
test_remainder(M(121));
623+
test_remainder(M(mid));
624+
}
593625
}
594626

595627

0 commit comments

Comments
 (0)