Skip to content

Commit

Permalink
fix run slowly gradually problem
Browse files Browse the repository at this point in the history
  • Loading branch information
Gaoee committed May 27, 2019
1 parent 967c829 commit cd3b606
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions experiments/Robot_arm/DDPG.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def __init__(self, sess, action_dim, action_bound, learning_rate, t_replace_iter

self.e_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='Actor/eval_net')
self.t_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='Actor/target_net')
self.replace = [tf.assign(t, e) for t, e in zip(self.t_params, self.e_params)]

def _build_net(self, s, scope, trainable):
with tf.variable_scope(scope):
Expand All @@ -97,7 +98,7 @@ def _build_net(self, s, scope, trainable):
def learn(self, s): # batch update
self.sess.run(self.train_op, feed_dict={S: s})
if self.t_replace_counter % self.t_replace_iter == 0:
self.sess.run([tf.assign(t, e) for t, e in zip(self.t_params, self.e_params)])
self.sess.run(self.replace)
self.t_replace_counter += 1

def choose_action(self, s):
Expand Down Expand Up @@ -145,6 +146,7 @@ def __init__(self, sess, state_dim, action_dim, learning_rate, gamma, t_replace_

with tf.variable_scope('a_grad'):
self.a_grads = tf.gradients(self.q, a)[0] # tensor of gradients of each sample (None, a_dim)
self.replace = [tf.assign(t, e) for t, e in zip(self.t_params, self.e_params)]

def _build_net(self, s, a, scope, trainable):
with tf.variable_scope(scope):
Expand All @@ -170,7 +172,7 @@ def _build_net(self, s, a, scope, trainable):
def learn(self, s, a, r, s_):
self.sess.run(self.train_op, feed_dict={S: s, self.a: a, R: r, S_: s_})
if self.t_replace_counter % self.t_replace_iter == 0:
self.sess.run([tf.assign(t, e) for t, e in zip(self.t_params, self.e_params)])
self.sess.run(self.replace)
self.t_replace_counter += 1


Expand Down Expand Up @@ -273,4 +275,4 @@ def eval():
if LOAD:
eval()
else:
train()
train()

0 comments on commit cd3b606

Please sign in to comment.