Skip to content

Commit

Permalink
Fix multiple check-in/peer nodeId handling in icd client side
Browse files Browse the repository at this point in the history
  • Loading branch information
yunhanw-google committed Aug 29, 2024
1 parent 82ef12b commit 224cbcd
Show file tree
Hide file tree
Showing 14 changed files with 78 additions and 39 deletions.
15 changes: 10 additions & 5 deletions examples/chip-tool/commands/clusters/ClusterCommand.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ class ClusterCommand : public InteractionModelCommands, public ModelCommand, pub
const chip::app::Clusters::IcdManagement::Commands::UnregisterClient::Type & value)
{
ReturnErrorOnFailure(InteractionModelCommands::SendCommand(device, endpointId, clusterId, commandId, value));
mScopedNodeId = chip::ScopedNodeId(value.checkInNodeID, device->GetSecureSession().Value()->GetFabricIndex());
mPeerNodeId = chip::ScopedNodeId(device->GetDeviceId(), device->GetSecureSession().Value()->GetFabricIndex());
mCheckInNodeId = chip::ScopedNodeId(value.checkInNodeID, device->GetSecureSession().Value()->GetFabricIndex());
return CHIP_NO_ERROR;
}

Expand All @@ -69,7 +70,8 @@ class ClusterCommand : public InteractionModelCommands, public ModelCommand, pub
const chip::app::Clusters::IcdManagement::Commands::RegisterClient::Type & value)
{
ReturnErrorOnFailure(InteractionModelCommands::SendCommand(device, endpointId, clusterId, commandId, value));
mScopedNodeId = chip::ScopedNodeId(value.checkInNodeID, device->GetSecureSession().Value()->GetFabricIndex());
mPeerNodeId = chip::ScopedNodeId(device->GetDeviceId(), device->GetSecureSession().Value()->GetFabricIndex());
mCheckInNodeId = chip::ScopedNodeId(value.checkInNodeID, device->GetSecureSession().Value()->GetFabricIndex());
mMonitoredSubject = value.monitoredSubject;
mClientType = value.clientType;
memcpy(mICDSymmetricKey, value.key.data(), value.key.size());
Expand Down Expand Up @@ -147,7 +149,9 @@ class ClusterCommand : public InteractionModelCommands, public ModelCommand, pub
return;
}
chip::app::ICDClientInfo clientInfo;
clientInfo.peer_node = mScopedNodeId;

clientInfo.peer_node = mPeerNodeId;
clientInfo.check_in_node = mCheckInNodeId;
clientInfo.monitored_subject = mMonitoredSubject;
clientInfo.start_icd_counter = value.ICDCounter;
clientInfo.client_type = mClientType;
Expand All @@ -159,7 +163,7 @@ class ClusterCommand : public InteractionModelCommands, public ModelCommand, pub
if ((path.mEndpointId == chip::kRootEndpointId) && (path.mClusterId == chip::app::Clusters::IcdManagement::Id) &&
(path.mCommandId == chip::app::Clusters::IcdManagement::Commands::UnregisterClient::Id))
{
ClearICDEntry(mScopedNodeId);
ClearICDEntry(mPeerNodeId);
}
}

Expand Down Expand Up @@ -260,7 +264,8 @@ class ClusterCommand : public InteractionModelCommands, public ModelCommand, pub
private:
chip::ClusterId mClusterId;
chip::CommandId mCommandId;
chip::ScopedNodeId mScopedNodeId;
chip::ScopedNodeId mPeerNodeId;
chip::ScopedNodeId mCheckInNodeId;
uint64_t mMonitoredSubject = static_cast<uint64_t>(0);
chip::app::Clusters::IcdManagement::ClientTypeEnum mClientType = chip::app::Clusters::IcdManagement::ClientTypeEnum::kPermanent;
uint8_t mICDSymmetricKey[chip::Crypto::kAES_CCM128_Key_Length];
Expand Down
5 changes: 3 additions & 2 deletions examples/chip-tool/commands/icd/ICDCommand.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,15 @@ CHIP_ERROR ICDListCommand::RunCommand()
fprintf(stderr, " +------------------------------------------------------------------------------------------+\n");
fprintf(stderr, " | %-88s |\n", "Known ICDs:");
fprintf(stderr, " +------------------------------------------------------------------------------------------+\n");
fprintf(stderr, " | %20s | %15s | %15s | %16s | %10s |\n", "Fabric Index:Node ID", "Start Counter", "Counter Offset",
fprintf(stderr, " | %20s | %20s | %15s | %15s | %16s | %10s |\n", "Fabric Index:Peer Node ID", "Fabric Index:CheckIn Node ID", "Start Counter", "Counter Offset",
"MonitoredSubject", "ClientType");

while (iter->Next(info))
{
fprintf(stderr, " +------------------------------------------------------------------------------------------+\n");
fprintf(stderr, " | %3" PRIu32 ":" ChipLogFormatX64 " | %15" PRIu32 " | %15" PRIu32 " | " ChipLogFormatX64 " | %10u |\n",
fprintf(stderr, " | %3" PRIu32 ":" ChipLogFormatX64 " | %3" PRIu32 ":" ChipLogFormatX64 " | %15" PRIu32 " | %15" PRIu32 " | " ChipLogFormatX64 " | %10u |\n",
static_cast<uint32_t>(info.peer_node.GetFabricIndex()), ChipLogValueX64(info.peer_node.GetNodeId()),
static_cast<uint32_t>(info.check_in_node.GetFabricIndex()), ChipLogValueX64(info.check_in_node.GetNodeId()),
info.start_icd_counter, info.offset, ChipLogValueX64(info.monitored_subject),
static_cast<uint8_t>(info.client_type));

Expand Down
3 changes: 2 additions & 1 deletion examples/chip-tool/commands/pairing/PairingCommand.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,8 @@ void PairingCommand::OnICDRegistrationComplete(ScopedNodeId nodeId, uint32_t icd
sizeof(icdSymmetricKeyHex), chip::Encoding::HexFlags::kNullTerminate);

app::ICDClientInfo clientInfo;
clientInfo.peer_node = chip::ScopedNodeId(mICDCheckInNodeId.Value(), nodeId.GetFabricIndex());
clientInfo.check_in_node = chip::ScopedNodeId(mICDCheckInNodeId.Value(), nodeId.GetFabricIndex());
clientInfo.peer_node = nodeId;
clientInfo.monitored_subject = mICDMonitoredSubject.Value();
clientInfo.start_icd_counter = icdCounter;

Expand Down
9 changes: 8 additions & 1 deletion src/app/icd/client/DefaultICDClientStorage.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,16 +235,22 @@ CHIP_ERROR DefaultICDClientStorage::Load(FabricIndex fabricIndex, std::vector<IC
ICDClientInfo clientInfo;
TLV::TLVType ICDClientInfoType;
NodeId nodeId;
NodeId checkInNodeId;
FabricIndex fabric;
ReturnErrorOnFailure(reader.EnterContainer(ICDClientInfoType));
// Peer Node ID
ReturnErrorOnFailure(reader.Next(TLV::ContextTag(ClientInfoTag::kPeerNodeId)));
ReturnErrorOnFailure(reader.Get(nodeId));

ReturnErrorOnFailure(reader.Next(TLV::ContextTag(ClientInfoTag::kCheckInNodeId)));
ReturnErrorOnFailure(reader.Get(checkInNodeId));

// Fabric Index
ReturnErrorOnFailure(reader.Next(TLV::ContextTag(ClientInfoTag::kFabricIndex)));
ReturnErrorOnFailure(reader.Get(fabric));
clientInfo.peer_node = ScopedNodeId(nodeId, fabric);

clientInfo.peer_node = ScopedNodeId(nodeId, fabric);
clientInfo.check_in_node = ScopedNodeId(checkInNodeId, fabric);

// Start ICD Counter
ReturnErrorOnFailure(reader.Next(TLV::ContextTag(ClientInfoTag::kStartICDCounter)));
Expand Down Expand Up @@ -323,6 +329,7 @@ CHIP_ERROR DefaultICDClientStorage::SerializeToTlv(TLV::TLVWriter & writer, cons
TLV::TLVType ICDClientInfoContainerType;
ReturnErrorOnFailure(writer.StartContainer(TLV::AnonymousTag(), TLV::kTLVType_Structure, ICDClientInfoContainerType));
ReturnErrorOnFailure(writer.Put(TLV::ContextTag(ClientInfoTag::kPeerNodeId), clientInfo.peer_node.GetNodeId()));
ReturnErrorOnFailure(writer.Put(TLV::ContextTag(ClientInfoTag::kCheckInNodeId), clientInfo.check_in_node.GetNodeId()));
ReturnErrorOnFailure(writer.Put(TLV::ContextTag(ClientInfoTag::kFabricIndex), clientInfo.peer_node.GetFabricIndex()));
ReturnErrorOnFailure(writer.Put(TLV::ContextTag(ClientInfoTag::kStartICDCounter), clientInfo.start_icd_counter));
ReturnErrorOnFailure(writer.Put(TLV::ContextTag(ClientInfoTag::kOffset), clientInfo.offset));
Expand Down
17 changes: 9 additions & 8 deletions src/app/icd/client/DefaultICDClientStorage.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,13 +124,14 @@ class DefaultICDClientStorage : public ICDClientStorage
enum class ClientInfoTag : uint8_t
{
kPeerNodeId = 1,
kFabricIndex = 2,
kStartICDCounter = 3,
kOffset = 4,
kMonitoredSubject = 5,
kAesKeyHandle = 6,
kHmacKeyHandle = 7,
kClientType = 8,
kCheckInNodeId = 2,
kFabricIndex = 3,
kStartICDCounter = 4,
kOffset = 5,
kMonitoredSubject = 6,
kAesKeyHandle = 7,
kHmacKeyHandle = 8,
kClientType = 9,
};

enum class CounterTag : uint8_t
Expand Down Expand Up @@ -158,7 +159,7 @@ class DefaultICDClientStorage : public ICDClientStorage
{
// All the fields added together
return TLV::EstimateStructOverhead(
sizeof(NodeId), sizeof(FabricIndex), sizeof(uint32_t) /*start_icd_counter*/, sizeof(uint32_t) /*offset*/,
sizeof(NodeId), sizeof(NodeId), sizeof(FabricIndex), sizeof(uint32_t) /*start_icd_counter*/, sizeof(uint32_t) /*offset*/,
sizeof(uint64_t) /*monitored_subject*/, sizeof(Crypto::Symmetric128BitsKeyByteArray) /*aes_key_handle*/,
sizeof(Crypto::Symmetric128BitsKeyByteArray) /*hmac_key_handle*/, sizeof(uint8_t) /*client_type*/);
}
Expand Down
2 changes: 2 additions & 0 deletions src/app/icd/client/ICDClientInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ namespace app {
struct ICDClientInfo
{
ScopedNodeId peer_node;
ScopedNodeId check_in_node;
uint32_t start_icd_counter = 0;
uint32_t offset = 0;
Clusters::IcdManagement::ClientTypeEnum client_type = Clusters::IcdManagement::ClientTypeEnum::kPermanent;
Expand All @@ -44,6 +45,7 @@ struct ICDClientInfo
ICDClientInfo & operator=(const ICDClientInfo & other)
{
peer_node = other.peer_node;
check_in_node = other.check_in_node;
start_icd_counter = other.start_icd_counter;
offset = other.offset;
client_type = other.client_type;
Expand Down
2 changes: 1 addition & 1 deletion src/app/icd/client/RefreshKeySender.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ CHIP_ERROR RefreshKeySender::RegisterClientWithNewKey(Messaging::ExchangeManager
EndpointId endpointId = 0;

Clusters::IcdManagement::Commands::RegisterClient::Type registerClientCommand;
registerClientCommand.checkInNodeID = mICDClientInfo.peer_node.GetNodeId();
registerClientCommand.checkInNodeID = mICDClientInfo.check_in_node.GetNodeId();
registerClientCommand.monitoredSubject = mICDClientInfo.monitored_subject;
registerClientCommand.key = mNewKey.Span();
return Controller::InvokeCommandRequest(&exchangeMgr, sessionHandle, endpointId, registerClientCommand, onSuccess, onFailure);
Expand Down
21 changes: 11 additions & 10 deletions src/controller/java/AndroidCheckInDelegate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@
#include <lib/support/JniReferences.h>
#include <lib/support/logging/CHIPLogging.h>

#define PARSE_CLIENT_INFO(_clientInfo, _peerNodeId, _startCounter, _offset, _monitoredSubject, _jniICDAesKey, _jniICDHmacKey) \
jlong _peerNodeId = static_cast<jlong>(_clientInfo.peer_node.GetNodeId()); \
#define PARSE_CLIENT_INFO(_clientInfo, _peerNodeId, _checkInNodeId, _startCounter, _offset, _monitoredSubject, _jniICDAesKey, _jniICDHmacKey) \
jlong _peerNodeId = static_cast<jlong>(_clientInfo.peer_node.GetNodeId());
jlong _checkInNodeId = static_cast<jlong>(_clientInfo.check_in_node.GetNodeId()); \
jlong _startCounter = static_cast<jlong>(_clientInfo.start_icd_counter); \
jlong _offset = static_cast<jlong>(_clientInfo.offset); \
jlong _monitoredSubject = static_cast<jlong>(_clientInfo.monitored_subject); \
Expand Down Expand Up @@ -54,22 +55,22 @@ CHIP_ERROR AndroidCheckInDelegate::SetDelegate(jobject checkInDelegateObj)
void AndroidCheckInDelegate::OnCheckInComplete(const ICDClientInfo & clientInfo)
{
ChipLogProgress(
ICD, "Check In Message processing complete: start_counter=%" PRIu32 " offset=%" PRIu32 " nodeid=" ChipLogFormatScopedNodeId,
clientInfo.start_icd_counter, clientInfo.offset, ChipLogValueScopedNodeId(clientInfo.peer_node));
ICD, "Check In Message processing complete: start_counter=%" PRIu32 " offset=%" PRIu32 " peernodeid=" ChipLogFormatScopedNodeId " checkinnodeid=" ChipLogFormatScopedNodeId,
clientInfo.start_icd_counter, clientInfo.offset, ChipLogValueScopedNodeId(clientInfo.peer_node), ChipLogValueScopedNodeId(clientInfo.check_in_node));

VerifyOrReturn(mCheckInDelegate.HasValidObjectRef(), ChipLogProgress(ICD, "check-in delegate is not implemented!"));

JNIEnv * env = chip::JniReferences::GetInstance().GetEnvForCurrentThread();
VerifyOrReturn(env != nullptr, ChipLogError(Controller, "JNIEnv is null!"));
PARSE_CLIENT_INFO(clientInfo, peerNodeId, startCounter, offset, monitoredSubject, jniICDAesKey, jniICDHmacKey)
PARSE_CLIENT_INFO(clientInfo, peerNodeId, checkInNodeId, startCounter, offset, monitoredSubject, jniICDAesKey, jniICDHmacKey)

jmethodID onCheckInCompleteMethodID = nullptr;
CHIP_ERROR err = chip::JniReferences::GetInstance().FindMethod(env, mCheckInDelegate.ObjectRef(), "onCheckInComplete",
"(JJJJ[B[B)V", &onCheckInCompleteMethodID);
"(JJJJJ[B[B)V", &onCheckInCompleteMethodID);
VerifyOrReturn(err == CHIP_NO_ERROR,
ChipLogProgress(ICD, "onCheckInComplete - FindMethod is failed! : %" CHIP_ERROR_FORMAT, err.Format()));

env->CallVoidMethod(mCheckInDelegate.ObjectRef(), onCheckInCompleteMethodID, peerNodeId, startCounter, offset, monitoredSubject,
env->CallVoidMethod(mCheckInDelegate.ObjectRef(), onCheckInCompleteMethodID, peerNodeId, checkInNodeId, startCounter, offset, monitoredSubject,
jniICDAesKey.jniValue(), jniICDHmacKey.jniValue());
}

Expand All @@ -84,16 +85,16 @@ RefreshKeySender * AndroidCheckInDelegate::OnKeyRefreshNeeded(ICDClientInfo & cl
JNIEnv * env = chip::JniReferences::GetInstance().GetEnvForCurrentThread();
VerifyOrReturnValue(env != nullptr, nullptr, ChipLogError(Controller, "JNIEnv is null!"));

PARSE_CLIENT_INFO(clientInfo, peerNodeId, startCounter, offset, monitoredSubject, jniICDAesKey, jniICDHmacKey)
PARSE_CLIENT_INFO(clientInfo, peerNodeId, checkInNodeId, startCounter, offset, monitoredSubject, jniICDAesKey, jniICDHmacKey)

jmethodID onKeyRefreshNeededMethodID = nullptr;
err = chip::JniReferences::GetInstance().FindMethod(env, mCheckInDelegate.ObjectRef(), "onKeyRefreshNeeded", "(JJJJ[B[B)V",
err = chip::JniReferences::GetInstance().FindMethod(env, mCheckInDelegate.ObjectRef(), "onKeyRefreshNeeded", "(JJJJJ[B[B)V",
&onKeyRefreshNeededMethodID);
VerifyOrReturnValue(err == CHIP_NO_ERROR, nullptr,
ChipLogProgress(ICD, "onKeyRefreshNeeded - FindMethod is failed! : %" CHIP_ERROR_FORMAT, err.Format()));

jbyteArray key = static_cast<jbyteArray>(env->CallObjectMethod(mCheckInDelegate.ObjectRef(), onKeyRefreshNeededMethodID,
peerNodeId, startCounter, offset, monitoredSubject,
peerNodeId, checkInNodeId, startCounter, offset, monitoredSubject,
jniICDAesKey.jniValue(), jniICDHmacKey.jniValue()));

if (key != nullptr)
Expand Down
5 changes: 3 additions & 2 deletions src/controller/java/AndroidDeviceControllerWrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1015,6 +1015,7 @@ void AndroidDeviceControllerWrapper::OnICDRegistrationComplete(chip::ScopedNodeI
CHIP_ERROR err = CHIP_NO_ERROR;
chip::app::ICDClientInfo clientInfo;
clientInfo.peer_node = icdNodeId;
clientInfo.check_in_node = chip::ScopedNodeId(mAutoCommissioner.GetCommissioningParameters().GetICDCheckInNodeId(), icdNodeId.GetFabricIndex());
clientInfo.monitored_subject = mAutoCommissioner.GetCommissioningParameters().GetICDMonitoredSubject().Value();
clientInfo.start_icd_counter = icdCounter;

Expand Down Expand Up @@ -1056,7 +1057,7 @@ void AndroidDeviceControllerWrapper::OnICDRegistrationComplete(chip::ScopedNodeI
methodErr = chip::JniReferences::GetInstance().GetLocalClassRef(env, "chip/devicecontroller/ICDDeviceInfo", icdDeviceInfoClass);
VerifyOrReturn(methodErr == CHIP_NO_ERROR, ChipLogError(Controller, "Could not find class ICDDeviceInfo"));

icdDeviceInfoStructCtor = env->GetMethodID(icdDeviceInfoClass, "<init>", "([BILjava/lang/String;JJIJJJJI)V");
icdDeviceInfoStructCtor = env->GetMethodID(icdDeviceInfoClass, "<init>", "([BILjava/lang/String;JJIJJJJJI)V");
VerifyOrReturn(icdDeviceInfoStructCtor != nullptr, ChipLogError(Controller, "Could not find ICDDeviceInfo constructor"));

methodErr =
Expand All @@ -1069,7 +1070,7 @@ void AndroidDeviceControllerWrapper::OnICDRegistrationComplete(chip::ScopedNodeI
icdDeviceInfoObj = env->NewObject(
icdDeviceInfoClass, icdDeviceInfoStructCtor, jSymmetricKey, static_cast<jint>(mUserActiveModeTriggerHint.Raw()),
jUserActiveModeTriggerInstruction, static_cast<jlong>(mIdleModeDuration), static_cast<jlong>(mActiveModeDuration),
static_cast<jint>(mActiveModeThreshold), static_cast<jlong>(icdNodeId.GetNodeId()), static_cast<jlong>(icdCounter),
static_cast<jint>(mActiveModeThreshold), static_cast<jlong>(icdNodeId.GetNodeId()), static_cast<jlong>(checkInNodeId.GetNodeId()), static_cast<jlong>(icdCounter),
static_cast<jlong>(mAutoCommissioner.GetCommissioningParameters().GetICDMonitoredSubject().Value()),
static_cast<jlong>(Controller()->GetFabricId()), static_cast<jint>(Controller()->GetFabricIndex()));

Expand Down
Loading

0 comments on commit 224cbcd

Please sign in to comment.