Skip to content

Commit

Permalink
Fix BigInteger.Rotate{Left,Right} for backport (#112878)
Browse files Browse the repository at this point in the history
* Add BigInteger.Rotate* tests

* Fix BigInteger.Rotate*

* avoid stackalloc

* Add comment
  • Loading branch information
kzrnm authored Feb 27, 2025
1 parent b54529f commit 2959612
Show file tree
Hide file tree
Showing 5 changed files with 799 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1701,7 +1701,7 @@ private static BigInteger Add(ReadOnlySpan<uint> leftBits, int leftSign, ReadOnl
}

if (bitsFromPool != null)
ArrayPool<uint>.Shared.Return(bitsFromPool);
ArrayPool<uint>.Shared.Return(bitsFromPool);

return result;
}
Expand Down Expand Up @@ -2636,7 +2636,7 @@ public static implicit operator BigInteger(nuint value)

if (zdFromPool != null)
ArrayPool<uint>.Shared.Return(zdFromPool);
exit:
exit:
if (xdFromPool != null)
ArrayPool<uint>.Shared.Return(xdFromPool);

Expand Down Expand Up @@ -3239,7 +3239,27 @@ public static BigInteger PopCount(BigInteger value)
public static BigInteger RotateLeft(BigInteger value, int rotateAmount)
{
value.AssertValid();
int byteCount = (value._bits is null) ? sizeof(int) : (value._bits.Length * 4);

bool negx = value._sign < 0;
uint smallBits = NumericsHelpers.Abs(value._sign);
scoped ReadOnlySpan<uint> bits = value._bits;
if (bits.IsEmpty)
{
bits = new ReadOnlySpan<uint>(in smallBits);
}

int xl = bits.Length;
if (negx && (bits[^1] >= kuMaskHighBit) && ((bits[^1] != kuMaskHighBit) || bits.IndexOfAnyExcept(0u) != (bits.Length - 1)))
{
// We check for a special case where its sign bit could be outside the uint array after 2's complement conversion.
// For example given [0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF], its 2's complement is [0x01, 0x00, 0x00]
// After a 32 bit right shift, it becomes [0x00, 0x00] which is [0x00, 0x00] when converted back.
// The expected result is [0x00, 0x00, 0xFFFFFFFF] (2's complement) or [0x00, 0x00, 0x01] when converted back
// If the 2's component's last element is a 0, we will track the sign externally
++xl;
}

int byteCount = xl * 4;

// Normalize the rotate amount to drop full rotations
rotateAmount = (int)(rotateAmount % (byteCount * 8L));
Expand All @@ -3256,14 +3276,13 @@ public static BigInteger RotateLeft(BigInteger value, int rotateAmount)
(int digitShift, int smallShift) = Math.DivRem(rotateAmount, kcbitUint);

uint[]? xdFromPool = null;
int xl = value._bits?.Length ?? 1;

Span<uint> xd = (xl <= BigIntegerCalculator.StackAllocThreshold)
? stackalloc uint[BigIntegerCalculator.StackAllocThreshold]
: xdFromPool = ArrayPool<uint>.Shared.Rent(xl);
xd = xd.Slice(0, xl);
xd[^1] = 0;

bool negx = value.GetPartsForBitManipulation(xd);
bits.CopyTo(xd);

int zl = xl;
uint[]? zdFromPool = null;
Expand Down Expand Up @@ -3374,7 +3393,28 @@ public static BigInteger RotateLeft(BigInteger value, int rotateAmount)
public static BigInteger RotateRight(BigInteger value, int rotateAmount)
{
value.AssertValid();
int byteCount = (value._bits is null) ? sizeof(int) : (value._bits.Length * 4);


bool negx = value._sign < 0;
uint smallBits = NumericsHelpers.Abs(value._sign);
scoped ReadOnlySpan<uint> bits = value._bits;
if (bits.IsEmpty)
{
bits = new ReadOnlySpan<uint>(in smallBits);
}

int xl = bits.Length;
if (negx && (bits[^1] >= kuMaskHighBit) && ((bits[^1] != kuMaskHighBit) || bits.IndexOfAnyExcept(0u) != (bits.Length - 1)))
{
// We check for a special case where its sign bit could be outside the uint array after 2's complement conversion.
// For example given [0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF], its 2's complement is [0x01, 0x00, 0x00]
// After a 32 bit right shift, it becomes [0x00, 0x00] which is [0x00, 0x00] when converted back.
// The expected result is [0x00, 0x00, 0xFFFFFFFF] (2's complement) or [0x00, 0x00, 0x01] when converted back
// If the 2's component's last element is a 0, we will track the sign externally
++xl;
}

int byteCount = xl * 4;

// Normalize the rotate amount to drop full rotations
rotateAmount = (int)(rotateAmount % (byteCount * 8L));
Expand All @@ -3391,14 +3431,13 @@ public static BigInteger RotateRight(BigInteger value, int rotateAmount)
(int digitShift, int smallShift) = Math.DivRem(rotateAmount, kcbitUint);

uint[]? xdFromPool = null;
int xl = value._bits?.Length ?? 1;

Span<uint> xd = (xl <= BigIntegerCalculator.StackAllocThreshold)
? stackalloc uint[BigIntegerCalculator.StackAllocThreshold]
: xdFromPool = ArrayPool<uint>.Shared.Rent(xl);
xd = xd.Slice(0, xl);
xd[^1] = 0;

bool negx = value.GetPartsForBitManipulation(xd);
bits.CopyTo(xd);

int zl = xl;
uint[]? zdFromPool = null;
Expand Down Expand Up @@ -3445,19 +3484,12 @@ public static BigInteger RotateRight(BigInteger value, int rotateAmount)
{
int carryShift = kcbitUint - smallShift;

int dstIndex = 0;
int srcIndex = digitShift;
int dstIndex = xd.Length - 1;
int srcIndex = digitShift == 0
? xd.Length - 1
: digitShift - 1;

uint carry = 0;

if (digitShift == 0)
{
carry = xd[^1] << carryShift;
}
else
{
carry = xd[srcIndex - 1] << carryShift;
}
uint carry = xd[digitShift] << carryShift;

do
{
Expand All @@ -3466,22 +3498,22 @@ public static BigInteger RotateRight(BigInteger value, int rotateAmount)
zd[dstIndex] = (part >> smallShift) | carry;
carry = part << carryShift;

dstIndex++;
srcIndex++;
dstIndex--;
srcIndex--;
}
while (srcIndex < xd.Length);
while ((uint)srcIndex < (uint)xd.Length); // is equivalent to (srcIndex >= 0 && srcIndex < xd.Length)

srcIndex = 0;
srcIndex = xd.Length - 1;

while (dstIndex < zd.Length)
while ((uint)dstIndex < (uint)zd.Length) // is equivalent to (dstIndex >= 0 && dstIndex < zd.Length)
{
uint part = xd[srcIndex];

zd[dstIndex] = (part >> smallShift) | carry;
carry = part << carryShift;

dstIndex++;
srcIndex++;
dstIndex--;
srcIndex--;
}
}

Expand Down
103 changes: 103 additions & 0 deletions src/libraries/System.Runtime.Numerics/tests/BigInteger/MyBigInt.cs
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,10 @@ public static BigInteger DoBinaryOperatorMine(BigInteger num1, BigInteger num2,
return new BigInteger(ShiftLeft(bytes1, Negate(bytes2)).ToArray());
case "b<<":
return new BigInteger(ShiftLeft(bytes1, bytes2).ToArray());
case "bRotateLeft":
return new BigInteger(RotateLeft(bytes1, bytes2).ToArray());
case "bRotateRight":
return new BigInteger(RotateLeft(bytes1, Negate(bytes2)).ToArray());
case "b^":
return new BigInteger(Xor(bytes1, bytes2).ToArray());
case "b|":
Expand Down Expand Up @@ -774,6 +778,105 @@ public static List<byte> ShiftRight(List<byte> bytes)
return bresult;
}

public static List<byte> RotateRight(List<byte> bytes)
{
List<byte> bresult = new List<byte>();

byte bottom = (byte)(bytes[0] & 0x01);

for (int i = 0; i < bytes.Count; i++)
{
byte newbyte = bytes[i];

newbyte = (byte)(newbyte / 2);
if ((i != (bytes.Count - 1)) && ((bytes[i + 1] & 0x01) == 1))
{
newbyte += 128;
}
if ((i == (bytes.Count - 1)) && (bottom != 0))
{
newbyte += 128;
}
bresult.Add(newbyte);
}

return bresult;
}

public static List<byte> RotateLeft(List<byte> bytes)
{
List<byte> bresult = new List<byte>();

bool prevHead = (bytes[bytes.Count - 1] & 0x80) != 0;

for (int i = 0; i < bytes.Count; i++)
{
byte newbyte = bytes[i];

newbyte = (byte)(newbyte * 2);
if (prevHead)
{
newbyte += 1;
}

bresult.Add(newbyte);

prevHead = (bytes[i] & 0x80) != 0;
}

return bresult;
}


public static List<byte> RotateLeft(List<byte> bytes1, List<byte> bytes2)
{
List<byte> bytes1Copy = Copy(bytes1);
int byteShift = (int)new BigInteger(Divide(Copy(bytes2), new List<byte>(new byte[] { 8 })).ToArray());
sbyte bitShift = (sbyte)new BigInteger(Remainder(bytes2, new List<byte>(new byte[] { 8 })).ToArray());

Trim(bytes1);

byte fill = (bytes1[bytes1.Count - 1] & 0x80) != 0 ? byte.MaxValue : (byte)0;

if (fill == 0 && bytes1.Count > 1 && bytes1[bytes1.Count - 1] == 0)
bytes1.RemoveAt(bytes1.Count - 1);

while (bytes1.Count % 4 != 0)
{
bytes1.Add(fill);
}

byteShift %= bytes1.Count;
if (byteShift == 0 && bitShift == 0)
return bytes1Copy;

for (int i = 0; i < Math.Abs(bitShift); i++)
{
if (bitShift < 0)
{
bytes1 = RotateRight(bytes1);
}
else
{
bytes1 = RotateLeft(bytes1);
}
}

List<byte> temp = new List<byte>();
for (int i = 0; i < bytes1.Count; i++)
{
temp.Add(bytes1[(i - byteShift + bytes1.Count) % bytes1.Count]);
}
bytes1 = temp;

if (fill == 0)
bytes1.Add(0);

Trim(bytes1);

return bytes1;
}

public static List<byte> SetLength(List<byte> bytes, int size)
{
List<byte> bresult = new List<byte>();
Expand Down
Loading

0 comments on commit 2959612

Please sign in to comment.