Skip to content

Commit

Permalink
[PHI] register c_comm_init_all_op to phi (#59672)
Browse files Browse the repository at this point in the history
* reg c_comm_init_all_op to phi

* Apply suggestions from code review

* fix build error
  • Loading branch information
AndSonder committed Dec 13, 2023
1 parent 538905c commit a8d5117
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 97 deletions.
105 changes: 8 additions & 97 deletions paddle/fluid/operators/collective/c_comm_init_all_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,110 +11,22 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <string>

#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/platform/collective_helper.h"

#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#endif

#if defined(PADDLE_WITH_XPU_BKCL)
#include "paddle/fluid/platform/device/xpu/bkcl_helper.h"
#endif

namespace paddle {
namespace framework {
class InferShapeContext;
class Scope;
} // namespace framework
} // namespace paddle
#include "paddle/fluid/operators/collective/c_comm_init_all_op.h"

namespace paddle {
namespace operators {

class CCommInitAllInferShape : public framework::InferShapeBase {
class CCommInitAllOp : public framework::OperatorWithKernel {
public:
~CCommInitAllInferShape() override = default;
void operator()(framework::InferShapeContext* ctx) const override{};
};

class CCommInitAllOp : public framework::OperatorBase {
public:
CCommInitAllOp(const std::string& type,
const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}

void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override {
// PADDLE_ENFORCE_EQ(platform::is_gpu_place(place), true,
// platform::errors::PreconditionNotMet(
// "CCommInitAllOp can run on gpu place only"));

#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
std::vector<int> devices = Attr<std::vector<int>>("devices");
if (devices.empty()) {
devices = platform::GetSelectedDevices();
}

int rid = Attr<int>("ring_id");

platform::NCCLCommContext::Instance().CreateAllNCCLComms(devices, rid);

#elif defined(PADDLE_WITH_XPU_BKCL)
std::vector<int> devices = Attr<std::vector<int>>("devices");
int ring_id = Attr<int>("ring_id");

if (devices.empty()) {
int count = platform::GetXPUDeviceCount();
for (int i = 0; i < count; ++i) {
devices.push_back(i);
}
}

if (devices.size() > 1) {
std::vector<platform::Place> place_list_;
for (size_t i = 0; i < devices.size(); ++i) {
auto p = platform::XPUPlace(devices[i]);
place_list_.push_back(p);
}

// create pthread to bkcl_init_rank on all devices
auto ptr = new platform::BKCLContextMap(place_list_);
ptr->init();

for (size_t i = 0; i < devices.size(); ++i) {
platform::BKCLCommContext::Instance().AssignBKCLComm(
ptr->contexts_.at(devices[i]).comm_,
devices.size(),
devices[i],
devices[i],
ring_id);

VLOG(0) << "bkcl communicator of rank " << devices[i] << " in ring "
<< ring_id << " has been created on device " << devices[i];
using framework::OperatorWithKernel::OperatorWithKernel;

// TODO(WorgenZhang): need release comm_map_ when quit
// std::call_once(once_flag_, []() {
// std::atexit([]() {
// platform::BKCLCommContext::Instance().ReleaseBKCLComms(); });
// });
}
void InferShape(framework::InferShapeContext* ctx) const override {}

VLOG(0) << "done bkcl_init_rank on all devices";
} else {
VLOG(0)
<< "bkcl_init_rank doesn't support on one device, skip init process";
}
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with GPU or XPU."));
#endif
protected:
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return phi::KernelKey(framework::proto::VarType::FP32, ctx.GetPlace());
}
};

Expand Down Expand Up @@ -142,5 +54,4 @@ namespace ops = paddle::operators;

REGISTER_OPERATOR(c_comm_init_all,
ops::CCommInitAllOp,
ops::CCommInitAllInferShape,
ops::CCommInitAllOpMaker);
20 changes: 20 additions & 0 deletions paddle/fluid/operators/collective/c_comm_init_all_op.cu.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/collective/c_comm_init_all_op.h"

namespace ops = paddle::operators;

PD_REGISTER_STRUCT_KERNEL(
c_comm_init_all, GPU, ALL_LAYOUT, ops::CCommInitAllKernel, float) {}
95 changes: 95 additions & 0 deletions paddle/fluid/operators/collective/c_comm_init_all_op.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <string>

#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/platform/collective_helper.h"

#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#endif

#if defined(PADDLE_WITH_XPU_BKCL)
#include "paddle/fluid/platform/device/xpu/bkcl_helper.h"
#endif

namespace paddle {
namespace operators {

template <typename T, typename DeviceContext>
class CCommInitAllKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
std::vector<int> devices = ctx.Attr<std::vector<int>>("devices");
if (devices.empty()) {
devices = platform::GetSelectedDevices();
}

int rid = ctx.Attr<int>("ring_id");

platform::NCCLCommContext::Instance().CreateAllNCCLComms(devices, rid);
#elif defined(PADDLE_WITH_XPU_BKCL)
std::vector<int> devices = ctx.Attr<std::vector<int>>("devices");
int ring_id = ctx.Attr<int>("ring_id");

if (devices.empty()) {
int count = platform::GetXPUDeviceCount();
for (int i = 0; i < count; ++i) {
devices.push_back(i);
}
}

if (devices.size() > 1) {
std::vector<platform::Place> place_list_;
for (size_t i = 0; i < devices.size(); ++i) {
auto p = platform::XPUPlace(devices[i]);
place_list_.push_back(p);
}

// create pthread to bkcl_init_rank on all devices
auto ptr = new platform::BKCLContextMap(place_list_);
ptr->init();

for (size_t i = 0; i < devices.size(); ++i) {
platform::BKCLCommContext::Instance().AssignBKCLComm(
ptr->contexts_.at(devices[i]).comm_,
devices.size(),
devices[i],
devices[i],
ring_id);

VLOG(0) << "bkcl communicator of rank " << devices[i] << " in ring "
<< ring_id << " has been created on device " << devices[i];

// TODO(WorgenZhang): need release comm_map_ when quit
// std::call_once(once_flag_, []() {
// std::atexit([]() {
// platform::BKCLCommContext::Instance().ReleaseBKCLComms(); });
// });
}

VLOG(0) << "done bkcl_init_rank on all devices";
}
#endif
}
};

} // namespace operators
} // namespace paddle
20 changes: 20 additions & 0 deletions paddle/fluid/operators/collective/c_comm_init_all_op_xpu.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/collective/c_comm_init_all_op.h"

namespace ops = paddle::operators;

PD_REGISTER_STRUCT_KERNEL(
c_comm_init_all, XPU, ALL_LAYOUT, ops::CCommInitAllKernel, float) {}

0 comments on commit a8d5117

Please sign in to comment.