-
Notifications
You must be signed in to change notification settings - Fork 25
/
make_prompt_datasets.py
77 lines (62 loc) · 2.75 KB
/
make_prompt_datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import argparse
from feature_datasets import time_art, common, space_us, space_world, headline, space_nyc, historical
from transformers import AutoTokenizer
from functools import partial
ENTITY_PROMPTS = {
# time
'art': time_art.ART_PROMPTS,
'headline': headline.HEADLINE_PROMPTS,
'historical_figure': historical.HISTORICAL_PROMPTS,
# space
'world_place': space_world.PLACE_PROMPTS,
'us_place': space_us.US_PLACE_PROMPTS,
'nyc_place': space_nyc.NYC_PLACE_PROMPTS,
}
DATASET_FUNCTIONS = {
'art': time_art.make_art_prompt_dataset,
'headline': headline.make_headline_prompt_dataset,
'historical_figure': historical.make_historical_figure_prompt_dataset,
'world_place': partial(space_world.make_world_prompt_dataset, entity_col='name'),
'us_place': partial(common.make_prompt_dataset, entity_col='name'),
'nyc_place': partial(space_nyc.make_nyc_prompt_dataset, entity_col='name'),
}
def make_and_save_tokenized_datasets(
tokenizer, model_family, entity_type, prompt_dict, ds_make_fn):
entity_data = common.load_entity_data(entity_type)
for short_prompt, full_prompt in prompt_dict.items():
dataset = ds_make_fn(
short_prompt, full_prompt, tokenizer, entity_data)
save_path = common.prompt_data_path(
entity_type, short_prompt, model_family)
dataset.save_to_disk(save_path)
def load_tokenizer(model_name):
if 'Llama-2' in model_name:
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
tokenizer.padding_side = 'right'
tokenizer.pad_token = tokenizer.eos_token
return tokenizer
elif 'pythia' in model_name:
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-70m")
tokenizer.pad_token = tokenizer.eos_token
return tokenizer
else:
raise ValueError('invalid model name')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--model_family', type=str, default='pythia')
parser.add_argument('--entity_type', type=str, default='all')
args = parser.parse_args()
model_family = args.model_family
tokenizer = load_tokenizer(model_family)
if args.entity_type == 'all':
for name, dataset_fn in DATASET_FUNCTIONS.items():
prompt_dict = ENTITY_PROMPTS[name]
make_and_save_tokenized_datasets(
tokenizer, model_family, name, prompt_dict, dataset_fn)
elif args.entity_type in DATASET_FUNCTIONS:
dataset_fn = DATASET_FUNCTIONS[args.entity_type]
prompt_dict = ENTITY_PROMPTS[args.entity_type]
make_and_save_tokenized_datasets(
tokenizer, model_family, args.entity_type, prompt_dict, dataset_fn)
else:
raise ValueError('dataset not found')