Simple code to save activations of a model’s intermediate layers.

I recently had to run a trained model and save intermediate-layer activations. This is some simple code to do that using torch forward hooks.

First define the hook:

import torch
import torch.nn.functional as F
from torch import nn
import collections
from typing import DefaultDict, Tuple, List, Dict
from functools import partial

def save_activations(
        activations: DefaultDict,
        name: str,
        module: nn.Module,
        inp: Tuple,
        out: torch.Tensor
) -> None:
    """PyTorch Forward hook to save outputs at each forward
    pass. Mutates specified dict objects with each fwd pass.

Then define a helper method that registers the hook to specified layers of a model (requires functools.partial).

def register_activation_hooks(
        model: nn.Module,
        layers_to_save: List[str]
) -> DefaultDict[List, torch.Tensor]:
    """Registers forward hooks in specified layers.
        PyTorch model
        Module names within ``model`` whose activations we want to save.

        dict of lists containing activations of specified layers in
    activations_dict = collections.defaultdict(list)

    for name, module in model.named_modules():
        if name in layers_to_save:
                partial(save_activations, activations_dict, name)
    return activations_dict

Now define a simple model to test out register_activation_hooks():

class Net(nn.Module):
    """Simple two layer conv net"""
    def __init__(self):
        self.conv1 = nn.Conv2d(3, 8, kernel_size=(5, 5), stride=(2,2))
        self.conv2 = nn.Conv2d(8, 8, kernel_size=(3, 3), stride=(2,2))

    def forward(self, x):
        y = F.relu(self.conv1(x))
        z = F.relu(self.conv2(y))
        return z

mdl = Net()
to_save = ["conv1", "conv2"]

# register fwd hooks in specified layers
saved_activations = register_activation_hooks(mdl, layers_to_save=to_save)

# run twice, then assert each created lists for conv1 and conv2, each with length 2
num_fwd = 2
images = [torch.randn(10, 3, 256, 256) for _ in range(num_fwd)]
for _ in range(num_fwd):

assert len(saved_activations["conv1"]) == num_fwd
assert len(saved_activations["conv2"]) == num_fwd