Skip to content

Commit

Permalink
getter setter for weights
Browse files Browse the repository at this point in the history
  • Loading branch information
ShivamShrirao committed Aug 12, 2020
1 parent c749bba commit 68c7d02
Showing 1 changed file with 24 additions and 15 deletions.
39 changes: 24 additions & 15 deletions nnet_gpu/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,14 @@ def predict(self, X_inp):
def train_on_batch(self, X_inp, labels):
X_inp = self.forward(X_inp.astype(self.dtype, copy=False))
grads = self.del_loss(X_inp, labels.astype(self.dtype, copy=False))
self.backprop(grads, do_d_inp=False) # The gradients with input layer will NOT be calculated.
self.backprop(grads, do_d_inp=False) # The gradients with input layer will NOT be calculated.
self.optimizer(self.sequence, self.learning_rate, self.beta)
return X_inp

def not_train_on_batch(self, X_inp, labels):
X_inp = self.forward(X_inp.astype(self.dtype, copy=False))
grads = self.del_loss(X_inp, labels.astype(self.dtype, copy=False))
grads = self.backprop(grads, do_d_inp=True) # Gradients with input layer will be calculated.
grads = self.backprop(grads, do_d_inp=True) # Gradients with input layer will be calculated.
return X_inp, grads

def fit(self, X_inp=None, labels=None, iterator=None, batch_size=1, epochs=1, validation_data=None, shuffle=True, accuracy_metric=True,
Expand Down Expand Up @@ -156,14 +156,33 @@ def validate(self, validation_data, batch_size, info_beta=0.2):
end = time.time()
print(f"\rValidation Accuracy: {(vacc / lnvx).get():.4f} - val_loss: {vloss.get():.4f} - Time: {end - start:.3f}s")

def save_weights(self, path): # TODO - make a proper saving mechanism.
@property
def weights(self):
sv_me = []
for obj in self.sequence:
if obj.param > 0:
if isinstance(obj, layers.BatchNormalization):
sv_me.append((obj.weights, obj.biases, obj.moving_mean, obj.moving_var))
else:
sv_me.append((obj.weights, obj.biases)) # ,obj.w_m,obj.w_v,obj.b_m,obj.b_v))
return sv_me

@weights.setter
def weights(self, sv_me):
idx = 0
for obj in self.sequence:
if obj.param > 0:
if isinstance(obj, layers.BatchNormalization):
obj.kernels, obj.biases, obj.moving_mean, obj.moving_var = sv_me[idx]
else:
obj.kernels, obj.biases = sv_me[idx]
if isinstance(obj, layers.Conv2D): # TODO - Verify isinstance works.
obj.init_back()
obj.weights = obj.kernels
idx += 1

def save_weights(self, path): # TODO - make a proper saving mechanism.
sv_me = self.weights
if isinstance(path, str):
with open(path, 'wb') as f:
pickle.dump(sv_me, f)
Expand All @@ -176,19 +195,9 @@ def load_weights(self, path):
sv_me = pickle.load(f)
else:
sv_me = pickle.load(path)
idx = 0
for obj in self.sequence:
if obj.param > 0:
if isinstance(obj, layers.BatchNormalization):
obj.kernels, obj.biases, obj.moving_mean, obj.moving_var = sv_me[idx]
else:
obj.kernels, obj.biases = sv_me[idx]
if isinstance(obj, layers.Conv2D): # TODO - Verify isinstance works.
obj.init_back()
obj.weights = obj.kernels
idx += 1
self.weights = sv_me

def summary(self): # TODO - Show connections.
def summary(self): # TODO - Show connections.
ipl = layers.InputLayer(self.sequence[0].input_shape)
reps = 90
print(chr(9149) * reps)
Expand Down

0 comments on commit 68c7d02

Please sign in to comment.