Keras学习过程中,使用的数据集都是keras.datasets.mnist.load_data(),直接使用处理好的数据集,想要尝试使用自己搜集的图片,却又不知道如何加载。查找其他人自定义数据集的方式,做一个记录。
链接: https://pan.baidu.com/s/12ddpk2eKDCAu7Z1UTZWtEA 提取码: gxqj
文件夹结构:
- images
- hua(包含花的图片文件夹)
- niao
- yu
- chong
- imgcsv.csv
- my_load_data.py
- train.py
思路:
- 列出images下文件夹(images/花;images/鸟等)
- 将images/花等文件夹下所有图片路径,标签写入imgcsv.csv,方便下次读取
- 读取csv文件,通过图片路径读取图片,转换为数组格式
my_load_data.py
import os, glob, csv, random
import tensorflow as tf
def load_csv(root, csvname, name2label):
'''
root:images文件夹路径
csvname:写入的csv文件名
name2label:字典,key为类别名,value为对于标签
'''
# 如果不存在csv文件,则创建写入
if not os.path.exists(os.path.join(root, csvname)):
images = []
for name in name2label.keys():
# 返回图片路径:images/hua/***
images += glob.glob(os.path.join(root, name, "*.*g"))
# 图片按类别写入列表,需要打乱顺序
random.shuffle(images)
# 将路径,标签信息写入csv文件
with open(os.path.join(root, csvname), 'w', newline="") as f:
writer = csv.writer(f)
for img in images:
# 将路径按/或\分离[images,hua,**.jpg]取出类别名
name = img.split(os.sep)[-2]
# 根据字典key,对应value找出对应标签
label = name2label[name]
writer.writerow([img, label])
# 已经存在csv文件,直接读取
else:
images, labels = [], []
with open(os.path.join(root, csvname)) as f:
reader = csv.reader(f)
for row in reader:
img, label = row
label = int(label)
images.append(img)
labels.append(label)
assert len(images) == len(labels)
return images, labels
def load_data(root, mode='train'):
name2label = {}
# 列出images下文件夹
for name in sorted(os.listdir(os.path.join(root))):
# 如果images/***不是文件夹,跳过
if not os.path.isdir(os.path.join(root, name)):
continue
else:
name2label[name] = len(name2label.keys())
images, labels = load_csv(root, "images.csv", name2label)
if mode == 'train': # 60%
images = images[:int(0.6 * len(images))]
labels = labels[:int(0.6 * len(labels))]
elif mode == 'val': # 20% = 60%->80%
images = images[int(0.6 * len(images)):int(0.8 * len(images))]
labels = labels[int(0.6 * len(labels)):int(0.8 * len(labels))]
else: # 20% = 80%->100%
images = images[int(0.8 * len(images)):]
labels = labels[int(0.8 * len(labels)):]
return images, labels, name2label
def preprocess(x, y):
x = tf.io.read_file(x)
x = tf.image.decode_jpeg(x, channels=3)
x = tf.image.resize(x, [244, 244])
x = tf.cast(x, dtype=tf.float32) / 255.0
y = tf.convert_to_tensor(y)
y = tf.one_hot(y, depth=5)
return x, y
train.py
import os
import tensorflow as tf
import numpy as np
from tensorflow import keras
from tensorflow.keras import layers,optimizers,losses
from tensorflow.keras.callbacks import EarlyStopping
from my_load_data import load_data,preprocess
batchsz = 32
# 创建训练集Datset对象
images, labels, table = load_data('flower_photos',mode='train')
db_train = tf.data.Dataset.from_tensor_slices((images, labels))
db_train = db_train.shuffle(1000).map(preprocess).batch(batchsz)
# 创建验证集Datset对象
images2, labels2, table = load_data('flower_photos',mode='val')
db_val = tf.data.Dataset.from_tensor_slices((images2, labels2))
db_val = db_val.map(preprocess).batch(batchsz)
# 创建测试集Datset对象
images3, labels3, table = load_data('flower_photos',mode='test')
db_test = tf.data.Dataset.from_tensor_slices((images3, labels3))
db_test = db_test.map(preprocess).batch(batchsz)
#创建网络模型
net = keras.applications.VGG16(weights='imagenet',include_top=False,pooling='max')
net.trainable = False
newnet = keras.Sequential([
net,
layers.Dense(128),
layers.Dense(64),
layers.Dense(5)
])
newnet.build(input_shape=(4, 224, 224, 3))
newnet.summary()
early_stopping = EarlyStopping(
monitor='val_accuracy',
min_delta=0.001,
patience=5
)
newnet.compile(optimizer=optimizers.Adam(lr=1e-3),
loss=losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
newnet.fit(db_train, validation_data=db_val, validation_freq=1, epochs=100,
callbacks=[early_stopping])
newnet.evaluate(db_test)