llm_train/create_model.py

73 lines
2.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# 1. 定义模型
class SimpleCNN(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
self.fc1 = nn.Linear(32 * 13 * 13, 10) # 假设输入为28x28经过池化后为13x13
def forward(self, x):
x = torch.relu(self.conv1(x))
x = x.view(x.size(0), -1)
x = self.fc1(x)
return x
# 2. 加载数据
transform = transforms.Compose([transforms.ToTensor()])
train_data = datasets.MNIST('./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
# 3. 初始化模型
model = SimpleCNN()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# 4. 训练原始模型(可选)
def train(model, epochs=5):
model.train()
for epoch in range(epochs):
for data, target in train_loader:
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
train(model) # 可省略,直接使用预训练权重
# 5. 剪枝
# 对conv1层按L1范数剪枝50%
prune.l1_unstructured(module=model.conv1, name='weight', amount=0.5)
# 对fc1层剪枝30%
prune.l1_unstructured(module=model.fc1, name='weight', amount=0.3)
# 6. 移除剪枝掩码(永久剪枝)
def remove_pruning(model):
for name, module in model.named_modules():
if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
prune.remove(module, 'weight')
remove_pruning(model)
# 7. 微调剪枝后模型
train(model, epochs=3)
# 8. 评估模型
def evaluate(model):
model.eval()
test_data = datasets.MNIST('./data', train=False, transform=transform)
test_loader = DataLoader(test_data, batch_size=64)
correct = 0
with torch.no_grad():
for data, target in test_loader:
output = model(data)
pred = output.argmax(dim=1)
correct += pred.eq(target).sum().item()
print(f"Accuracy: {correct / len(test_data):.2f}")
evaluate(model)