-
Notifications
You must be signed in to change notification settings - Fork 1
/
01_train_basic_models.py
38 lines (34 loc) 路 1.17 KB
/
01_train_basic_models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import submit_utils
from os.path import dirname
import os.path
repo_dir = dirname(dirname(os.path.abspath(__file__)))
# Showcasing different ways to sweep over arguments
# Can pass any empty dict for any of these to avoid sweeping
# List of values to sweep over (sweeps over all combinations of these)
params_shared_dict = {
'seed': [1, 2],
'save_dir': ['results'],
'use_cache': [0], # pass binary values with 0/1 instead of the ambiguous strings True/False
}
# List of tuples to sweep over (these values are coupled, and swept over together)
params_coupled_dict = {
('model_name', 'alpha'): [
('ridge', 0.1),
('ridge', 1),
],
('model_name', 'max_depth'): [
('decision_tree', i)
for i in range(2, 4)
],
}
# Args list is a list of dictionaries
# If you want to do something special to remove some of these runs, can remove them before calling run_args_list
args_list = submit_utils.get_args_list(
params_shared_dict=params_shared_dict,
params_coupled_dict=params_coupled_dict,
)
submit_utils.run_args_list(
args_list,
script_name=os.path.join(repo_dir, 'experiments', '01_train_model.py'),
actually_run=True,
)