diff --git a/include/oneapi/math/rng/device/detail/uniform_impl.hpp b/include/oneapi/math/rng/device/detail/uniform_impl.hpp index 5580de75b..09c640df5 100644 --- a/include/oneapi/math/rng/device/detail/uniform_impl.hpp +++ b/include/oneapi/math/rng/device/detail/uniform_impl.hpp @@ -118,17 +118,15 @@ class distribution_base> { using OutType = typename std::conditional>::type; using FpType = - typename std::conditional::value, double, - float>::type; + typename std::conditional || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v, float, double>::type; OutType res; if constexpr (std::is_integral::value) { - if constexpr (std::is_same_v || std::is_same_v || - std::is_same_v || - std::is_same_v) { - return generate_single_int(engine); - } - else if constexpr (std::is_same_v || - std::is_same_v) { + if constexpr (!std::is_same_v && !std::is_same_v) { return generate_single_int(engine); } else { @@ -244,21 +242,15 @@ class distribution_base> { template Type generate_single(EngineType& engine) { using FpType = - typename std::conditional::value, double, - float>::type; + typename std::conditional || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v, float, double>::type; Type res; if constexpr (std::is_integral::value) { - if constexpr (std::is_same_v || std::is_same_v || - std::is_same_v || - std::is_same_v) { - float res_fp = - engine.generate_single(static_cast(a_), static_cast(b_)); - res_fp = sycl::floor(res_fp); - res = static_cast(res_fp); - return res; - } - else if constexpr (std::is_same_v || - std::is_same_v) { + if constexpr (!std::is_same_v && !std::is_same_v) { FpType res_fp = engine.generate_single(static_cast(a_), static_cast(b_)); res_fp = sycl::floor(res_fp);