Skip to content

Commit

Permalink
Enable reusability of samplers
Browse files Browse the repository at this point in the history
  • Loading branch information
ylongqi committed Oct 1, 2018
1 parent 16e176a commit aca2aee
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 8 deletions.
4 changes: 4 additions & 0 deletions openrec/model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ def train(self, total_iter, eval_iter, save_iter, train_sampler, start_iter=0, e
acc_loss = 0
self._eval_manager = EvalManager(evaluators=evaluators)

train_sampler.reset()
for sampler in eval_samplers:
sampler.reset()

print(colored('[Training starts, total_iter: %d, eval_iter: %d, save_iter: %d]' \
% (total_iter, eval_iter, save_iter), 'blue'))

Expand Down
34 changes: 26 additions & 8 deletions openrec/utils/samplers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,35 @@ def __init__(self, dataset=None, generate_batch=None, num_process=5):

assert generate_batch is not None, "Batch generation function is not specified"
assert dataset is not None, "Dataset is not specified"
self._q = Queue(maxsize=num_process)
self._q = None
self._dataset = dataset
self._runner_list = []
self._start = False
self._num_process = num_process
self._generate_batch = generate_batch
self.name = self._dataset.name

for i in range(num_process):
runner = _Sampler(dataset, self._q, generate_batch)

def next_batch(self):

if not self._start:
self.reset()

return self._q.get(block=True)

def reset(self):

while len(self._runner_list) > 0:
runner = self._runner_list.pop()
runner.terminate()
del runner

if self._q is not None:
del self._q
self._q = Queue(maxsize=self._num_process)

for i in range(self._num_process):
runner = _Sampler(self._dataset, self._q, self._generate_batch)
runner.daemon = True
runner.start()
self._runner_list.append(runner)

def next_batch(self):

return self._q.get(block=True)
self._start = True

0 comments on commit aca2aee

Please sign in to comment.