Skip to content

Commit

Permalink
Removing networking bits from CASESession ParseSigma1 and creating En…
Browse files Browse the repository at this point in the history
…codeSigma1
  • Loading branch information
Alami-Amine committed Nov 29, 2024
1 parent d37eae1 commit 13a8f48
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 107 deletions.
198 changes: 109 additions & 89 deletions src/protocols/secure_channel/CASESession.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
#include <protocols/secure_channel/SessionResumptionStorage.h>
#include <protocols/secure_channel/StatusReport.h>
#include <system/SystemClock.h>
#include <system/TLVPacketBufferBackingStore.h>
#include <tracing/macros.h>
#include <tracing/metric_event.h>
#include <transport/SessionManager.h>
Expand All @@ -68,16 +67,13 @@ enum
kTag_TBSData_ReceiverPubKey = 4,
};

enum
{
kTag_Sigma1_InitiatorRandom = 1,
kTag_Sigma1_InitiatorSessionId = 2,
kTag_Sigma1_DestinationId = 3,
kTag_Sigma1_InitiatorEphPubKey = 4,
kTag_Sigma1_InitiatorMRPParams = 5,
kTag_Sigma1_ResumptionID = 6,
kTag_Sigma1_InitiatorResumeMIC = 7,
};
inline constexpr uint8_t kInitiatorRandomTag = 1;
inline constexpr uint8_t kInitiatorSessionIdTag = 2;
inline constexpr uint8_t kDestinationIdTag = 3;
inline constexpr uint8_t kInitiatorPubKeyTag = 4;
inline constexpr uint8_t kInitiatorMRPParamsTag = 5;
inline constexpr uint8_t kResumptionIDTag = 6;
inline constexpr uint8_t kResume1MICTag = 7;

enum
{
Expand Down Expand Up @@ -770,24 +766,19 @@ void CASESession::HandleConnectionClosed(Transport::ActiveTCPConnectionState * c
CHIP_ERROR CASESession::SendSigma1()
{
MATTER_TRACE_SCOPE("SendSigma1", "CASESession");
size_t data_len = TLV::EstimateStructOverhead(kSigmaParamRandomNumberSize, // initiatorRandom
sizeof(uint16_t), // initiatorSessionId,
kSHA256_Hash_Length, // destinationId
kP256_PublicKey_Length, // InitiatorEphPubKey,
SessionParameters::kEstimatedTLVSize, // initiatorSessionParams
SessionResumptionStorage::kResumptionIdSize, CHIP_CRYPTO_AEAD_MIC_LENGTH_BYTES);

System::PacketBufferTLVWriter tlvWriter;
System::PacketBufferHandle msg_R1;
TLV::TLVType outerContainerType = TLV::kTLVType_NotSpecified;
uint8_t destinationIdentifier[kSHA256_Hash_Length] = { 0 };

Sigma1Param encodeSigma1Params;

// Lookup fabric info.
const auto * fabricInfo = mFabricsTable->FindFabricWithIndex(mFabricIndex);
VerifyOrReturnError(fabricInfo != nullptr, CHIP_ERROR_INCORRECT_STATE);

// Validate that we have a session ID allocated.
VerifyOrReturnError(GetLocalSessionId().HasValue(), CHIP_ERROR_INCORRECT_STATE);
encodeSigma1Params.initiatorSessionId = GetLocalSessionId().Value();

// Generate an ephemeral keypair
mEphemeralKey = mFabricsTable->AllocateEphemeralKeypairForCASE();
Expand All @@ -797,16 +788,6 @@ CHIP_ERROR CASESession::SendSigma1()
// Fill in the random value
ReturnErrorOnFailure(DRBG_get_bytes(mInitiatorRandom, sizeof(mInitiatorRandom)));

// Construct Sigma1 Msg
msg_R1 = System::PacketBufferHandle::New(data_len);
VerifyOrReturnError(!msg_R1.IsNull(), CHIP_ERROR_NO_MEMORY);

tlvWriter.Init(std::move(msg_R1));
ReturnErrorOnFailure(tlvWriter.StartContainer(TLV::AnonymousTag(), TLV::kTLVType_Structure, outerContainerType));
ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(1), ByteSpan(mInitiatorRandom)));
// Retrieve Session Identifier
ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(2), GetLocalSessionId().Value()));

// Generate a Destination Identifier based on the node we are attempting to reach
{
// Obtain originator IPK matching the fabric where we are trying to open a session. mIPK
Expand All @@ -821,14 +802,10 @@ CHIP_ERROR CASESession::SendSigma1()
MutableByteSpan destinationIdSpan(destinationIdentifier);
ReturnErrorOnFailure(GenerateCaseDestinationId(ByteSpan(mIPK), ByteSpan(mInitiatorRandom), rootPubKeySpan, fabricId,
mPeerNodeId, destinationIdSpan));
encodeSigma1Params.destinationId = destinationIdSpan;
}
ReturnErrorOnFailure(tlvWriter.PutBytes(TLV::ContextTag(3), destinationIdentifier, sizeof(destinationIdentifier)));

ReturnErrorOnFailure(
tlvWriter.PutBytes(TLV::ContextTag(4), mEphemeralKey->Pubkey(), static_cast<uint32_t>(mEphemeralKey->Pubkey().Length())));

VerifyOrReturnError(mLocalMRPConfig.HasValue(), CHIP_ERROR_INCORRECT_STATE);
ReturnErrorOnFailure(EncodeSessionParameters(TLV::ContextTag(5), mLocalMRPConfig.Value(), tlvWriter));

// Try to find persistent session, and resume it.
bool resuming = false;
Expand All @@ -839,20 +816,20 @@ CHIP_ERROR CASESession::SendSigma1()
if (err == CHIP_NO_ERROR)
{
// Found valid resumption state, try to resume the session.
ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(6), mResumeResumptionId));

uint8_t initiatorResume1MIC[CHIP_CRYPTO_AEAD_MIC_LENGTH_BYTES];
MutableByteSpan resumeMICSpan(initiatorResume1MIC);
MutableByteSpan resumeMICSpan(encodeSigma1Params.initiatorResume1MIC);
ReturnErrorOnFailure(GenerateSigmaResumeMIC(ByteSpan(mInitiatorRandom), ByteSpan(mResumeResumptionId),
ByteSpan(kKDFS1RKeyInfo), ByteSpan(kResume1MIC_Nonce), resumeMICSpan));

ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(7), resumeMICSpan));
encodeSigma1Params.initiatorResumeMICSpan = resumeMICSpan;
encodeSigma1Params.sessionResumptionRequested = true;

resuming = true;
}
}

ReturnErrorOnFailure(tlvWriter.EndContainer(outerContainerType));
ReturnErrorOnFailure(tlvWriter.Finalize(&msg_R1));
// Encode Sigma1 into into msg_R1
ReturnErrorOnFailure(EncodeSigma1(msg_R1, encodeSigma1Params));

ReturnErrorOnFailure(mCommissioningHash.AddData(ByteSpan{ msg_R1->Start(), msg_R1->DataLength() }));

Expand Down Expand Up @@ -884,6 +861,52 @@ CHIP_ERROR CASESession::SendSigma1()
return CHIP_NO_ERROR;
}

CHIP_ERROR CASESession::EncodeSigma1(System::PacketBufferHandle & msg, Sigma1Param & inputParams)
{

MATTER_TRACE_SCOPE("EncodeSigma1", "CASESession");

size_t data_len = TLV::EstimateStructOverhead(kSigmaParamRandomNumberSize, // initiatorRandom
sizeof(uint16_t), // initiatorSessionId,
kSHA256_Hash_Length, // destinationId
kP256_PublicKey_Length, // InitiatorEphPubKey,
SessionParameters::kEstimatedTLVSize, // initiatorSessionParams
SessionResumptionStorage::kResumptionIdSize, CHIP_CRYPTO_AEAD_MIC_LENGTH_BYTES);

msg = System::PacketBufferHandle::New(data_len);
VerifyOrReturnError(!msg.IsNull(), CHIP_ERROR_NO_MEMORY);

System::PacketBufferTLVWriter tlvWriter;
TLV::TLVType outerContainerType = TLV::kTLVType_NotSpecified;

tlvWriter.Init(std::move(msg));
ReturnErrorOnFailure(tlvWriter.StartContainer(TLV::AnonymousTag(), TLV::kTLVType_Structure, outerContainerType));
// TODO Pass this in the struct?
ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(kInitiatorRandomTag), ByteSpan(mInitiatorRandom)));
ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(kInitiatorSessionIdTag), inputParams.initiatorSessionId));

ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(kDestinationIdTag), inputParams.destinationId));

// TODO Pass this in the struct?
ReturnErrorOnFailure(tlvWriter.PutBytes(TLV::ContextTag(kInitiatorPubKeyTag), mEphemeralKey->Pubkey(),
static_cast<uint32_t>(mEphemeralKey->Pubkey().Length())));

// TODO is it redudunt?
VerifyOrReturnError(mLocalMRPConfig.HasValue(), CHIP_ERROR_INCORRECT_STATE);
ReturnErrorOnFailure(EncodeSessionParameters(TLV::ContextTag(kInitiatorMRPParamsTag), mLocalMRPConfig.Value(), tlvWriter));

if (inputParams.sessionResumptionRequested)
{
ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(kResumptionIDTag), mResumeResumptionId));
ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(kResume1MICTag), inputParams.initiatorResumeMICSpan));
}

ReturnErrorOnFailure(tlvWriter.EndContainer(outerContainerType));
ReturnErrorOnFailure(tlvWriter.Finalize(&msg));

return CHIP_NO_ERROR;
}

CHIP_ERROR CASESession::HandleSigma1_and_SendSigma2(System::PacketBufferHandle && msg)
{
MATTER_TRACE_SCOPE("HandleSigma1_and_SendSigma2", "CASESession");
Expand Down Expand Up @@ -923,7 +946,7 @@ CHIP_ERROR CASESession::FindLocalNodeFromDestinationId(const ByteSpan & destinat
MutableByteSpan candidateDestinationIdSpan(candidateDestinationId);
ByteSpan candidateIpkSpan(ipkKeySet.epoch_keys[keyIdx].key);

err = GenerateCaseDestinationId(ByteSpan(candidateIpkSpan), ByteSpan(initiatorRandom), rootPubKeySpan, fabricId, nodeId,
err = GenerateCaseDestinationId(candidateIpkSpan, initiatorRandom, rootPubKeySpan, fabricId, nodeId,
candidateDestinationIdSpan);
if ((err == CHIP_NO_ERROR) && (candidateDestinationIdSpan.data_equal(destinationId)))
{
Expand Down Expand Up @@ -974,38 +997,43 @@ CHIP_ERROR CASESession::TryResumeSession(SessionResumptionStorage::ConstResumpti
CHIP_ERROR CASESession::HandleSigma1(System::PacketBufferHandle && msg)
{
MATTER_TRACE_SCOPE("HandleSigma1", "CASESession");
CHIP_ERROR err = CHIP_NO_ERROR;
System::PacketBufferTLVReader tlvReader;

uint16_t initiatorSessionId;
ByteSpan destinationIdentifier;
ByteSpan initiatorRandom;

ChipLogProgress(SecureChannel, "Received Sigma1 msg");
MATTER_TRACE_COUNTER("Sigma1");

bool sessionResumptionRequested = false;
ByteSpan resumptionId;
ByteSpan resume1MIC;
ByteSpan initiatorPubKey;
CHIP_ERROR err = CHIP_NO_ERROR;
System::PacketBufferTLVReader tlvReader;

Sigma1Param parsedSigma1;

SuccessOrExit(err = mCommissioningHash.AddData(ByteSpan{ msg->Start(), msg->DataLength() }));

tlvReader.Init(std::move(msg));
SuccessOrExit(err = ParseSigma1(tlvReader, initiatorRandom, initiatorSessionId, destinationIdentifier, initiatorPubKey,
sessionResumptionRequested, resumptionId, resume1MIC));

ChipLogDetail(SecureChannel, "Peer assigned session key ID %d", initiatorSessionId);
SetPeerSessionId(initiatorSessionId);
SuccessOrExit(err = ParseSigma1(tlvReader, parsedSigma1));

ChipLogDetail(SecureChannel, "Peer assigned session key ID %d", parsedSigma1.initiatorSessionId);
SetPeerSessionId(parsedSigma1.initiatorSessionId);

VerifyOrExit(mFabricsTable != nullptr, err = CHIP_ERROR_INCORRECT_STATE);

if (sessionResumptionRequested && resumptionId.size() == SessionResumptionStorage::kResumptionIdSize &&
// TODO: Added by Amine, taken from inside ParseSigma1
// This was removed to remove the non-parsing parts from ParseSigma1, decoupling it from higher levels
// TODO: Should i change it?
// Set the recieved MRP parameters included with Sigma1
if (parsedSigma1.InitiatorMRPParamsPresent == true)
{
mExchangeCtxt.Value()->GetSessionHandle()->AsUnauthenticatedSession()->SetRemoteSessionParameters(
GetRemoteSessionParameters());
}

if (parsedSigma1.sessionResumptionRequested &&
parsedSigma1.resumptionId.size() == SessionResumptionStorage::kResumptionIdSize &&
CHIP_NO_ERROR ==
TryResumeSession(SessionResumptionStorage::ConstResumptionIdView(resumptionId.data()), resume1MIC, initiatorRandom))
TryResumeSession(SessionResumptionStorage::ConstResumptionIdView(parsedSigma1.resumptionId.data()),
parsedSigma1.initiatorResumeMICSpan, parsedSigma1.initiatorRandom))
{
std::copy(initiatorRandom.begin(), initiatorRandom.end(), mInitiatorRandom);
std::copy(resumptionId.begin(), resumptionId.end(), mResumeResumptionId.begin());
std::copy(parsedSigma1.initiatorRandom.begin(), parsedSigma1.initiatorRandom.end(), mInitiatorRandom);
std::copy(parsedSigma1.resumptionId.begin(), parsedSigma1.resumptionId.end(), mResumeResumptionId.begin());

// Send Sigma2Resume message to the initiator
MATTER_LOG_METRIC_BEGIN(kMetricDeviceCASESessionSigma2Resume);
Expand All @@ -1023,7 +1051,7 @@ CHIP_ERROR CASESession::HandleSigma1(System::PacketBufferHandle && msg)
}

// Attempt to match the initiator's desired destination based on local fabric table.
err = FindLocalNodeFromDestinationId(destinationIdentifier, initiatorRandom);
err = FindLocalNodeFromDestinationId(parsedSigma1.destinationId, parsedSigma1.initiatorRandom);
if (err == CHIP_NO_ERROR)
{
ChipLogProgress(SecureChannel, "CASE matched destination ID: fabricIndex %u, NodeID 0x" ChipLogFormatX64,
Expand All @@ -1035,13 +1063,13 @@ CHIP_ERROR CASESession::HandleSigma1(System::PacketBufferHandle && msg)
else
{
ChipLogError(SecureChannel, "CASE failed to match destination ID with local fabrics");
ChipLogByteSpan(SecureChannel, destinationIdentifier);
ChipLogByteSpan(SecureChannel, parsedSigma1.destinationId);
}
SuccessOrExit(err);

// ParseSigma1 ensures that:
// mRemotePubKey.Length() == initiatorPubKey.size() == kP256_PublicKey_Length.
memcpy(mRemotePubKey.Bytes(), initiatorPubKey.data(), mRemotePubKey.Length());
memcpy(mRemotePubKey.Bytes(), parsedSigma1.initiatorEphPubKey.data(), mRemotePubKey.Length());

MATTER_LOG_METRIC_BEGIN(kMetricDeviceCASESessionSigma2);
err = SendSigma2();
Expand Down Expand Up @@ -2163,46 +2191,36 @@ CHIP_ERROR CASESession::OnFailureStatusReport(Protocols::SecureChannel::GeneralS
return err;
}

CHIP_ERROR CASESession::ParseSigma1(TLV::ContiguousBufferTLVReader & tlvReader, ByteSpan & initiatorRandom,
uint16_t & initiatorSessionId, ByteSpan & destinationId, ByteSpan & initiatorEphPubKey,
bool & resumptionRequested, ByteSpan & resumptionId, ByteSpan & initiatorResumeMIC)
CHIP_ERROR CASESession::ParseSigma1(TLV::ContiguousBufferTLVReader & tlvReader, Sigma1Param & output)
{
using namespace TLV;

constexpr uint8_t kInitiatorRandomTag = 1;
constexpr uint8_t kInitiatorSessionIdTag = 2;
constexpr uint8_t kDestinationIdTag = 3;
constexpr uint8_t kInitiatorPubKeyTag = 4;
constexpr uint8_t kInitiatorMRPParamsTag = 5;
constexpr uint8_t kResumptionIDTag = 6;
constexpr uint8_t kResume1MICTag = 7;

TLVType containerType = kTLVType_Structure;
ReturnErrorOnFailure(tlvReader.Next(containerType, AnonymousTag()));
ReturnErrorOnFailure(tlvReader.EnterContainer(containerType));

ReturnErrorOnFailure(tlvReader.Next(ContextTag(kInitiatorRandomTag)));
ReturnErrorOnFailure(tlvReader.GetByteView(initiatorRandom));
VerifyOrReturnError(initiatorRandom.size() == kSigmaParamRandomNumberSize, CHIP_ERROR_INVALID_CASE_PARAMETER);
ReturnErrorOnFailure(tlvReader.GetByteView(output.initiatorRandom));
VerifyOrReturnError(output.initiatorRandom.size() == kSigmaParamRandomNumberSize, CHIP_ERROR_INVALID_CASE_PARAMETER);

ReturnErrorOnFailure(tlvReader.Next(ContextTag(kInitiatorSessionIdTag)));
ReturnErrorOnFailure(tlvReader.Get(initiatorSessionId));
ReturnErrorOnFailure(tlvReader.Get(output.initiatorSessionId));

ReturnErrorOnFailure(tlvReader.Next(ContextTag(kDestinationIdTag)));
ReturnErrorOnFailure(tlvReader.GetByteView(destinationId));
VerifyOrReturnError(destinationId.size() == kSHA256_Hash_Length, CHIP_ERROR_INVALID_CASE_PARAMETER);
ReturnErrorOnFailure(tlvReader.GetByteView(output.destinationId));
VerifyOrReturnError(output.destinationId.size() == kSHA256_Hash_Length, CHIP_ERROR_INVALID_CASE_PARAMETER);

ReturnErrorOnFailure(tlvReader.Next(ContextTag(kInitiatorPubKeyTag)));
ReturnErrorOnFailure(tlvReader.GetByteView(initiatorEphPubKey));
VerifyOrReturnError(initiatorEphPubKey.size() == kP256_PublicKey_Length, CHIP_ERROR_INVALID_CASE_PARAMETER);
ReturnErrorOnFailure(tlvReader.GetByteView(output.initiatorEphPubKey));
VerifyOrReturnError(output.initiatorEphPubKey.size() == kP256_PublicKey_Length, CHIP_ERROR_INVALID_CASE_PARAMETER);

// Optional members start here.
CHIP_ERROR err = tlvReader.Next();
if (err == CHIP_NO_ERROR && tlvReader.GetTag() == ContextTag(kInitiatorMRPParamsTag))
{
ReturnErrorOnFailure(DecodeMRPParametersIfPresent(TLV::ContextTag(kInitiatorMRPParamsTag), tlvReader));
mExchangeCtxt.Value()->GetSessionHandle()->AsUnauthenticatedSession()->SetRemoteSessionParameters(
GetRemoteSessionParameters());
output.InitiatorMRPParamsPresent = true;

err = tlvReader.Next();
}

Expand All @@ -2212,16 +2230,18 @@ CHIP_ERROR CASESession::ParseSigma1(TLV::ContiguousBufferTLVReader & tlvReader,
if (err == CHIP_NO_ERROR && tlvReader.GetTag() == ContextTag(kResumptionIDTag))
{
resumptionIDTagFound = true;
ReturnErrorOnFailure(tlvReader.GetByteView(resumptionId));
VerifyOrReturnError(resumptionId.size() == SessionResumptionStorage::kResumptionIdSize, CHIP_ERROR_INVALID_CASE_PARAMETER);
ReturnErrorOnFailure(tlvReader.GetByteView(output.resumptionId));
VerifyOrReturnError(output.resumptionId.size() == SessionResumptionStorage::kResumptionIdSize,
CHIP_ERROR_INVALID_CASE_PARAMETER);
err = tlvReader.Next();
}

if (err == CHIP_NO_ERROR && tlvReader.GetTag() == ContextTag(kResume1MICTag))
{
resume1MICTagFound = true;
ReturnErrorOnFailure(tlvReader.GetByteView(initiatorResumeMIC));
VerifyOrReturnError(initiatorResumeMIC.size() == CHIP_CRYPTO_AEAD_MIC_LENGTH_BYTES, CHIP_ERROR_INVALID_CASE_PARAMETER);
ReturnErrorOnFailure(tlvReader.GetByteView(output.initiatorResumeMICSpan));
VerifyOrReturnError(output.initiatorResumeMICSpan.size() == CHIP_CRYPTO_AEAD_MIC_LENGTH_BYTES,
CHIP_ERROR_INVALID_CASE_PARAMETER);
err = tlvReader.Next();
}

Expand All @@ -2236,11 +2256,11 @@ CHIP_ERROR CASESession::ParseSigma1(TLV::ContiguousBufferTLVReader & tlvReader,

if (resumptionIDTagFound && resume1MICTagFound)
{
resumptionRequested = true;
output.sessionResumptionRequested = true;
}
else if (!resumptionIDTagFound && !resume1MICTagFound)
{
resumptionRequested = false;
output.sessionResumptionRequested = false;
}
else
{
Expand Down
Loading

0 comments on commit 13a8f48

Please sign in to comment.