Skip to content

Commit

Permalink
Reapply "[jit] Implement ScriptProfile to collect instruction profile…
Browse files Browse the repository at this point in the history
…s." (pytorch#58783)

Summary:
Pull Request resolved: pytorch#58783

This reverts commit fc804b5.

Test Plan: Imported from OSS

Reviewed By: gmagogsfm

Differential Revision: D28617037

Pulled By: zhxchen17

fbshipit-source-id: 645de2ede20500a5c218d6ec3c7faae94de37a14
  • Loading branch information
zhxchen17 authored and facebook-github-bot committed May 25, 2021
1 parent 705dd9f commit 2b0ec9c
Show file tree
Hide file tree
Showing 8 changed files with 351 additions and 1 deletion.
1 change: 1 addition & 0 deletions test/cpp/jit/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ set(JIT_TEST_SRCS
${JIT_TEST_ROOT}/test_subgraph_rewriter.cpp
${JIT_TEST_ROOT}/test_subgraph_utils.cpp
${JIT_TEST_ROOT}/test_utils.cpp
${JIT_TEST_ROOT}/test_script_profile.cpp
)

if(USE_CUDA)
Expand Down
62 changes: 62 additions & 0 deletions test/cpp/jit/test_script_profile.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
#include <gtest/gtest.h>

#include <c10/util/Optional.h>
#include <test/cpp/jit/test_utils.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/ir/irparser.h>
#include <torch/csrc/jit/runtime/script_profile.h>

namespace torch {
namespace jit {

TEST(ScriptProfileTest, Basic) {
const std::string source_string = R"V0G0N(
def foo(a, b):
return a + b #
)V0G0N";
auto begin = source_string.find("return");
auto end = source_string.find(" #");

Graph g;
const auto graph_string = R"IR(
graph(%a : Tensor,
%b : Tensor):
%2 : int = prim::Constant[value=1]()
%3 : Tensor = aten::add(%a, %b, %2)
return (%3))IR";

torch::jit::parseIR(graph_string, &g);
auto source = std::make_shared<Source>(source_string, "", 0);
auto node = *g.nodes().begin();
node->setSourceRange(SourceRange{source, begin, end});

ScriptProfile p;
p.enable();
{
profiling::InstructionSpan g0(*node);
profiling::InstructionSpan g1(*node);
profiling::InstructionSpan g2(*node);
}
p.disable();

auto stats = p.dumpStats();
EXPECT_EQ(stats.size(), 1);
auto it = stats.find(*source.get());
EXPECT_NE(it, stats.end());
auto& lines = it->second;
EXPECT_EQ(lines.size(), 1);
const auto& stat = lines.at(source->lineno_for_offset(begin));
EXPECT_EQ(stat.count, 3);
}

TEST(ScriptProfileTest, CallingOrder) {
ScriptProfile p;
p.enable();
EXPECT_THROW(p.dumpStats(), c10::Error);
p.disable();
auto dp = std::make_shared<profiling::Datapoint>(SourceRange{});
EXPECT_THROW(p.addDatapoint(std::move(dp)), c10::Error);
}

} // namespace jit
} // namespace torch
1 change: 1 addition & 0 deletions tools/build_variables.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ core_sources_full_mobile = [
"torch/csrc/jit/runtime/logging.cpp",
"torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp",
"torch/csrc/jit/runtime/profiling_record.cpp",
"torch/csrc/jit/runtime/script_profile.cpp",
"torch/csrc/jit/runtime/symbolic_script.cpp",
"torch/csrc/jit/serialization/callstack_debug_info_serialization.cpp",
"torch/csrc/jit/serialization/import.cpp",
Expand Down
6 changes: 5 additions & 1 deletion torch/csrc/jit/frontend/source_ref.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <functional>
#include <memory>

#include <ATen/core/ivalue.h>
#include <c10/macros/Export.h>
#include <torch/csrc/jit/frontend/source_range.h>

Expand All @@ -18,7 +19,7 @@ namespace jit {
* support heteogeneous lookup, and also shared_ptr is an implementation detail
* which should be encapsulated.
*/
class TORCH_API SourceRef {
class TORCH_API SourceRef : public CustomClassHolder {
public:
explicit SourceRef(std::shared_ptr<Source> source)
: source_(std::move(source)) {}
Expand All @@ -34,6 +35,9 @@ class TORCH_API SourceRef {
bool operator<(const SourceRef& other) const {
return *this < *other.source_.get();
}
const Source* operator->() const {
return source_.get();
}

private:
std::shared_ptr<Source> source_;
Expand Down
1 change: 1 addition & 0 deletions torch/csrc/jit/python/script_init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <torch/csrc/jit/serialization/import.h>
#include <torch/csrc/jit/testing/file_check.h>

#include <c10/util/intrusive_ptr.h>
#include <torch/csrc/jit/frontend/parser.h>
#include <torch/csrc/jit/frontend/tracer.h>
#include <torch/csrc/jit/ir/constants.h>
Expand Down
3 changes: 3 additions & 0 deletions torch/csrc/jit/runtime/interpreter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <torch/csrc/jit/runtime/jit_exception.h>
#include <torch/csrc/jit/runtime/operator.h>
#include <torch/csrc/jit/runtime/profiling_record.h>
#include <torch/csrc/jit/runtime/script_profile.h>
#include <torch/csrc/jit/runtime/vararg_functions.h>

#ifdef USE_RPC
Expand Down Expand Up @@ -229,6 +230,8 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target {
// std::cout << "RUNNING ";
// frames.back().function->dump(std::cout, frame.pc);
Instruction inst = frame.function->instructions_[frame.pc];
profiling::InstructionSpan instSpan{
*frame.function->instructions_source()[frame.pc]};
switch (inst.op) {
case ENTER: {
const auto& obj = peek(stack, 0, 1);
Expand Down
177 changes: 177 additions & 0 deletions torch/csrc/jit/runtime/script_profile.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
#include <torch/csrc/jit/runtime/script_profile.h>

#include <atomic>
#include <chrono>
#include <mutex>
#include <unordered_set>

#include <c10/util/Exception.h>
#include <c10/util/intrusive_ptr.h>
#include <torch/csrc/jit/api/function_impl.h>

namespace torch {
namespace jit {

namespace {

class ProfilesRegistry {
public:
bool empty() {
return empty_.load(std::memory_order_relaxed);
}

void addProfile(ScriptProfile& p) {
std::lock_guard<std::mutex> g(mutex_);
enabledProfiles_.emplace(&p);
empty_.store(false, std::memory_order_relaxed);
}

void removeProfile(ScriptProfile& p) {
std::lock_guard<std::mutex> g(mutex_);
enabledProfiles_.erase(&p);
if (enabledProfiles_.empty()) {
empty_.store(true, std::memory_order_relaxed);
}
}

void send(std::unique_ptr<profiling::Datapoint> datapoint) {
auto shared = std::shared_ptr<profiling::Datapoint>(std::move(datapoint));
std::lock_guard<std::mutex> g(mutex_);
for (auto* p : enabledProfiles_) {
p->addDatapoint(shared);
}
}

private:
std::atomic<bool> empty_{true};
std::mutex mutex_;
std::unordered_set<ScriptProfile*> enabledProfiles_;
};

ProfilesRegistry& getProfilesRegistry() {
static auto registry = std::ref(*new ProfilesRegistry{});
return registry;
}

auto initBindings() {
torch::class_<SourceRef>("profiling", "SourceRef")
.def(
"starting_lineno",
[](const c10::intrusive_ptr<SourceRef>& self) {
return static_cast<int64_t>((*self)->starting_line_no());
})
.def("text", [](const c10::intrusive_ptr<SourceRef>& self) {
return (*self)->text();
});

torch::class_<InstructionStats>("profiling", "InstructionStats")
.def(
"count",
[](const c10::intrusive_ptr<InstructionStats>& self) {
return self->count;
})
.def("duration_ns", [](const c10::intrusive_ptr<InstructionStats>& self) {
return static_cast<int64_t>(self->duration.count());
});

torch::class_<SourceStats>("profiling", "SourceStats")
.def(
"source",
[](const c10::intrusive_ptr<SourceStats>& self) {
return c10::make_intrusive<SourceRef>(self->getSourceRef());
})
.def("line_map", &SourceStats::getLineMap);

torch::class_<ScriptProfile>("profiling", "_ScriptProfile")
.def(torch::init<>())
.def("enable", &ScriptProfile::enable)
.def("disable", &ScriptProfile::disable)
.def("_dump_stats", [](const c10::intrusive_ptr<ScriptProfile>& self) {
const auto& stats = self->dumpStats();
c10::List<c10::intrusive_ptr<SourceStats>> ret;
for (const auto& source : stats) {
SourceStats::LineMap lineMap;
for (const auto& line : source.second) {
lineMap.insert(
line.first, c10::make_intrusive<InstructionStats>(line.second));
}
ret.push_back(c10::make_intrusive<SourceStats>(
source.first, std::move(lineMap)));
}
return ret;
});
return nullptr;
}

const auto torchBindInitializer = initBindings();

} // namespace

namespace profiling {

InstructionSpan::InstructionSpan(Node& node) {
if (getProfilesRegistry().empty()) {
return;
}

datapoint_ = std::make_unique<Datapoint>(node.sourceRange());
}

InstructionSpan::~InstructionSpan() {
if (!datapoint_) {
return;
}

datapoint_->end = std::chrono::steady_clock::now();
getProfilesRegistry().send(std::move(datapoint_));
}

} // namespace profiling

void ScriptProfile::enable() {
if (!std::exchange(enabled_, true)) {
getProfilesRegistry().addProfile(*this);
}
}

void ScriptProfile::disable() {
if (std::exchange(enabled_, false)) {
getProfilesRegistry().removeProfile(*this);
}
}

void ScriptProfile::addDatapoint(
std::shared_ptr<profiling::Datapoint> datapoint) {
TORCH_CHECK(enabled_, "Cannot only add datapoint to disabled profilers.");
datapoints_.push_back(std::move(datapoint));
}

const ScriptProfile::SourceMap& ScriptProfile::dumpStats() {
TORCH_CHECK(!enabled_, "Only disabled profilers are allowed to dump stats.");

for (const auto& datapoint : datapoints_) {
if (const auto& source = datapoint->sourceRange.source()) {
if (auto fileLineCol = datapoint->sourceRange.file_line_col()) {
auto it = sourceMap_.find(*source.get());
if (it == sourceMap_.end()) {
it = sourceMap_.emplace(SourceRef{source}, LineMap{}).first;
}
auto& stats = it->second[std::get<1>(*fileLineCol)];
stats.count++;
stats.duration += datapoint->end - datapoint->start;
}
}
}
datapoints_.clear();

return sourceMap_;
}

ScriptProfile::~ScriptProfile() {
if (enabled_) {
getProfilesRegistry().removeProfile(*this);
}
}

} // namespace jit
} // namespace torch
Loading

0 comments on commit 2b0ec9c

Please sign in to comment.