Skip to content

Commit

Permalink
update nb with fill_missing_args
Browse files Browse the repository at this point in the history
  • Loading branch information
csinva committed Feb 1, 2023
1 parent 958f6f4 commit 5b98a60
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 1 deletion.
5 changes: 4 additions & 1 deletion notebooks/01_model_results.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,9 @@
}
],
"source": [
"# fill missing args with default values from argparse\n",
"r = notebook_helper.fill_missing_args_with_default(r, fname='01_train_model.py')\n",
"\n",
"# group using these experiment hyperparams when averaging over random seeds\n",
"group_keys = [k for k in notebook_helper.get_main_args_list(fname='01_train_model.py') if not k == 'seed']\n",
"ravg = (\n",
Expand Down Expand Up @@ -266,7 +269,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
"version": "3.8.10 (default, Nov 14 2022, 12:59:47) \n[GCC 9.4.0]"
},
"orig_nbformat": 4,
"vscode": {
Expand Down
17 changes: 17 additions & 0 deletions notebooks/notebook_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,20 @@ def get_main_args_list(fname='01_train_model.py'):
train_script = __import__(fname)
args = train_script.add_main_args(argparse.ArgumentParser()).parse_args([])
return list(vars(args).keys())

def fill_missing_args_with_default(df, fname='01_train_model.py'):
"""Returns main arguments from the argparser used by an experiments script
"""
if fname.endswith('.py'):
fname = fname[:-3]
sys.path.append(join(repo_dir, 'experiments'))
train_script = __import__(fname)
parser = train_script.add_main_args(argparse.ArgumentParser())
parser = train_script.add_computational_args(parser)
args = parser.parse_args([])
args_dict = vars(args)
for k, v in args_dict.items():
if k not in df.columns:
df[k] = v
df[k] = df[k].fillna(v)
return df

0 comments on commit 5b98a60

Please sign in to comment.