diff --git a/src/aead/aes_gcm.rs b/src/aead/aes_gcm.rs index 712cec8399..4c368c6380 100644 --- a/src/aead/aes_gcm.rs +++ b/src/aead/aes_gcm.rs @@ -176,9 +176,9 @@ pub(super) fn seal( unreachable!() } }; - let (whole, remainder) = slice::as_chunks_mut(ramaining); - aes_key.ctr32_encrypt_within(slice::flatten_mut(whole).into(), &mut ctr); - auth.update_blocks(whole); + let (mut whole, remainder) = slice::as_chunks_mut(ramaining); + aes_key.ctr32_encrypt_within(whole.as_flattened_mut().into(), &mut ctr); + auth.update_blocks(whole.into()); let remainder = OverlappingPartialBlock::new(remainder.into()) .unwrap_or_else(|InputTooLongError { .. }| unreachable!()); seal_finish(aes_key, auth, remainder, ctr, tag_iv) @@ -264,11 +264,11 @@ fn seal_strided Result { let mut auth = gcm::Context::new(gcm_key, aad, in_out.len())?; - let (whole, remainder) = slice::as_chunks_mut(in_out); + let (mut whole, remainder) = slice::as_chunks_mut(in_out); - for chunk in whole.chunks_mut(CHUNK_BLOCKS) { - aes_key.ctr32_encrypt_within(slice::flatten_mut(chunk).into(), &mut ctr); - auth.update_blocks(chunk); + for mut chunk in whole.chunks_mut::() { + aes_key.ctr32_encrypt_within(chunk.as_flattened_mut().into(), &mut ctr); + auth.update_blocks(chunk.into()); } let remainder = OverlappingPartialBlock::new(remainder.into()) @@ -361,7 +361,7 @@ pub(super) fn open( let (whole, _) = slice::as_chunks(in_out.input()); auth.update_blocks(whole); - let whole_len = slice::flatten(whole).len(); + let whole_len = whole.as_flattened().len(); // Decrypt any remaining whole blocks. let whole = Overlapping::new(&mut in_out_slice[..(src.start + whole_len)], src.clone()) diff --git a/src/aead/gcm.rs b/src/aead/gcm.rs index 0f2f9d145b..1f28b84ac4 100644 --- a/src/aead/gcm.rs +++ b/src/aead/gcm.rs @@ -17,7 +17,7 @@ use super::{aes_gcm, Aad}; use crate::{ bits::{BitLength, FromByteLen as _}, error::{self, InputTooLongError}, - polyfill::{sliceutil::overwrite_at_start, NotSend}, + polyfill::{slice::AsChunks, sliceutil::overwrite_at_start, NotSend}, }; use cfg_if::cfg_if; @@ -120,7 +120,7 @@ impl Context<'_, clmulavxmovbe::Key> { impl Context<'_, K> { #[inline(always)] - pub fn update_blocks(&mut self, input: &[[u8; BLOCK_LEN]]) { + pub fn update_blocks(&mut self, input: AsChunks) { self.key.update_blocks(&mut self.Xi, input); } } @@ -150,5 +150,5 @@ pub(super) trait Gmult { } pub(super) trait UpdateBlocks { - fn update_blocks(&self, xi: &mut Xi, input: &[[u8; BLOCK_LEN]]); + fn update_blocks(&self, xi: &mut Xi, input: AsChunks); } diff --git a/src/aead/gcm/clmul.rs b/src/aead/gcm/clmul.rs index 2a5316a6a8..92a9bec424 100644 --- a/src/aead/gcm/clmul.rs +++ b/src/aead/gcm/clmul.rs @@ -19,7 +19,7 @@ ))] use super::{ffi::KeyValue, Gmult, HTable, Xi}; -use crate::cpu; +use crate::{cpu, polyfill::slice::AsChunks}; #[cfg(all(target_arch = "aarch64", target_endian = "little"))] pub(in super::super) type RequiredCpuFeatures = cpu::arm::PMull; @@ -66,7 +66,7 @@ impl Gmult for Key { #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] impl super::UpdateBlocks for Key { - fn update_blocks(&self, xi: &mut Xi, input: &[[u8; super::BLOCK_LEN]]) { + fn update_blocks(&self, xi: &mut Xi, input: AsChunks) { let _: cpu::Features = cpu::features(); unsafe { ghash!(gcm_ghash_clmul, xi, &self.h_table, input) } } diff --git a/src/aead/gcm/clmulavxmovbe.rs b/src/aead/gcm/clmulavxmovbe.rs index 92d2bf46ac..5de1be7c69 100644 --- a/src/aead/gcm/clmulavxmovbe.rs +++ b/src/aead/gcm/clmulavxmovbe.rs @@ -15,7 +15,7 @@ #![cfg(target_arch = "x86_64")] use super::{clmul, Gmult, HTable, KeyValue, UpdateBlocks, Xi, BLOCK_LEN}; -use crate::cpu; +use crate::{cpu, polyfill::slice::AsChunks}; pub(in super::super) type RequiredCpuFeatures = ( clmul::RequiredCpuFeatures, @@ -47,7 +47,7 @@ impl Gmult for Key { } impl UpdateBlocks for Key { - fn update_blocks(&self, xi: &mut Xi, input: &[[u8; BLOCK_LEN]]) { + fn update_blocks(&self, xi: &mut Xi, input: AsChunks) { unsafe { ghash!(gcm_ghash_avx, xi, &self.inner.inner(), input) } } } diff --git a/src/aead/gcm/fallback.rs b/src/aead/gcm/fallback.rs index 219fbcc81f..62779ec7d4 100644 --- a/src/aead/gcm/fallback.rs +++ b/src/aead/gcm/fallback.rs @@ -23,7 +23,7 @@ // Unlike the BearSSL notes, we use u128 in the 64-bit implementation. use super::{ffi::U128, Gmult, KeyValue, UpdateBlocks, Xi, BLOCK_LEN}; -use crate::polyfill::ArraySplitMap as _; +use crate::polyfill::{slice::AsChunks, ArraySplitMap as _}; #[derive(Clone)] pub struct Key { @@ -43,7 +43,7 @@ impl Gmult for Key { } impl UpdateBlocks for Key { - fn update_blocks(&self, xi: &mut Xi, input: &[[u8; BLOCK_LEN]]) { + fn update_blocks(&self, xi: &mut Xi, input: AsChunks) { ghash(xi, self.h, input); } } @@ -248,9 +248,9 @@ fn gmult(xi: &mut Xi, h: U128) { }) } -fn ghash(xi: &mut Xi, h: U128, input: &[[u8; BLOCK_LEN]]) { +fn ghash(xi: &mut Xi, h: U128, input: AsChunks) { with_swapped_xi(xi, |swapped| { - input.iter().for_each(|&input| { + input.into_iter().for_each(|&input| { let input = input.array_split_map(u64::from_be_bytes); swapped[0] ^= input[1]; swapped[1] ^= input[0]; diff --git a/src/aead/gcm/ffi.rs b/src/aead/gcm/ffi.rs index 1cb863d78d..cd76a6f825 100644 --- a/src/aead/gcm/ffi.rs +++ b/src/aead/gcm/ffi.rs @@ -12,7 +12,10 @@ // OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN // CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -use crate::{constant_time, polyfill::ArraySplitMap}; +use crate::{ + constant_time, + polyfill::{slice::AsChunks, ArraySplitMap}, +}; pub(in super::super) const BLOCK_LEN: usize = 16; pub(in super::super) type Block = [u8; BLOCK_LEN]; @@ -125,12 +128,11 @@ impl HTable { len: crate::c::NonZero_size_t, ), xi: &mut Xi, - input: &[[u8; BLOCK_LEN]], + input: AsChunks, ) { - use crate::polyfill::slice; use core::num::NonZeroUsize; - let input = slice::flatten(input); + let input = input.as_flattened(); let input_len = match NonZeroUsize::new(input.len()) { Some(len) => len, diff --git a/src/aead/gcm/neon.rs b/src/aead/gcm/neon.rs index e163748db4..49794a3582 100644 --- a/src/aead/gcm/neon.rs +++ b/src/aead/gcm/neon.rs @@ -42,7 +42,7 @@ impl Gmult for Key { } impl UpdateBlocks for Key { - fn update_blocks(&self, xi: &mut Xi, input: &[[u8; BLOCK_LEN]]) { + fn update_blocks(&self, xi: &mut Xi, input: AsChunks) { unsafe { ghash!(gcm_ghash_neon, xi, &self.h_table, input) } } } diff --git a/src/aead/poly1305.rs b/src/aead/poly1305.rs index 45229d560f..99f04f7eb9 100644 --- a/src/aead/poly1305.rs +++ b/src/aead/poly1305.rs @@ -17,8 +17,7 @@ use super::{Tag, TAG_LEN}; #[cfg(all(target_arch = "arm", target_endian = "little"))] use crate::cpu::GetFeature as _; -use crate::{cpu, polyfill::slice}; -use core::array; +use crate::{cpu, polyfill::slice::AsChunks}; mod ffi_arm_neon; mod ffi_fallback; @@ -64,11 +63,11 @@ impl Context { } pub fn update_block(&mut self, input: [u8; BLOCK_LEN]) { - self.update(array::from_ref(&input)) + self.update(AsChunks::from_ref(&input)) } - pub fn update(&mut self, input: &[[u8; BLOCK_LEN]]) { - self.update_internal(slice::flatten(input)); + pub fn update(&mut self, input: AsChunks) { + self.update_internal(input.as_flattened()); } fn update_internal(&mut self, input: &[u8]) { diff --git a/src/digest/dynstate.rs b/src/digest/dynstate.rs index 0615213a69..f79fba2e95 100644 --- a/src/digest/dynstate.rs +++ b/src/digest/dynstate.rs @@ -60,7 +60,7 @@ pub(super) fn sha1_block_data_order<'d>( let (full_blocks, leftover) = slice::as_chunks(data); sha1::sha1_block_data_order(state, full_blocks); - (full_blocks.len() * sha1::BLOCK_LEN.into(), leftover) + (full_blocks.as_flattened().len(), leftover) } pub(super) fn sha256_block_data_order<'d>( diff --git a/src/digest/sha1.rs b/src/digest/sha1.rs index c31a93914a..98808d50ff 100644 --- a/src/digest/sha1.rs +++ b/src/digest/sha1.rs @@ -20,7 +20,7 @@ use super::{ }, BlockLen, OutputLen, }; -use crate::polyfill::slice; +use crate::polyfill::slice::{self, AsChunks}; use core::num::Wrapping; pub(super) const BLOCK_LEN: BlockLen = BlockLen::_512; @@ -39,7 +39,7 @@ fn parity(x: W32, y: W32, z: W32) -> W32 { type State = [W32; CHAINING_WORDS]; const ROUNDS: usize = 80; -pub fn sha1_block_data_order(state: &mut State32, data: &[[u8; BLOCK_LEN.into()]]) { +pub fn sha1_block_data_order(state: &mut State32, data: AsChunks) { // The unwrap won't fail because `CHAINING_WORDS` is smaller than the // length. let state: &mut State = (&mut state[..CHAINING_WORDS]).try_into().unwrap(); @@ -52,11 +52,11 @@ pub fn sha1_block_data_order(state: &mut State32, data: &[[u8; BLOCK_LEN.into()] #[rustfmt::skip] fn block_data_order( mut H: [W32; CHAINING_WORDS], - M: &[[u8; BLOCK_LEN.into()]], + M: AsChunks, ) -> [W32; CHAINING_WORDS] { for M in M { - let (M, remainder): (&[::InputBytes], &[u8]) = slice::as_chunks(M); + let (M, remainder): (AsChunks()}>, &[u8]) = slice::as_chunks(M); debug_assert!(remainder.is_empty()); // FIPS 180-4 6.1.2 Step 1 diff --git a/src/digest/sha2/fallback.rs b/src/digest/sha2/fallback.rs index 6180015235..02e14f628b 100644 --- a/src/digest/sha2/fallback.rs +++ b/src/digest/sha2/fallback.rs @@ -13,7 +13,7 @@ // CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. use super::CHAINING_WORDS; -use crate::polyfill::slice; +use crate::polyfill::slice::{self, AsChunks}; use core::{ num::Wrapping, ops::{Add, AddAssign, BitAnd, BitOr, BitXor, Not, Shr}, @@ -30,13 +30,13 @@ use core::{ #[inline] pub(super) fn block_data_order( mut H: [S; CHAINING_WORDS], - M: &[[u8; BLOCK_LEN]], + M: AsChunks, ) -> [S; CHAINING_WORDS] where for<'a> &'a S::InputBytes: From<&'a [u8; BYTES_LEN]>, { for M in M { - let (M, remainder): (&[[u8; BYTES_LEN]], &[u8]) = slice::as_chunks(M); + let (M, remainder): (AsChunks, &[u8]) = slice::as_chunks(M); debug_assert!(remainder.is_empty()); // FIPS 180-4 {6.2.2, 6.4.2} Step 1 diff --git a/src/digest/sha2/ffi.rs b/src/digest/sha2/ffi.rs index a359aa2e79..6c119057db 100644 --- a/src/digest/sha2/ffi.rs +++ b/src/digest/sha2/ffi.rs @@ -13,6 +13,7 @@ // CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. use super::CHAINING_WORDS; +use crate::polyfill::slice::AsChunks; use core::num::{NonZeroUsize, Wrapping}; /// `unsafe { T => f }` means it is safe to call `f` iff we can construct @@ -49,7 +50,7 @@ macro_rules! sha2_64_ffi { pub(super) unsafe fn sha2_ffi( state: &mut [Wrapping; CHAINING_WORDS], - data: &[[u8; BLOCK_LEN]], + data: AsChunks, cpu: Cpu, f: unsafe extern "C" fn( &mut [Wrapping; CHAINING_WORDS], diff --git a/src/digest/sha2/sha2_32.rs b/src/digest/sha2/sha2_32.rs index f704beed4a..fb76b25035 100644 --- a/src/digest/sha2/sha2_32.rs +++ b/src/digest/sha2/sha2_32.rs @@ -13,7 +13,7 @@ // CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. use super::{BlockLen, CHAINING_WORDS}; -use crate::cpu; +use crate::{cpu, polyfill::slice::AsChunks}; use cfg_if::cfg_if; use core::num::Wrapping; @@ -23,7 +23,7 @@ pub type State32 = [Wrapping; CHAINING_WORDS]; pub(crate) fn block_data_order_32( state: &mut State32, - data: &[[u8; SHA256_BLOCK_LEN.into()]], + data: AsChunks, cpu: cpu::Features, ) { cfg_if! { diff --git a/src/digest/sha2/sha2_64.rs b/src/digest/sha2/sha2_64.rs index 5215d7ac26..9325cf034f 100644 --- a/src/digest/sha2/sha2_64.rs +++ b/src/digest/sha2/sha2_64.rs @@ -13,7 +13,7 @@ // CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. use super::{BlockLen, CHAINING_WORDS}; -use crate::cpu; +use crate::{cpu, polyfill::slice::AsChunks}; use cfg_if::cfg_if; use core::num::Wrapping; @@ -23,7 +23,7 @@ pub type State64 = [Wrapping; CHAINING_WORDS]; pub(crate) fn block_data_order_64( state: &mut State64, - data: &[[u8; SHA512_BLOCK_LEN.into()]], + data: AsChunks, cpu: cpu::Features, ) { cfg_if! { diff --git a/src/ec/curve25519/scalar.rs b/src/ec/curve25519/scalar.rs index da3d41aa5e..2754550eef 100644 --- a/src/ec/curve25519/scalar.rs +++ b/src/ec/curve25519/scalar.rs @@ -12,7 +12,11 @@ // OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN // CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -use crate::{arithmetic::limbs_from_hex, digest, error, limb, polyfill::slice}; +use crate::{ + arithmetic::limbs_from_hex, + digest, error, limb, + polyfill::slice::{self, AsChunks}, +}; use core::array; #[repr(transparent)] @@ -28,7 +32,8 @@ impl Scalar { limbs_from_hex("1000000000000000000000000000000014def9dea2f79cd65812631a5cf5d3ed"); let order = ORDER.map(limb::Limb::from); - let (limbs_as_bytes, _empty): (&[[u8; limb::LIMB_BYTES]], _) = slice::as_chunks(&bytes); + let (limbs_as_bytes, _empty): (AsChunks, _) = + slice::as_chunks(&bytes); debug_assert!(_empty.is_empty()); let limbs: [limb::Limb; SCALAR_LEN / limb::LIMB_BYTES] = array::from_fn(|i| limb::Limb::from_le_bytes(limbs_as_bytes[i])); diff --git a/src/polyfill/slice.rs b/src/polyfill/slice.rs index b585429c21..2a29ef8fe5 100644 --- a/src/polyfill/slice.rs +++ b/src/polyfill/slice.rs @@ -22,66 +22,11 @@ // IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use core::mem::size_of; +mod as_chunks; +mod as_chunks_mut; -// TODO(MSRV feature(slice_as_chunks)): Use `slice::as_chunks` instead. -// This is copied from the libcore implementation of `slice::as_chunks`. -#[inline(always)] -pub fn as_chunks(slice: &[T]) -> (&[[T; N]], &[T]) { - assert!(N != 0, "chunk size must be non-zero"); - let len = slice.len() / N; - let (multiple_of_n, remainder) = slice.split_at(len * N); - // SAFETY: We already panicked for zero, and ensured by construction - // that the length of the subslice is a multiple of N. - // SAFETY: We cast a slice of `new_len * N` elements into - // a slice of `new_len` many `N` elements chunks. - let chunked = unsafe { core::slice::from_raw_parts(multiple_of_n.as_ptr().cast(), len) }; - (chunked, remainder) -} - -// TODO(MSRV feature(slice_as_chunks)): Use `slice::as_chunks_mut` instead. -// This is adapted from above implementation of `slice::as_chunks`, as the -// libcore implementation uses other unstable APIs. -pub fn as_chunks_mut(slice: &mut [T]) -> (&mut [[T; N]], &mut [T]) { - assert!(N != 0, "chunk size must be non-zero"); - let len = slice.len() / N; - let (multiple_of_n, remainder) = slice.split_at_mut(len * N); - // SAFETY: We already panicked for zero, and ensured by construction - // that the length of the subslice is a multiple of N. - // SAFETY: We cast a slice of `new_len * N` elements into - // a slice of `new_len` many `N` elements chunks. - let chunked = - unsafe { core::slice::from_raw_parts_mut(multiple_of_n.as_mut_ptr().cast(), len) }; - (chunked, remainder) -} - -// TODO(MSRV feature(slice_flatten)): Use `slice::flatten` instead. -// This is derived from the libcore implementation, using only stable APIs. -pub fn flatten(slice: &[[T; N]]) -> &[T] { - let len = if size_of::() == 0 { - slice.len().checked_mul(N).expect("slice len overflow") - } else { - // SAFETY: `slice.len() * N` cannot overflow because `slice` is - // already in the address space. - slice.len() * N - }; - // SAFETY: `[T]` is layout-identical to `[T; N]` - unsafe { core::slice::from_raw_parts(slice.as_ptr().cast(), len) } -} - -// TODO(MSRV feature(slice_flatten)): Use `slice::flatten_mut` instead. -// This is derived from the libcore implementation, using only stable APIs. -pub fn flatten_mut(slice: &mut [[T; N]]) -> &mut [T] { - let len = if size_of::() == 0 { - slice.len().checked_mul(N).expect("slice len overflow") - } else { - // SAFETY: `slice.len() * N` cannot overflow because `slice` is - // already in the address space. - slice.len() * N - }; - // SAFETY: `[T]` is layout-identical to `[T; N]` - unsafe { core::slice::from_raw_parts_mut(slice.as_mut_ptr().cast(), len) } -} +pub use as_chunks::{as_chunks, AsChunks}; +pub use as_chunks_mut::{as_chunks_mut, AsChunksMut}; // TODO(MSRV feature(split_at_checked)): Use `slice::split_at_checked`. // diff --git a/src/polyfill/slice/as_chunks.rs b/src/polyfill/slice/as_chunks.rs new file mode 100644 index 0000000000..4ed688feb0 --- /dev/null +++ b/src/polyfill/slice/as_chunks.rs @@ -0,0 +1,98 @@ +// Copyright 2025 Brian Smith. +// +// Permission to use, copy, modify, and/or distribute this software for any +// purpose with or without fee is hereby granted, provided that the above +// copyright notice and this permission notice appear in all copies. +// +// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHORS DISCLAIM ALL WARRANTIES +// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY +// SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION +// OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN +// CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +use super::AsChunksMut; +use core::ops; + +#[inline(always)] +pub fn as_chunks(slice: &[T]) -> (AsChunks, &[T]) { + assert!(N != 0, "chunk size must be non-zero"); + let len = slice.len() / N; + let (multiple_of_n, remainder) = slice.split_at(len * N); + (AsChunks(multiple_of_n), remainder) +} + +#[derive(Clone, Copy)] +pub struct AsChunks<'a, T, const N: usize>(&'a [T]); + +impl<'a, T, const N: usize> AsChunks<'a, T, N> { + #[inline(always)] + pub fn from_ref(value: &'a [T; N]) -> Self { + Self(value) + } + + #[inline(always)] + pub fn as_flattened(&self) -> &[T] { + self.0 + } + + #[inline(always)] + pub fn as_ptr(&self) -> *const [T; N] { + self.0.as_ptr().cast() + } + + #[inline(always)] + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } + + #[inline(always)] + pub fn len(&self) -> usize { + self.0.len() / N + } +} + +impl<'a, T, const N: usize> ops::Index for AsChunks<'a, T, N> +where + [T]: ops::Index, Output = [T]>, +{ + type Output = [T; N]; + + #[inline(always)] + fn index(&self, index: usize) -> &Self::Output { + let start = N * index; + let slice = &self.0[start..(start + N)]; + slice.try_into().unwrap() + } +} + +impl<'a, T, const N: usize> IntoIterator for AsChunks<'a, T, N> { + type IntoIter = AsChunksIter<'a, T, N>; + type Item = &'a [T; N]; + + #[inline(always)] + fn into_iter(self) -> Self::IntoIter { + AsChunksIter(self.0.chunks_exact(N)) + } +} + +pub struct AsChunksIter<'a, T, const N: usize>(core::slice::ChunksExact<'a, T>); + +impl<'a, T, const N: usize> Iterator for AsChunksIter<'a, T, N> { + type Item = &'a [T; N]; + + #[inline(always)] + fn next(&mut self) -> Option { + self.0.next().map(|x| x.try_into().unwrap()) + } +} + +// `&mut [[T; N]]` is implicitly convertable to `&[[T; N]]` but our types can't +// do that. +impl<'a, T, const N: usize> From> for AsChunks<'a, T, N> { + #[inline(always)] + fn from(as_mut: AsChunksMut<'a, T, N>) -> Self { + Self(as_mut.into_inner_for_conversion()) + } +} diff --git a/src/polyfill/slice/as_chunks_mut.rs b/src/polyfill/slice/as_chunks_mut.rs new file mode 100644 index 0000000000..dac574c7a2 --- /dev/null +++ b/src/polyfill/slice/as_chunks_mut.rs @@ -0,0 +1,86 @@ +// Copyright 2025 Brian Smith. +// +// Permission to use, copy, modify, and/or distribute this software for any +// purpose with or without fee is hereby granted, provided that the above +// copyright notice and this permission notice appear in all copies. +// +// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHORS DISCLAIM ALL WARRANTIES +// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY +// SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION +// OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN +// CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +#[inline(always)] +pub fn as_chunks_mut(slice: &mut [T]) -> (AsChunksMut, &mut [T]) { + assert!(N != 0, "chunk size must be non-zero"); + let len = slice.len() / N; + let (multiple_of_n, remainder) = slice.split_at_mut(len * N); + (AsChunksMut(multiple_of_n), remainder) +} + +pub struct AsChunksMut<'a, T, const N: usize>(&'a mut [T]); + +impl<'a, T, const N: usize> AsChunksMut<'a, T, N> { + #[inline(always)] + pub fn as_flattened_mut(&mut self) -> &mut [T] { + self.0 + } + + #[inline(always)] + pub fn as_mut_ptr(&mut self) -> *mut T { + self.0.as_mut_ptr().cast() + } + + #[inline(always)] + pub fn len(&self) -> usize { + self.0.len() / N + } + + // Argument moved from runtime argument to `const` argument so that + // `CHUNK_LEN * N` is checked at compile time for overflow. + #[inline(always)] + pub fn chunks_mut<'s, const CHUNK_LEN: usize>( + &'s mut self, + ) -> AsChunksMutChunksMutIter<'s, T, N> { + AsChunksMutChunksMutIter(self.0.chunks_mut(CHUNK_LEN * N)) + } + + #[inline(always)] + pub(super) fn into_inner_for_conversion(self) -> &'a [T] { + self.0 + } +} + +impl<'a, T, const N: usize> IntoIterator for &'a mut AsChunksMut<'_, T, N> { + type IntoIter = AsChunksMutIter<'a, T, N>; + type Item = &'a mut [T; N]; + + #[inline(always)] + fn into_iter(self) -> Self::IntoIter { + AsChunksMutIter(self.0.chunks_exact_mut(N)) + } +} + +pub struct AsChunksMutIter<'a, T, const N: usize>(core::slice::ChunksExactMut<'a, T>); + +impl<'a, T, const N: usize> Iterator for AsChunksMutIter<'a, T, N> { + type Item = &'a mut [T; N]; + + #[inline(always)] + fn next(&mut self) -> Option { + self.0.next().map(|x| x.try_into().unwrap()) + } +} + +pub struct AsChunksMutChunksMutIter<'a, T, const N: usize>(core::slice::ChunksMut<'a, T>); + +impl<'a, T, const N: usize> Iterator for AsChunksMutChunksMutIter<'a, T, N> { + type Item = AsChunksMut<'a, T, N>; + + #[inline(always)] + fn next(&mut self) -> Option { + self.0.next().map(AsChunksMut) + } +}