目录
一、安装步骤
1、pytorch下载
确保是在安装好的python环境下进行,先进入Start Locally | PyTorch官网。
根据机器的配置选择相应的信息,然后将下面的代码从控制台运行即可
2、环境验证
在python中导入库的名称为torch
import torch
torch.randn(2,2,2)
二、CNN
初学,数据集用的是MNist手写数据集,CNN的处理步骤如下,在这里直接继承torch.nn类然后设置参数即可。
1、导入库
这是需要导入的库,如果提示no moudel的话直接install即可
import torch
import networkx as nx
import matplotlib.pyplot as plt
import torch.nn as nn
import numpy as np
import pandas as pd
import torchvision.datasets as dataset
import torchvision.transforms as transforms
import torch.utils.data as data_utils
import time
2、导入训练集与测试集
如果训练集没有下载的话,添加一个download属性然后赋值为True即可
#导入训练集测试集,如果没有下载,download设置为True
train_data=dataset.MNIST("mnist-data",train=True,transform=transforms.ToTensor())
test_data=dataset.MNIST("mnist-data",train=False,transform=transforms.ToTensor())
3、定义CNN
第一句代码需要注意一下,如果你的电脑是m1的话device的括号就填mps,一般是填cpu或者suad
device = torch.device('mps')
class CNN(torch.nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(1, 32, 5, 1, 2),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, 5, 1, 2),
nn.ReLU(),
nn.MaxPool2d(2)
)
self.fc = nn.Linear(64 * 7 * 7, 10)
def forward(self,x):
x = self.conv(x)
x = x.view(x.size(0), -1)
y = self.fc(x)
return y
cnn=CNN().to(device)
4、损失函数、优化函数
#损失函数
los=torch.nn.CrossEntropyLoss()
#优化函数
optime=torch.optim.Adam(cnn.parameters(), lr=0.01)
5、训练模型
这里在输入lmages和lab的时候注意也要加上to(device),device在前面定义过
#训练模型
start = time.time()
for epo in range(1):
for i, (images,lab) in enumerate(train_loader):
optime.zero_grad()
images=images.to(device)
lab=lab.to(device)
out = cnn(images)
loss=los(out,lab)
loss.backward()
optime.step()
print("epo:{},i:{},loss:{}".format(epo+1,i,loss))
end = time.time()
print(end-start,"s")
6、测试与保存模型
loss = 0
total = 0
correct = 0
with torch.no_grad():
for data, targets in test_loader:
data = data.to(device)
targets = targets.to(device)
output = cnn(data)
_,p=output.max(1)
loss += los(output, targets)
correct += (p == targets).sum()
total += data.size(0)
loss = loss.item()/len(test_loader)
acc = correct.item()/total
print(loss,acc)
#保存模型
torch.save(cnn.state_dict(), 'model.pt')
7、加载模型与预测
这里可以展示看一下模型的具体详细信息
#加载模型
model = CNN().to(device='mps')
model.load_state_dict(torch.load('model.pt', map_location=torch.device('mps')))
model.eval()
接着通过画图软件在本地制作了0-9十个手写数字图片,格式为bmp灰度图,为了方便输入,在制作的时候就将大小设置为了28*28这样输入的时候就不用再做处理。保存之后进行读取看下效果。
#读取本地图片
plt.rcParams['font.sans-serif']='Heiti TC'
plt.rcParams['axes.unicode_minus'] = False # 负号正常显示
fig = plt.figure(figsize=(10, 4))
plt.title("手写的数字",fontsize=20)
path_img=[]
for i in range(10):
img=cv2.imread(f"/Users/van/Downloads/bmp/number{i}.bmp",cv2.IMREAD_GRAYSCALE)
path_img.append(f"/Users/van/Downloads/bmp/number{i}.bmp")
ax = fig.add_subplot(2, 5, i + 1, xticks=[], yticks=[])
plt.subplots_adjust(wspace=0, hspace=0)
plt.imshow(img)
plt.show()
接着就可以将这些图片先转化为torch.tensor类型然后输入到模型了。
#读识别开始识别
import cv2
fig = plt.figure(figsize=(10, 4))
plt.xlabel("识别结果",fontsize=20)
j=0
for i in path_img:
img=cv2.imread(i,cv2.IMREAD_GRAYSCALE)
ax = fig.add_subplot(2, 5, j + 1, xticks=[], yticks=[])
plt.subplots_adjust(wspace=0, hspace=0.4)
plt.imshow(img)
imgtensor=torch.from_numpy(img.reshape((1,1,28,28)))
inputs = imgtensor.to(torch.float32).to(device)
output=model(inputs)
_,p=output.max(1)
j=j+1
plt.title(f"预测结果为数字: {p.item()}")
三、总结
go on!