Skip to content

Commit

Permalink
Don't do BC check on ops with valid upgraders (pytorch#72313)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: pytorch#72313

Test Plan: Imported from OSS

Reviewed By: malfet

Differential Revision: D33997711

Pulled By: tugsbayasgalan

fbshipit-source-id: 73ad16eda4e519fb67403541a6533fe568a4743a
(cherry picked from commit 4304f95)
  • Loading branch information
tugsbayasgalan authored and pytorchmergebot committed Feb 11, 2022
1 parent 426f50e commit 5f2eed6
Showing 1 changed file with 46 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,6 @@
("aten::linalg_svd_out", datetime.date(2022, 3, 31)),
("aten::_max_pool1d_cpu_forward", datetime.date(2022, 2, 8)),
("aten::_convolution_nogroup", datetime.date(9999, 1, 1)),
("aten::linspace", datetime.date(2022, 3, 1)), # TODO this will be removed soon
("aten::logspace", datetime.date(2022, 3, 1)), # TODO this will be removed soon
("aten::miopen_convolution_backward", datetime.date(9999, 1, 1)),
("aten::miopen_convolution_backward_bias", datetime.date(9999, 1, 1)),
("aten::miopen_convolution_backward_input", datetime.date(9999, 1, 1)),
Expand Down Expand Up @@ -140,6 +138,33 @@ def allow_listed(schema):
("dist_c10d", datetime.date(2099, 9, 17)),
]

def has_valid_upgraders(schema, version_map):
# we want to parse through the map to find if
# the schema has valid upgraders. Since the
# version map has entry for each overload
# we need to do some ugly parsing.

# the name of the operator
schema_name = schema.name

if schema_name not in version_map:
return False

entries = version_map[schema_name]

possible_overloads = []
possible_schemas = []
for key, upgrader_schema_entries in entries.items():
possible_overloads.append(key)
possible_schemas.extend(upgrader_schema_entries)

# let's make sure this existing schema is part of possible
# schemas
for old_schema in possible_schemas:
if old_schema == schema:
return True

return False

def dont_parse(schema_line):
for item in dont_parse_list:
Expand All @@ -158,14 +183,33 @@ def load_schemas_to_dict():
new_schema_dict[s.name].append(s)
return new_schema_dict

def process_version_map(version_map):
# version map maps full schema name to
# list of upgraders. Since we only have
# the name of the schema (aka no overload)
# we want to first process the map to make
# the key lookup easier. After this it will be:
# Dict[schema_name, Dict[overload, List[schema]]]

output = defaultdict(dict)
for (key, entries) in version_map.items():
operator_name = key.split(".")[0]
schema_entries = [parse_schema(entry.old_schema) for entry in entries]
output[operator_name][key] = schema_entries
return output

def check_bc(existing_schemas):
new_schema_dict = load_schemas_to_dict()
version_map = process_version_map(torch._C._get_operator_version_map())
is_bc = True
broken_ops = []
for existing_schema in existing_schemas:
if allow_listed(existing_schema):
print("schema: ", str(existing_schema), " found on allowlist, skipping")
continue
if has_valid_upgraders(existing_schema, version_map):
print("schema: ", str(existing_schema), " has valid upgrader, skipping")
continue
print("processing existing schema: ", str(existing_schema))
matching_new_schemas = new_schema_dict.get(existing_schema.name, [])
found = False
Expand Down

0 comments on commit 5f2eed6

Please sign in to comment.