Skip to content

Commit

Permalink
[black][codemod] formatting changes from black 22.3.0
Browse files Browse the repository at this point in the history
Summary:
Applies the black-fbsource codemod with the new build of pyfmt.

paintitblack

Test Plan:
Probably going to have to ignore signals and just land it.

**Static Docs Preview: pyre**
|[Full Site](https://our.intern.facebook.com/intern/staticdocs/eph/D36324783/V4/pyre/)|

|**Modified Pages**|

**Static Docs Preview: classyvision**
|[Full Site](https://our.intern.facebook.com/intern/staticdocs/eph/D36324783/V4/classyvision/)|

|**Modified Pages**|

**Static Docs Preview: antlir**
|[Full Site](https://our.intern.facebook.com/intern/staticdocs/eph/D36324783/V4/antlir/)|

|**Modified Pages**|

Reviewed By: lisroach

Differential Revision: D36324783

fbshipit-source-id: 280c09e88257e5e569ab729691165d8dedd767bc
(cherry picked from commit 704f50c)
  • Loading branch information
amyreese authored and pytorchmergebot committed May 12, 2022
1 parent c25bdee commit d973ece
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 96 deletions.
1 change: 0 additions & 1 deletion test/package/package_a/use_dunder_package.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
def is_from_package():
return True


else:

def is_from_package():
Expand Down
1 change: 0 additions & 1 deletion test/package/package_c/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ def forward(self, x):
x = a_non_torch_leaf(x, x)
return torch.relu(x + 3.0)


except ImportError:
pass

Expand Down
132 changes: 38 additions & 94 deletions tools/fast_nvcc/fast_nvcc.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@
import subprocess
import sys
import time
from typing import Awaitable, DefaultDict, Dict, List, Match, Optional, Set, cast

from typing_extensions import TypedDict

help_msg = """fast_nvcc [OPTION]... -- [NVCC_ARG]...
Expand All @@ -39,7 +37,7 @@
)
parser.add_argument(
"--graph",
metavar="FILE.gv",
metavar="FILE.dot",
help="write Graphviz DOT file with execution graph",
)
parser.add_argument(
Expand Down Expand Up @@ -80,14 +78,14 @@
re_tmp = r"(?<![\w\-/])(?:/tmp/)?(tmp[^ \"\'\\]+)"


def fast_nvcc_warn(warning: str) -> None:
def fast_nvcc_warn(warning):
"""
Warn the user about something regarding fast_nvcc.
"""
print(f"warning (fast_nvcc): {warning}", file=sys.stderr)


def warn_if_windows() -> None:
def warn_if_windows():
"""
Warn the user that using fast_nvcc on Windows might not work.
"""
Expand All @@ -99,7 +97,7 @@ def warn_if_windows() -> None:
fast_nvcc_warn(url_vars)


def warn_if_tmpdir_flag(args: List[str]) -> None:
def warn_if_tmpdir_flag(args):
"""
Warn the user that using fast_nvcc with some flags might not work.
"""
Expand All @@ -123,17 +121,11 @@ def warn_if_tmpdir_flag(args: List[str]) -> None:
fast_nvcc_warn(f"{url_base}#{frag}")


class DryunData(TypedDict):
env: Dict[str, str]
commands: List[str]
exit_code: int


def nvcc_dryrun_data(binary: str, args: List[str]) -> DryunData:
def nvcc_dryrun_data(binary, args):
"""
Return parsed environment variables and commands from nvcc --dryrun.
"""
result = subprocess.run( # type: ignore[call-overload]
result = subprocess.run(
[binary, "--dryrun"] + args,
capture_output=True,
encoding="ascii", # this is just a guess
Expand All @@ -156,7 +148,7 @@ def nvcc_dryrun_data(binary: str, args: List[str]) -> DryunData:
return {"env": env, "commands": commands, "exit_code": result.returncode}


def warn_if_tmpdir_set(env: Dict[str, str]) -> None:
def warn_if_tmpdir_set(env):
"""
Warn the user that setting TMPDIR with fast_nvcc might not work.
"""
Expand All @@ -165,7 +157,7 @@ def warn_if_tmpdir_set(env: Dict[str, str]) -> None:
fast_nvcc_warn(url_vars)


def contains_non_executable(commands: List[str]) -> bool:
def contains_non_executable(commands):
for command in commands:
# This is to deal with special command dry-run result from NVCC such as:
# ```
Expand All @@ -178,7 +170,7 @@ def contains_non_executable(commands: List[str]) -> bool:
return False


def module_id_contents(command: List[str]) -> str:
def module_id_contents(command):
"""
Guess the contents of the .module_id file contained within command.
"""
Expand All @@ -195,7 +187,7 @@ def module_id_contents(command: List[str]) -> str:
return f"_{len(middle)}_{middle}_{suffix}"


def unique_module_id_files(commands: List[str]) -> List[str]:
def unique_module_id_files(commands):
"""
Give each command its own .module_id filename instead of sharing.
"""
Expand All @@ -204,7 +196,7 @@ def unique_module_id_files(commands: List[str]) -> List[str]:
for i, line in enumerate(commands):
arr = []

def uniqueify(s: Match[str]) -> str:
def uniqueify(s):
filename = re.sub(r"\-(\d+)", r"-\1-" + str(i), s.group(0))
arr.append(filename)
return filename
Expand All @@ -220,19 +212,14 @@ def uniqueify(s: Match[str]) -> str:
return uniqueified


def make_rm_force(commands: List[str]) -> List[str]:
def make_rm_force(commands):
"""
Add --force to all rm commands.
"""
return [f"{c} --force" if c.startswith("rm ") else c for c in commands]


def print_verbose_output(
*,
env: Dict[str, str],
commands: List[List[str]],
filename: str,
) -> None:
def print_verbose_output(*, env, commands, filename):
"""
Human-readably write nvcc --dryrun data to stderr.
"""
Expand All @@ -247,24 +234,21 @@ def print_verbose_output(
print(f'#{" "*len(prefix)}{part}', file=f)


Graph = List[Set[int]]


def straight_line_dependencies(commands: List[str]) -> Graph:
def straight_line_dependencies(commands):
"""
Return a straight-line dependency graph.
"""
return [({i - 1} if i > 0 else set()) for i in range(len(commands))]


def files_mentioned(command: str) -> List[str]:
def files_mentioned(command):
"""
Return fully-qualified names of all tmp files referenced by command.
"""
return [f"/tmp/{match.group(1)}" for match in re.finditer(re_tmp, command)]


def nvcc_data_dependencies(commands: List[str]) -> Graph:
def nvcc_data_dependencies(commands):
"""
Return a list of the set of dependencies for each command.
"""
Expand All @@ -277,8 +261,8 @@ def nvcc_data_dependencies(commands: List[str]) -> Graph:
# data dependency is sort of flipped, because the steps that use the
# files generated by cicc need to wait for the fatbinary step to
# finish first
tmp_files: Dict[str, int] = {}
fatbins: DefaultDict[int, Set[str]] = collections.defaultdict(set)
tmp_files = {}
fatbins = collections.defaultdict(set)
graph = []
for i, line in enumerate(commands):
deps = set()
Expand All @@ -300,13 +284,13 @@ def nvcc_data_dependencies(commands: List[str]) -> Graph:
return graph


def is_weakly_connected(graph: Graph) -> bool:
def is_weakly_connected(graph):
"""
Return true iff graph is weakly connected.
"""
if not graph:
return True
neighbors: List[Set[int]] = [set() for _ in graph]
neighbors = [set() for _ in graph]
for node, predecessors in enumerate(graph):
for pred in predecessors:
neighbors[pred].add(node)
Expand All @@ -323,25 +307,20 @@ def is_weakly_connected(graph: Graph) -> bool:
return len(found) == len(graph)


def warn_if_not_weakly_connected(graph: Graph) -> None:
def warn_if_not_weakly_connected(graph):
"""
Warn the user if the execution graph is not weakly connected.
"""
if not is_weakly_connected(graph):
fast_nvcc_warn("execution graph is not (weakly) connected")


def print_dot_graph(
*,
commands: List[List[str]],
graph: Graph,
filename: str,
) -> None:
def print_dot_graph(*, commands, graph, filename):
"""
Print a DOT file displaying short versions of the commands in graph.
"""

def name(k: int) -> str:
def name(k):
return f'"{k} {os.path.basename(commands[k][0])}"'

with open(filename, "w") as f:
Expand All @@ -355,23 +334,7 @@ def name(k: int) -> str:
print("}", file=f)


class Result(TypedDict, total=False):
exit_code: int
stdout: bytes
stderr: bytes
time: float
files: Dict[str, int]


async def run_command(
command: str,
*,
env: Dict[str, str],
deps: Set[Awaitable[Result]],
gather_data: bool,
i: int,
save: Optional[str],
) -> Result:
async def run_command(command, *, env, deps, gather_data, i, save):
"""
Run the command with the given env after waiting for deps.
"""
Expand All @@ -389,8 +352,8 @@ async def run_command(
stderr=asyncio.subprocess.PIPE,
)
stdout, stderr = await proc.communicate()
code = cast(int, proc.returncode)
results: Result = {"exit_code": code, "stdout": stdout, "stderr": stderr}
code = proc.returncode
results = {"exit_code": code, "stdout": stdout, "stderr": stderr}
if gather_data:
t2 = time.monotonic()
results["time"] = t2 - t1
Expand All @@ -410,23 +373,16 @@ async def run_command(
return results


async def run_graph(
*,
env: Dict[str, str],
commands: List[str],
graph: Graph,
gather_data: bool = False,
save: Optional[str] = None,
) -> List[Result]:
async def run_graph(*, env, commands, graph, gather_data=False, save=None):
"""
Return outputs/errors (and optionally time/file info) from commands.
"""
tasks: List[Awaitable[Result]] = []
tasks = []
for i, (command, indices) in enumerate(zip(commands, graph)):
deps = {tasks[j] for j in indices}
tasks.append(
asyncio.create_task(
run_command( # type: ignore[attr-defined]
run_command(
command,
env=env,
deps=deps,
Expand All @@ -439,7 +395,7 @@ async def run_graph(
return [await task for task in tasks]


def print_command_outputs(command_results: List[Result]) -> None:
def print_command_outputs(command_results):
"""
Print captured stdout and stderr from commands.
"""
Expand All @@ -448,16 +404,11 @@ def print_command_outputs(command_results: List[Result]) -> None:
sys.stderr.write(result.get("stderr", b"").decode("ascii"))


def write_log_csv(
command_parts: List[List[str]],
command_results: List[Result],
*,
filename: str,
) -> None:
def write_log_csv(command_parts, command_results, *, filename):
"""
Write a CSV file of the times and /tmp file sizes from each command.
"""
tmp_files: List[str] = []
tmp_files = []
for result in command_results:
tmp_files.extend(result.get("files", {}).keys())
with open(filename, "w", newline="") as csvfile:
Expand All @@ -470,7 +421,7 @@ def write_log_csv(
writer.writerow({**row, **result.get("files", {})})


def exit_code(results: List[Result]) -> int:
def exit_code(results):
"""
Aggregate individual exit codes into a single code.
"""
Expand All @@ -481,18 +432,11 @@ def exit_code(results: List[Result]) -> int:
return 0


def wrap_nvcc(
args: List[str],
config: argparse.Namespace = default_config,
) -> int:
def wrap_nvcc(args, config=default_config):
return subprocess.call([config.nvcc] + args)


def fast_nvcc(
args: List[str],
*,
config: argparse.Namespace = default_config,
) -> int:
def fast_nvcc(args, *, config=default_config):
"""
Emulate the result of calling the given nvcc binary with args.
Expand Down Expand Up @@ -528,7 +472,7 @@ def fast_nvcc(
if config.sequential:
graph = straight_line_dependencies(commands)
results = asyncio.run(
run_graph( # type: ignore[attr-defined]
run_graph(
env=env,
commands=commands,
graph=graph,
Expand All @@ -539,10 +483,10 @@ def fast_nvcc(
print_command_outputs(results)
if config.table:
write_log_csv(command_parts, results, filename=config.table)
return exit_code([dryrun_data] + results) # type: ignore[arg-type, operator]
return exit_code([dryrun_data] + results)


def our_arg(arg: str) -> bool:
def our_arg(arg):
return arg != "--"


Expand Down

0 comments on commit d973ece

Please sign in to comment.