Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle invalid ip address from cluster slots and added tests #7984

Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions source/extensions/clusters/redis/redis_cluster.cc
Original file line number Diff line number Diff line change
Expand Up @@ -170,11 +170,11 @@ RedisCluster::RedisDiscoverySession::RedisDiscoverySession(
resolve_timer_(parent.dispatcher_.createTimer([this]() -> void { startResolveRedis(); })),
client_factory_(client_factory), buffer_timeout_(0) {}

namespace {
// Convert the cluster slot IP/Port response to and address, return null if the response does not
// match the expected type.
Network::Address::InstanceConstSharedPtr
ProcessCluster(const NetworkFilters::Common::Redis::RespValue& value) {
RedisCluster::RedisDiscoverySession::RedisDiscoverySession::ProcessCluster(
const NetworkFilters::Common::Redis::RespValue& value) {
if (value.type() != NetworkFilters::Common::Redis::RespType::Array) {
return nullptr;
}
Expand All @@ -187,12 +187,16 @@ ProcessCluster(const NetworkFilters::Common::Redis::RespValue& value) {

std::string address = array[0].asString();
bool ipv6 = (address.find(':') != std::string::npos);
if (ipv6) {
return std::make_shared<Network::Address::Ipv6Instance>(address, array[1].asInteger());
try {
if (ipv6) {
return std::make_shared<Network::Address::Ipv6Instance>(address, array[1].asInteger());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry just realized, can you use parseInternetAddress here and avoid the ipv6 detection logic?

/wait-any

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for letting me know, I don't know why I didn't find this function I last looked. Let me change that now.

}
return std::make_shared<Network::Address::Ipv4Instance>(address, array[1].asInteger());
} catch (const EnvoyException& ex) {
ENVOY_LOG(debug, "Invalid ip address in CLUSTER SLOTS response: {}", ex.what());
return nullptr;
}
return std::make_shared<Network::Address::Ipv4Instance>(address, array[1].asInteger());
}
} // namespace

RedisCluster::RedisDiscoverySession::~RedisDiscoverySession() {
if (current_request_) {
Expand Down
3 changes: 3 additions & 0 deletions source/extensions/clusters/redis/redis_cluster.h
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,9 @@ class RedisCluster : public Upstream::BaseDynamicClusterImpl {
bool onRedirection(const NetworkFilters::Common::Redis::RespValue&) override { return true; }
void onUnexpectedResponse(const NetworkFilters::Common::Redis::RespValuePtr&);

Network::Address::InstanceConstSharedPtr
ProcessCluster(const NetworkFilters::Common::Redis::RespValue& value);

RedisCluster& parent_;
Event::Dispatcher& dispatcher_;
std::string current_host_address_;
Expand Down
48 changes: 31 additions & 17 deletions test/extensions/clusters/redis/redis_cluster_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ const std::string BasicConfig = R"EOF(
)EOF";
}

static const int ResponseFlagSize = 11;
static const int ResponseReplicaFlagSize = 4;
class RedisClusterTest : public testing::Test,
public Extensions::NetworkFilters::Common::Redis::Client::ClientFactory {
public:
Expand Down Expand Up @@ -192,7 +194,7 @@ class RedisClusterTest : public testing::Test,
replica_1[1].type(NetworkFilters::Common::Redis::RespType::Integer);
replica_1[1].asInteger() = port;

std::vector<NetworkFilters::Common::Redis::RespValue> slot_1(4);
std::vector<NetworkFilters::Common::Redis::RespValue> slot_1(ResponseReplicaFlagSize);
slot_1[0].type(NetworkFilters::Common::Redis::RespType::Integer);
slot_1[0].asInteger() = 0;
slot_1[1].type(NetworkFilters::Common::Redis::RespType::Integer);
Expand Down Expand Up @@ -280,7 +282,7 @@ class RedisClusterTest : public testing::Test,
replica_2[1].type(NetworkFilters::Common::Redis::RespType::Integer);
replica_2[1].asInteger() = 22120;

std::vector<NetworkFilters::Common::Redis::RespValue> slot_1(4);
std::vector<NetworkFilters::Common::Redis::RespValue> slot_1(ResponseReplicaFlagSize);
slot_1[0].type(NetworkFilters::Common::Redis::RespType::Integer);
slot_1[0].asInteger() = 0;
slot_1[1].type(NetworkFilters::Common::Redis::RespType::Integer);
Expand All @@ -290,7 +292,7 @@ class RedisClusterTest : public testing::Test,
slot_1[3].type(NetworkFilters::Common::Redis::RespType::Array);
slot_1[3].asArray().swap(replica_1);

std::vector<NetworkFilters::Common::Redis::RespValue> slot_2(4);
std::vector<NetworkFilters::Common::Redis::RespValue> slot_2(ResponseReplicaFlagSize);
slot_2[0].type(NetworkFilters::Common::Redis::RespType::Integer);
slot_2[0].asInteger() = 10000;
slot_2[1].type(NetworkFilters::Common::Redis::RespType::Integer);
Expand Down Expand Up @@ -321,7 +323,7 @@ class RedisClusterTest : public testing::Test,
respValue.asString() = correct_value;
} else {
respValue.type(NetworkFilters::Common::Redis::RespType::Integer);
respValue.asInteger() = 10;
respValue.asInteger() = ResponseFlagSize;
}
return respValue;
}
Expand Down Expand Up @@ -355,8 +357,9 @@ class RedisClusterTest : public testing::Test,

// Create a redis cluster slot response. If a bit is set in the bitset, then that part of
// of the response is correct, otherwise it's incorrect.
NetworkFilters::Common::Redis::RespValuePtr createResponse(std::bitset<10> flags,
std::bitset<3> replica_flags) const {
NetworkFilters::Common::Redis::RespValuePtr
createResponse(std::bitset<ResponseFlagSize> flags,
std::bitset<ResponseReplicaFlagSize> replica_flags) const {
int64_t idx(0);
int64_t slots_type = idx++;
int64_t slots_size = idx++;
Expand All @@ -367,25 +370,36 @@ class RedisClusterTest : public testing::Test,
int64_t master_type = idx++;
int64_t master_size = idx++;
int64_t master_ip_type = idx++;
int64_t master_ip_value = idx++;
int64_t master_port_type = idx++;
idx = 0;
int64_t replica_size = idx++;
int64_t replica_ip_type = idx++;
int64_t replica_ip_value = idx++;
int64_t replica_port_type = idx++;

std::vector<NetworkFilters::Common::Redis::RespValue> master_1_array;
if (flags.test(master_size)) {
// Ip field.
master_1_array.push_back(createStringField(flags.test(master_ip_type), "127.0.0.1"));
if (flags.test(master_ip_value)) {
master_1_array.push_back(createStringField(flags.test(master_ip_type), "127.0.0.1"));
} else {
master_1_array.push_back(createStringField(flags.test(master_ip_type), "bad ip foo"));
}
// Port field.
master_1_array.push_back(createIntegerField(flags.test(master_port_type), 22120));
}

std::vector<NetworkFilters::Common::Redis::RespValue> replica_1_array;
if (replica_flags.any()) {
// Ip field.
replica_1_array.push_back(
createStringField(replica_flags.test(replica_ip_type), "127.0.0.2"));
if (replica_flags.test(replica_ip_value)) {
replica_1_array.push_back(
createStringField(replica_flags.test(replica_ip_type), "127.0.0.2"));
} else {
replica_1_array.push_back(
createStringField(replica_flags.test(replica_ip_type), "bad ip bar"));
}
// Port field.
replica_1_array.push_back(createIntegerField(replica_flags.test(replica_port_type), 22120));
}
Expand Down Expand Up @@ -771,17 +785,17 @@ TEST_F(RedisClusterTest, RedisErrorResponse) {
EXPECT_CALL(membership_updated_, ready());
EXPECT_CALL(initialized_, ready());
EXPECT_CALL(*cluster_callback_, onClusterSlotUpdate(_, _)).Times(1);
std::bitset<10> single_slot_master(0x7ff);
std::bitset<3> no_replica(0);
std::bitset<ResponseFlagSize> single_slot_master(0xfff);
std::bitset<ResponseReplicaFlagSize> no_replica(0);
expectClusterSlotResponse(createResponse(single_slot_master, no_replica));
expectHealthyHosts(std::list<std::string>({"127.0.0.1:22120"}));

// Expect no change if resolve failed.
uint64_t update_attempt = 2;
uint64_t update_failure = 1;
// Test every combination the cluster slots response.
for (uint64_t i = 0; i < (1 << 10); i++) {
std::bitset<10> flags(i);
for (uint64_t i = 0; i < (1 << ResponseFlagSize); i++) {
std::bitset<ResponseFlagSize> flags(i);
expectRedisResolve();
resolve_timer_->invokeCallback();
if (flags.all()) {
Expand All @@ -807,17 +821,17 @@ TEST_F(RedisClusterTest, RedisReplicaErrorResponse) {
EXPECT_CALL(membership_updated_, ready());
EXPECT_CALL(initialized_, ready());
EXPECT_CALL(*cluster_callback_, onClusterSlotUpdate(_, _)).Times(1);
std::bitset<10> single_slot_master(0x7ff);
std::bitset<3> no_replica(0);
std::bitset<ResponseFlagSize> single_slot_master(0xfff);
std::bitset<ResponseReplicaFlagSize> no_replica(0);
expectClusterSlotResponse(createResponse(single_slot_master, no_replica));
expectHealthyHosts(std::list<std::string>({"127.0.0.1:22120"}));

// Expect no change if resolve failed.
uint64_t update_attempt = 1;
uint64_t update_failure = 0;
// Test every combination the replica error response.
for (uint64_t i = 1; i < (1 << 3); i++) {
std::bitset<3> replica_flags(i);
for (uint64_t i = 1; i < (1 << ResponseReplicaFlagSize); i++) {
std::bitset<ResponseReplicaFlagSize> replica_flags(i);
expectRedisResolve();
resolve_timer_->invokeCallback();
if (replica_flags.all()) {
Expand Down