keras实现图像数字多分类
目标:基于mnist数据集,建立mlp模型,实现0-9数字的十分类
1.实现mnist数据载入,可视化图形数字
2.完成数据预处理,图像数据维度转化与归一化,输出结果格式转化
3.计算模型在预测数据集的准确率
4.模型结构:两层隐藏层,每层有392个模型
一、数据处理及可视化
1、获取数据集
from keras.datasets import mnist
(X_train,y_train),(X_test,y_test) = mnist.load_data()
2、查看数据
X_train.shape
#(60000, 28, 28)
3、部分数据的可视化
import matplotlib.pyplot as plt
img1 = X_train[0]
fig = plt.figure(figsize=(3,3))
plt.imshow(img1)
plt.title(y_train[0])
plt.show()
可视化效果为:数字5
二、数据预处理
1、查看图片大小
# 图片的大小
img1.shape
# (28, 28)
2、维度转换
feature_size = img1.shape[0]*img1.shape[1]
X_train_format = X_train.reshape(X_train.shape[0],feature_size)
X_test_format = X_test.reshape(X_test.shape[0],feature_size)
X_train_format.shape
# (60000, 784)
3、归一化处理
由于对图像进行数字处理,所以归一化时除以255即可
X_train_normal = X_train_format/255
X_test_normal = X_test_format/255
4、输出结果格式转化
tf版本过高时,导入包: from keras.utils import to_categorical
会显示报错
ImportError: cannot import name ‘to_categorical’ from ‘keras.utils’ (/usr/local/lib/python3.7/dist-packages/keras/utils/init.py)
现在keras完全置于tf模块中,这个要从tensoflow根模块导入,修改为:
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.utils import to_categorical
y_train_format = to_categorical(y_train)
y_test_format = to_categorical(y_test)
print(y_train[0])
print(y_test_format[0])
三、建立模型,训练并预测
1、建立模型
2个隐藏层,每层有392个
最后的分类结果是10个
from keras.models import Sequential
from keras.layers import Dense,Activation
mlp = Sequential()
mlp.add(Dense(units=392,activation='sigmoid',input_dim=feature_size))
mlp.add(Dense(units=392,activation='sigmoid'))
mlp.add(Dense(units=10,activation='softmax'))
mlp.summary()
2、训练模型
#模型训练
mlp.fit(X_train_normal,y_train_format,epochs=10)
3、模型评估
3.1、训练集
# 训练集
import numpy as np
y_train_predict = mlp.predict(X_train_normal)
y_train_predict=np.argmax(y_train_predict,axis=1)
y_train_predict
# array([5, 0, 4, ..., 5, 6, 8], dtype=int64)
准确度
# 计算准确率
from sklearn.metrics import accuracy_score
accuracy_train = accuracy_score(y_train,y_train_predict)
accuracy_train
# 0.9942833333333333
3.2 测试集
y_test_predict = mlp.predict(X_test_normal)
y_test_predict = np.argmax(y_test_predict,axis=1)
accuracy_test = accuracy_score(y_test,y_test_predict)
accuracy_test
# 0.98
四、可视化验证结果
mg2 = X_test[10]
fig2 = plt.figure(figsize=(3,3))
plt.imshow(img2)
plt.title(y_test_predict[10])
分类成功
完整代码已上传至 https://github.com/jrt-20/-