resnet10-model

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_channels, out_channels, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != self.expansion * out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, self.expansion * out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * out_channels)
            )

    def forward(self, x):
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x) # residual
        out = self.relu(out)
        return out

class ResNet10(nn.Module):
    def __init__(self, num_classes=20):
        super(ResNet10, self).__init__()
        self.in_channels = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.layer1 = self._make_layer(BasicBlock, 64, 1, stride=1)
        self.layer2 = self._make_layer(BasicBlock, 128, 1, stride=2)
        self.layer3 = self._make_layer(BasicBlock, 256, 1, stride=2)
        self.layer4 = self._make_layer(BasicBlock, 512, 1, stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * BasicBlock.expansion, num_classes)

    def _make_layer(self, block, out_channels, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for s in strides:
            layers.append(block(self.in_channels, out_channels, s))
            self.in_channels = out_channels * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x, return_cam=False):
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.maxpool(out)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        fmap = self.layer4(out)
        out = self.avgpool(fmap)
        out = torch.flatten(out, 1)
        out = self.fc(out)
        
        if return_cam:
            cam = F.conv2d(fmap, self.fc.weight.view(self.fc.out_features, self.fc.in_features, 1, 1))
            if self.fc.bias is not None:
                cam = cam + self.fc.bias.view(1, -1, 1, 1)
            return out, cam
        return out

    def generate_gradcam(self, input_batch, target_class=None):
        self.eval()
        
        gradients = []
        activations = []

        def save_gradient(grad):
            gradients.append(grad)

        def save_activation(module, input, output):
            activations.append(output)
            output.register_hook(save_gradient)

        target_layer = self.layer4[-1].conv2
        hook = target_layer.register_forward_hook(save_activation)

        with torch.enable_grad():
            output = self(input_batch)
            if target_class is None:
                target_class = output.argmax(dim=1)
            
            self.zero_grad()
            
            one_hot = torch.zeros_like(output)
            one_hot[torch.arange(output.shape[0]), target_class] = 1
            
            output.backward(gradient=one_hot, retain_graph=True)

        hook.remove()

        grads = gradients[0]
        fmaps = activations[0]

        weights = grads.mean(dim=(2, 3), keepdim=True)
        cam = (weights * fmaps).sum(dim=1, keepdim=True)
        cam = F.relu(cam)

        cam = F.interpolate(cam, size=input_batch.shape[2:], mode='bilinear', align_corners=False)

        cam_min = cam.view(cam.size(0), -1).min(dim=1)[0].view(-1, 1, 1, 1)
        cam_max = cam.view(cam.size(0), -1).max(dim=1)[0].view(-1, 1, 1, 1)
        cam = (cam - cam_min) / (cam_max - cam_min + 1e-8)

        return cam.detach()
============================================================================================================================================
Layer (type:depth-idx)                   Input Shape               Output Shape              Param #                   Kernel Shape
============================================================================================================================================
ResNet10                                 [64, 3, 240, 240]         [64, 20]                  --                        --
├─Conv2d: 1-1                            [64, 3, 240, 240]         [64, 64, 120, 120]        9,408                     [7, 7]
├─BatchNorm2d: 1-2                       [64, 64, 120, 120]        [64, 64, 120, 120]        128                       --
├─ReLU: 1-3                              [64, 64, 120, 120]        [64, 64, 120, 120]        --                        --
├─MaxPool2d: 1-4                         [64, 64, 120, 120]        [64, 64, 60, 60]          --                        3
├─Sequential: 1-5                        [64, 64, 60, 60]          [64, 64, 60, 60]          --                        --
│    └─BasicBlock: 2-1                   [64, 64, 60, 60]          [64, 64, 60, 60]          --                        --
│    │    └─Conv2d: 3-1                  [64, 64, 60, 60]          [64, 64, 60, 60]          36,864                    [3, 3]
│    │    └─BatchNorm2d: 3-2             [64, 64, 60, 60]          [64, 64, 60, 60]          128                       --
│    │    └─ReLU: 3-3                    [64, 64, 60, 60]          [64, 64, 60, 60]          --                        --
│    │    └─Conv2d: 3-4                  [64, 64, 60, 60]          [64, 64, 60, 60]          36,864                    [3, 3]
│    │    └─BatchNorm2d: 3-5             [64, 64, 60, 60]          [64, 64, 60, 60]          128                       --
│    │    └─Sequential: 3-6              [64, 64, 60, 60]          [64, 64, 60, 60]          --                        --
│    │    └─ReLU: 3-7                    [64, 64, 60, 60]          [64, 64, 60, 60]          --                        --
├─Sequential: 1-6                        [64, 64, 60, 60]          [64, 128, 30, 30]         --                        --
│    └─BasicBlock: 2-2                   [64, 64, 60, 60]          [64, 128, 30, 30]         --                        --
│    │    └─Conv2d: 3-8                  [64, 64, 60, 60]          [64, 128, 30, 30]         73,728                    [3, 3]
│    │    └─BatchNorm2d: 3-9             [64, 128, 30, 30]         [64, 128, 30, 30]         256                       --
│    │    └─ReLU: 3-10                   [64, 128, 30, 30]         [64, 128, 30, 30]         --                        --
│    │    └─Conv2d: 3-11                 [64, 128, 30, 30]         [64, 128, 30, 30]         147,456                   [3, 3]
│    │    └─BatchNorm2d: 3-12            [64, 128, 30, 30]         [64, 128, 30, 30]         256                       --
│    │    └─Sequential: 3-13             [64, 64, 60, 60]          [64, 128, 30, 30]         8,448                     --
│    │    └─ReLU: 3-14                   [64, 128, 30, 30]         [64, 128, 30, 30]         --                        --
├─Sequential: 1-7                        [64, 128, 30, 30]         [64, 256, 15, 15]         --                        --
│    └─BasicBlock: 2-3                   [64, 128, 30, 30]         [64, 256, 15, 15]         --                        --
│    │    └─Conv2d: 3-15                 [64, 128, 30, 30]         [64, 256, 15, 15]         294,912                   [3, 3]
│    │    └─BatchNorm2d: 3-16            [64, 256, 15, 15]         [64, 256, 15, 15]         512                       --
│    │    └─ReLU: 3-17                   [64, 256, 15, 15]         [64, 256, 15, 15]         --                        --
│    │    └─Conv2d: 3-18                 [64, 256, 15, 15]         [64, 256, 15, 15]         589,824                   [3, 3]
│    │    └─BatchNorm2d: 3-19            [64, 256, 15, 15]         [64, 256, 15, 15]         512                       --
│    │    └─Sequential: 3-20             [64, 128, 30, 30]         [64, 256, 15, 15]         33,280                    --
│    │    └─ReLU: 3-21                   [64, 256, 15, 15]         [64, 256, 15, 15]         --                        --
├─Sequential: 1-8                        [64, 256, 15, 15]         [64, 512, 8, 8]           --                        --
│    └─BasicBlock: 2-4                   [64, 256, 15, 15]         [64, 512, 8, 8]           --                        --
│    │    └─Conv2d: 3-22                 [64, 256, 15, 15]         [64, 512, 8, 8]           1,179,648                 [3, 3]
│    │    └─BatchNorm2d: 3-23            [64, 512, 8, 8]           [64, 512, 8, 8]           1,024                     --
│    │    └─ReLU: 3-24                   [64, 512, 8, 8]           [64, 512, 8, 8]           --                        --
│    │    └─Conv2d: 3-25                 [64, 512, 8, 8]           [64, 512, 8, 8]           2,359,296                 [3, 3]
│    │    └─BatchNorm2d: 3-26            [64, 512, 8, 8]           [64, 512, 8, 8]           1,024                     --
│    │    └─Sequential: 3-27             [64, 256, 15, 15]         [64, 512, 8, 8]           132,096                   --
│    │    └─ReLU: 3-28                   [64, 512, 8, 8]           [64, 512, 8, 8]           --                        --
├─AdaptiveAvgPool2d: 1-9                 [64, 512, 8, 8]           [64, 512, 1, 1]           --                        --
├─Linear: 1-10                           [64, 512]                 [64, 20]                  10,260                    --
============================================================================================================================================
Total params: 4,916,052
Trainable params: 4,916,052
Non-trainable params: 0
Total mult-adds (Units.GIGABYTES): 67.11
============================================================================================================================================
Input size (MB): 44.24
Forward/backward pass size (MB): 2047.09
Params size (MB): 19.66
Estimated Total Size (MB): 2110.99
============================================================================================================================================