Skip to content

Commit

Permalink
Add a bunch of extra functionality to SymFloat (pytorch#86046)
Browse files Browse the repository at this point in the history
- SymInt to SymFloat conversion
- All the basic arithmetic operators on c10::SymFloat

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: pytorch#86046
Approved by: https://github.com/wconstab
  • Loading branch information
ezyang authored and pytorchmergebot committed Oct 2, 2022
1 parent 833edeb commit 0060d87
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 5 deletions.
54 changes: 54 additions & 0 deletions c10/core/SymFloat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,60 @@ SymFloatNode SymFloat::toSymFloatNodeImpl() const {
return SymFloatNode::reclaim_copy(toSymFloatNodeImplUnowned());
}

static std::array<SymFloatNode, 2> normalize_symfloats(
SymFloat a_,
SymFloat b_) {
SymFloatNode a, b;
if (a_.is_symbolic())
a = a_.toSymFloatNodeImpl();
if (b_.is_symbolic())
b = b_.toSymFloatNodeImpl();

SymFloatNodeImpl* common = a ? a.get() : b.get();
// TODO: technically we need to check that the classes match
if (!a) {
a = common->wrap(a_.as_float_unchecked());
a_.toSymFloat(a); //
}
if (!b) {
b = common->wrap(b_.as_float_unchecked());
b_.toSymFloat(b);
}
return {a, b};
}

SymFloat SymFloat::operator+(SymFloat sci) const {
if (!is_symbolic() && !sci.is_symbolic()) {
return SymFloat(data_ + sci.data_);
}
auto res = normalize_symfloats(*this, sci);
return SymFloat::toSymFloat(res[0]->add(res[1]));
}

SymFloat SymFloat::operator-(SymFloat sci) const {
if (!is_symbolic() && !sci.is_symbolic()) {
return SymFloat(data_ - sci.data_);
}
auto res = normalize_symfloats(*this, sci);
return SymFloat::toSymFloat(res[0]->sub(res[1]));
}

SymFloat SymFloat::operator*(SymFloat sci) const {
if (!is_symbolic() && !sci.is_symbolic()) {
return SymFloat(data_ * sci.data_);
}
auto res = normalize_symfloats(*this, sci);
return SymFloat::toSymFloat(res[0]->mul(res[1]));
}

SymFloat SymFloat::operator/(SymFloat sci) const {
if (!is_symbolic() && !sci.is_symbolic()) {
return SymFloat(data_ / sci.data_);
}
auto res = normalize_symfloats(*this, sci);
return SymFloat::toSymFloat(res[0]->truediv(res[1]));
}

c10::SymFloat SymFloat::toSymFloat(SymFloatNode sin_sp) {
return c10::SymFloat(std::move(sin_sp));
}
Expand Down
5 changes: 5 additions & 0 deletions c10/core/SymFloat.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ class C10_API SymFloat {
return data_;
}

SymFloat operator+(SymFloat) const;
SymFloat operator-(SymFloat) const;
SymFloat operator*(SymFloat) const;
SymFloat operator/(SymFloat) const;

// N.B. It's important to keep this definition in the header
// as we expect if checks to be folded for mobile builds
// where `is_symbolic` is always false
Expand Down
8 changes: 8 additions & 0 deletions c10/core/SymInt.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include <c10/core/SymFloat.h>
#include <c10/core/SymInt.h>
#include <c10/core/SymIntNodeImpl.h>
#include <array>
Expand Down Expand Up @@ -60,6 +61,13 @@ int64_t SymInt::guard_int(const char* file, int64_t line) const {
return a->guard_int(file, line);
}

SymInt::operator SymFloat() const {
if (!is_symbolic()) {
return SymFloat(double(data_));
}
return SymFloat::toSymFloat(toSymIntNodeImpl()->sym_float());
}

SymInt SymInt::operator+(SymInt sci) const {
if (!is_symbolic() && !sci.is_symbolic()) {
return SymInt(data_ + sci.data_);
Expand Down
4 changes: 4 additions & 0 deletions c10/core/SymInt.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

namespace c10 {

class SymFloat;

// `SymInt` is a C++ wrapper class around int64_t data_ which and is used to
// represent concrete dimension values.
//
Expand Down Expand Up @@ -188,6 +190,8 @@ class C10_API SymInt {
bool operator>(int64_t sci) const;
bool operator>=(int64_t sci) const;

operator SymFloat() const;

int64_t as_int_unchecked() const {
return data_;
}
Expand Down
3 changes: 3 additions & 0 deletions c10/core/SymIntNodeImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ class C10_API SymIntNodeImpl : public c10::intrusive_ptr_target {
virtual SymIntNode clone() {
TORCH_CHECK(false, "NYI");
};
virtual SymFloatNode sym_float() {
TORCH_CHECK(false, "NYI");
}
virtual SymIntNode wrap(int64_t num) {
TORCH_CHECK(false, "NYI");
};
Expand Down
8 changes: 3 additions & 5 deletions torch/csrc/jit/python/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,7 @@ class PythonSymIntNodeImpl : public c10::SymIntNodeImpl {
return getPyObj().attr("__int__")().cast<int64_t>();
}

// TODO: virtualize
SymFloat sym_float();
SymFloatNode sym_float() override;

virtual std::string str() override {
py::gil_scoped_acquire acquire;
Expand Down Expand Up @@ -299,11 +298,10 @@ SymFloatNode PythonSymIntNodeImpl::truediv(const SymIntNode& other) {
return c10::make_intrusive<PythonSymFloatNodeImpl>(r);
}

SymFloat PythonSymIntNodeImpl::sym_float() {
SymFloatNode PythonSymIntNodeImpl::sym_float() {
py::gil_scoped_acquire acquire;
return c10::make_intrusive<PythonSymFloatNodeImpl>(
getPyObj().attr("__sym_float__")())
->toSymFloat();
getPyObj().attr("__sym_float__")());
}

namespace {
Expand Down

0 comments on commit 0060d87

Please sign in to comment.