Skip to content

Commit

Permalink
[PyTorch] Take const Type& in TensorType::fromNumberType (pytorch#66716)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#66716

No need to require a refcount bump for this function.
ghstack-source-id: 140754065

Test Plan: CI

Reviewed By: suo

Differential Revision: D31696639

fbshipit-source-id: bf8aa3f542d52e82e0f6a444b8898330f3d16a31
  • Loading branch information
swolchok authored and facebook-github-bot committed Oct 18, 2021
1 parent 6a7296b commit 622e19b
Show file tree
Hide file tree
Showing 6 changed files with 13 additions and 15 deletions.
14 changes: 7 additions & 7 deletions aten/src/ATen/core/jit_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,7 @@ struct TORCH_API TensorType : public Type {
at::Device device,
at::IntArrayRef sizes);

static TypePtr fromNumberType(TypePtr typ);
static TypePtr fromNumberType(const Type& typ);
static TypePtr fromBoolType();

c10::optional<size_t> dim() const {
Expand Down Expand Up @@ -1606,17 +1606,17 @@ inline TypePtr unshapedType(const TypePtr& type) {
return type->withContained(fmap(type->containedTypes(), unshapedType));
}

inline TypePtr TensorType::fromNumberType(TypePtr typ) {
if (typ->isSubtypeOf(*IntType::get())) {
inline TypePtr TensorType::fromNumberType(const Type& typ) {
if (typ.isSubtypeOf(*IntType::get())) {
return TensorType::createContiguous(at::kLong, at::kCPU, {});
} else if (typ->isSubtypeOf(*FloatType::get())) {
} else if (typ.isSubtypeOf(*FloatType::get())) {
return TensorType::createContiguous(at::kDouble, at::kCPU, {});
} else if (typ->isSubtypeOf(*BoolType::get())) {
} else if (typ.isSubtypeOf(*BoolType::get())) {
return TensorType::createContiguous(at::kBool, at::kCPU, {});
} else if (typ->kind() == NumberType::Kind) {
} else if (typ.kind() == NumberType::Kind) {
return TensorType::create(c10::nullopt, at::kCPU, {}, c10::nullopt);
}
TORCH_CHECK(false, "Unknown number type: ", typ->str());
TORCH_CHECK(false, "Unknown number type: ", typ.str());
}
inline TypePtr TensorType::fromBoolType() {
return TensorType::createContiguous(at::kBool, at::kCPU, {});
Expand Down
3 changes: 1 addition & 2 deletions torch/csrc/jit/ir/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1814,9 +1814,8 @@ Node* Graph::createDict(
}

Node* Graph::createNumToTensor(Value* value) {
auto typ = value->type();
Node* result = create(prim::NumToTensor, {value});
result->output()->setType(TensorType::fromNumberType(std::move(typ)));
result->output()->setType(TensorType::fromNumberType(*value->type()));
return result;
}

Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/passes/erase_number_types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace jit {

void SetNumTypeToTensorType(Value* v) {
if (v->type()->isSubtypeOf(*NumberType::get())) {
v->setType(TensorType::fromNumberType(v->type()));
v->setType(TensorType::fromNumberType(*v->type()));
} else if (v->type()->isSubtypeOf(*BoolType::get())) {
v->setType(TensorType::fromBoolType());
}
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ void FixupONNXLoopNodeInputs(Node* node) {
cond->setType(BoolType::get());

Value* i = sub_block->inputs().at(0);
i->setType(TensorType::fromNumberType(IntType::get()));
i->setType(TensorType::fromNumberType(*IntType::get()));

// add cast to condition input inside the loop.
Value* next_cond_val = sub_block->outputs().at(0);
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/passes/onnx/prepare_division_for_onnx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ static void PrepareDivisionForONNXOnBlock(Block* block) {

it->replaceInput(0, floattensor_inputs[0]);
it->replaceInput(1, floattensor_inputs[1]);
it->output()->setType(TensorType::fromNumberType(FloatType::get()));
it->output()->setType(TensorType::fromNumberType(*FloatType::get()));
}
}
}
Expand Down
5 changes: 2 additions & 3 deletions torch/csrc/jit/passes/onnx/preprocess_for_onnx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,16 +144,15 @@ static void ReplaceAddWithConcat(Block* b) {
continue;
}

TypePtr elem =
const auto& elem =
it->input(0)->type()->castRaw<ListType>()->getElementType();
if (elem->cast<IntType>()) {
Node* concat_node = b->owningGraph()->create(onnx::Concat, 1);
concat_node->i_(attr::axis, 0);
concat_node->insertBefore(*it);
concat_node->addInput(it->input(0));
concat_node->addInput(it->input(1));
concat_node->outputs()[0]->setType(
TensorType::fromNumberType(std::move(elem)));
concat_node->outputs()[0]->setType(TensorType::fromNumberType(*elem));
it->replaceAllUsesWith(concat_node);
it->removeAllInputs();
it.destroyCurrent();
Expand Down

0 comments on commit 622e19b

Please sign in to comment.