文件夹结构
mobilenet{
mainwin.ui
mainwin.py
img.jpg 预测的图片
class.json
mobilenet_v2.pth 预训练权重
MobileNetV2.pth 自己的数据集训练好的权重
mobilenet_v2.py
myMainWin.py 编写调用窗口程序
predict.py 预测
train.py 训练
}
其中myMainWin.py代码如下
import sys
from PyQt5.QtWidgets import QMainWindow, QApplication, QFileDialog, QLabel
from PyQt5.QtGui import QIcon, QImage, QPixmap
from mainwin import Ui_MainWindow
import cv2
import numpy
from predict_new_v2 import predict_new
from PIL import Image
class myMainWin(QMainWindow, Ui_MainWindow):
def __init__(self):
super(myMainWin, self).__init__()
self.setupUi(self)
# 设置主窗口的标题
self.setWindowTitle('基于xxx分类系统')
# 连接动作对应的函数
self.pushButton_2.clicked.connect(self.openimg) # 构造函数(label可以不用写)
self.pushButton_3.clicked.connect(self.detect)
def openimg(self):
global fname
# 选择且获取图片文件的地址
fileName, filetype = QFileDialog.getOpenFileName(
self,
"选取文件",
"F:/python/mobilenet",
"Image Files (*.bmp *.jpg *.jpeg *.png)")
self.showFile(fileName)
fname = fileName
# 将图片显示在label
def showFile(self, fileName):
srcImage = cv2.imdecode(numpy.fromfile(fileName, dtype=numpy.uint8), -1)
image_height, image_width, image_depth = srcImage.shape # 获取图像的高,宽以及深度。
# opencv读图片是BGR,qt显示要RGB,所以需要转换一下
QImg = cv2.cvtColor(srcImage, cv2.COLOR_BGR2RGB)
QShowImage = QImage(QImg.data, image_width, image_height, # 创建QImage格式的图像,并读入图像信息
image_width * image_depth,
QImage.Format_RGB888)
self.label_img.clear()
QShowImage = QShowImage.scaled(
self.label_img.width(),
self.label_img.height()) # 图片适应label大小
self.label_img.setPixmap(QPixmap.fromImage(QShowImage))
def detect(self):
img = Image.open(fname)
predict = predict_new(img)
self.label_result.setText(predict)
# 只有单独执行调用条件语句
# 加个程序入口
if __name__ == '__main__':
app = QApplication(sys.argv) # 传入参数
app.setWindowIcon(QIcon('./Knight.ico'))
main = myMainWin()
main.show()
sys.exit(app.exec_())
predict.py的代码如下
import os
import json
import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
from model_v2 import MobileNetV2
def predict_new(img):
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
data_transform = transforms.Compose(
[transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
# load image
# img = Image.open("./00006.jpg")
# [N, C, H, W]
# img = Image.open(img_path)
img = data_transform(img)
# expand batch dimension
img = torch.unsqueeze(img, dim=0)
# read class_indict
json_path = './class_indices.json'
assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
json_file = open(json_path, "r")
class_indict = json.load(json_file)
# create model
model = MobileNetV2(num_classes=2).to(device)
# load model weights
model_weight_path = "./MobileNetV2.pth"
model.load_state_dict(torch.load(model_weight_path, map_location=device))
model.eval()
with torch.no_grad():
# predict class
output = torch.squeeze(model(img.to(device))).cpu()
predict = torch.softmax(output, dim=0)
predict_cla = torch.argmax(predict).numpy()
print_res = "class: {} \n prob: {:.3}".format(class_indict[str(predict_cla)],
predict[predict_cla].numpy())
return print_res
- 犯了个小错误,在myMainWin.py中调用predict = predict_new(img),传入的参数img没有和predict.py中的def predict_new(img):保持一致。
- 传入的是图片路径,需要读取成图片再处理img = Image.open(fname)
- 本文中初始化的时候label可以不用初始化,需要连接动作的函数进行初始化
至此,初学pyqt5,希望越学越顺利,开头难解决了还有中间难结尾难,慢慢来吧!
主要参考视频及文章:
https://www.bilibili.com/video/BV15u41197EQ?p=3&share_source=copy_web
https://zhuanlan.zhihu.com/p/274436031