Skip to content

Commit

Permalink
Relocated the file and fixed issue in ending the training phase
Browse files Browse the repository at this point in the history
  • Loading branch information
ruiliLaMeilleure committed Jul 28, 2018
1 parent 7d4b120 commit b66fb13
Showing 1 changed file with 12 additions and 10 deletions.
22 changes: 12 additions & 10 deletions discrete_dppo.py → ...imal_Policy_Optimization/discrete_DPP0.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@
import numpy as np
import matplotlib.pyplot as plt
import gym, threading, queue
import time

EP_MAX = 4000
EP_MAX = 1000
EP_LEN = 500
N_WORKER = 4 # parallel workers
GAMMA = 0.9 # reward discount factor
Expand Down Expand Up @@ -171,14 +172,15 @@ def work(self):

QUEUE.put(q_in)

if GLOBAL_UPDATE_COUNTER >= MIN_BATCH_SIZE or done:
if GLOBAL_UPDATE_COUNTER >= MIN_BATCH_SIZE:
ROLLING_EVENT.clear() # stop collecting data
UPDATE_EVENT.set() # globalPPO update
break

if GLOBAL_EP >= EP_MAX: # stop training
COORD.request_stop()
break

if done:break

# record reward changes, plot later
if len(GLOBAL_RUNNING_R) == 0: GLOBAL_RUNNING_R.append(ep_r)
Expand All @@ -187,12 +189,6 @@ def work(self):
print("EP", GLOBAL_EP,'|W%i' % self.wid, '|step %i' %t, '|Ep_r: %.2f' % ep_r,)
np.save("Global_return",GLOBAL_RUNNING_R)
np.savez("PI_PARA",self.ppo.sess.run(GLOBAL_PPO.pi_params))
# np.savez("tfa",self.ppo.sess.run(GLOBAL_PPO.tfa))
# np.savez("tfadv",self.ppo.sess.run(GLOBAL_PPO.tfadv))
# np.savez("val1",self.ppo.sess.run(GLOBAL_PPO.val1))
# np.savez("val2",self.ppo.sess.run(GLOBAL_PPO.val2))
# print self.ppo.sess.run(GLOBAL_PPO.val2)



if __name__ == '__main__':
Expand All @@ -202,6 +198,8 @@ def work(self):
ROLLING_EVENT.set() # start to roll out
workers = [Worker(wid=i) for i in range(N_WORKER)]

start = time.time()

GLOBAL_UPDATE_COUNTER, GLOBAL_EP = 0, 0
GLOBAL_RUNNING_R = []
COORD = tf.train.Coordinator()
Expand All @@ -216,6 +214,9 @@ def work(self):
threads[-1].start()
COORD.join(threads)

end = time.time()
print "Total time ", (end - start)

# plot reward change and test
plt.plot(np.arange(len(GLOBAL_RUNNING_R)), GLOBAL_RUNNING_R)
plt.xlabel('Episode'); plt.ylabel('Moving reward'); plt.ion(); plt.show()
Expand All @@ -224,6 +225,7 @@ def work(self):
s = env.reset()
for t in range(1000):
env.render()
s = env.step(GLOBAL_PPO.choose_action(s))[0]
s, r, done, info = env.step(GLOBAL_PPO.choose_action(s))
if done:
break

0 comments on commit b66fb13

Please sign in to comment.