import torch import torch.nn as nn import torch.nn.utils.prune as prune from torchvision import datasets, transforms from torch.utils.data import DataLoader # 1. 定义模型 class SimpleCNN(nn.Module): def __init__(self): super().__init__() self.conv1 = nn.Conv2d(1, 32, kernel_size=3) self.fc1 = nn.Linear(32 * 13 * 13, 10) # 假设输入为28x28,经过池化后为13x13 def forward(self, x): x = torch.relu(self.conv1(x)) x = x.view(x.size(0), -1) x = self.fc1(x) return x # 2. 加载数据 transform = transforms.Compose([transforms.ToTensor()]) train_data = datasets.MNIST('./data', train=True, download=True, transform=transform) train_loader = DataLoader(train_data, batch_size=64, shuffle=True) # 3. 初始化模型 model = SimpleCNN() criterion = nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) # 4. 训练原始模型(可选) def train(model, epochs=5): model.train() for epoch in range(epochs): for data, target in train_loader: optimizer.zero_grad() output = model(data) loss = criterion(output, target) loss.backward() optimizer.step() train(model) # 可省略,直接使用预训练权重 # 5. 剪枝 # 对conv1层按L1范数剪枝50% prune.l1_unstructured(module=model.conv1, name='weight', amount=0.5) # 对fc1层剪枝30% prune.l1_unstructured(module=model.fc1, name='weight', amount=0.3) # 6. 移除剪枝掩码(永久剪枝) def remove_pruning(model): for name, module in model.named_modules(): if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear): prune.remove(module, 'weight') remove_pruning(model) # 7. 微调剪枝后模型 train(model, epochs=3) # 8. 评估模型 def evaluate(model): model.eval() test_data = datasets.MNIST('./data', train=False, transform=transform) test_loader = DataLoader(test_data, batch_size=64) correct = 0 with torch.no_grad(): for data, target in test_loader: output = model(data) pred = output.argmax(dim=1) correct += pred.eq(target).sum().item() print(f"Accuracy: {correct / len(test_data):.2f}") evaluate(model)