-
Notifications
You must be signed in to change notification settings - Fork 800
Adding Early Exit branch
Using the Early-Exit method requires attaching exit branches to an exiting models. This can be done by making static changes to the PyTorch model code, or by dynamically adding the branches at runtime.
In Distiller, distiller.EarlyExitMgr
is a utility class that helps us make dynamic changes to the model, in order to attach exit branches.
The EE implementation for the CIFAR10 ResNet architecture resnet_cifar_earlyexit.py is a good example of how this is done in Distiller:
- Define the exit branches.
- Attach the exit branches.
- Define the
forward()
method
An early-exit branch is simply a PyTorch sub-model. It can perform any processing you like, as long as its inputs can be attached to the output of the model. The output of the branch must be the same as that of the original model's output. For example, in a CIFAR10 image classification model the output is a vector of 10 class-probabilities, and this must also be the output of each branch.
def get_exits_def():
exits_def = [('layer1.2.relu2', nn.Sequential(nn.AvgPool2d(3),
nn.Flatten(),
nn.Linear(1600, NUM_CLASSES)))]
return exits_def
Attaching exit branches is straight-forward: instantiate an distiller.EarlyExitMgr
, and invokes its attach_exits
method, passing the model we are attaching to, and a dictionary exit branches. The branches dictionary is keyed by the fully-qualified name of the layer to which we are attaching to. In the example above, we are attaching to layer layer1.2.relu2
.
ee_mgr = distiller.EarlyExitMgr()
ee_mgr.attach_exits(my_model, get_exits_def())
The forward
method of our new model (the original model, now with the attached exits), should return the output of the original model output, plus the outputs of all the newly attached exits. Exits cache their outputs so before computing new outputs, we first clear the caches by invoking ee_mgr.delete_exits_outputs
, then run the forward
method of our model, and finally collect and return the newly cached outputs using ee_mgr.get_exits_outputs
.
class ResNetCifarEarlyExit(ResNetCifar):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.ee_mgr = distiller.EarlyExitMgr()
self.ee_mgr.attach_exits(self, get_exits_def())
def forward(self, x):
self.ee_mgr.delete_exits_outputs(self)
# Run the input through the network (including exits)
x = super().forward(x)
outputs = self.ee_mgr.get_exits_outputs(self) + [x]
return outputs