Skip to content

Commit

Permalink
file io
Browse files Browse the repository at this point in the history
  • Loading branch information
ShivamShrirao committed Aug 12, 2020
1 parent 94fdbcf commit c749bba
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions nnet_gpu/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,12 +164,18 @@ def save_weights(self, path): # TODO - make a proper saving mechanism.
sv_me.append((obj.weights, obj.biases, obj.moving_mean, obj.moving_var))
else:
sv_me.append((obj.weights, obj.biases)) # ,obj.w_m,obj.w_v,obj.b_m,obj.b_v))
with open(path, 'wb') as f:
pickle.dump(sv_me, f)
if isinstance(path, str):
with open(path, 'wb') as f:
pickle.dump(sv_me, f)
else:
pickle.dump(sv_me, path)

def load_weights(self, path):
with open(path, 'rb') as f:
sv_me = pickle.load(f)
if isinstance(path, str):
with open(path, 'rb') as f:
sv_me = pickle.load(f)
else:
sv_me = pickle.load(path)
idx = 0
for obj in self.sequence:
if obj.param > 0:
Expand Down

0 comments on commit c749bba

Please sign in to comment.