-
Notifications
You must be signed in to change notification settings - Fork 1
/
modified_ignite_engine.py
56 lines (46 loc) · 2.42 KB
/
modified_ignite_engine.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from ignite.utils import convert_tensor
from ignite.engine.engine import Engine
from apex import amp
def _prepare_batch(batch, device=None, non_blocking=False):
"""Prepare batch for training: pass to a device with options.
"""
x, y = batch
return (convert_tensor(x, device=device, non_blocking=non_blocking),
convert_tensor(y, device=device, non_blocking=non_blocking))
def create_supervised_trainer(model, optimizer, loss_fn,
device=None, accumulation_steps=1, non_blocking=False,
prepare_batch=_prepare_batch,
output_transform=lambda x, y, y_pred, loss: loss.item()):
"""
Factory function for creating a trainer for supervised models.
Args:
model (`torch.nn.Module`): the model to train.
optimizer (`torch.optim.Optimizer`): the optimizer to use.
loss_fn (torch.nn loss function): the loss function to use.
device (str, optional): device type specification (default: None).
Applies to both model and batches.
non_blocking (bool, optional): if True and this copy is between CPU and GPU, the copy may occur asynchronously
with respect to the host. For other cases, this argument has no effect.
prepare_batch (callable, optional): function that receives `batch`, `device`, `non_blocking` and outputs
tuple of tensors `(batch_x, batch_y)`.
output_transform (callable, optional): function that receives 'x', 'y', 'y_pred', 'loss' and returns value
to be assigned to engine's state.output after each iteration. Default is returning `loss.item()`.
Note: `engine.state.output` for this engine is defind by `output_transform` parameter and is the loss
of the processed batch by default.
Returns:
Engine: a trainer engine with supervised update function.
"""
if device:
model.to(device)
def _update(engine, batch):
model.train()
x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
y_pred = model(x)
loss = loss_fn(y_pred, y) / accumulation_steps
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
if engine.state.iteration % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
return output_transform(x, y, y_pred, loss)
return Engine(_update)