Skip to content

Commit

Permalink
add notebook_helper
Browse files Browse the repository at this point in the history
  • Loading branch information
csinva committed Jan 25, 2023
1 parent 4c7fd18 commit fcfc16e
Show file tree
Hide file tree
Showing 3 changed files with 204 additions and 319 deletions.
67 changes: 33 additions & 34 deletions experiments/01_train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,42 +39,41 @@ def evaluate_model(model, X_train, X_cv, X_test, y_train, y_cv, y_test, r):

return r

# initialize args
def add_main_args(parser):
"""Caching uses the non-default values from argparse to name the saving directory.
Changing the default arg an argument will break cache compatibility with previous runs.
"""

if __name__ == '__main__':
# initialize args
def add_main_args(parser):
"""Caching uses the non-default values from argparse to name the saving directory.
Changing the default arg an argument will break cache compatibility with previous runs.
"""

# dataset args
parser.add_argument('--dataset_name', type=str,
default='rotten_tomatoes', help='name of dataset')
parser.add_argument('--subsample_frac', type=float,
default=1, help='fraction of samples to use')

# training misc args
parser.add_argument('--seed', type=int, default=1,
help='random seed')
parser.add_argument('--save_dir', type=str, default='results',
help='directory for saving')

# model args
parser.add_argument('--model_name', type=str, choices=['decision_tree', 'ridge'],
default='decision_tree', help='name of model')
parser.add_argument('--alpha', type=float, default=1,
help='regularization strength')
parser.add_argument('--max_depth', type=int,
default=2, help='max depth of tree')
return parser

def add_computational_args(parser):
"""Arguments that only affect computation and not the results (shouldnt use when checking cache)
"""
parser.add_argument('--use_cache', type=int, default=1, choices=[0, 1],
help='whether to check for cache')
return parser
# dataset args
parser.add_argument('--dataset_name', type=str,
default='rotten_tomatoes', help='name of dataset')
parser.add_argument('--subsample_frac', type=float,
default=1, help='fraction of samples to use')

# training misc args
parser.add_argument('--seed', type=int, default=1,
help='random seed')
parser.add_argument('--save_dir', type=str, default='results',
help='directory for saving')

# model args
parser.add_argument('--model_name', type=str, choices=['decision_tree', 'ridge'],
default='decision_tree', help='name of model')
parser.add_argument('--alpha', type=float, default=1,
help='regularization strength')
parser.add_argument('--max_depth', type=int,
default=2, help='max depth of tree')
return parser

def add_computational_args(parser):
"""Arguments that only affect computation and not the results (shouldnt use when checking cache)
"""
parser.add_argument('--use_cache', type=int, default=1, choices=[0, 1],
help='whether to check for cache')
return parser

if __name__ == '__main__':
# get args
parser = argparse.ArgumentParser()
parser_without_computational_args = add_main_args(parser)
Expand Down
Loading

0 comments on commit fcfc16e

Please sign in to comment.