Skip to content

Issue with Performance Drop after Permutation Smoothing on ResNet-18 #3

@secret-hammer

Description

@secret-hammer

I'm attempting to reproduce the results of applying permutation-based parameter smoothing to the ResNet-18 model, but I'm seeing a significant performance drop after the permutation step. Below are the details of the issue, including the code I used and the experimental results.

Experiment Details:

  • Model: ResNet-18
  • Data: Office-Home dataset (Art domain)
  • Procedure:
    1. Loaded the pre-trained ResNet-18 model (model.pt).
    2. Tested baseline accuracy using test_model_base.
    3. Applied permutation-based smoothing using PermutationManager.
    4. Evaluated the accuracy of the permuted model.

Code:

Experiment Code:

import models
import loaders
import torch
import torch.nn as nn
import metrics
from utils.train_util import AverageMeter
from utils import permute

def test_model_base(test_loader, model_base, device):
    model_base.eval()
    acc_meter = AverageMeter('Acc', ':6.2f')

    for i, samples in enumerate(test_loader):
        inputs, labels, _ = samples
        inputs = inputs.to(device)
        labels = labels.to(device)

        with torch.set_grad_enabled(False):
            outputs = model_base(inputs)
            acc = metrics.accuracy(outputs, labels)[0]
            acc_meter.update(acc.item(), inputs.size(0))

    return acc_meter


def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    state_dict = torch.load('/nfs196/wjx/projects/PP/outputs/rn18_OH_Ar_base_10/Art/model.pt', weights_only=True)
    model = models.load_model('rn18', num_classes=65)
    model.load_state_dict(state_dict['last_param'])
    model = model.to(device)
    model.eval()
    
    test_loader = loaders.load_images("/nfs196/hjc/datasets/Office-Home/Art", 'Office-Home', data_type='test', batch_size=512)
    
    acc_init = test_model_base(test_loader, model, device)
    
    # Compute the total variation loss for the network
    total_tv = permute.compute_tv_loss_for_network(model, lambda_tv=1.0)
    print("Total Total Variation After Training:", total_tv)

    # Apply permutations to the model's layers and check the total variation
    input_tensor = torch.randn(1, 3, 224, 224).to(device)
    permute_func = permute.PermutationManager(model, input_tensor)
    permute_dict = permute_func.compute_permute_dict()
    model_permute = permute_func.apply_permutations(permute_dict, ignored_keys=[])
    total_tv = permute.compute_tv_loss_for_network(model_permute, lambda_tv=1.0)
    print("Total Total Variation After Permute:", total_tv)
    
    acc_permute = test_model_base(test_loader, model_permute, device)
    
    print("Initial accuracy: ", acc_init.avg)
    print("Accuracy after permutation: ", acc_permute.avg)
    
if __name__ == '__main__':
    main()

ResNet-18 Model Code:

import torch
import torch.nn as nn
import torch.nn.functional as F

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion * planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10, input_channel=3):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(input_channel, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.adaptive_avg_pool2d(out, (1, 1))
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

Observations:

Here is the output from the experiment:

Total Total Variation After Training: tensor(3745.0908, device='cuda:0', grad_fn=<DivBackward0>)
Total Total Variation After Permute: tensor(3479.0449, device='cuda:0', grad_fn=<DivBackward0>)
Initial accuracy:  68.39719817763157
Accuracy after permutation:  1.8953440410392393

Issue:

The accuracy after permutation is unexpectedly much lower than the initial accuracy. The total variation loss decreases after smoothing, but the accuracy drops significantly, which is not the expected behavior. Is there any issue with the permutation procedure or its configuration that could cause such a drastic drop in performance?

Thank you for your help!

Best,
[Secret Hammer]

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions