Skip to content

Commit

Permalink
Upgrade to torch 1.6 & ddp (#9)
Browse files Browse the repository at this point in the history
* doc

* upgrade for segmentation main

* update pip

* split batches

* modify other segmentation codes

* upgrade classification codes for multi-GPU

* doc

* notes
  • Loading branch information
voldemortX authored Jun 7, 2021
1 parent 53853f6 commit 6011a78
Show file tree
Hide file tree
Showing 12 changed files with 286 additions and 179 deletions.
14 changes: 11 additions & 3 deletions CLASSIFICATION.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,15 @@ The CIFAR-10 dataset can be downloaded and splitted to 5 random splits and valid

## Run the code

We provide examples in scripts and commands. Final results can be found at log.txt after training.
For multi-GPU/TPU/Distributed machine users, first run:

```
accelerate config
```

More details can be found at [Accelerate](https://github.com/huggingface/accelerate). Note that the mixed precision config cannot be used, you should still use `--mixed-precision` for that.

We provide examples in scripts and commands. Final results can be found at `log.txt` after training.

For example, with 1000 labels, to compare CL and DMT in a controlled experiment with same baseline model to start training:

Expand All @@ -43,6 +51,6 @@ For example, with 1000 labels, to compare CL and DMT in a controlled experiment
./ss-dmt-full-1.sh
```

Of course you'll need to run 5 times average to determine performance by changing the *seed* parameter (we used 1,2,3,4,5) in shell scripts.
You'll need to run 5 times average to determine performance by changing the `seed` parameter (we used 1,2,3,4,5) in shell scripts.

For small validation set, use *--valtiny*; for fine-grained testing, use *--fine-grain*.
For small validation set, use `--valtiny`; for fine-grained testing, use `--fine-grain`.
24 changes: 14 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,19 @@ This repository contains the code for our paper [DMT: Dynamic Mutual Training fo

Some might know it as the previous version **DST-CBC**, or *Semi-Supervised Semantic Segmentation via Dynamic Self-Training and Class-Balanced Curriculum*, if you want the old code, you can check out the [dst-cbc](https://github.com/voldemortX/DST-CBC/tree/dst-cbc) branch.

Also, for older PyTorch version (<1.6.0) users, or the **exact** same environment that produced the paper's results, refer to 53853f6.

<div align="center">
<img src="overview.png"/>
</div>

## News

### 2021.6.7

**Multi-GPU** training support (based on [Accelerate](https://github.com/huggingface/accelerate)) is added, and the whole project is upgraded to PyTorch 1.6.
Thanks to the codes & testing by [**@jinhuan-hit**](https://github.com/jinhuan-hit), and discussions from [**@lorenmt**](https://github.com/lorenmt), [**@TiankaiHang**](https://github.com/TiankaiHang).

### 2021.2.10

A slight backbone architecture difference in the segmentation task has just been identified and described in Acknowledgement.
Expand All @@ -31,24 +38,18 @@ Also, thanks to [**@lorenmt**](https://github.com/lorenmt), a data augmentation

## Setup

You'll need a CUDA 10, Python3 environment (best on Linux) with PyTorch 1.2.0, TorchVision 0.4.0 and Apex to run the code in this repo.
First, you'll need a CUDA 10, Python3 environment (best on Linux).

### 1. Setup the exact version of Apex & PyTorch & TorchVision for mixed precision training:
### 1. Setup PyTorch & TorchVision:

```
pip install https://download.pytorch.org/whl/cu100/torch-1.2.0-cp36-cp36m-manylinux1_x86_64.whl && pip install https://download.pytorch.org/whl/cu100/torchvision-0.4.0-cp36-cp36m-manylinux1_x86_64.whl
git clone https://github.com/NVIDIA/apex
cd apex
pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
pip install torch==1.6.0 torchvision==0.7.0
```
!There seems to be an issue of apex installations from the official repo sometimes. If you encounter errors, we suggest you use our stored older apex [codes](https://drive.google.com/open?id=1x8enpvdTTZ3RChf17XvcLdSYulUPg3sR).

**PyTorch 1.6** now includes automatic mixed precision at apex level "O1". We probably will update this repo accordingly in the future.

### 2. Install other python packages you may require:

```
pip install future matplotlib tensorboard tqdm
pip install packaging accelerate future matplotlib tensorboard tqdm
```

### 3. Download the code and prepare the scripts:
Expand All @@ -67,9 +68,11 @@ Get started with [SEGMENTATION.md](SEGMENTATION.md) for semantic segmentation.
Get started with [CLASSIFICATION.md](CLASSIFICATION.md) for image classification.

## Understand the code

We refer interested readers to this repository's [wiki](https://github.com/voldemortX/DST-CBC/wiki). *It is not updated for DMT yet.*

## Notes

It's best to use a **Turing** or **Volta** architecture GPU when running our code, since they have tensor cores and the computation speed is much faster with mixed precision. For instance, RTX 2080 Ti (which is what we used) or Tesla V100, RTX 20/30 series.

Our implementation is fast and memory efficient. A whole run (train 2 models by DMT on PASCAL VOC 2012) takes about 8 hours on a single RTX 2080 Ti using up to 6GB graphic memory, including on-the-fly evaluations and training baselines. The Cityscapes experiments are even faster.
Expand Down Expand Up @@ -99,3 +102,4 @@ The CBC part of the older version DST-CBC is adapted from [CRST](https://github.

The overall implementation is based on [TorchVision](https://github.com/pytorch/vision) and [PyTorch](https://github.com/pytorch/pytorch).

The people who've helped to make the method & code better: [**lorenmt**](https://github.com/lorenmt), [**jinhuan-hit**](https://github.com/jinhuan-hit), [**TiankaiHang**](https://github.com/TiankaiHang), etc.
12 changes: 10 additions & 2 deletions SEGMENTATION.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,14 @@ ImageNet pre-trained weights will be automatically downloaded when running code.

## Run the code

For multi-GPU/TPU/Distributed machine users, first run:

```
accelerate config
```

More details can be found at [Accelerate](https://github.com/huggingface/accelerate). Note that the mixed precision config cannot be used, you should still use `--mixed-precision` for that.

We provide examples in scripts and commands. Final results can be found at log.txt after training.

For example, run DMT with different pre-trained weights:
Expand All @@ -120,8 +128,8 @@ python pascal_sbd_split.py
```


Of course you'll need to run 3 times average to determine performance by changing the *sid* parameter (we used 0,1,2) in shell scripts.
Of course you'll need to run 3 times average to determine performance by changing the `sid` parameter (we used 0,1,2) in shell scripts.

We also provide scripts for ablations, be sure to run *abl_baseline.sh* first.

For small validation set, use *--valtiny*.
For small validation set, use `--valtiny`.
63 changes: 39 additions & 24 deletions classification/main_dmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import numpy as np
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
from apex import amp
from models.wideresnet import wrn_28_2
from utils.common import num_classes_cifar10, mean_cifar10, std_cifar10, input_sizes_cifar10, base_cifar10, \
load_checkpoint, save_checkpoint, EMA, rank_label_confidence
Expand All @@ -18,6 +17,8 @@
from utils.randomrandaugment import RandomRandAugment
from utils.cutout import Cutout
from utils.autoaugment import CIFAR10Policy
from accelerate import Accelerator
from torch.cuda.amp import autocast, GradScaler


def get_transforms(auto_augment, input_sizes, m, mean, n, std):
Expand Down Expand Up @@ -48,8 +49,9 @@ def get_transforms(auto_augment, input_sizes, m, mean, n, std):
return test_transforms, train_transforms


def generate_pseudo_labels(net, device, loader, label_ratio, num_images, filename):
k = rank_label_confidence(net=net, device=device, loader=loader, ratio=label_ratio, num_images=num_images)
def generate_pseudo_labels(net, device, loader, label_ratio, num_images, filename, is_mixed_precision):
k = rank_label_confidence(net=net, device=device, loader=loader, ratio=label_ratio, num_images=num_images,
is_mixed_precision=is_mixed_precision)
print(k)
# 1 forward pass (build pickle file)
selected_files = None
Expand All @@ -59,9 +61,10 @@ def generate_pseudo_labels(net, device, loader, label_ratio, num_images, filenam
for images, original_file in tqdm(loader):
# Inference
images = images.to(device)
outputs = net(images)
temp = torch.nn.functional.softmax(input=outputs, dim=-1) # ! softmax
pseudo_probabilities = temp.max(dim=-1).values
with autocast(is_mixed_precision):
outputs = net(images)
temp = torch.nn.functional.softmax(input=outputs, dim=-1) # ! softmax
pseudo_probabilities = temp.max(dim=-1).values

# Select
temp_predictions = temp[pseudo_probabilities > k].cpu().numpy()
Expand Down Expand Up @@ -110,7 +113,7 @@ def init(mean, std, input_sizes, base, num_workers, prefix, val_set, train, batc
return labeled_loader, unlabeled_loader, pseudo_labeled_loader, val_loader, unlabeled_set.__len__()


def test(loader, device, net, fine_grain=False):
def test(loader, device, net, fine_grain=False, is_mixed_precision=False):
# Evaluate
net.eval()
test_correct = 0
Expand All @@ -119,7 +122,8 @@ def test(loader, device, net, fine_grain=False):
with torch.no_grad():
for image, target in tqdm(loader):
image, target = image.to(device), target.to(device)
output = net(image)
with autocast(is_mixed_precision):
output = net(image)
test_all += target.shape[0]
if fine_grain:
predictions = output.softmax(1)
Expand Down Expand Up @@ -154,6 +158,9 @@ def train(writer, labeled_loader, pseudo_labeled_loader, val_loader, device, cri
if val_num_steps is None:
val_num_steps = min_len

if is_mixed_precision:
scaler = GradScaler()

net.train()

# Use EMA to report final performance instead of select best checkpoint with valtiny
Expand Down Expand Up @@ -203,7 +210,8 @@ def train(writer, labeled_loader, pseudo_labeled_loader, val_loader, device, cri
split_index=inputs_pseudo.shape[0], labeled_weight=labeled_weight)
inputs, dynamic_weights, labels_a, labels_b, lam = mixup_data(x=inputs, w=dynamic_weights, y=labels,
alpha=alpha, keep_max=True)
outputs = net(inputs)
with autocast(is_mixed_precision):
outputs = net(inputs)

if alpha != -1:
# Pseudo training accuracy & interesting loss
Expand All @@ -218,12 +226,12 @@ def train(writer, labeled_loader, pseudo_labeled_loader, val_loader, device, cri
gamma1=gamma1, gamma2=gamma2)

if is_mixed_precision:
# 2/3 & 3/3 of mixed precision training with amp
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
accelerator.backward(scaler.scale(loss))
scaler.step(optimizer)
scaler.update()
else:
loss.backward()
optimizer.step()
accelerator.backward(loss)
optimizer.step()
criterion.step()
if lr_scheduler is not None:
lr_scheduler.step()
Expand Down Expand Up @@ -252,8 +260,9 @@ def train(writer, labeled_loader, pseudo_labeled_loader, val_loader, device, cri
# Validate and find the best snapshot
if current_step_num % val_num_steps == (val_num_steps - 1) or \
current_step_num == num_epochs * len(pseudo_labeled_loader) - 1:
# A bug in Apex? https://github.com/NVIDIA/apex/issues/706
test_acc = test(loader=val_loader, device=device, net=net, fine_grain=fine_grain)
# Apex bug https://github.com/NVIDIA/apex/issues/706, fixed in PyTorch1.6, kept here for BC
test_acc = test(loader=val_loader, device=device, net=net, fine_grain=fine_grain,
is_mixed_precision=is_mixed_precision)
writer.add_scalar(tensorboard_prefix + 'test accuracy',
test_acc,
current_step_num)
Expand Down Expand Up @@ -284,7 +293,7 @@ def train(writer, labeled_loader, pseudo_labeled_loader, val_loader, device, cri

if __name__ == '__main__':
# Settings
parser = argparse.ArgumentParser(description='PyTorch 1.2.0 && torchvision 0.4.0')
parser = argparse.ArgumentParser(description='PyTorch 1.6.0 && torchvision 0.7.0')
parser.add_argument('--exp-name', type=str, default='auto',
help='Name of the experiment (default: auto)')
parser.add_argument('--dataset', type=str, default='cifar10',
Expand Down Expand Up @@ -350,9 +359,11 @@ def train(writer, labeled_loader, pseudo_labeled_loader, val_loader, device, cri
# torch.backends.cudnn.benchmark = False # Might hurt performance
if args.exp_name != 'auto':
exp_name = args.exp_name
device = torch.device('cpu')
if torch.cuda.is_available():
device = torch.device('cuda:0')
# device = torch.device('cpu')
# if torch.cuda.is_available():
# device = torch.device('cuda:0')
accelerator = Accelerator(split_batches=True)
device = accelerator.device
if args.valtiny:
val_set = 'valtiny_seed1'
else:
Expand All @@ -373,8 +384,6 @@ def train(writer, labeled_loader, pseudo_labeled_loader, val_loader, device, cri
params_to_optimize = net.parameters()
optimizer = torch.optim.SGD(params_to_optimize, lr=args.lr, momentum=0.9, weight_decay=args.weight_decay)
# optimizer = torch.optim.Adam(params_to_optimize, lr=args.lr, weight_decay=args.weight_decay)
if args.mixed_precision:
net, optimizer = amp.initialize(net, optimizer, opt_level='O1')

if args.continue_from is not None:
load_checkpoint(net=net, optimizer=None, lr_scheduler=None,
Expand All @@ -385,13 +394,18 @@ def train(writer, labeled_loader, pseudo_labeled_loader, val_loader, device, cri
dataset=args.dataset, n=args.n, m=args.m, auto_augment=args.aa, input_sizes=input_sizes, std=std,
num_workers=args.num_workers, batch_size_pseudo=args.batch_size_pseudo, train=False if args.labeling else True)

net, optimizer, labeled_loader, pseudo_labeled_loader = accelerator.prepare(net, optimizer,
labeled_loader,
pseudo_labeled_loader)

# Pseudo labeling
if args.labeling:
time_now = time.time()
sub_base = CIFAR10.base_folder
filename = os.path.join(base, sub_base, args.train_set + '_pseudo')
generate_pseudo_labels(net=net, device=device, loader=unlabeled_loader, filename=filename,
label_ratio=args.label_ratio, num_images=num_images)
label_ratio=args.label_ratio, num_images=num_images,
is_mixed_precision=args.mixed_precision)
print('Pseudo labeling time: %.2fs' % (time.time() - time_now))
else:
# Mutual-training
Expand All @@ -402,7 +416,8 @@ def train(writer, labeled_loader, pseudo_labeled_loader, val_loader, device, cri
T_max=args.epochs * len(pseudo_labeled_loader))
writer = SummaryWriter('logs/' + exp_name)

best_acc = test(loader=val_loader, device=device, net=net, fine_grain=args.fine_grain)
best_acc = test(loader=val_loader, device=device, net=net, fine_grain=args.fine_grain,
is_mixed_precision=args.mixed_precision)
save_checkpoint(net=net, optimizer=None, lr_scheduler=None, is_mixed_precision=args.mixed_precision)
print('Original acc: ' + str(best_acc))

Expand Down
Loading

0 comments on commit 6011a78

Please sign in to comment.