Skip to content

Commit

Permalink
Update train_imagenet_distillation.py
Browse files Browse the repository at this point in the history
  • Loading branch information
jingyang2017 authored Jun 15, 2021
1 parent 051ccc1 commit 1f4009f
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions train_imagenet_distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,10 @@ def train(train_loader,net_t,net_s,optimizer, conector,epoch):
feat_s = conector(feat_s)
loss_stat = statm_loss(feat_s, feat_t.detach())
pred_sc = net_t(x=None,feat_s=feat_s)
loss_kd = loss_stat + F.mse_loss(pred_sc, pred_t)
loss_kd = loss_stat + F.mse_loss(pred_sc, pred_t)*args.weight
loss_ce = F.cross_entropy(pred_s, target)

loss = loss_ce+loss_kd*args.weight
loss = loss_ce+loss_kd
prec1, prec5 = accuracy(pred_s, target, topk=(1,5))
optimizer.zero_grad()
loss.backward()
Expand Down

0 comments on commit 1f4009f

Please sign in to comment.