-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
62 lines (50 loc) · 1.17 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import tensorflow as tf
import numpy as np
import lstm
import sys
import mlflow
import data
import pickle
import calendar
import time
light_device = "/cpu:0"
heavy_device = "/cpu:0"
test = False
restore = False
def main():
run_id = np.random.randint(1000)
if restore:
with open('./saves/state.pkl', 'rb') as f:
X, Y, char2ix, ix2char = pickle.load(f)
else:
X, Y, char2ix, ix2char = data.read_data("warandpeace.txt", sequence_length=100)
with open('./saves/state.pkl', 'wb') as f:
pickle.dump([X, Y, char2ix, ix2char], f)
train_set = data.train_set(X,Y,128)
solver = lstm.LSTM(
num_classes=len(char2ix),
heavy_device=heavy_device,
light_device=light_device,
restore=restore
)
if test == False:
solver.train(train_set)
else:
print(solver.generate(char2ix, ix2char, 100))
if __name__ == "__main__":
for o in sys.argv[1:]:
if o == '--gpu':
heavy_device = "/gpu:0"
light_device = "/cpu:0"
elif o == '--cpu':
heavy_device = "/cpu:0"
light_device = "/cpu:0"
elif o == '--test':
test = True
elif o == '--train':
test = False
elif o == '--restore':
restore = True
else:
raise ValueError("Unkown argument: " + o)
main()