Skip to content

Commit

Permalink
add __hash__ to FunctionSchema (pytorch#90730)
Browse files Browse the repository at this point in the history
This PR adds __hash__ to FunctionSchema pybind binding, so that
it could be used for things like dict indexing
Pull Request resolved: pytorch#90730
Approved by: https://github.com/ezyang
  • Loading branch information
wanchaol authored and pytorchmergebot committed Jan 4, 2023
1 parent a7749ae commit 17bc40c
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 0 deletions.
32 changes: 32 additions & 0 deletions aten/src/ATen/core/alias_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <vector>
#include <ATen/core/symbol.h>
#include <c10/util/Exception.h>
#include <c10/util/hash.h>

namespace c10 {
/**
Expand Down Expand Up @@ -117,3 +118,34 @@ inline std::ostream& operator<<(std::ostream& out, const AliasInfo& aliasInfo) {
return out;
}
} // namespace c10

namespace std {
template <>
struct hash<c10::AliasInfo> {
size_t operator()(const c10::AliasInfo& aliasInfo) const {
auto hash = std::hash<bool>()(aliasInfo.isWrite());

// NOTE: for unordered_set hashes, we couldn't use hash_combine
// because hash_combine is order dependent. Instead, we choose to
// use XOR as the combining function as XOR is commutative.
size_t before_set_hash_seed = 0;
for (auto &e: aliasInfo.beforeSets()) {
auto symbol_hash = std::hash<c10::Symbol>()(e);
before_set_hash_seed = before_set_hash_seed ^ symbol_hash;
}
size_t after_set_hash_seed = 0;
for (auto &e: aliasInfo.afterSets()) {
auto symbol_hash = std::hash<c10::Symbol>()(e);
after_set_hash_seed = after_set_hash_seed ^ symbol_hash;
}

hash = c10::hash_combine(hash, before_set_hash_seed);
hash = c10::hash_combine(hash, after_set_hash_seed);
for (auto &e: aliasInfo.containedTypes()) {
auto contained_type_hash = std::hash<c10::AliasInfo>()(e);
hash = c10::hash_combine(hash, contained_type_hash);
}
return hash;
}
};
}
41 changes: 41 additions & 0 deletions aten/src/ATen/core/function_schema.h
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,47 @@ template<>
return c10::hash_combine(std::hash<size_t>()(arg.index), std::hash<size_t>()(static_cast<std::size_t>(arg.type)));
}
};
template<>
struct hash<c10::Argument> {
size_t operator()(const c10::Argument& arg) const
{
auto hash = std::hash<std::string>{}(arg.name());
auto type_hash = std::hash<c10::TypePtr>{}(arg.type());
auto kwarg_only_hash = std::hash<bool>{}(arg.kwarg_only());
hash = c10::hash_combine(hash, type_hash);
hash = c10::hash_combine(hash, kwarg_only_hash);
// hashing optional fields if they exist
if (arg.default_value()) {
auto default_value_hash = c10::hash<c10::IValue>{}(arg.default_value().value());
hash = c10::hash_combine(hash, default_value_hash);
}
if (arg.N()) {
auto N_hash = std::hash<int64_t>{}(*arg.N());
hash = c10::hash_combine(hash, N_hash);
}
if (arg.alias_info()) {
auto alias_info_hash = std::hash<c10::AliasInfo>{}(*arg.alias_info());
hash = c10::hash_combine(hash, alias_info_hash);
}
return hash;
}
};
template<>
struct hash<c10::FunctionSchema> {
size_t operator()(const c10::FunctionSchema& schema) const
{
auto hash = std::hash<c10::OperatorName>{}(schema.operator_name());
auto args_hash = c10::hash<std::vector<c10::Argument>>{}(schema.arguments());
auto returns_hash = c10::hash<std::vector<c10::Argument>>{}(schema.returns());
auto is_vararg_hash = std::hash<bool>{}(schema.is_vararg());
auto is_varret_hash = std::hash<bool>{}(schema.is_varret());
hash = c10::hash_combine(hash, args_hash);
hash = c10::hash_combine(hash, returns_hash);
hash = c10::hash_combine(hash, is_vararg_hash);
hash = c10::hash_combine(hash, is_varret_hash);
return hash;
}
};
} // namespace std


Expand Down
31 changes: 31 additions & 0 deletions test/test_function_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,37 @@ def test_out_schema(self):
schema_without_out = parse_schema('any.not_out(Tensor self, Tensor b) -> Tensor')
self.assertFalse(schema_without_out.arguments[-1].is_out)

def test_hash_schema(self):
schema1 = parse_schema('any.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)')
schema2 = parse_schema('any.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)')
self.assertEqual(hash(schema1), hash(schema2))

schema3 = parse_schema('any.not_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)')
self.assertNotEqual(hash(schema2), hash(schema3))

schema4 = parse_schema('foo(Tensor self, *, int a, Tensor(a!) out) -> Tensor(a!)')
self.assertNotEqual(hash(schema2), hash(schema4))

# schemas with different default value, or different kw-only arg, should have different hash
default_val_schema0 = parse_schema('foo(Tensor self, int a = 2) -> Tensor(a!)')
default_val_schema1 = parse_schema('foo(Tensor self, int a = 3) -> Tensor(a!)')
default_val_schema2 = parse_schema('foo(Tensor self, *, int a = 2) -> Tensor(a!)')
self.assertNotEqual(hash(default_val_schema0), hash(default_val_schema1))
self.assertNotEqual(hash(default_val_schema0), hash(default_val_schema2))

# schema with different alias annotation should have different hash
alias_schema = parse_schema('foo(Tensor(a!) self, int a = 2) -> Tensor(a!)')
self.assertNotEqual(hash(default_val_schema0), hash(alias_schema))
alias_schema2 = parse_schema('foo(Tensor(b!) self, int a = 2) -> Tensor(a!)')
self.assertNotEqual(hash(alias_schema), hash(alias_schema2))

# schema with different alias infos
alias_schema3 = parse_schema('foo(Tensor self, *, int a, int b=1, Tensor(a!) out, Tensor(b!) b) -> Tensor(a!)')
alias_schema4 = parse_schema('foo(Tensor self, *, int a, int b=1, Tensor(a!) out, Tensor(b!) b) -> Tensor(b!)')
alias_schema5 = parse_schema('foo(Tensor self, *, int a, int b=1, Tensor(b!) out, Tensor(a!) b) -> Tensor(a!)')
self.assertNotEqual(hash(alias_schema3), hash(alias_schema4))
self.assertNotEqual(hash(alias_schema3), hash(alias_schema5))

def test_backward_compatible_structure(self):
old_schema = parse_schema('any.over(Tensor self, *, Tensor b) -> Tensor')
# BC: A new schema without changes.
Expand Down
5 changes: 5 additions & 0 deletions torch/csrc/jit/python/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1702,6 +1702,11 @@ void initJITBindings(PyObject* module) {
[](const FunctionSchema& self, const FunctionSchema& other) {
return self == other;
})
.def(
"__hash__",
[](const FunctionSchema& self) {
return std::hash<FunctionSchema>{}(self);
})
.def(
"__str__",
[](FunctionSchema& self) {
Expand Down

0 comments on commit 17bc40c

Please sign in to comment.