Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Geng #7

Open
wants to merge 22 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -766,3 +766,15 @@ TSWLatexianTemp*


## This repository

.scripts/
scripts/
Results/
test_pretrained.py
/scripts/
/Checkpoints/
Checkpoints/
.Checkpoints/
/Checkpoints
/scripts

Binary file added Checkpoints/default/latest_net_D.pth
Binary file not shown.
Binary file added Checkpoints/default/latest_net_G.pth
Binary file not shown.
4 changes: 4 additions & 0 deletions Checkpoints/pretrained/LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Copyright (C) 2019 Mengtian (Martin) Li.

Licensed under the CC BY-NC-SA 4.0 International License
(http://creativecommons.org/licenses/by-nc-sa/4.0/).
Binary file added Checkpoints/pretrained/latest_net_D.pth
Binary file not shown.
Binary file added Checkpoints/pretrained/latest_net_G.pth
Binary file not shown.
56 changes: 56 additions & 0 deletions Checkpoints/pretrained/opt.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
------------ Options -------------
aspect_ratio: 1.0
aug_folder: width-5
batchSize: 1
checkpoints_dir: /users/guest292/scratch/PhotoSketch/Checkpoints/
color_jitter: False
crop: False
dataroot: examples/
dataset_mode: test_dir
display_id: 1
display_port: 8097
display_server: http://localhost
display_winsize: 256
file_name:
fineSize: 256
how_many: 50
img_mean: None
img_std: None
init_type: normal
input_nc: 3
inverse_gamma: False
isTrain: False
jitter_amount: 0.02
loadSize: 286
lst_file: None
max_dataset_size: inf
model: pix2pix
nGT: 5
nThreads: 6
n_layers_D: 3
name: pretrained
ndf: 64
ngf: 64
no_cuda: False
no_dropout: True
no_flip: False
norm: batch
ntest: inf
output_nc: 1
phase: test
pretrain_path:
render_dir: sketch-rendered
resize_or_crop: resize_and_crop
results_dir: /users/guest292/scratch/PhotoSketch/Results/
rot_int_max: 3
rotate: False
serial_batches: False
stroke_dir:
stroke_no_couple: False
suffix:
use_cuda: False
which_direction: AtoB
which_epoch: latest
which_model_netD: basic
which_model_netG: resnet_9blocks
-------------- End ----------------
Binary file added Checkpoints/pretrained/pretrained.zip
Binary file not shown.
125 changes: 68 additions & 57 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,57 +1,68 @@
# Photo-Sketching: Inferring Contour Drawings from Images

<p align="center"><img alt="Teaser" src="doc/teaser.jpg"></p>

This repo contains the training & testing code for our sketch generator. We also provide a [[pre-trained model]](https://drive.google.com/file/d/1TQf-LyS8rRDDapdcTnEgWzYJllPgiXdj/view).

For technical details and the dataset, please refer to the [**[paper]**](https://arxiv.org/abs/1901.00542) and the [**[project page]**](http://www.cs.cmu.edu/~mengtial/proj/sketch/).

# Setting up

The code is now updated to use PyTorch 0.4 and runs on Windows, Mac and Linux. For the obsolete version with PyTorch 0.3 (Linux only), please check out the branch [pytorch-0.3-obsolete](../../tree/pytorch-0.3-obsolete).

Windows users should find the corresponding `.cmd` files instead of `.sh` files mentioned below.

## One-line installation (with Conda environments)
`conda env create -f environment.yml`

Then activate the environment (sketch) and you are ready to go!

See [here](https://conda.io/docs/user-guide/tasks/manage-environments.html) for more information about conda environments.

## Manual installation
See `environment.yml` for a list of dependencies.

# Using the pre-trained model

- Download the [pre-trained model](https://drive.google.com/file/d/1TQf-LyS8rRDDapdcTnEgWzYJllPgiXdj/view)
- Modify the path in `scripts/test_pretrained.sh`
- From the repo's **root directory**, run `scripts/test_pretrained.sh`

It supports a folder of images as input.

# Train & test on our contour drawing dataset

- Download the images and the rendered sketch from the [project page](http://www.cs.cmu.edu/~mengtial/proj/sketch/)
- Unzip and organize them into the following structure:
<p align="center"><img alt="File structure" src="doc/file_structure.png"></p>

- Modify the path in `scripts/train.sh` and `scripts/test.sh`
- From the repo's **root directory**, run `scripts/train.sh` to train the model
- From the repo's **root directory**, run `scripts/test.sh` to test on the val set or the test set (specified by the phase flag)

## Citation
If you use the code or the data for your research, please cite the paper:

```
@article{LIPS2019,
title={Photo-Sketching: Inferring Contour Drawings from Images},
author={Li, Mengtian and Lin, Zhe and M\v ech, Radom\'ir and and Yumer, Ersin and Ramanan, Deva},
journal={WACV},
year={2019}
}
```

## Acknowledgement
This code is based on an old version of [pix2pix](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/).

# Inferring Contour Sketches from Images
## ENGN2560 - Computer Vision
### Final Project (*May, 2019*)


#### Xingchen Ming [email protected]
#### Ming Xu [email protected]
#### Geng Yang [email protected]

<p align="center"><img alt="Teaser" src="doc/teaser.jpg"></p>

# Dataset
## NOTE: Please download and extract the dataset under directory `PhotoSketch/`

### https://drive.google.com/open?id=1ajNGbYSSxWZyCT3X4qlga7maUz6UZ3nD

# Setting up on Brown CCV

1. Load Anaconda3-5.2.0

```
module load anaconda/3-5.2.0
```


2. One-line installation (with Conda environments)
```
conda env create -f environment.yml
```

3. Activate the environment
```
source activate sketch
```

# Running Instructions
## NOTE: All srcipts should be executed under directory `PhotoSketch/`
## Train model
```
sbatch cuda.sh
```
## Test model
1. Request a GPU node
```
interact -n 16 -m 16g -q gpu -g 1
```
2. Run the test script
```
sh scripts/test_pretrained.sh
```



## Citation
If you use the code or the data for your research, please cite the paper:

```
@article{LIPS2019,
title={Photo-Sketching: Inferring Contour Drawings from Images},
author={Li, Mengtian and Lin, Zhe and M\v ech, Radom\'ir and and Yumer, Ersin and Ramanan, Deva},
journal={WACV},
year={2019}
}
```

## Acknowledgement
This code is based on an old version of [pix2pix](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/).

2 changes: 1 addition & 1 deletion data/base_data_loader.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@

#base class of data loader
class BaseDataLoader():
def __init__(self):
pass
Expand Down
39 changes: 37 additions & 2 deletions data/base_dataset.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
#An abstract class representing a Dataset,All subclasses should override __len__, that provides the size of the dataset, and __getitem__, supporting integer indexing in range from 0 to len(self) exclusive
import torch.utils.data as data
#The Image module provides a class with the same name which is used to represent a PIL image. The module also provides a number of factory functions, including functions to load images from files, and to create new images.
from PIL import Image
#
import torchvision.transforms as transforms

#base class of data set
class BaseDataset(data.Dataset):
def __init__(self):

#initialize the base class
super(BaseDataset, self).__init__()

def name(self):
Expand All @@ -12,34 +18,63 @@ def name(self):
def initialize(self, opt):
pass

'''
* @name: get_transform
* @description: the function to scale or crop the input function
* @param opt: the parameter set
* @return: the composed transform method of an image
'''
def get_transform(opt):
#the list the save every single transform on input images
transform_list = []
#if the scaling and cropping of images method is set by resize_and_crop
if opt.resize_or_crop == 'resize_and_crop':
#get the set image size
osize = [opt.loadSize, opt.loadSize]
#save resized result
transform_list.append(transforms.Scale(osize, Image.BICUBIC))
#save cropped randomly result
transform_list.append(transforms.RandomCrop(opt.fineSize))
#if crop only
elif opt.resize_or_crop == 'crop':
#crop randomly
transform_list.append(transforms.RandomCrop(opt.fineSize))
#if scale only
elif opt.resize_or_crop == 'scale_width':
#scale the image
transform_list.append(transforms.Lambda(
lambda img: __scale_width(img, opt.fineSize)))
#if scale with crop
elif opt.resize_or_crop == 'scale_width_and_crop':
#save scaled result
transform_list.append(transforms.Lambda(
lambda img: __scale_width(img, opt.loadSize)))
#save cropped randomly result
transform_list.append(transforms.RandomCrop(opt.fineSize))

#see if need to add flipped images
if opt.isTrain and not opt.no_flip:
transform_list.append(transforms.RandomHorizontalFlip())


transform_list += [transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5))]
#Composes several transforms together, and return the new transform
return transforms.Compose(transform_list)

'''
* @name: __scale_width
* @description: scale the input image based on the width
* @param img: the input image
* @param target_width: the target width after scaling
* @return: return the image after scaling
'''
def __scale_width(img, target_width):
#get the original size of image
ow, oh = img.size
if (ow == target_width):
return img
#calculate the new width and height
w = target_width
h = int(target_width * oh / ow)
#scale the image
return img.resize((w, h), Image.BICUBIC)
47 changes: 43 additions & 4 deletions data/custom_dataset_data_loader.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,81 @@
import torch.utils.data
from data.base_data_loader import BaseDataLoader


'''
* @name: CreateDataset
* @description: scale the input image based on the width
* @param opt: the parameter set
* @return: return the data set
'''
def CreateDataset(opt):
dataset = None
#if the mode is "1_to_n"
if opt.dataset_mode == '1_to_n':
from data.paired_1_to_n_dataset import Paired1ToNDataset
#call class Paired1ToNDataset
dataset = Paired1ToNDataset()
#if the mode is "test_dir"
elif opt.dataset_mode == 'test_dir':
from data.test_dir_dataset import TestDirDataset
#call class TestDirDataset
dataset = TestDirDataset()
#error msg
else:
raise ValueError("Dataset [%s] not recognized." % opt.dataset_mode)

print("dataset [%s] was created" % (dataset.name()))
#initialize the dataset, basically process data augumentation
dataset.initialize(opt)
return dataset

return dataset

#the class inheriated from BaseDataLoader
class CustomDatasetDataLoader(BaseDataLoader):

'''
* @name: name
* @description: return the name of the object
* @return: return the name
'''
def name(self):
return 'CustomDatasetDataLoader'

'''
* @name: initialize
* @description: initialize the BaseDataLoader
* @param opt: the parameter set
'''
def initialize(self, opt):
#initialize the base class
BaseDataLoader.initialize(self, opt)
#set the dataset
self.dataset = CreateDataset(opt)
#set the dataloader
self.dataloader = torch.utils.data.DataLoader(
self.dataset,
batch_size=opt.batchSize,
shuffle=not opt.serial_batches,
num_workers=int(opt.nThreads))

'''
* @name: initialize
* @description: initialize the BaseDataLoader
* @param opt: the parameter set
'''
def load_data(self):
return self

'''
* @name: __len__
* @description: get the length of the data set
* @return: length of the data set
'''
def __len__(self):
return min(len(self.dataset), self.opt.max_dataset_size)

'''
* @name: __iter__
* @description: the iterator of the data set
* @return: the iterator of the data set
'''
def __iter__(self):
for i, data in enumerate(self.dataloader):
if i >= self.opt.max_dataset_size:
Expand Down
Loading