Skip to content

PyTorch reimplementation of Moving Semantic Transfer Network

Notifications You must be signed in to change notification settings

EasonApolo/mstn

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

13 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

PyTorch-MSTN

pytorch reimplementation of Moving Semantic Transfer Network

@inproceedings{xie2018learning,
  title={Learning Semantic Representations for Unsupervised Domain Adaptation},
  author={Xie, Shaoan and Zheng, Zibin and Chen, Liang and Chen, Chuan},
  booktitle={International Conference on Machine Learning},
  pages={5419--5428},
  year={2018}
}

Environment

  • Python 2.7
  • PyTorch 1.0.0

Note

  • Amazon-Webcam实验复现成功,使用的是pytorch_imagenet中提供的pretrained weights和LRN实现。
  • MSTN作者使用的AlexNet来自于Finetuning AlexNet with Tensorflow(和pytorch_imagenet的模型几乎相同)。我尝试[1]将这个模型及weights经过转换应用到PyTorch中;[2]使用torchvision提供的预训练的AlexNet(和[1]架构不同)。但这两种方式结果都只能到达67-70%。
  • 尝试了SGD和Adam,目前实验中momentum=0.9,init_lr=0.01的SGD的效果更好。
  • 对A-W任务,约在9000-12000次迭代时收敛。
  • 在将代码从TF迁移到PyTorch时可能需要注意的问题:[1]OpenCV默认的图像通道是BGR,而PyTorch(PIL)使用的通道一般是RGB(这个Repo没有转换通道);[2]npy文件中的模型参数需要转置才能赋给PyTorch模型;[3]LRN层PyTorch已有官方实现,但它的参数size似乎和TF有所不同(这个Repo没有使用PyTorch的LRN层);[4]PyTorch一般认为输入数据是0-1的,而caffe是0-255(这个Repo没有/255)。代码的迁移过程中还有很多没有理解的问题。
  • 在train.py中import model或PretrainedAlexnet来使用上面提到的[1][2]两种预训练AlexNet。如果要使用[2]的模型及weights,请在train.py中注释掉model.load_state_dict一行。

Result

Amazon-Webcam Amazon-Dslr Dslr-Webcam
(paper)Source Only 0.616 63.8 95.4
(paper)MSTN 0.805 74.5 96.9
(this repo) Source Only 0.618 63.2 95.6
(this repo) MSTN 0.805 76.1 96.8

Reference

About

PyTorch reimplementation of Moving Semantic Transfer Network

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages