From a8d5117371e8b9d16ff28011329bc04104eaf50a Mon Sep 17 00:00:00 2001 From: Sonder <55493212+AndSonder@users.noreply.github.com> Date: Wed, 13 Dec 2023 11:54:43 +0800 Subject: [PATCH] [PHI] register c_comm_init_all_op to phi (#59672) * reg c_comm_init_all_op to phi * Apply suggestions from code review * fix build error --- .../collective/c_comm_init_all_op.cc | 105 ++---------------- .../collective/c_comm_init_all_op.cu.cc | 20 ++++ .../operators/collective/c_comm_init_all_op.h | 95 ++++++++++++++++ .../collective/c_comm_init_all_op_xpu.cc | 20 ++++ 4 files changed, 143 insertions(+), 97 deletions(-) create mode 100644 paddle/fluid/operators/collective/c_comm_init_all_op.cu.cc create mode 100644 paddle/fluid/operators/collective/c_comm_init_all_op.h create mode 100644 paddle/fluid/operators/collective/c_comm_init_all_op_xpu.cc diff --git a/paddle/fluid/operators/collective/c_comm_init_all_op.cc b/paddle/fluid/operators/collective/c_comm_init_all_op.cc index 2dc9af0139546..3a6156eb96e71 100644 --- a/paddle/fluid/operators/collective/c_comm_init_all_op.cc +++ b/paddle/fluid/operators/collective/c_comm_init_all_op.cc @@ -11,110 +11,22 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include -#include "paddle/fluid/framework/op_info.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/threadpool.h" -#include "paddle/fluid/platform/collective_helper.h" - -#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) -#include "paddle/fluid/platform/device/gpu/nccl_helper.h" -#endif - -#if defined(PADDLE_WITH_XPU_BKCL) -#include "paddle/fluid/platform/device/xpu/bkcl_helper.h" -#endif - -namespace paddle { -namespace framework { -class InferShapeContext; -class Scope; -} // namespace framework -} // namespace paddle +#include "paddle/fluid/operators/collective/c_comm_init_all_op.h" namespace paddle { namespace operators { -class CCommInitAllInferShape : public framework::InferShapeBase { +class CCommInitAllOp : public framework::OperatorWithKernel { public: - ~CCommInitAllInferShape() override = default; - void operator()(framework::InferShapeContext* ctx) const override{}; -}; - -class CCommInitAllOp : public framework::OperatorBase { - public: - CCommInitAllOp(const std::string& type, - const framework::VariableNameMap& inputs, - const framework::VariableNameMap& outputs, - const framework::AttributeMap& attrs) - : OperatorBase(type, inputs, outputs, attrs) {} - - void RunImpl(const framework::Scope& scope, - const platform::Place& place) const override { - // PADDLE_ENFORCE_EQ(platform::is_gpu_place(place), true, - // platform::errors::PreconditionNotMet( - // "CCommInitAllOp can run on gpu place only")); - -#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) - std::vector devices = Attr>("devices"); - if (devices.empty()) { - devices = platform::GetSelectedDevices(); - } - - int rid = Attr("ring_id"); - - platform::NCCLCommContext::Instance().CreateAllNCCLComms(devices, rid); - -#elif defined(PADDLE_WITH_XPU_BKCL) - std::vector devices = Attr>("devices"); - int ring_id = Attr("ring_id"); - - if (devices.empty()) { - int count = platform::GetXPUDeviceCount(); - for (int i = 0; i < count; ++i) { - devices.push_back(i); - } - } - - if (devices.size() > 1) { - std::vector place_list_; - for (size_t i = 0; i < devices.size(); ++i) { - auto p = platform::XPUPlace(devices[i]); - place_list_.push_back(p); - } - - // create pthread to bkcl_init_rank on all devices - auto ptr = new platform::BKCLContextMap(place_list_); - ptr->init(); - - for (size_t i = 0; i < devices.size(); ++i) { - platform::BKCLCommContext::Instance().AssignBKCLComm( - ptr->contexts_.at(devices[i]).comm_, - devices.size(), - devices[i], - devices[i], - ring_id); - - VLOG(0) << "bkcl communicator of rank " << devices[i] << " in ring " - << ring_id << " has been created on device " << devices[i]; + using framework::OperatorWithKernel::OperatorWithKernel; - // TODO(WorgenZhang): need release comm_map_ when quit - // std::call_once(once_flag_, []() { - // std::atexit([]() { - // platform::BKCLCommContext::Instance().ReleaseBKCLComms(); }); - // }); - } + void InferShape(framework::InferShapeContext* ctx) const override {} - VLOG(0) << "done bkcl_init_rank on all devices"; - } else { - VLOG(0) - << "bkcl_init_rank doesn't support on one device, skip init process"; - } -#else - PADDLE_THROW(platform::errors::PreconditionNotMet( - "PaddlePaddle should compile with GPU or XPU.")); -#endif + protected: + phi::KernelKey GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return phi::KernelKey(framework::proto::VarType::FP32, ctx.GetPlace()); } }; @@ -142,5 +54,4 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(c_comm_init_all, ops::CCommInitAllOp, - ops::CCommInitAllInferShape, ops::CCommInitAllOpMaker); diff --git a/paddle/fluid/operators/collective/c_comm_init_all_op.cu.cc b/paddle/fluid/operators/collective/c_comm_init_all_op.cu.cc new file mode 100644 index 0000000000000..43e40b6389932 --- /dev/null +++ b/paddle/fluid/operators/collective/c_comm_init_all_op.cu.cc @@ -0,0 +1,20 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/collective/c_comm_init_all_op.h" + +namespace ops = paddle::operators; + +PD_REGISTER_STRUCT_KERNEL( + c_comm_init_all, GPU, ALL_LAYOUT, ops::CCommInitAllKernel, float) {} diff --git a/paddle/fluid/operators/collective/c_comm_init_all_op.h b/paddle/fluid/operators/collective/c_comm_init_all_op.h new file mode 100644 index 0000000000000..c5caff337e526 --- /dev/null +++ b/paddle/fluid/operators/collective/c_comm_init_all_op.h @@ -0,0 +1,95 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include + +#include "paddle/fluid/framework/op_info.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/threadpool.h" +#include "paddle/fluid/platform/collective_helper.h" + +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) +#include "paddle/fluid/platform/device/gpu/nccl_helper.h" +#endif + +#if defined(PADDLE_WITH_XPU_BKCL) +#include "paddle/fluid/platform/device/xpu/bkcl_helper.h" +#endif + +namespace paddle { +namespace operators { + +template +class CCommInitAllKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) + std::vector devices = ctx.Attr>("devices"); + if (devices.empty()) { + devices = platform::GetSelectedDevices(); + } + + int rid = ctx.Attr("ring_id"); + + platform::NCCLCommContext::Instance().CreateAllNCCLComms(devices, rid); +#elif defined(PADDLE_WITH_XPU_BKCL) + std::vector devices = ctx.Attr>("devices"); + int ring_id = ctx.Attr("ring_id"); + + if (devices.empty()) { + int count = platform::GetXPUDeviceCount(); + for (int i = 0; i < count; ++i) { + devices.push_back(i); + } + } + + if (devices.size() > 1) { + std::vector place_list_; + for (size_t i = 0; i < devices.size(); ++i) { + auto p = platform::XPUPlace(devices[i]); + place_list_.push_back(p); + } + + // create pthread to bkcl_init_rank on all devices + auto ptr = new platform::BKCLContextMap(place_list_); + ptr->init(); + + for (size_t i = 0; i < devices.size(); ++i) { + platform::BKCLCommContext::Instance().AssignBKCLComm( + ptr->contexts_.at(devices[i]).comm_, + devices.size(), + devices[i], + devices[i], + ring_id); + + VLOG(0) << "bkcl communicator of rank " << devices[i] << " in ring " + << ring_id << " has been created on device " << devices[i]; + + // TODO(WorgenZhang): need release comm_map_ when quit + // std::call_once(once_flag_, []() { + // std::atexit([]() { + // platform::BKCLCommContext::Instance().ReleaseBKCLComms(); }); + // }); + } + + VLOG(0) << "done bkcl_init_rank on all devices"; + } +#endif + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/collective/c_comm_init_all_op_xpu.cc b/paddle/fluid/operators/collective/c_comm_init_all_op_xpu.cc new file mode 100644 index 0000000000000..d23f0931c6e73 --- /dev/null +++ b/paddle/fluid/operators/collective/c_comm_init_all_op_xpu.cc @@ -0,0 +1,20 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/collective/c_comm_init_all_op.h" + +namespace ops = paddle::operators; + +PD_REGISTER_STRUCT_KERNEL( + c_comm_init_all, XPU, ALL_LAYOUT, ops::CCommInitAllKernel, float) {}