forked from NVlabs/stylegan
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit f3a0446
Showing
32 changed files
with
6,346 additions
and
0 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# This work is licensed under the Creative Commons Attribution-NonCommercial | ||
# 4.0 International License. To view a copy of this license, visit | ||
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to | ||
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. | ||
|
||
"""Global configuration.""" | ||
|
||
#---------------------------------------------------------------------------- | ||
# Paths. | ||
|
||
result_dir = 'results' | ||
data_dir = 'datasets' | ||
cache_dir = 'cache' | ||
|
||
#---------------------------------------------------------------------------- |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# This work is licensed under the Creative Commons Attribution-NonCommercial | ||
# 4.0 International License. To view a copy of this license, visit | ||
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to | ||
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. | ||
|
||
from . import submission | ||
|
||
from .submission.run_context import RunContext | ||
|
||
from .submission.submit import SubmitTarget | ||
from .submission.submit import PathType | ||
from .submission.submit import SubmitConfig | ||
from .submission.submit import get_path_from_template | ||
from .submission.submit import submit_run | ||
|
||
from .util import EasyDict | ||
|
||
submit_config: SubmitConfig = None # Package level variable for SubmitConfig which is only valid when inside the run function. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# This work is licensed under the Creative Commons Attribution-NonCommercial | ||
# 4.0 International License. To view a copy of this license, visit | ||
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to | ||
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. | ||
|
||
from . import run_context | ||
from . import submit |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# This work is licensed under the Creative Commons Attribution-NonCommercial | ||
# 4.0 International License. To view a copy of this license, visit | ||
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to | ||
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. | ||
|
||
"""Helper for launching run functions in computing clusters. | ||
During the submit process, this file is copied to the appropriate run dir. | ||
When the job is launched in the cluster, this module is the first thing that | ||
is run inside the docker container. | ||
""" | ||
|
||
import os | ||
import pickle | ||
import sys | ||
|
||
# PYTHONPATH should have been set so that the run_dir/src is in it | ||
import dnnlib | ||
|
||
def main(): | ||
if not len(sys.argv) >= 4: | ||
raise RuntimeError("This script needs three arguments: run_dir, task_name and host_name!") | ||
|
||
run_dir = str(sys.argv[1]) | ||
task_name = str(sys.argv[2]) | ||
host_name = str(sys.argv[3]) | ||
|
||
submit_config_path = os.path.join(run_dir, "submit_config.pkl") | ||
|
||
# SubmitConfig should have been pickled to the run dir | ||
if not os.path.exists(submit_config_path): | ||
raise RuntimeError("SubmitConfig pickle file does not exist!") | ||
|
||
submit_config: dnnlib.SubmitConfig = pickle.load(open(submit_config_path, "rb")) | ||
dnnlib.submission.submit.set_user_name_override(submit_config.user_name) | ||
|
||
submit_config.task_name = task_name | ||
submit_config.host_name = host_name | ||
|
||
dnnlib.submission.submit.run_wrapper(submit_config) | ||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# This work is licensed under the Creative Commons Attribution-NonCommercial | ||
# 4.0 International License. To view a copy of this license, visit | ||
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to | ||
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. | ||
|
||
"""Helpers for managing the run/training loop.""" | ||
|
||
import datetime | ||
import json | ||
import os | ||
import pprint | ||
import time | ||
import types | ||
|
||
from typing import Any | ||
|
||
from . import submit | ||
|
||
|
||
class RunContext(object): | ||
"""Helper class for managing the run/training loop. | ||
The context will hide the implementation details of a basic run/training loop. | ||
It will set things up properly, tell if run should be stopped, and then cleans up. | ||
User should call update periodically and use should_stop to determine if run should be stopped. | ||
Args: | ||
submit_config: The SubmitConfig that is used for the current run. | ||
config_module: The whole config module that is used for the current run. | ||
max_epoch: Optional cached value for the max_epoch variable used in update. | ||
""" | ||
|
||
def __init__(self, submit_config: submit.SubmitConfig, config_module: types.ModuleType = None, max_epoch: Any = None): | ||
self.submit_config = submit_config | ||
self.should_stop_flag = False | ||
self.has_closed = False | ||
self.start_time = time.time() | ||
self.last_update_time = time.time() | ||
self.last_update_interval = 0.0 | ||
self.max_epoch = max_epoch | ||
|
||
# pretty print the all the relevant content of the config module to a text file | ||
if config_module is not None: | ||
with open(os.path.join(submit_config.run_dir, "config.txt"), "w") as f: | ||
filtered_dict = {k: v for k, v in config_module.__dict__.items() if not k.startswith("_") and not isinstance(v, (types.ModuleType, types.FunctionType, types.LambdaType, submit.SubmitConfig, type))} | ||
pprint.pprint(filtered_dict, stream=f, indent=4, width=200, compact=False) | ||
|
||
# write out details about the run to a text file | ||
self.run_txt_data = {"task_name": submit_config.task_name, "host_name": submit_config.host_name, "start_time": datetime.datetime.now().isoformat(sep=" ")} | ||
with open(os.path.join(submit_config.run_dir, "run.txt"), "w") as f: | ||
pprint.pprint(self.run_txt_data, stream=f, indent=4, width=200, compact=False) | ||
|
||
def __enter__(self) -> "RunContext": | ||
return self | ||
|
||
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: | ||
self.close() | ||
|
||
def update(self, loss: Any = 0, cur_epoch: Any = 0, max_epoch: Any = None) -> None: | ||
"""Do general housekeeping and keep the state of the context up-to-date. | ||
Should be called often enough but not in a tight loop.""" | ||
assert not self.has_closed | ||
|
||
self.last_update_interval = time.time() - self.last_update_time | ||
self.last_update_time = time.time() | ||
|
||
if os.path.exists(os.path.join(self.submit_config.run_dir, "abort.txt")): | ||
self.should_stop_flag = True | ||
|
||
max_epoch_val = self.max_epoch if max_epoch is None else max_epoch | ||
|
||
def should_stop(self) -> bool: | ||
"""Tell whether a stopping condition has been triggered one way or another.""" | ||
return self.should_stop_flag | ||
|
||
def get_time_since_start(self) -> float: | ||
"""How much time has passed since the creation of the context.""" | ||
return time.time() - self.start_time | ||
|
||
def get_time_since_last_update(self) -> float: | ||
"""How much time has passed since the last call to update.""" | ||
return time.time() - self.last_update_time | ||
|
||
def get_last_update_interval(self) -> float: | ||
"""How much time passed between the previous two calls to update.""" | ||
return self.last_update_interval | ||
|
||
def close(self) -> None: | ||
"""Close the context and clean up. | ||
Should only be called once.""" | ||
if not self.has_closed: | ||
# update the run.txt with stopping time | ||
self.run_txt_data["stop_time"] = datetime.datetime.now().isoformat(sep=" ") | ||
with open(os.path.join(self.submit_config.run_dir, "run.txt"), "w") as f: | ||
pprint.pprint(self.run_txt_data, stream=f, indent=4, width=200, compact=False) | ||
|
||
self.has_closed = True |
Oops, something went wrong.