Skip to content

Commit

Permalink
Fix multiple check-in/peer nodeId handling in icd client side (#35304)
Browse files Browse the repository at this point in the history
  • Loading branch information
yunhanw-google authored Aug 30, 2024
1 parent de9d906 commit 3d69583
Show file tree
Hide file tree
Showing 14 changed files with 119 additions and 47 deletions.
27 changes: 21 additions & 6 deletions examples/chip-tool/commands/clusters/ClusterCommand.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class ClusterCommand : public InteractionModelCommands, public ModelCommand, pub
const chip::app::Clusters::IcdManagement::Commands::UnregisterClient::Type & value)
{
ReturnErrorOnFailure(InteractionModelCommands::SendCommand(device, endpointId, clusterId, commandId, value));
mScopedNodeId = chip::ScopedNodeId(value.checkInNodeID, device->GetSecureSession().Value()->GetFabricIndex());
mPeerNodeId = chip::ScopedNodeId(device->GetDeviceId(), device->GetSecureSession().Value()->GetFabricIndex());
return CHIP_NO_ERROR;
}

Expand All @@ -69,7 +69,8 @@ class ClusterCommand : public InteractionModelCommands, public ModelCommand, pub
const chip::app::Clusters::IcdManagement::Commands::RegisterClient::Type & value)
{
ReturnErrorOnFailure(InteractionModelCommands::SendCommand(device, endpointId, clusterId, commandId, value));
mScopedNodeId = chip::ScopedNodeId(value.checkInNodeID, device->GetSecureSession().Value()->GetFabricIndex());
mPeerNodeId = chip::ScopedNodeId(device->GetDeviceId(), device->GetSecureSession().Value()->GetFabricIndex());
mCheckInNodeId = chip::ScopedNodeId(value.checkInNodeID, device->GetSecureSession().Value()->GetFabricIndex());
mMonitoredSubject = value.monitoredSubject;
mClientType = value.clientType;
memcpy(mICDSymmetricKey, value.key.data(), value.key.size());
Expand Down Expand Up @@ -147,7 +148,9 @@ class ClusterCommand : public InteractionModelCommands, public ModelCommand, pub
return;
}
chip::app::ICDClientInfo clientInfo;
clientInfo.peer_node = mScopedNodeId;

clientInfo.peer_node = mPeerNodeId;
clientInfo.check_in_node = mCheckInNodeId;
clientInfo.monitored_subject = mMonitoredSubject;
clientInfo.start_icd_counter = value.ICDCounter;
clientInfo.client_type = mClientType;
Expand All @@ -159,7 +162,7 @@ class ClusterCommand : public InteractionModelCommands, public ModelCommand, pub
if ((path.mEndpointId == chip::kRootEndpointId) && (path.mClusterId == chip::app::Clusters::IcdManagement::Id) &&
(path.mCommandId == chip::app::Clusters::IcdManagement::Commands::UnregisterClient::Id))
{
ClearICDEntry(mScopedNodeId);
ClearICDEntry(mPeerNodeId);
}
}

Expand Down Expand Up @@ -260,9 +263,21 @@ class ClusterCommand : public InteractionModelCommands, public ModelCommand, pub
private:
chip::ClusterId mClusterId;
chip::CommandId mCommandId;
chip::ScopedNodeId mScopedNodeId;
uint64_t mMonitoredSubject = static_cast<uint64_t>(0);
// The scoped node ID to which RegisterClient and UnregisterClient command will be sent. Not set for other commands.
chip::ScopedNodeId mPeerNodeId;
// The scoped node ID to which a Check-In message will be sent. Only set for the RegisterClient command.
chip::ScopedNodeId mCheckInNodeId;

// Used to determine if a particular client has an active subscription for the given entry.
// The MonitoredSubject, when it is a NodeID, MAY be the same as the CheckInNodeID.
// The MonitoredSubject gives the registering client the flexibility of having a different
// CheckInNodeID from the MonitoredSubject.
uint64_t mMonitoredSubject = static_cast<uint64_t>(0);

// Client type of the client registering
chip::app::Clusters::IcdManagement::ClientTypeEnum mClientType = chip::app::Clusters::IcdManagement::ClientTypeEnum::kPermanent;

// Shared secret between the client and the ICD to encrypt the Check-In message.
uint8_t mICDSymmetricKey[chip::Crypto::kAES_CCM128_Key_Length];

CHIP_ERROR mError = CHIP_NO_ERROR;
Expand Down
9 changes: 6 additions & 3 deletions examples/chip-tool/commands/icd/ICDCommand.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,17 @@ CHIP_ERROR ICDListCommand::RunCommand()
fprintf(stderr, " +------------------------------------------------------------------------------------------+\n");
fprintf(stderr, " | %-88s |\n", "Known ICDs:");
fprintf(stderr, " +------------------------------------------------------------------------------------------+\n");
fprintf(stderr, " | %20s | %15s | %15s | %16s | %10s |\n", "Fabric Index:Node ID", "Start Counter", "Counter Offset",
"MonitoredSubject", "ClientType");
fprintf(stderr, " | %20s | %20s | %15s | %15s | %16s | %10s |\n", "Fabric Index:Peer Node ID", "Fabric Index:CheckIn Node ID",
"Start Counter", "Counter Offset", "MonitoredSubject", "ClientType");

while (iter->Next(info))
{
fprintf(stderr, " +------------------------------------------------------------------------------------------+\n");
fprintf(stderr, " | %3" PRIu32 ":" ChipLogFormatX64 " | %15" PRIu32 " | %15" PRIu32 " | " ChipLogFormatX64 " | %10u |\n",
fprintf(stderr,
" | %3" PRIu32 ":" ChipLogFormatX64 " | %3" PRIu32 ":" ChipLogFormatX64 " | %15" PRIu32 " | %15" PRIu32
" | " ChipLogFormatX64 " | %10u |\n",
static_cast<uint32_t>(info.peer_node.GetFabricIndex()), ChipLogValueX64(info.peer_node.GetNodeId()),
static_cast<uint32_t>(info.check_in_node.GetFabricIndex()), ChipLogValueX64(info.check_in_node.GetNodeId()),
info.start_icd_counter, info.offset, ChipLogValueX64(info.monitored_subject),
static_cast<uint8_t>(info.client_type));

Expand Down
3 changes: 2 additions & 1 deletion examples/chip-tool/commands/pairing/PairingCommand.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,8 @@ void PairingCommand::OnICDRegistrationComplete(ScopedNodeId nodeId, uint32_t icd
sizeof(icdSymmetricKeyHex), chip::Encoding::HexFlags::kNullTerminate);

app::ICDClientInfo clientInfo;
clientInfo.peer_node = chip::ScopedNodeId(mICDCheckInNodeId.Value(), nodeId.GetFabricIndex());
clientInfo.check_in_node = chip::ScopedNodeId(mICDCheckInNodeId.Value(), nodeId.GetFabricIndex());
clientInfo.peer_node = nodeId;
clientInfo.monitored_subject = mICDMonitoredSubject.Value();
clientInfo.start_icd_counter = icdCounter;

Expand Down
9 changes: 8 additions & 1 deletion src/app/icd/client/DefaultICDClientStorage.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,16 +235,22 @@ CHIP_ERROR DefaultICDClientStorage::Load(FabricIndex fabricIndex, std::vector<IC
ICDClientInfo clientInfo;
TLV::TLVType ICDClientInfoType;
NodeId nodeId;
NodeId checkInNodeId;
FabricIndex fabric;
ReturnErrorOnFailure(reader.EnterContainer(ICDClientInfoType));
// Peer Node ID
ReturnErrorOnFailure(reader.Next(TLV::ContextTag(ClientInfoTag::kPeerNodeId)));
ReturnErrorOnFailure(reader.Get(nodeId));

ReturnErrorOnFailure(reader.Next(TLV::ContextTag(ClientInfoTag::kCheckInNodeId)));
ReturnErrorOnFailure(reader.Get(checkInNodeId));

// Fabric Index
ReturnErrorOnFailure(reader.Next(TLV::ContextTag(ClientInfoTag::kFabricIndex)));
ReturnErrorOnFailure(reader.Get(fabric));
clientInfo.peer_node = ScopedNodeId(nodeId, fabric);

clientInfo.peer_node = ScopedNodeId(nodeId, fabric);
clientInfo.check_in_node = ScopedNodeId(checkInNodeId, fabric);

// Start ICD Counter
ReturnErrorOnFailure(reader.Next(TLV::ContextTag(ClientInfoTag::kStartICDCounter)));
Expand Down Expand Up @@ -323,6 +329,7 @@ CHIP_ERROR DefaultICDClientStorage::SerializeToTlv(TLV::TLVWriter & writer, cons
TLV::TLVType ICDClientInfoContainerType;
ReturnErrorOnFailure(writer.StartContainer(TLV::AnonymousTag(), TLV::kTLVType_Structure, ICDClientInfoContainerType));
ReturnErrorOnFailure(writer.Put(TLV::ContextTag(ClientInfoTag::kPeerNodeId), clientInfo.peer_node.GetNodeId()));
ReturnErrorOnFailure(writer.Put(TLV::ContextTag(ClientInfoTag::kCheckInNodeId), clientInfo.check_in_node.GetNodeId()));
ReturnErrorOnFailure(writer.Put(TLV::ContextTag(ClientInfoTag::kFabricIndex), clientInfo.peer_node.GetFabricIndex()));
ReturnErrorOnFailure(writer.Put(TLV::ContextTag(ClientInfoTag::kStartICDCounter), clientInfo.start_icd_counter));
ReturnErrorOnFailure(writer.Put(TLV::ContextTag(ClientInfoTag::kOffset), clientInfo.offset));
Expand Down
20 changes: 11 additions & 9 deletions src/app/icd/client/DefaultICDClientStorage.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,13 +124,14 @@ class DefaultICDClientStorage : public ICDClientStorage
enum class ClientInfoTag : uint8_t
{
kPeerNodeId = 1,
kFabricIndex = 2,
kStartICDCounter = 3,
kOffset = 4,
kMonitoredSubject = 5,
kAesKeyHandle = 6,
kHmacKeyHandle = 7,
kClientType = 8,
kCheckInNodeId = 2,
kFabricIndex = 3,
kStartICDCounter = 4,
kOffset = 5,
kMonitoredSubject = 6,
kAesKeyHandle = 7,
kHmacKeyHandle = 8,
kClientType = 9,
};

enum class CounterTag : uint8_t
Expand Down Expand Up @@ -158,8 +159,9 @@ class DefaultICDClientStorage : public ICDClientStorage
{
// All the fields added together
return TLV::EstimateStructOverhead(
sizeof(NodeId), sizeof(FabricIndex), sizeof(uint32_t) /*start_icd_counter*/, sizeof(uint32_t) /*offset*/,
sizeof(uint64_t) /*monitored_subject*/, sizeof(Crypto::Symmetric128BitsKeyByteArray) /*aes_key_handle*/,
sizeof(NodeId), sizeof(NodeId), sizeof(FabricIndex), sizeof(uint32_t) /*start_icd_counter*/,
sizeof(uint32_t) /*offset*/, sizeof(uint64_t) /*monitored_subject*/,
sizeof(Crypto::Symmetric128BitsKeyByteArray) /*aes_key_handle*/,
sizeof(Crypto::Symmetric128BitsKeyByteArray) /*hmac_key_handle*/, sizeof(uint8_t) /*client_type*/);
}

Expand Down
2 changes: 2 additions & 0 deletions src/app/icd/client/ICDClientInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ namespace app {
struct ICDClientInfo
{
ScopedNodeId peer_node;
ScopedNodeId check_in_node;
uint32_t start_icd_counter = 0;
uint32_t offset = 0;
Clusters::IcdManagement::ClientTypeEnum client_type = Clusters::IcdManagement::ClientTypeEnum::kPermanent;
Expand All @@ -44,6 +45,7 @@ struct ICDClientInfo
ICDClientInfo & operator=(const ICDClientInfo & other)
{
peer_node = other.peer_node;
check_in_node = other.check_in_node;
start_icd_counter = other.start_icd_counter;
offset = other.offset;
client_type = other.client_type;
Expand Down
2 changes: 1 addition & 1 deletion src/app/icd/client/RefreshKeySender.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ CHIP_ERROR RefreshKeySender::RegisterClientWithNewKey(Messaging::ExchangeManager
EndpointId endpointId = 0;

Clusters::IcdManagement::Commands::RegisterClient::Type registerClientCommand;
registerClientCommand.checkInNodeID = mICDClientInfo.peer_node.GetNodeId();
registerClientCommand.checkInNodeID = mICDClientInfo.check_in_node.GetNodeId();
registerClientCommand.monitoredSubject = mICDClientInfo.monitored_subject;
registerClientCommand.key = mNewKey.Span();
return Controller::InvokeCommandRequest(&exchangeMgr, sessionHandle, endpointId, registerClientCommand, onSuccess, onFailure);
Expand Down
31 changes: 18 additions & 13 deletions src/controller/java/AndroidCheckInDelegate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@
#include <lib/support/JniReferences.h>
#include <lib/support/logging/CHIPLogging.h>

#define PARSE_CLIENT_INFO(_clientInfo, _peerNodeId, _startCounter, _offset, _monitoredSubject, _jniICDAesKey, _jniICDHmacKey) \
#define PARSE_CLIENT_INFO(_clientInfo, _peerNodeId, _checkInNodeId, _startCounter, _offset, _monitoredSubject, _jniICDAesKey, \
_jniICDHmacKey) \
jlong _peerNodeId = static_cast<jlong>(_clientInfo.peer_node.GetNodeId()); \
jlong _checkInNodeId = static_cast<jlong>(_clientInfo.check_in_node.GetNodeId()); \
jlong _startCounter = static_cast<jlong>(_clientInfo.start_icd_counter); \
jlong _offset = static_cast<jlong>(_clientInfo.offset); \
jlong _monitoredSubject = static_cast<jlong>(_clientInfo.monitored_subject); \
Expand Down Expand Up @@ -53,24 +55,26 @@ CHIP_ERROR AndroidCheckInDelegate::SetDelegate(jobject checkInDelegateObj)

void AndroidCheckInDelegate::OnCheckInComplete(const ICDClientInfo & clientInfo)
{
ChipLogProgress(
ICD, "Check In Message processing complete: start_counter=%" PRIu32 " offset=%" PRIu32 " nodeid=" ChipLogFormatScopedNodeId,
clientInfo.start_icd_counter, clientInfo.offset, ChipLogValueScopedNodeId(clientInfo.peer_node));
ChipLogProgress(ICD,
"Check In Message processing complete: start_counter=%" PRIu32 " offset=%" PRIu32
" peernodeid=" ChipLogFormatScopedNodeId " checkinnodeid=" ChipLogFormatScopedNodeId,
clientInfo.start_icd_counter, clientInfo.offset, ChipLogValueScopedNodeId(clientInfo.peer_node),
ChipLogValueScopedNodeId(clientInfo.check_in_node));

VerifyOrReturn(mCheckInDelegate.HasValidObjectRef(), ChipLogProgress(ICD, "check-in delegate is not implemented!"));

JNIEnv * env = chip::JniReferences::GetInstance().GetEnvForCurrentThread();
VerifyOrReturn(env != nullptr, ChipLogError(Controller, "JNIEnv is null!"));
PARSE_CLIENT_INFO(clientInfo, peerNodeId, startCounter, offset, monitoredSubject, jniICDAesKey, jniICDHmacKey)
PARSE_CLIENT_INFO(clientInfo, peerNodeId, checkInNodeId, startCounter, offset, monitoredSubject, jniICDAesKey, jniICDHmacKey)

jmethodID onCheckInCompleteMethodID = nullptr;
CHIP_ERROR err = chip::JniReferences::GetInstance().FindMethod(env, mCheckInDelegate.ObjectRef(), "onCheckInComplete",
"(JJJJ[B[B)V", &onCheckInCompleteMethodID);
"(JJJJJ[B[B)V", &onCheckInCompleteMethodID);
VerifyOrReturn(err == CHIP_NO_ERROR,
ChipLogProgress(ICD, "onCheckInComplete - FindMethod is failed! : %" CHIP_ERROR_FORMAT, err.Format()));

env->CallVoidMethod(mCheckInDelegate.ObjectRef(), onCheckInCompleteMethodID, peerNodeId, startCounter, offset, monitoredSubject,
jniICDAesKey.jniValue(), jniICDHmacKey.jniValue());
env->CallVoidMethod(mCheckInDelegate.ObjectRef(), onCheckInCompleteMethodID, peerNodeId, checkInNodeId, startCounter, offset,
monitoredSubject, jniICDAesKey.jniValue(), jniICDHmacKey.jniValue());
}

RefreshKeySender * AndroidCheckInDelegate::OnKeyRefreshNeeded(ICDClientInfo & clientInfo, ICDClientStorage * clientStorage)
Expand All @@ -84,17 +88,18 @@ RefreshKeySender * AndroidCheckInDelegate::OnKeyRefreshNeeded(ICDClientInfo & cl
JNIEnv * env = chip::JniReferences::GetInstance().GetEnvForCurrentThread();
VerifyOrReturnValue(env != nullptr, nullptr, ChipLogError(Controller, "JNIEnv is null!"));

PARSE_CLIENT_INFO(clientInfo, peerNodeId, startCounter, offset, monitoredSubject, jniICDAesKey, jniICDHmacKey)
PARSE_CLIENT_INFO(clientInfo, peerNodeId, checkInNodeId, startCounter, offset, monitoredSubject, jniICDAesKey,
jniICDHmacKey)

jmethodID onKeyRefreshNeededMethodID = nullptr;
err = chip::JniReferences::GetInstance().FindMethod(env, mCheckInDelegate.ObjectRef(), "onKeyRefreshNeeded", "(JJJJ[B[B)V",
err = chip::JniReferences::GetInstance().FindMethod(env, mCheckInDelegate.ObjectRef(), "onKeyRefreshNeeded", "(JJJJJ[B[B)V",
&onKeyRefreshNeededMethodID);
VerifyOrReturnValue(err == CHIP_NO_ERROR, nullptr,
ChipLogProgress(ICD, "onKeyRefreshNeeded - FindMethod is failed! : %" CHIP_ERROR_FORMAT, err.Format()));

jbyteArray key = static_cast<jbyteArray>(env->CallObjectMethod(mCheckInDelegate.ObjectRef(), onKeyRefreshNeededMethodID,
peerNodeId, startCounter, offset, monitoredSubject,
jniICDAesKey.jniValue(), jniICDHmacKey.jniValue()));
jbyteArray key = static_cast<jbyteArray>(
env->CallObjectMethod(mCheckInDelegate.ObjectRef(), onKeyRefreshNeededMethodID, peerNodeId, checkInNodeId, startCounter,
offset, monitoredSubject, jniICDAesKey.jniValue(), jniICDHmacKey.jniValue()));

if (key != nullptr)
{
Expand Down
8 changes: 6 additions & 2 deletions src/controller/java/AndroidDeviceControllerWrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1015,6 +1015,8 @@ void AndroidDeviceControllerWrapper::OnICDRegistrationComplete(chip::ScopedNodeI
CHIP_ERROR err = CHIP_NO_ERROR;
chip::app::ICDClientInfo clientInfo;
clientInfo.peer_node = icdNodeId;
clientInfo.check_in_node = chip::ScopedNodeId(mAutoCommissioner.GetCommissioningParameters().GetICDCheckInNodeId().Value(),
icdNodeId.GetFabricIndex());
clientInfo.monitored_subject = mAutoCommissioner.GetCommissioningParameters().GetICDMonitoredSubject().Value();
clientInfo.start_icd_counter = icdCounter;

Expand Down Expand Up @@ -1056,7 +1058,7 @@ void AndroidDeviceControllerWrapper::OnICDRegistrationComplete(chip::ScopedNodeI
methodErr = chip::JniReferences::GetInstance().GetLocalClassRef(env, "chip/devicecontroller/ICDDeviceInfo", icdDeviceInfoClass);
VerifyOrReturn(methodErr == CHIP_NO_ERROR, ChipLogError(Controller, "Could not find class ICDDeviceInfo"));

icdDeviceInfoStructCtor = env->GetMethodID(icdDeviceInfoClass, "<init>", "([BILjava/lang/String;JJIJJJJI)V");
icdDeviceInfoStructCtor = env->GetMethodID(icdDeviceInfoClass, "<init>", "([BILjava/lang/String;JJIJJJJJI)V");
VerifyOrReturn(icdDeviceInfoStructCtor != nullptr, ChipLogError(Controller, "Could not find ICDDeviceInfo constructor"));

methodErr =
Expand All @@ -1069,7 +1071,9 @@ void AndroidDeviceControllerWrapper::OnICDRegistrationComplete(chip::ScopedNodeI
icdDeviceInfoObj = env->NewObject(
icdDeviceInfoClass, icdDeviceInfoStructCtor, jSymmetricKey, static_cast<jint>(mUserActiveModeTriggerHint.Raw()),
jUserActiveModeTriggerInstruction, static_cast<jlong>(mIdleModeDuration), static_cast<jlong>(mActiveModeDuration),
static_cast<jint>(mActiveModeThreshold), static_cast<jlong>(icdNodeId.GetNodeId()), static_cast<jlong>(icdCounter),
static_cast<jint>(mActiveModeThreshold), static_cast<jlong>(icdNodeId.GetNodeId()),
static_cast<jlong>(mAutoCommissioner.GetCommissioningParameters().GetICDCheckInNodeId().Value()),
static_cast<jlong>(icdCounter),
static_cast<jlong>(mAutoCommissioner.GetCommissioningParameters().GetICDMonitoredSubject().Value()),
static_cast<jlong>(Controller()->GetFabricId()), static_cast<jint>(Controller()->GetFabricIndex()));

Expand Down
Loading

0 comments on commit 3d69583

Please sign in to comment.