From 164e5967b554538908621aa28111cf4eae9dc632 Mon Sep 17 00:00:00 2001 From: renzhiyuan <465386466@qq.com> Date: Thu, 16 Oct 2025 13:46:57 +0800 Subject: [PATCH] 11 --- create_model.py | 73 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) create mode 100644 create_model.py diff --git a/create_model.py b/create_model.py new file mode 100644 index 0000000..4a89112 --- /dev/null +++ b/create_model.py @@ -0,0 +1,73 @@ +import torch +import torch.nn as nn +import torch.nn.utils.prune as prune +from torchvision import datasets, transforms +from torch.utils.data import DataLoader + +# 1. 定义模型 +class SimpleCNN(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(1, 32, kernel_size=3) + self.fc1 = nn.Linear(32 * 13 * 13, 10) # 假设输入为28x28,经过池化后为13x13 + + def forward(self, x): + x = torch.relu(self.conv1(x)) + x = x.view(x.size(0), -1) + x = self.fc1(x) + return x + +# 2. 加载数据 +transform = transforms.Compose([transforms.ToTensor()]) +train_data = datasets.MNIST('./data', train=True, download=True, transform=transform) +train_loader = DataLoader(train_data, batch_size=64, shuffle=True) + +# 3. 初始化模型 +model = SimpleCNN() +criterion = nn.CrossEntropyLoss() +optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) + +# 4. 训练原始模型(可选) +def train(model, epochs=5): + model.train() + for epoch in range(epochs): + for data, target in train_loader: + optimizer.zero_grad() + output = model(data) + loss = criterion(output, target) + loss.backward() + optimizer.step() + +train(model) # 可省略,直接使用预训练权重 + +# 5. 剪枝 +# 对conv1层按L1范数剪枝50% +prune.l1_unstructured(module=model.conv1, name='weight', amount=0.5) +# 对fc1层剪枝30% +prune.l1_unstructured(module=model.fc1, name='weight', amount=0.3) + +# 6. 移除剪枝掩码(永久剪枝) +def remove_pruning(model): + for name, module in model.named_modules(): + if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear): + prune.remove(module, 'weight') + +remove_pruning(model) + +# 7. 微调剪枝后模型 +train(model, epochs=3) + +# 8. 评估模型 +def evaluate(model): + model.eval() + test_data = datasets.MNIST('./data', train=False, transform=transform) + test_loader = DataLoader(test_data, batch_size=64) + correct = 0 + with torch.no_grad(): + for data, target in test_loader: + output = model(data) + pred = output.argmax(dim=1) + correct += pred.eq(target).sum().item() + print(f"Accuracy: {correct / len(test_data):.2f}") + +evaluate(model) \ No newline at end of file