Skip to content

Commit

Permalink
Remove a bunch of X509 PAL closures from FindCore
Browse files Browse the repository at this point in the history
  • Loading branch information
stephentoub authored Apr 1, 2021
1 parent 00d4dd8 commit 6617892
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,14 @@ public string NormalizeOid(string maybeOid, OidGroup expectedGroup)

public void FindByThumbprint(byte[] thumbprint)
{
FindCore(cert => cert.GetCertHash().ContentsEqual(thumbprint));
FindCore(thumbprint, static (thumbprint, cert) => cert.GetCertHash().ContentsEqual(thumbprint));
}

public void FindBySubjectName(string subjectName)
{
FindCore(
cert =>
subjectName,
static (subjectName, cert) =>
{
string formedSubject = X500NameEncoder.X500DistinguishedNameDecode(cert.SubjectName.RawData, false, X500DistinguishedNameFlags.None);

Expand All @@ -57,13 +58,14 @@ public void FindBySubjectName(string subjectName)

public void FindBySubjectDistinguishedName(string subjectDistinguishedName)
{
FindCore(cert => StringComparer.OrdinalIgnoreCase.Equals(subjectDistinguishedName, cert.Subject));
FindCore(subjectDistinguishedName, static (subjectDistinguishedName, cert) => StringComparer.OrdinalIgnoreCase.Equals(subjectDistinguishedName, cert.Subject));
}

public void FindByIssuerName(string issuerName)
{
FindCore(
cert =>
issuerName,
static (issuerName, cert) =>
{
string formedIssuer = X500NameEncoder.X500DistinguishedNameDecode(cert.IssuerName.RawData, false, X500DistinguishedNameFlags.None);

Expand All @@ -73,17 +75,18 @@ public void FindByIssuerName(string issuerName)

public void FindByIssuerDistinguishedName(string issuerDistinguishedName)
{
FindCore(cert => StringComparer.OrdinalIgnoreCase.Equals(issuerDistinguishedName, cert.Issuer));
FindCore(issuerDistinguishedName, static (issuerDistinguishedName, cert) => StringComparer.OrdinalIgnoreCase.Equals(issuerDistinguishedName, cert.Issuer));
}

public void FindBySerialNumber(BigInteger hexValue, BigInteger decimalValue)
{
FindCore(
cert =>
(hexValue, decimalValue),
static (state, cert) =>
{
byte[] serialBytes = cert.GetSerialNumber();
BigInteger serialNumber = FindPal.PositiveBigIntegerFromByteArray(serialBytes);
bool match = hexValue.Equals(serialNumber) || decimalValue.Equals(serialNumber);
bool match = state.hexValue.Equals(serialNumber) || state.decimalValue.Equals(serialNumber);

return match;
});
Expand All @@ -107,27 +110,28 @@ public void FindByTimeValid(DateTime dateTime)
{
DateTime normalized = NormalizeDateTime(dateTime);

FindCore(cert => cert.NotBefore <= normalized && normalized <= cert.NotAfter);
FindCore(normalized, static (normalized, cert) => cert.NotBefore <= normalized && normalized <= cert.NotAfter);
}

public void FindByTimeNotYetValid(DateTime dateTime)
{
DateTime normalized = NormalizeDateTime(dateTime);

FindCore(cert => cert.NotBefore > normalized);
FindCore(normalized, static (normalized, cert) => cert.NotBefore > normalized);
}

public void FindByTimeExpired(DateTime dateTime)
{
DateTime normalized = NormalizeDateTime(dateTime);

FindCore(cert => cert.NotAfter < normalized);
FindCore(normalized, static (normalized, cert) => cert.NotAfter < normalized);
}

public void FindByTemplateName(string templateName)
{
FindCore(
cert =>
templateName,
static (templateName, cert) =>
{
X509Extension? ext = FindExtension(cert, Oids.EnrollCertTypeExtension);

Expand Down Expand Up @@ -172,7 +176,8 @@ public void FindByTemplateName(string templateName)
public void FindByApplicationPolicy(string oidValue)
{
FindCore(
cert =>
oidValue,
static (oidValue, cert) =>
{
X509Extension? ext = FindExtension(cert, Oids.EnhancedKeyUsage);

Expand Down Expand Up @@ -201,7 +206,8 @@ public void FindByApplicationPolicy(string oidValue)
public void FindByCertificatePolicy(string oidValue)
{
FindCore(
cert =>
oidValue,
static (oidValue, cert) =>
{
X509Extension? ext = FindExtension(cert, Oids.CertPolicies);

Expand All @@ -218,13 +224,14 @@ public void FindByCertificatePolicy(string oidValue)

public void FindByExtension(string oidValue)
{
FindCore(cert => FindExtension(cert, oidValue) != null);
FindCore(oidValue, static (oidValue, cert) => FindExtension(cert, oidValue) != null);
}

public void FindByKeyUsage(X509KeyUsageFlags keyUsage)
{
FindCore(
cert =>
keyUsage,
static (keyUsage, cert) =>
{
X509Extension? ext = FindExtension(cert, Oids.KeyUsage);

Expand All @@ -246,7 +253,8 @@ public void FindByKeyUsage(X509KeyUsageFlags keyUsage)
public void FindBySubjectKeyIdentifier(byte[] keyIdentifier)
{
FindCore(
cert =>
keyIdentifier,
(keyIdentifier, cert) =>
{
X509Extension? ext = FindExtension(cert, Oids.SubjectKeyIdentifier);
byte[] certKeyId;
Expand Down Expand Up @@ -305,11 +313,14 @@ protected virtual void Dispose(bool disposing)

protected abstract X509Certificate2 CloneCertificate(X509Certificate2 cert);

private void FindCore(Predicate<X509Certificate2> predicate)
private void FindCore<TState>(TState state, Func<TState, X509Certificate2, bool> predicate)
{
foreach (X509Certificate2 cert in _findFrom)
X509Certificate2Collection findFrom = _findFrom;
int count = findFrom.Count;
for (int i = 0; i < count; i++)
{
if (predicate(cert))
X509Certificate2 cert = findFrom[i];
if (predicate(state, cert))
{
if (!_validOnly || IsCertValid(cert))
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,22 +50,23 @@ public unsafe void FindByThumbprint(byte[] thumbPrint)
fixed (byte* pThumbPrint = thumbPrint)
{
CRYPTOAPI_BLOB blob = new CRYPTOAPI_BLOB(thumbPrint.Length, pThumbPrint);
FindCore(CertFindType.CERT_FIND_HASH, &blob);
FindCore<object>(CertFindType.CERT_FIND_HASH, &blob);
}
}

public unsafe void FindBySubjectName(string subjectName)
{
fixed (char* pSubjectName = subjectName)
{
FindCore(CertFindType.CERT_FIND_SUBJECT_STR, pSubjectName);
FindCore<object>(CertFindType.CERT_FIND_SUBJECT_STR, pSubjectName);
}
}

public void FindBySubjectDistinguishedName(string subjectDistinguishedName)
{
FindCore(
delegate (SafeCertContextHandle pCertContext)
subjectDistinguishedName,
static (subjectDistinguishedName, pCertContext) =>
{
string actual = GetCertNameInfo(pCertContext, CertNameType.CERT_NAME_RDN_TYPE, CertNameFlags.None);
return subjectDistinguishedName.Equals(actual, StringComparison.OrdinalIgnoreCase);
Expand All @@ -76,14 +77,15 @@ public unsafe void FindByIssuerName(string issuerName)
{
fixed (char* pIssuerName = issuerName)
{
FindCore(CertFindType.CERT_FIND_ISSUER_STR, pIssuerName);
FindCore<object>(CertFindType.CERT_FIND_ISSUER_STR, pIssuerName);
}
}

public void FindByIssuerDistinguishedName(string issuerDistinguishedName)
{
FindCore(
delegate (SafeCertContextHandle pCertContext)
issuerDistinguishedName,
static (issuerDistinguishedName, pCertContext) =>
{
string actual = GetCertNameInfo(pCertContext, CertNameType.CERT_NAME_RDN_TYPE, CertNameFlags.CERT_NAME_ISSUER_FLAG);
return issuerDistinguishedName.Equals(actual, StringComparison.OrdinalIgnoreCase);
Expand All @@ -93,15 +95,16 @@ public void FindByIssuerDistinguishedName(string issuerDistinguishedName)
public unsafe void FindBySerialNumber(BigInteger hexValue, BigInteger decimalValue)
{
FindCore(
delegate (SafeCertContextHandle pCertContext)
(hexValue, decimalValue),
static (state, pCertContext) =>
{
byte[] actual = pCertContext.CertContext->pCertInfo->SerialNumber.ToByteArray();
GC.KeepAlive(pCertContext);

// Convert to BigInteger as the comparison must not fail due to spurious leading zeros
BigInteger actualAsBigInteger = PositiveBigIntegerFromByteArray(actual);

return hexValue.Equals(actualAsBigInteger) || decimalValue.Equals(actualAsBigInteger);
return state.hexValue.Equals(actualAsBigInteger) || state.decimalValue.Equals(actualAsBigInteger);
});
}

Expand All @@ -125,19 +128,21 @@ private unsafe void FindByTime(DateTime dateTime, int compareResult)
FILETIME fileTime = FILETIME.FromDateTime(dateTime);

FindCore(
delegate (SafeCertContextHandle pCertContext)
(fileTime, compareResult),
static (state, pCertContext) =>
{
int comparison = Interop.crypt32.CertVerifyTimeValidity(ref fileTime,
int comparison = Interop.crypt32.CertVerifyTimeValidity(ref state.fileTime,
pCertContext.CertContext->pCertInfo);
GC.KeepAlive(pCertContext);
return comparison == compareResult;
return comparison == state.compareResult;
});
}

public unsafe void FindByTemplateName(string templateName)
{
FindCore(
delegate (SafeCertContextHandle pCertContext)
templateName,
static (templateName, pCertContext) =>
{
// The template name can have 2 different formats: V1 format (<= Win2K) is just a string
// V2 format (XP only) can be a friendly name or an OID.
Expand Down Expand Up @@ -203,7 +208,8 @@ public unsafe void FindByTemplateName(string templateName)
public unsafe void FindByApplicationPolicy(string oidValue)
{
FindCore(
delegate (SafeCertContextHandle pCertContext)
oidValue,
static (oidValue, pCertContext) =>
{
int numOids;
int cbData = 0;
Expand Down Expand Up @@ -234,7 +240,8 @@ public unsafe void FindByApplicationPolicy(string oidValue)
public unsafe void FindByCertificatePolicy(string oidValue)
{
FindCore(
delegate (SafeCertContextHandle pCertContext)
oidValue,
static (oidValue, pCertContext) =>
{
CERT_INFO* pCertInfo = pCertContext.CertContext->pCertInfo;
CERT_EXTENSION* pCertExtension = Interop.crypt32.CertFindExtension(Oids.CertPolicies,
Expand Down Expand Up @@ -274,7 +281,8 @@ public unsafe void FindByCertificatePolicy(string oidValue)
public unsafe void FindByExtension(string oidValue)
{
FindCore(
delegate (SafeCertContextHandle pCertContext)
oidValue,
static (oidValue, pCertContext) =>
{
CERT_INFO* pCertInfo = pCertContext.CertContext->pCertInfo;
CERT_EXTENSION* pCertExtension = Interop.crypt32.CertFindExtension(oidValue, pCertInfo->cExtension, pCertInfo->rgExtension);
Expand All @@ -286,7 +294,8 @@ public unsafe void FindByExtension(string oidValue)
public unsafe void FindByKeyUsage(X509KeyUsageFlags keyUsage)
{
FindCore(
delegate (SafeCertContextHandle pCertContext)
keyUsage,
static (keyUsage, pCertContext) =>
{
CERT_INFO* pCertInfo = pCertContext.CertContext->pCertInfo;
X509KeyUsageFlags actual;
Expand All @@ -300,7 +309,8 @@ public unsafe void FindByKeyUsage(X509KeyUsageFlags keyUsage)
public void FindBySubjectKeyIdentifier(byte[] keyIdentifier)
{
FindCore(
delegate (SafeCertContextHandle pCertContext)
keyIdentifier,
static (keyIdentifier, pCertContext) =>
{
int cbData = 0;
if (!Interop.crypt32.CertGetCertificateContextProperty(pCertContext, CertContextPropId.CERT_KEY_IDENTIFIER_PROP_ID, null, ref cbData))
Expand All @@ -319,12 +329,12 @@ public void Dispose()
_storePal.Dispose();
}

private unsafe void FindCore(Func<SafeCertContextHandle, bool> filter)
private unsafe void FindCore<TState>(TState state, Func<TState, SafeCertContextHandle, bool> filter)
{
FindCore(CertFindType.CERT_FIND_ANY, null, filter);
FindCore(CertFindType.CERT_FIND_ANY, null, state, filter);
}

private unsafe void FindCore(CertFindType dwFindType, void* pvFindPara, Func<SafeCertContextHandle, bool>? filter = null)
private unsafe void FindCore<TState>(CertFindType dwFindType, void* pvFindPara, TState state = default!, Func<TState, SafeCertContextHandle, bool>? filter = null)
{
SafeCertStoreHandle findResults = Interop.crypt32.CertOpenStore(
CertStoreProvider.CERT_STORE_PROV_MEMORY,
Expand All @@ -338,7 +348,7 @@ private unsafe void FindCore(CertFindType dwFindType, void* pvFindPara, Func<Saf
SafeCertContextHandle? pCertContext = null;
while (Interop.crypt32.CertFindCertificateInStore(_storePal.SafeCertStoreHandle, dwFindType, pvFindPara, ref pCertContext))
{
if (filter != null && !filter(pCertContext))
if (filter != null && !filter(state, pCertContext))
continue;

if (_validOnly)
Expand Down

0 comments on commit 6617892

Please sign in to comment.