Skip to content

Commit

Permalink
Improve allocations in NegotiateStreamPal (#71280)
Browse files Browse the repository at this point in the history
* Reduce buffer allocations during NTLM/Negotiate authentication

* Update ReadWriteAdapter.WriteAsync prototype to use Memory<byte> instead of explicit offset/count

* Spanify NTAuthentication.Decrypt and avoid couple of offset/count checks

* Spanify NegotiateStreamPal.VerifySignature/MakeSignature.
Remove indirect Encrypt/Decrypt layer from SSPIWrapper, it is unnecessarily cumbersome to use and SslStreamPal already migrated away from it.

* Update src/libraries/Common/src/System/Net/NTAuthentication.Common.cs

Co-authored-by: Stephen Toub <stoub@microsoft.com>

Co-authored-by: Stephen Toub <stoub@microsoft.com>
  • Loading branch information
filipnavara and stephentoub authored Jun 29, 2022
1 parent 6c1d321 commit 1e8eaef
Show file tree
Hide file tree
Showing 17 changed files with 325 additions and 457 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -210,10 +210,10 @@ private static unsafe partial Status Wrap(
ref GssBuffer outBuffer);

[LibraryImport(Interop.Libraries.NetSecurityNative, EntryPoint="NetSecurityNative_Unwrap")]
private static partial Status Unwrap(
private static unsafe partial Status Unwrap(
out Status minorStatus,
SafeGssContextHandle? contextHandle,
byte[] inputBytes,
byte* inputBytes,
int offset,
int count,
ref GssBuffer outBuffer);
Expand All @@ -231,19 +231,16 @@ internal static unsafe Status WrapBuffer(
}
}

internal static Status UnwrapBuffer(
internal static unsafe Status UnwrapBuffer(
out Status minorStatus,
SafeGssContextHandle? contextHandle,
byte[] inputBytes,
int offset,
int count,
ReadOnlySpan<byte> inputBytes,
ref GssBuffer outBuffer)
{
Debug.Assert(inputBytes != null, "inputBytes must be valid value");
Debug.Assert(offset >= 0 && offset <= inputBytes.Length, "offset must be valid");
Debug.Assert(count >= 0 && count <= inputBytes.Length, "count must be valid");

return Unwrap(out minorStatus, contextHandle, inputBytes, offset, count, ref outBuffer);
fixed (byte* inputBytesPtr = inputBytes)
{
return Unwrap(out minorStatus, contextHandle, inputBytesPtr, 0, inputBytes.Length, ref outBuffer);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ internal interface ISSPIInterface
int QueryContextChannelBinding(SafeDeleteContext phContext, Interop.SspiCli.ContextAttribute attribute, out SafeFreeContextBufferChannelBinding refHandle);
int QueryContextAttributes(SafeDeleteContext phContext, Interop.SspiCli.ContextAttribute attribute, Span<byte> buffer, Type? handleType, out SafeHandle? refHandle);
int QuerySecurityContextToken(SafeDeleteContext phContext, out SecurityContextTokenHandle phToken);
int CompleteAuthToken(ref SafeDeleteSslContext? refContext, in SecurityBuffer inputBuffer);
int CompleteAuthToken(ref SafeDeleteSslContext? refContext, in InputSecurityBuffer inputBuffer);
int ApplyControlToken(ref SafeDeleteContext? refContext, in SecurityBuffer inputBuffer);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ public int QuerySecurityContextToken(SafeDeleteContext phContext, out SecurityCo
return GetSecurityContextToken(phContext, out phToken);
}

public int CompleteAuthToken(ref SafeDeleteSslContext? refContext, in SecurityBuffer inputBuffer)
public int CompleteAuthToken(ref SafeDeleteSslContext? refContext, in InputSecurityBuffer inputBuffer)
{
return SafeDeleteContext.CompleteAuthToken(ref refContext, in inputBuffer);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ public int QuerySecurityContextToken(SafeDeleteContext phContext, out SecurityCo
throw new NotSupportedException();
}

public int CompleteAuthToken(ref SafeDeleteSslContext? refContext, in SecurityBuffer inputBuffer)
public int CompleteAuthToken(ref SafeDeleteSslContext? refContext, in InputSecurityBuffer inputBuffer)
{
throw new NotSupportedException();
}
Expand Down
143 changes: 1 addition & 142 deletions src/libraries/Common/src/Interop/Windows/SspiCli/SSPIWrapper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ internal static int AcceptSecurityContext(ISSPIInterface secModule, SafeFreeCred
return errorCode;
}

internal static int CompleteAuthToken(ISSPIInterface secModule, ref SafeDeleteSslContext? context, in SecurityBuffer inputBuffer)
internal static int CompleteAuthToken(ISSPIInterface secModule, ref SafeDeleteSslContext? context, in InputSecurityBuffer inputBuffer)
{
int errorCode = secModule.CompleteAuthToken(ref context, in inputBuffer);

Expand All @@ -186,147 +186,6 @@ public static int QuerySecurityContextToken(ISSPIInterface secModule, SafeDelete
return secModule.QuerySecurityContextToken(context, out token);
}

public static int EncryptMessage(ISSPIInterface secModule, SafeDeleteContext context, Span<SecurityBuffer> input, uint sequenceNumber)
{
return EncryptDecryptHelper(OP.Encrypt, secModule, context, input, sequenceNumber);
}

public static int DecryptMessage(ISSPIInterface secModule, SafeDeleteContext context, Span<SecurityBuffer> input, uint sequenceNumber)
{
return EncryptDecryptHelper(OP.Decrypt, secModule, context, input, sequenceNumber);
}

internal static int MakeSignature(ISSPIInterface secModule, SafeDeleteContext context, Span<SecurityBuffer> input, uint sequenceNumber)
{
return EncryptDecryptHelper(OP.MakeSignature, secModule, context, input, sequenceNumber);
}

public static int VerifySignature(ISSPIInterface secModule, SafeDeleteContext context, Span<SecurityBuffer> input, uint sequenceNumber)
{
return EncryptDecryptHelper(OP.VerifySignature, secModule, context, input, sequenceNumber);
}

private enum OP
{
Encrypt = 1,
Decrypt,
MakeSignature,
VerifySignature
}

[StructLayout(LayoutKind.Sequential)]
private ref struct ThreeByteArrays
{
public const int NumItems = 3;
internal byte[] _item0;
private byte[] _item1;
private byte[] _item2;
}

private static unsafe int EncryptDecryptHelper(OP op, ISSPIInterface secModule, SafeDeleteContext context, Span<SecurityBuffer> input, uint sequenceNumber)
{
Debug.Assert(Enum.IsDefined<OP>(op), $"Unknown op: {op}");
Debug.Assert(input.Length <= 3, "The below logic only works for 3 or fewer buffers.");

Interop.SspiCli.SecBufferDesc sdcInOut = new Interop.SspiCli.SecBufferDesc(input.Length);
Span<Interop.SspiCli.SecBuffer> unmanagedBuffer = stackalloc Interop.SspiCli.SecBuffer[input.Length];
unmanagedBuffer.Clear();

fixed (Interop.SspiCli.SecBuffer* unmanagedBufferPtr = unmanagedBuffer)
fixed (byte* pinnedBuffer0 = input.Length > 0 ? input[0].token : null)
fixed (byte* pinnedBuffer1 = input.Length > 1 ? input[1].token : null)
fixed (byte* pinnedBuffer2 = input.Length > 2 ? input[2].token : null)
{
sdcInOut.pBuffers = unmanagedBufferPtr;

ThreeByteArrays byteArrayStruct = default;
Span<byte[]> buffers = MemoryMarshal.CreateSpan(ref byteArrayStruct._item0!, ThreeByteArrays.NumItems).Slice(0, input.Length);

for (int i = 0; i < input.Length; i++)
{
ref readonly SecurityBuffer iBuffer = ref input[i];
unmanagedBuffer[i].cbBuffer = iBuffer.size;
unmanagedBuffer[i].BufferType = iBuffer.type;
if (iBuffer.token == null || iBuffer.token.Length == 0)
{
unmanagedBuffer[i].pvBuffer = IntPtr.Zero;
}
else
{
unmanagedBuffer[i].pvBuffer = Marshal.UnsafeAddrOfPinnedArrayElement(iBuffer.token, iBuffer.offset);
buffers[i] = iBuffer.token;
}
}

// The result is written in the input Buffer passed as type=BufferType.Data.
int errorCode = op switch
{
OP.Encrypt => secModule.EncryptMessage(context, ref sdcInOut, sequenceNumber),
OP.Decrypt => secModule.DecryptMessage(context, ref sdcInOut, sequenceNumber),
OP.MakeSignature => secModule.MakeSignature(context, ref sdcInOut, sequenceNumber),
_ /* OP.VerifySignature */ => secModule.VerifySignature(context, ref sdcInOut, sequenceNumber),
};

// Marshalling back returned sizes / data.
for (int i = 0; i < input.Length; i++)
{
ref SecurityBuffer iBuffer = ref input[i];
iBuffer.size = unmanagedBuffer[i].cbBuffer;
iBuffer.type = unmanagedBuffer[i].BufferType;

if (iBuffer.size == 0)
{
iBuffer.offset = 0;
iBuffer.token = null;
}
else
{

// Find the buffer this is inside of. Usually they all point inside buffer 0.
int j;
for (j = 0; j < input.Length; j++)
{
if (buffers[j] != null)
{
checked
{
byte* bufferAddress = (byte*)Marshal.UnsafeAddrOfPinnedArrayElement(buffers[j], 0);
if ((byte*)unmanagedBuffer[i].pvBuffer >= bufferAddress &&
(byte*)unmanagedBuffer[i].pvBuffer + iBuffer.size <= bufferAddress + buffers[j].Length)
{
iBuffer.offset = (int)((byte*)unmanagedBuffer[i].pvBuffer - bufferAddress);
iBuffer.token = buffers[j];
break;
}
}
}
}

if (j >= input.Length)
{
Debug.Fail("Output buffer out of range.");
iBuffer.size = 0;
iBuffer.offset = 0;
iBuffer.token = null;
}
}

// Backup validate the new sizes.
Debug.Assert(iBuffer.offset >= 0 && iBuffer.offset <= (iBuffer.token == null ? 0 : iBuffer.token.Length), $"'offset' out of range. [{iBuffer.offset}]");
Debug.Assert(iBuffer.size >= 0 && iBuffer.size <= (iBuffer.token == null ? 0 : iBuffer.token.Length - iBuffer.offset), $"'size' out of range. [{iBuffer.size}]");
}

if (NetEventSource.Log.IsEnabled() && errorCode != 0)
{
NetEventSource.Error(null, errorCode == Interop.SspiCli.SEC_I_RENEGOTIATE ?
SR.Format(SR.event_OperationReturnedSomething, op, "SEC_I_RENEGOTIATE") :
SR.Format(SR.net_log_operation_failed_with_error, op, $"0x{0:X}"));
}

return errorCode;
}
}

public static SafeFreeContextBufferChannelBinding? QueryContextChannelBinding(ISSPIInterface secModule, SafeDeleteContext securityContext, Interop.SspiCli.ContextAttribute contextAttribute)
{
SafeFreeContextBufferChannelBinding result;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -472,9 +472,13 @@ internal static unsafe int InitializeSecurityContext(
// Get unmanaged buffer with index 0 as the only one passed into PInvoke.
outSecBuffer.size = outUnmanagedBuffer.cbBuffer;
outSecBuffer.type = outUnmanagedBuffer.BufferType;
outSecBuffer.token = outSecBuffer.size > 0 ?
new Span<byte>((byte*)outUnmanagedBuffer.pvBuffer, outUnmanagedBuffer.cbBuffer).ToArray() :
null;

if (isSspiAllocated)
{
outSecBuffer.token = outSecBuffer.size > 0 ?
new Span<byte>((byte*)outUnmanagedBuffer.pvBuffer, outUnmanagedBuffer.cbBuffer).ToArray() :
null;
}

if (inSecBuffers.Count > 1 && inUnmanagedBuffer[1].BufferType == SecurityBufferType.SECBUFFER_EXTRA && inSecBuffers._item1.Type == SecurityBufferType.SECBUFFER_EMPTY)
{
Expand Down Expand Up @@ -952,25 +956,22 @@ private static unsafe int MustRunAcceptSecurityContext_SECURITY(

internal static unsafe int CompleteAuthToken(
ref SafeDeleteSslContext? refContext,
in SecurityBuffer inSecBuffer)
in InputSecurityBuffer inSecBuffer)
{
if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(null, $"refContext = {refContext}, inSecBuffer = {inSecBuffer}");
if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(null, $"refContext = {refContext}");

var inSecurityBufferDescriptor = new Interop.SspiCli.SecBufferDesc(1);
int errorCode = (int)Interop.SECURITY_STATUS.InvalidHandle;

Interop.SspiCli.SecBuffer inUnmanagedBuffer = default;
inSecurityBufferDescriptor.pBuffers = &inUnmanagedBuffer;
fixed (byte* pinnedToken = inSecBuffer.token)
fixed (byte* pinnedToken = inSecBuffer.Token)
{
inUnmanagedBuffer.cbBuffer = inSecBuffer.size;
inUnmanagedBuffer.BufferType = inSecBuffer.type;

// Use the unmanaged token if it's not null; otherwise use the managed buffer.
Debug.Assert(inSecBuffer.UnmanagedToken != null);
inUnmanagedBuffer.cbBuffer = inSecBuffer.Token.Length;
inUnmanagedBuffer.BufferType = inSecBuffer.Type;
inUnmanagedBuffer.pvBuffer =
inSecBuffer.unmanagedToken != null ? inSecBuffer.unmanagedToken.DangerousGetHandle() :
inSecBuffer.token == null || inSecBuffer.token.Length == 0 ? IntPtr.Zero :
(IntPtr)(pinnedToken + inSecBuffer.offset);
inSecBuffer.Token.IsEmpty ? IntPtr.Zero : (IntPtr)pinnedToken;

Interop.SspiCli.CredHandle contextHandle = refContext != null ? refContext._handle : default;
if (refContext == null || refContext.IsInvalid)
Expand Down
33 changes: 21 additions & 12 deletions src/libraries/Common/src/System/Net/NTAuthentication.Common.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ internal sealed partial class NTAuthentication
private string? _spn;

private int _tokenSize;
private byte[]? _tokenBuffer;
private ContextFlagsPal _requestedContextFlags;
private ContextFlagsPal _contextFlags;

Expand Down Expand Up @@ -147,14 +148,14 @@ internal void CloseContext()
_isCompleted = false;
}

internal int VerifySignature(byte[] buffer, int offset, int count)
internal int VerifySignature(ReadOnlySpan<byte> buffer)
{
return NegotiateStreamPal.VerifySignature(_securityContext!, buffer, offset, count);
return NegotiateStreamPal.VerifySignature(_securityContext!, buffer);
}

internal int MakeSignature(byte[] buffer, int offset, int count, [AllowNull] ref byte[] output)
internal int MakeSignature(ReadOnlySpan<byte> buffer, [AllowNull] ref byte[] output)
{
return NegotiateStreamPal.MakeSignature(_securityContext!, buffer, offset, count, ref output);
return NegotiateStreamPal.MakeSignature(_securityContext!, buffer, ref output);
}

internal string? GetOutgoingBlob(string? incomingBlob)
Expand Down Expand Up @@ -210,9 +211,10 @@ internal int MakeSignature(byte[] buffer, int offset, int count, [AllowNull] ref

internal byte[]? GetOutgoingBlob(ReadOnlySpan<byte> incomingBlob, bool throwOnError, out SecurityStatusPal statusCode)
{
byte[]? result = new byte[_tokenSize];
_tokenBuffer ??= _tokenSize == 0 ? Array.Empty<byte>() : new byte[_tokenSize];

bool firstTime = _securityContext == null;
int resultBlobLength;
try
{
if (!_isServer)
Expand All @@ -225,18 +227,19 @@ internal int MakeSignature(byte[] buffer, int offset, int count, [AllowNull] ref
_requestedContextFlags,
incomingBlob,
_channelBinding,
ref result,
ref _tokenBuffer,
out resultBlobLength,
ref _contextFlags);

if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, $"SSPIWrapper.InitializeSecurityContext() returns statusCode:0x{((int)statusCode.ErrorCode):x8} ({statusCode})");

if (statusCode.ErrorCode == SecurityStatusPalErrorCode.CompleteNeeded)
{
statusCode = NegotiateStreamPal.CompleteAuthToken(ref _securityContext, result);
statusCode = NegotiateStreamPal.CompleteAuthToken(ref _securityContext, _tokenBuffer.AsSpan(0, resultBlobLength));

if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, $"SSPIWrapper.CompleteAuthToken() returns statusCode:0x{((int)statusCode.ErrorCode):x8} ({statusCode})");

result = null;
resultBlobLength = 0;
}
}
else
Expand All @@ -248,7 +251,8 @@ internal int MakeSignature(byte[] buffer, int offset, int count, [AllowNull] ref
_requestedContextFlags,
incomingBlob,
_channelBinding,
ref result,
ref _tokenBuffer,
out resultBlobLength,
ref _contextFlags);

if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, $"SSPIWrapper.AcceptSecurityContext() returns statusCode:0x{((int)statusCode.ErrorCode):x8} ({statusCode})");
Expand All @@ -273,6 +277,7 @@ internal int MakeSignature(byte[] buffer, int offset, int count, [AllowNull] ref
{
CloseContext();
_isCompleted = true;
_tokenBuffer = null;
if (throwOnError)
{
throw NegotiateStreamPal.CreateExceptionFromError(statusCode);
Expand All @@ -286,12 +291,18 @@ internal int MakeSignature(byte[] buffer, int offset, int count, [AllowNull] ref
SSPIHandleCache.CacheCredential(_credentialsHandle);
}

byte[]? result =
resultBlobLength == 0 || _tokenBuffer == null ? null :
_tokenBuffer.Length == resultBlobLength ? _tokenBuffer :
_tokenBuffer[0..resultBlobLength];

// The return value will tell us correctly if the handshake is over or not
if (statusCode.ErrorCode == SecurityStatusPalErrorCode.OK
|| (_isServer && statusCode.ErrorCode == SecurityStatusPalErrorCode.CompleteNeeded))
{
// Success.
_isCompleted = true;
_tokenBuffer = null;
}
else
{
Expand Down Expand Up @@ -324,13 +335,11 @@ internal int Encrypt(ReadOnlySpan<byte> buffer, [NotNull] ref byte[]? output, ui
sequenceNumber);
}

internal int Decrypt(byte[] payload, int offset, int count, out int newOffset, uint expectedSeqNumber)
internal int Decrypt(Span<byte> payload, out int newOffset, uint expectedSeqNumber)
{
return NegotiateStreamPal.Decrypt(
_securityContext!,
payload,
offset,
count,
(_contextFlags & ContextFlagsPal.Confidentiality) != 0,
IsNTLM,
out newOffset,
Expand Down
Loading

0 comments on commit 1e8eaef

Please sign in to comment.