Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RNG][Device API] Added (u)int8_t and (u)int16_t types support for Uniform #632

Open
wants to merge 9 commits into
base: develop
Choose a base branch
from
20 changes: 16 additions & 4 deletions include/oneapi/math/rng/device/detail/uniform_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,12 @@ class distribution_base<oneapi::math::rng::device::uniform<Type, Method>> {
float>::type;
OutType res;
if constexpr (std::is_integral<Type>::value) {
if constexpr (std::is_same_v<Type, std::int32_t> ||
std::is_same_v<Type, std::uint32_t>) {
if constexpr (std::is_same_v<Type, std::int8_t> || std::is_same_v<Type, std::uint8_t> ||
std::is_same_v<Type, std::int16_t> || std::is_same_v<Type, std::uint16_t>) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest to move this dispatch above to have floating type definition in one place, e.g.

std::conditional_t<
    !std::is_same_v<Method, uniform_method::accurate> ||
     std::is_same_v<Type, std::int8_t> ||
     std::is_same_v<Type, std::uint8_t> ||
     std::is_same_v<Type, std::int16_t> ||
     std::is_same_v<Type, std::uint16_t>,
float, double>

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it looks clearer with the extra condition. Furthermore, the existing condition relates to (u)int32, (u)int64 and floating point precision types. But with the option you suggested we will have less additional code.
@iMartyan, could you please share your opinion here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also like Andrey's approach more because the generation method is the same. So, better not to split it into several branches.

return generate_single_int<float, OutType>(engine);
}
else if constexpr (std::is_same_v<Type, std::int32_t> ||
std::is_same_v<Type, std::uint32_t>) {
return generate_single_int<FpType, OutType>(engine);
}
else {
Expand Down Expand Up @@ -243,8 +247,16 @@ class distribution_base<oneapi::math::rng::device::uniform<Type, Method>> {
float>::type;
Type res;
if constexpr (std::is_integral<Type>::value) {
if constexpr (std::is_same_v<Type, std::int32_t> ||
std::is_same_v<Type, std::uint32_t>) {
if constexpr (std::is_same_v<Type, std::int8_t> || std::is_same_v<Type, std::uint8_t> ||
std::is_same_v<Type, std::int16_t> || std::is_same_v<Type, std::uint16_t>) {
float res_fp =
engine.generate_single(static_cast<float>(a_), static_cast<float>(b_));
res_fp = sycl::floor(res_fp);
res = static_cast<Type>(res_fp);
return res;
}
else if constexpr (std::is_same_v<Type, std::int32_t> ||
std::is_same_v<Type, std::uint32_t>) {
FpType res_fp =
engine.generate_single(static_cast<FpType>(a_), static_cast<FpType>(b_));
res_fp = sycl::floor(res_fp);
Expand Down
32 changes: 25 additions & 7 deletions include/oneapi/math/rng/device/distributions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,14 @@ namespace oneapi::math::rng::device {
// Supported types:
// float
// double
// std::int8_t>
// std::uint8_t>
// std::int16_t>
// std::uint16_t>
// std::int32_t
// std::uint32_t
// std::int64_t
// std::uint64_t
//
// Supported methods:
// oneapi::math::rng::device::uniform_method::standard
Expand All @@ -46,7 +52,8 @@ namespace oneapi::math::rng::device {
// Input arguments:
// a - left bound. 0.0 by default
// b - right bound. 1.0 by default (for std::(u)int32_t std::numeric_limits<std::int32_t>::max()
// is used for accurate method and 2^23 is used for standard method)
// is used for accurate method and 2^23 is used for standard method;
// for std::(u)int8, std::(u)int16, std::(u)int64_t are used std::numeric_limits<one_of_these_three_types>::max())
//
// Note: using (un)signed integer uniform distribution with uniform_method::standard method may
// cause incorrect statistics of the produced random numbers (due to rounding error) if
Expand All @@ -61,6 +68,10 @@ class uniform : detail::distribution_base<uniform<Type, Method>> {
"oneMath: rng/uniform: method is incorrect");

static_assert(std::is_same<Type, float>::value || std::is_same<Type, double>::value ||
std::is_same_v<Type, std::int8_t> ||
std::is_same_v<Type, std::uint8_t> ||
std::is_same_v<Type, std::int16_t> ||
std::is_same_v<Type, std::uint16_t> ||
std::is_same<Type, std::int32_t>::value ||
std::is_same<Type, std::uint32_t>::value ||
std::is_same<Type, std::int64_t>::value ||
Expand All @@ -75,12 +86,15 @@ class uniform : detail::distribution_base<uniform<Type, Method>> {
: detail::distribution_base<uniform<Type, Method>>(
Type(0.0),
std::is_integral<Type>::value
? ((std::is_same_v<Type, std::uint64_t> || std::is_same_v<Type, std::int64_t>)
? (std::numeric_limits<Type>::max)()
: (std::is_same<Method, uniform_method::standard>::value
? (1 << 23)
: (std::numeric_limits<Type>::max)()))
: Type(1.0)) {}
? ((std::is_same_v<Type, std::uint64_t> || std::is_same_v<Type, std::int64_t>)
? (std::numeric_limits<Type>::max)()
: (std::is_same_v<Type, std::int8_t> || std::is_same_v<Type, std::uint8_t> ||
std::is_same_v<Type, std::int16_t> || std::is_same_v<Type, std::uint16_t>)
? (std::numeric_limits<Type>::max)()
: (std::is_same<Method, uniform_method::standard>::value
? (1 << 23)
: (std::numeric_limits<Type>::max)()))
: Type(1.0)) {}

explicit uniform(Type a, Type b) : detail::distribution_base<uniform<Type, Method>>(a, b) {}
explicit uniform(const param_type& pt)
Expand Down Expand Up @@ -581,6 +595,10 @@ class poisson : detail::distribution_base<poisson<IntType, Method>> {
// Supported types:
// std::uint32_t
// std::int32_t
// std::uint16_t
// std::int16_t
// std::uint8_t
// std::int8_t
//
// Supported methods:
// oneapi::math::rng::bernoulli_method::icdf;
Expand Down
76 changes: 76 additions & 0 deletions tests/unit_tests/rng/device/include/rng_device_test_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,82 @@ struct statistics_device<oneapi::math::rng::device::uniform<Fp, Method>> {
}
};

template <typename Method>
struct statistics_device<oneapi::math::rng::device::uniform<std::int8_t, Method>> {
template <typename AllocType>
bool check(const std::vector<std::int8_t, AllocType>& r,
const oneapi::math::rng::device::uniform<std::int8_t, Method>& distr) {
double tM, tD, tQ;
float a = distr.a();
float b = distr.b();

// Theoretical moments
tM = (a + b - 1.0) / 2.0;
tD = ((b - a) * (b - a) - 1.0) / 12.0;
tQ = (((b - a) * (b - a)) * ((1.0 / 80.0) * (b - a) * (b - a) - (1.0 / 24.0))) +
(7.0 / 240.0);

return compare_moments(r, tM, tD, tQ);
}
};

template <typename Method>
struct statistics_device<oneapi::math::rng::device::uniform<std::uint8_t, Method>> {
template <typename AllocType>
bool check(const std::vector<std::uint8_t, AllocType>& r,
const oneapi::math::rng::device::uniform<std::uint8_t, Method>& distr) {
double tM, tD, tQ;
float a = distr.a();
float b = distr.b();

// Theoretical moments
tM = (a + b - 1.0) / 2.0;
tD = ((b - a) * (b - a) - 1.0) / 12.0;
tQ = (((b - a) * (b - a)) * ((1.0 / 80.0) * (b - a) * (b - a) - (1.0 / 24.0))) +
(7.0 / 240.0);

return compare_moments(r, tM, tD, tQ);
}
};

template <typename Method>
struct statistics_device<oneapi::math::rng::device::uniform<std::int16_t, Method>> {
template <typename AllocType>
bool check(const std::vector<std::int16_t, AllocType>& r,
const oneapi::math::rng::device::uniform<std::int16_t, Method>& distr) {
double tM, tD, tQ;
float a = distr.a();
float b = distr.b();

// Theoretical moments
tM = (a + b - 1.0) / 2.0;
tD = ((b - a) * (b - a) - 1.0) / 12.0;
tQ = (((b - a) * (b - a)) * ((1.0 / 80.0) * (b - a) * (b - a) - (1.0 / 24.0))) +
(7.0 / 240.0);

return compare_moments(r, tM, tD, tQ);
}
};

template <typename Method>
struct statistics_device<oneapi::math::rng::device::uniform<std::uint16_t, Method>> {
template <typename AllocType>
bool check(const std::vector<std::uint16_t, AllocType>& r,
const oneapi::math::rng::device::uniform<std::uint16_t, Method>& distr) {
double tM, tD, tQ;
float a = distr.a();
float b = distr.b();

// Theoretical moments
tM = (a + b - 1.0) / 2.0;
tD = ((b - a) * (b - a) - 1.0) / 12.0;
tQ = (((b - a) * (b - a)) * ((1.0 / 80.0) * (b - a) * (b - a) - (1.0 / 24.0))) +
(7.0 / 240.0);

return compare_moments(r, tM, tD, tQ);
}
};

template <typename Method>
struct statistics_device<oneapi::math::rng::device::uniform<std::int32_t, Method>> {
template <typename AllocType>
Expand Down
86 changes: 86 additions & 0 deletions tests/unit_tests/rng/device/moments/moments.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,92 @@ TEST_P(Philox4x32x10UniformStdDeviceMomentsTests, RealDoublePrecision) {
EXPECT_TRUEORSKIP((test3(GetParam())));
}

/* Test small types (u)int8, (u)int16 only with uniform_method::standard since numbers are always generated
as single precision numbers */
TEST_P(Philox4x32x10UniformStdDeviceMomentsTests, Integer8Precision) {
rng_device_test<
moments_test<oneapi::math::rng::device::philox4x32x10<1>,
oneapi::math::rng::device::uniform<
std::int8_t, oneapi::math::rng::device::uniform_method::standard>>>
test1;
EXPECT_TRUEORSKIP((test1(GetParam())));
rng_device_test<
moments_test<oneapi::math::rng::device::philox4x32x10<4>,
oneapi::math::rng::device::uniform<
std::int8_t, oneapi::math::rng::device::uniform_method::standard>>>
test2;
EXPECT_TRUEORSKIP((test2(GetParam())));
rng_device_test<
moments_test<oneapi::math::rng::device::philox4x32x10<16>,
oneapi::math::rng::device::uniform<
std::int8_t, oneapi::math::rng::device::uniform_method::standard>>>
test3;
EXPECT_TRUEORSKIP((test3(GetParam())));
}

TEST_P(Philox4x32x10UniformStdDeviceMomentsTests, UnsignedInteger8Precision) {
rng_device_test<
moments_test<oneapi::math::rng::device::philox4x32x10<1>,
oneapi::math::rng::device::uniform<
std::uint8_t, oneapi::math::rng::device::uniform_method::standard>>>
test1;
EXPECT_TRUEORSKIP((test1(GetParam())));
rng_device_test<
moments_test<oneapi::math::rng::device::philox4x32x10<4>,
oneapi::math::rng::device::uniform<
std::uint8_t, oneapi::math::rng::device::uniform_method::standard>>>
test2;
EXPECT_TRUEORSKIP((test2(GetParam())));
rng_device_test<
moments_test<oneapi::math::rng::device::philox4x32x10<16>,
oneapi::math::rng::device::uniform<
std::uint8_t, oneapi::math::rng::device::uniform_method::standard>>>
test3;
EXPECT_TRUEORSKIP((test3(GetParam())));
}

TEST_P(Philox4x32x10UniformStdDeviceMomentsTests, Integer16Precision) {
rng_device_test<
moments_test<oneapi::math::rng::device::philox4x32x10<1>,
oneapi::math::rng::device::uniform<
std::int16_t, oneapi::math::rng::device::uniform_method::standard>>>
test1;
EXPECT_TRUEORSKIP((test1(GetParam())));
rng_device_test<
moments_test<oneapi::math::rng::device::philox4x32x10<4>,
oneapi::math::rng::device::uniform<
std::int16_t, oneapi::math::rng::device::uniform_method::standard>>>
test2;
EXPECT_TRUEORSKIP((test2(GetParam())));
rng_device_test<
moments_test<oneapi::math::rng::device::philox4x32x10<16>,
oneapi::math::rng::device::uniform<
std::int16_t, oneapi::math::rng::device::uniform_method::standard>>>
test3;
EXPECT_TRUEORSKIP((test3(GetParam())));
}

TEST_P(Philox4x32x10UniformStdDeviceMomentsTests, UnsignedInteger16Precision) {
rng_device_test<
moments_test<oneapi::math::rng::device::philox4x32x10<1>,
oneapi::math::rng::device::uniform<
std::uint16_t, oneapi::math::rng::device::uniform_method::standard>>>
test1;
EXPECT_TRUEORSKIP((test1(GetParam())));
rng_device_test<
moments_test<oneapi::math::rng::device::philox4x32x10<4>,
oneapi::math::rng::device::uniform<
std::uint16_t, oneapi::math::rng::device::uniform_method::standard>>>
test2;
EXPECT_TRUEORSKIP((test2(GetParam())));
rng_device_test<
moments_test<oneapi::math::rng::device::philox4x32x10<16>,
oneapi::math::rng::device::uniform<
std::uint16_t, oneapi::math::rng::device::uniform_method::standard>>>
test3;
EXPECT_TRUEORSKIP((test3(GetParam())));
}

TEST_P(Philox4x32x10UniformStdDeviceMomentsTests, IntegerPrecision) {
rng_device_test<
moments_test<oneapi::math::rng::device::philox4x32x10<1>,
Expand Down
Loading