This commit is contained in:
		
							parent
							
								
									22ef225c8a
								
							
						
					
					
						commit
						164e5967b5
					
				| 
						 | 
					@ -0,0 +1,73 @@
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					import torch.nn as nn
 | 
				
			||||||
 | 
					import torch.nn.utils.prune as prune
 | 
				
			||||||
 | 
					from torchvision import datasets, transforms
 | 
				
			||||||
 | 
					from torch.utils.data import DataLoader
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# 1. 定义模型
 | 
				
			||||||
 | 
					class SimpleCNN(nn.Module):
 | 
				
			||||||
 | 
					    def __init__(self):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
 | 
				
			||||||
 | 
					        self.fc1 = nn.Linear(32 * 13 * 13, 10)  # 假设输入为28x28,经过池化后为13x13
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, x):
 | 
				
			||||||
 | 
					        x = torch.relu(self.conv1(x))
 | 
				
			||||||
 | 
					        x = x.view(x.size(0), -1)
 | 
				
			||||||
 | 
					        x = self.fc1(x)
 | 
				
			||||||
 | 
					        return x
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# 2. 加载数据
 | 
				
			||||||
 | 
					transform = transforms.Compose([transforms.ToTensor()])
 | 
				
			||||||
 | 
					train_data = datasets.MNIST('./data', train=True, download=True, transform=transform)
 | 
				
			||||||
 | 
					train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# 3. 初始化模型
 | 
				
			||||||
 | 
					model = SimpleCNN()
 | 
				
			||||||
 | 
					criterion = nn.CrossEntropyLoss()
 | 
				
			||||||
 | 
					optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# 4. 训练原始模型(可选)
 | 
				
			||||||
 | 
					def train(model, epochs=5):
 | 
				
			||||||
 | 
					    model.train()
 | 
				
			||||||
 | 
					    for epoch in range(epochs):
 | 
				
			||||||
 | 
					        for data, target in train_loader:
 | 
				
			||||||
 | 
					            optimizer.zero_grad()
 | 
				
			||||||
 | 
					            output = model(data)
 | 
				
			||||||
 | 
					            loss = criterion(output, target)
 | 
				
			||||||
 | 
					            loss.backward()
 | 
				
			||||||
 | 
					            optimizer.step()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					train(model)  # 可省略,直接使用预训练权重
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# 5. 剪枝
 | 
				
			||||||
 | 
					# 对conv1层按L1范数剪枝50%
 | 
				
			||||||
 | 
					prune.l1_unstructured(module=model.conv1, name='weight', amount=0.5)
 | 
				
			||||||
 | 
					# 对fc1层剪枝30%
 | 
				
			||||||
 | 
					prune.l1_unstructured(module=model.fc1, name='weight', amount=0.3)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# 6. 移除剪枝掩码(永久剪枝)
 | 
				
			||||||
 | 
					def remove_pruning(model):
 | 
				
			||||||
 | 
					    for name, module in model.named_modules():
 | 
				
			||||||
 | 
					        if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
 | 
				
			||||||
 | 
					            prune.remove(module, 'weight')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					remove_pruning(model)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# 7. 微调剪枝后模型
 | 
				
			||||||
 | 
					train(model, epochs=3)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# 8. 评估模型
 | 
				
			||||||
 | 
					def evaluate(model):
 | 
				
			||||||
 | 
					    model.eval()
 | 
				
			||||||
 | 
					    test_data = datasets.MNIST('./data', train=False, transform=transform)
 | 
				
			||||||
 | 
					    test_loader = DataLoader(test_data, batch_size=64)
 | 
				
			||||||
 | 
					    correct = 0
 | 
				
			||||||
 | 
					    with torch.no_grad():
 | 
				
			||||||
 | 
					        for data, target in test_loader:
 | 
				
			||||||
 | 
					            output = model(data)
 | 
				
			||||||
 | 
					            pred = output.argmax(dim=1)
 | 
				
			||||||
 | 
					            correct += pred.eq(target).sum().item()
 | 
				
			||||||
 | 
					    print(f"Accuracy: {correct / len(test_data):.2f}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					evaluate(model)
 | 
				
			||||||
		Loading…
	
		Reference in New Issue