resnet10-model
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, in_channels, out_channels, stride=1):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
self.shortcut = nn.Sequential()
if stride != 1 or in_channels != self.expansion * out_channels:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, self.expansion * out_channels, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion * out_channels)
)
def forward(self, x):
out = self.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x) # residual
out = self.relu(out)
return out
class ResNet10(nn.Module):
def __init__(self, num_classes=20):
super(ResNet10, self).__init__()
self.in_channels = 64
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(BasicBlock, 64, 1, stride=1)
self.layer2 = self._make_layer(BasicBlock, 128, 1, stride=2)
self.layer3 = self._make_layer(BasicBlock, 256, 1, stride=2)
self.layer4 = self._make_layer(BasicBlock, 512, 1, stride=2)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * BasicBlock.expansion, num_classes)
def _make_layer(self, block, out_channels, num_blocks, stride):
strides = [stride] + [1]*(num_blocks-1)
layers = []
for s in strides:
layers.append(block(self.in_channels, out_channels, s))
self.in_channels = out_channels * block.expansion
return nn.Sequential(*layers)
def forward(self, x, return_cam=False):
out = self.relu(self.bn1(self.conv1(x)))
out = self.maxpool(out)
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
fmap = self.layer4(out)
out = self.avgpool(fmap)
out = torch.flatten(out, 1)
out = self.fc(out)
if return_cam:
cam = F.conv2d(fmap, self.fc.weight.view(self.fc.out_features, self.fc.in_features, 1, 1))
if self.fc.bias is not None:
cam = cam + self.fc.bias.view(1, -1, 1, 1)
return out, cam
return out
def generate_gradcam(self, input_batch, target_class=None):
self.eval()
gradients = []
activations = []
def save_gradient(grad):
gradients.append(grad)
def save_activation(module, input, output):
activations.append(output)
output.register_hook(save_gradient)
target_layer = self.layer4[-1].conv2
hook = target_layer.register_forward_hook(save_activation)
with torch.enable_grad():
output = self(input_batch)
if target_class is None:
target_class = output.argmax(dim=1)
self.zero_grad()
one_hot = torch.zeros_like(output)
one_hot[torch.arange(output.shape[0]), target_class] = 1
output.backward(gradient=one_hot, retain_graph=True)
hook.remove()
grads = gradients[0]
fmaps = activations[0]
weights = grads.mean(dim=(2, 3), keepdim=True)
cam = (weights * fmaps).sum(dim=1, keepdim=True)
cam = F.relu(cam)
cam = F.interpolate(cam, size=input_batch.shape[2:], mode='bilinear', align_corners=False)
cam_min = cam.view(cam.size(0), -1).min(dim=1)[0].view(-1, 1, 1, 1)
cam_max = cam.view(cam.size(0), -1).max(dim=1)[0].view(-1, 1, 1, 1)
cam = (cam - cam_min) / (cam_max - cam_min + 1e-8)
return cam.detach()
============================================================================================================================================
Layer (type:depth-idx) Input Shape Output Shape Param # Kernel Shape
============================================================================================================================================
ResNet10 [64, 3, 240, 240] [64, 20] -- --
├─Conv2d: 1-1 [64, 3, 240, 240] [64, 64, 120, 120] 9,408 [7, 7]
├─BatchNorm2d: 1-2 [64, 64, 120, 120] [64, 64, 120, 120] 128 --
├─ReLU: 1-3 [64, 64, 120, 120] [64, 64, 120, 120] -- --
├─MaxPool2d: 1-4 [64, 64, 120, 120] [64, 64, 60, 60] -- 3
├─Sequential: 1-5 [64, 64, 60, 60] [64, 64, 60, 60] -- --
│ └─BasicBlock: 2-1 [64, 64, 60, 60] [64, 64, 60, 60] -- --
│ │ └─Conv2d: 3-1 [64, 64, 60, 60] [64, 64, 60, 60] 36,864 [3, 3]
│ │ └─BatchNorm2d: 3-2 [64, 64, 60, 60] [64, 64, 60, 60] 128 --
│ │ └─ReLU: 3-3 [64, 64, 60, 60] [64, 64, 60, 60] -- --
│ │ └─Conv2d: 3-4 [64, 64, 60, 60] [64, 64, 60, 60] 36,864 [3, 3]
│ │ └─BatchNorm2d: 3-5 [64, 64, 60, 60] [64, 64, 60, 60] 128 --
│ │ └─Sequential: 3-6 [64, 64, 60, 60] [64, 64, 60, 60] -- --
│ │ └─ReLU: 3-7 [64, 64, 60, 60] [64, 64, 60, 60] -- --
├─Sequential: 1-6 [64, 64, 60, 60] [64, 128, 30, 30] -- --
│ └─BasicBlock: 2-2 [64, 64, 60, 60] [64, 128, 30, 30] -- --
│ │ └─Conv2d: 3-8 [64, 64, 60, 60] [64, 128, 30, 30] 73,728 [3, 3]
│ │ └─BatchNorm2d: 3-9 [64, 128, 30, 30] [64, 128, 30, 30] 256 --
│ │ └─ReLU: 3-10 [64, 128, 30, 30] [64, 128, 30, 30] -- --
│ │ └─Conv2d: 3-11 [64, 128, 30, 30] [64, 128, 30, 30] 147,456 [3, 3]
│ │ └─BatchNorm2d: 3-12 [64, 128, 30, 30] [64, 128, 30, 30] 256 --
│ │ └─Sequential: 3-13 [64, 64, 60, 60] [64, 128, 30, 30] 8,448 --
│ │ └─ReLU: 3-14 [64, 128, 30, 30] [64, 128, 30, 30] -- --
├─Sequential: 1-7 [64, 128, 30, 30] [64, 256, 15, 15] -- --
│ └─BasicBlock: 2-3 [64, 128, 30, 30] [64, 256, 15, 15] -- --
│ │ └─Conv2d: 3-15 [64, 128, 30, 30] [64, 256, 15, 15] 294,912 [3, 3]
│ │ └─BatchNorm2d: 3-16 [64, 256, 15, 15] [64, 256, 15, 15] 512 --
│ │ └─ReLU: 3-17 [64, 256, 15, 15] [64, 256, 15, 15] -- --
│ │ └─Conv2d: 3-18 [64, 256, 15, 15] [64, 256, 15, 15] 589,824 [3, 3]
│ │ └─BatchNorm2d: 3-19 [64, 256, 15, 15] [64, 256, 15, 15] 512 --
│ │ └─Sequential: 3-20 [64, 128, 30, 30] [64, 256, 15, 15] 33,280 --
│ │ └─ReLU: 3-21 [64, 256, 15, 15] [64, 256, 15, 15] -- --
├─Sequential: 1-8 [64, 256, 15, 15] [64, 512, 8, 8] -- --
│ └─BasicBlock: 2-4 [64, 256, 15, 15] [64, 512, 8, 8] -- --
│ │ └─Conv2d: 3-22 [64, 256, 15, 15] [64, 512, 8, 8] 1,179,648 [3, 3]
│ │ └─BatchNorm2d: 3-23 [64, 512, 8, 8] [64, 512, 8, 8] 1,024 --
│ │ └─ReLU: 3-24 [64, 512, 8, 8] [64, 512, 8, 8] -- --
│ │ └─Conv2d: 3-25 [64, 512, 8, 8] [64, 512, 8, 8] 2,359,296 [3, 3]
│ │ └─BatchNorm2d: 3-26 [64, 512, 8, 8] [64, 512, 8, 8] 1,024 --
│ │ └─Sequential: 3-27 [64, 256, 15, 15] [64, 512, 8, 8] 132,096 --
│ │ └─ReLU: 3-28 [64, 512, 8, 8] [64, 512, 8, 8] -- --
├─AdaptiveAvgPool2d: 1-9 [64, 512, 8, 8] [64, 512, 1, 1] -- --
├─Linear: 1-10 [64, 512] [64, 20] 10,260 --
============================================================================================================================================
Total params: 4,916,052
Trainable params: 4,916,052
Non-trainable params: 0
Total mult-adds (Units.GIGABYTES): 67.11
============================================================================================================================================
Input size (MB): 44.24
Forward/backward pass size (MB): 2047.09
Params size (MB): 19.66
Estimated Total Size (MB): 2110.99
============================================================================================================================================