Skip to content

Commit

Permalink
Refactor MemoryPool passing for ExchangeClient (#2447)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2447

Setter methods make it difficult to reason about the state of the object. Setters are removed and MemoryPool is fed in through creation methods.

Reviewed By: xiaoxmeng

Differential Revision: D39058083

fbshipit-source-id: a9518886a74e4ac3727e04890837d2c39cdcf4b5
  • Loading branch information
tanjialiang authored and facebook-github-bot committed Sep 7, 2022
1 parent 9bce36f commit 0ae3c3c
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 28 deletions.
35 changes: 15 additions & 20 deletions velox/exec/Exchange.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
#include "velox/exec/Exchange.h"
#include <velox/buffer/Buffer.h>
#include <velox/common/base/Exceptions.h>
#include <velox/common/memory/MappedMemory.h>
#include <velox/common/memory/Memory.h>
#include "velox/exec/PartitionedOutputBufferManager.h"

namespace facebook::velox::exec {
Expand Down Expand Up @@ -58,20 +60,17 @@ void SerializedPage::prepareStreamForDeserialize(ByteStream* input) {
std::shared_ptr<ExchangeSource> ExchangeSource::create(
const std::string& taskId,
int destination,
std::shared_ptr<ExchangeQueue> queue) {
std::shared_ptr<ExchangeQueue> queue,
memory::MemoryPool* FOLLY_NONNULL pool) {
for (auto& factory : factories()) {
auto result = factory(taskId, destination, queue);
auto result = factory(taskId, destination, queue, pool);
if (result) {
return result;
}
}
VELOX_FAIL("No ExchangeSource factory matches {}", taskId);
}

void ExchangeSource::setMemoryPool(memory::MemoryPool* FOLLY_NULLABLE pool) {
pool_ = pool;
}

// static
std::vector<ExchangeSource::Factory>& ExchangeSource::factories() {
static std::vector<Factory> factories;
Expand All @@ -84,8 +83,9 @@ class LocalExchangeSource : public ExchangeSource {
LocalExchangeSource(
const std::string& taskId,
int destination,
std::shared_ptr<ExchangeQueue> queue)
: ExchangeSource(taskId, destination, queue) {}
std::shared_ptr<ExchangeQueue> queue,
memory::MemoryPool* FOLLY_NONNULL pool)
: ExchangeSource(taskId, destination, queue, pool) {}

bool shouldRequestLocked() override {
if (atEnd_) {
Expand Down Expand Up @@ -131,7 +131,7 @@ class LocalExchangeSource : public ExchangeSource {
}
inputPage->unshare();
pages.push_back(
std::make_unique<SerializedPage>(std::move(inputPage)));
std::make_unique<SerializedPage>(std::move(inputPage), pool_));
inputPage = nullptr;
}
int64_t ackSequence;
Expand Down Expand Up @@ -168,23 +168,19 @@ class LocalExchangeSource : public ExchangeSource {
std::unique_ptr<ExchangeSource> createLocalExchangeSource(
const std::string& taskId,
int destination,
std::shared_ptr<ExchangeQueue> queue) {
std::shared_ptr<ExchangeQueue> queue,
memory::MemoryPool* FOLLY_NONNULL pool) {
if (strncmp(taskId.c_str(), "local://", 8) == 0) {
return std::make_unique<LocalExchangeSource>(
taskId, destination, std::move(queue));
taskId, destination, std::move(queue), pool);
}
return nullptr;
}

} // namespace

void ExchangeClient::maybeSetMemoryPool(
memory::MemoryPool* FOLLY_NONNULL pool) {
// ExchangeClient could be shared by the same exchange operators from
// different drivers so we only need to set it on the first operator setup.
if (pool_ == nullptr) {
pool_ = pool;
}
void ExchangeClient::initialize(memory::MemoryPool* FOLLY_NONNULL pool) {
pool_ = pool;
}

void ExchangeClient::addRemoteTaskId(const std::string& taskId) {
Expand All @@ -199,8 +195,7 @@ void ExchangeClient::addRemoteTaskId(const std::string& taskId) {
// and the task updates have no guarantees of arriving in order.
return;
}
auto source = ExchangeSource::create(taskId, destination_, queue_);
source->setMemoryPool(pool_);
auto source = ExchangeSource::create(taskId, destination_, queue_, pool_);

if (closed_) {
toClose = std::move(source);
Expand Down
25 changes: 17 additions & 8 deletions velox/exec/Exchange.h
Original file line number Diff line number Diff line change
Expand Up @@ -211,20 +211,26 @@ class ExchangeSource : public std::enable_shared_from_this<ExchangeSource> {
using Factory = std::function<std::shared_ptr<ExchangeSource>(
const std::string& taskId,
int destination,
std::shared_ptr<ExchangeQueue> queue)>;
std::shared_ptr<ExchangeQueue> queue,
memory::MemoryPool* FOLLY_NONNULL pool)>;

ExchangeSource(
const std::string& taskId,
int destination,
std::shared_ptr<ExchangeQueue> queue)
: taskId_(taskId), destination_(destination), queue_(std::move(queue)) {}
std::shared_ptr<ExchangeQueue> queue,
memory::MemoryPool* FOLLY_NONNULL pool)
: taskId_(taskId),
destination_(destination),
queue_(std::move(queue)),
pool_(pool) {}

virtual ~ExchangeSource() = default;

static std::shared_ptr<ExchangeSource> create(
const std::string& taskId,
int destination,
std::shared_ptr<ExchangeQueue> queue);
std::shared_ptr<ExchangeQueue> queue,
memory::MemoryPool* FOLLY_NONNULL pool);

// Returns true if there is no request to the source pending or if
// this should be retried. If true, the caller is expected to call
Expand Down Expand Up @@ -262,7 +268,6 @@ class ExchangeSource : public std::enable_shared_from_this<ExchangeSource> {

static std::vector<Factory>& factories();

void setMemoryPool(memory::MemoryPool* FOLLY_NULLABLE pool);
// ID of the task producing data
const std::string taskId_;
// Destination number of 'this' on producer
Expand All @@ -273,7 +278,7 @@ class ExchangeSource : public std::enable_shared_from_this<ExchangeSource> {
bool atEnd_ = false;

protected:
memory::MemoryPool* FOLLY_NULLABLE pool_{nullptr};
memory::MemoryPool* FOLLY_NONNULL pool_;
};

struct RemoteConnectorSplit : public connector::ConnectorSplit {
Expand Down Expand Up @@ -304,7 +309,7 @@ class ExchangeClient {
return pool_;
}

void maybeSetMemoryPool(memory::MemoryPool* FOLLY_NONNULL pool);
void initialize(memory::MemoryPool* FOLLY_NONNULL pool);

// Creates an exchange source and starts fetching data from the specified
// upstream task. If 'close' has been called already, creates an exchange
Expand Down Expand Up @@ -351,7 +356,11 @@ class Exchange : public SourceOperator {
"Exchange"),
planNodeId_(exchangeNode->id()),
exchangeClient_(std::move(exchangeClient)) {
exchangeClient_->maybeSetMemoryPool(operatorCtx_->pool());
if (operatorCtx_->driverCtx()->driverId == 0) {
// As all Exchange operators share the same ExchangeClient, we only
// need one to do client initialization.
exchangeClient_->initialize(operatorCtx_->pool());
}
}

~Exchange() override {
Expand Down
1 change: 1 addition & 0 deletions velox/exec/MergeSource.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ class MergeExchangeSource : public MergeSource {
int destination)
: mergeExchange_(mergeExchange),
client_(std::make_unique<ExchangeClient>(destination)) {
client_->initialize(mergeExchange->pool());
client_->addRemoteTaskId(taskId);
client_->noMoreRemoteTasks();
}
Expand Down
5 changes: 5 additions & 0 deletions velox/exec/Task.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1290,6 +1290,11 @@ ContinueFuture Task::terminate(TaskState terminalState) {

for (auto& [planNodeId, splits] : remainingRemoteSplits) {
for (auto& split : splits.first) {
if (!exchangeClientByPlanNode_[planNodeId]->pool()) {
// If we terminate even before the client's initialization, we
// initialize the client with Task's memory pool.
exchangeClientByPlanNode_[planNodeId]->initialize(pool_.get());
}
addRemoteSplit(planNodeId, split);
}
if (splits.second) {
Expand Down

0 comments on commit 0ae3c3c

Please sign in to comment.