Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#298 from zhouwei25/0D
Browse files Browse the repository at this point in the history
change loss.numpy()[0] to float(loss) to adapt 0D
  • Loading branch information
XiaoguangHu01 committed Jan 18, 2023
2 parents 8fac5b5 + 4e5f09c commit 9f08f2f
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def train(self, start_iter, end_iter):
lr_sche.step()

self.recorder.record("batch_cost_time", time.time() - batch_start)
self.recorder.record("loss", loss.numpy()[0])
self.recorder.record("loss", float(loss))
self.recorder.record("lr", current_lr)

if iter % self.log_iters == 0 and local_rank == 0:
Expand Down
2 changes: 1 addition & 1 deletion CV/PWCNet/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def val(model, batch_reader, epoch, batch_num):
network_output = model(im_all, output_more=False)
loss = realEPE(network_output, label)
end = time.time()
loss_cnt.update(loss.numpy()[0], step)
loss_cnt.update(float(loss), step)
print('val epoch {} batch {}/{} run time: {}s read data time {}s loss {}'.format(epoch, batch_id, batch_num,
round(end - start, 2),
round(read_data_time, 2),
Expand Down
4 changes: 2 additions & 2 deletions CV/SemSegPaddleDG/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ def train(cfg):
epoch + 1,
during_time[:6],
batch_id + 1,
train_avg_loss.numpy()[0],
float(train_avg_loss),
lr)

if batch_id % 100 == 0:
Expand Down Expand Up @@ -334,7 +334,7 @@ def train(cfg):
print('Validation trained_id:{}, epoch:{}, batch_id:{}, time:{}s, val_loss:{}.'.format(train_id, epoch+1,
batch_id,
during_time[:6],
val_avg_loss.numpy()[0],
float(val_avg_loss),
))
acc, acc_cls, iu, mean_iu_val, fwavacc, kappa = val_iou_np.evaluate()
print('Validation epoch:{}, val_loss{}, val_iou_np:{}'.format(epoch+1, test_avg_loss_manager.eval()[0], mean_iu_val))
Expand Down
2 changes: 1 addition & 1 deletion ST_DM/CIKM2022-DuMapper/DME/engine/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def valid(self, epochs):
embedding=list(embeddings.detach().cpu()[num].numpy())
embedding=[str(x) for x in embedding]
content.extend([' '.join(embedding)])
x_norm = str(embeddings_norm.detach().cpu()[num].numpy()[0])
x_norm = str(float(embeddings_norm.detach().cpu()[num]))
content.append(x_norm)
content='\t'.join(content)
f.write(content + '\n')
Expand Down

0 comments on commit 9f08f2f

Please sign in to comment.