From 028ed6acca7331c1233907d6ca559c8216dcf9fa Mon Sep 17 00:00:00 2001 From: YuriCat Date: Fri, 12 Apr 2019 14:06:32 +0900 Subject: [PATCH] Fix .data[0] and tensor index problem --- layers/modules/multibox_loss.py | 2 +- train.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/layers/modules/multibox_loss.py b/layers/modules/multibox_loss.py index fb49cf439..0014fa310 100644 --- a/layers/modules/multibox_loss.py +++ b/layers/modules/multibox_loss.py @@ -94,7 +94,7 @@ def forward(self, predictions, targets): loss_c = log_sum_exp(batch_conf) - batch_conf.gather(1, conf_t.view(-1, 1)) # Hard Negative Mining - loss_c[pos] = 0 # filter out pos boxes for now + loss_c[pos.view(-1)] = 0 # filter out pos boxes for now loss_c = loss_c.view(num, -1) _, loss_idx = loss_c.sort(1, descending=True) _, idx_rank = loss_idx.sort(1) diff --git a/train.py b/train.py index 427dd9244..dd0a2a92c 100644 --- a/train.py +++ b/train.py @@ -180,15 +180,15 @@ def train(): loss.backward() optimizer.step() t1 = time.time() - loc_loss += loss_l.data[0] - conf_loss += loss_c.data[0] + loc_loss += loss_l.item() + conf_loss += loss_c.item() if iteration % 10 == 0: print('timer: %.4f sec.' % (t1 - t0)) - print('iter ' + repr(iteration) + ' || Loss: %.4f ||' % (loss.data[0]), end=' ') + print('iter ' + repr(iteration) + ' || Loss: %.4f ||' % (loss.item()), end=' ') if args.visdom: - update_vis_plot(iteration, loss_l.data[0], loss_c.data[0], + update_vis_plot(iteration, loss_l.item(), loss_c.item(), iter_plot, epoch_plot, 'append') if iteration != 0 and iteration % 5000 == 0: