淘先锋技术网

首页 1 2 3 4 5 6 7

Keras学习过程中,使用的数据集都是keras.datasets.mnist.load_data(),直接使用处理好的数据集,想要尝试使用自己搜集的图片,却又不知道如何加载。查找其他人自定义数据集的方式,做一个记录。

链接: https://pan.baidu.com/s/12ddpk2eKDCAu7Z1UTZWtEA 提取码: gxqj

文件夹结构:

  • images
    • hua(包含花的图片文件夹)
    • niao
    • yu
    • chong
    • imgcsv.csv
  • my_load_data.py
  • train.py

思路:

  • 列出images下文件夹(images/花;images/鸟等)
  • 将images/花等文件夹下所有图片路径,标签写入imgcsv.csv,方便下次读取
  • 读取csv文件,通过图片路径读取图片,转换为数组格式

my_load_data.py

import os, glob, csv, random
import tensorflow as tf


def load_csv(root, csvname, name2label):
    '''
    root:images文件夹路径
    csvname:写入的csv文件名
    name2label:字典,key为类别名,value为对于标签
    '''
    # 如果不存在csv文件,则创建写入
    if not os.path.exists(os.path.join(root, csvname)):
        images = []
        for name in name2label.keys():
            # 返回图片路径:images/hua/***
            images += glob.glob(os.path.join(root, name, "*.*g"))
        # 图片按类别写入列表,需要打乱顺序
        random.shuffle(images)
        # 将路径,标签信息写入csv文件
        with open(os.path.join(root, csvname), 'w', newline="") as f:
            writer = csv.writer(f)
            for img in images:
                # 将路径按/或\分离[images,hua,**.jpg]取出类别名
                name = img.split(os.sep)[-2]
                # 根据字典key,对应value找出对应标签
                label = name2label[name]
                writer.writerow([img, label])
    # 已经存在csv文件,直接读取
    else:
        images, labels = [], []
        with open(os.path.join(root, csvname)) as f:
            reader = csv.reader(f)
            for row in reader:
                img, label = row
                label = int(label)
                images.append(img)
                labels.append(label)
        assert len(images) == len(labels)
        return images, labels


def load_data(root, mode='train'):
    name2label = {}
    # 列出images下文件夹
    for name in sorted(os.listdir(os.path.join(root))):
        # 如果images/***不是文件夹,跳过
        if not os.path.isdir(os.path.join(root, name)):
            continue
        else:
            name2label[name] = len(name2label.keys())
    images, labels = load_csv(root, "images.csv", name2label)
    if mode == 'train':  # 60%
        images = images[:int(0.6 * len(images))]
        labels = labels[:int(0.6 * len(labels))]
    elif mode == 'val':  # 20% = 60%->80%
        images = images[int(0.6 * len(images)):int(0.8 * len(images))]
        labels = labels[int(0.6 * len(labels)):int(0.8 * len(labels))]
    else:  # 20% = 80%->100%
        images = images[int(0.8 * len(images)):]
        labels = labels[int(0.8 * len(labels)):]

    return images, labels, name2label


def preprocess(x, y):
    x = tf.io.read_file(x)
    x = tf.image.decode_jpeg(x, channels=3)
    x = tf.image.resize(x, [244, 244])
    x = tf.cast(x, dtype=tf.float32) / 255.0
    y = tf.convert_to_tensor(y)
    y = tf.one_hot(y, depth=5)
    return x, y
            
            

train.py

import  os
import  tensorflow as tf
import  numpy as np
from    tensorflow import keras
from    tensorflow.keras import layers,optimizers,losses
from    tensorflow.keras.callbacks import EarlyStopping
from my_load_data import load_data,preprocess

batchsz = 32
# 创建训练集Datset对象
images, labels, table = load_data('flower_photos',mode='train')
db_train = tf.data.Dataset.from_tensor_slices((images, labels))
db_train = db_train.shuffle(1000).map(preprocess).batch(batchsz)
# 创建验证集Datset对象
images2, labels2, table = load_data('flower_photos',mode='val')
db_val = tf.data.Dataset.from_tensor_slices((images2, labels2))
db_val = db_val.map(preprocess).batch(batchsz)
# 创建测试集Datset对象
images3, labels3, table = load_data('flower_photos',mode='test')
db_test = tf.data.Dataset.from_tensor_slices((images3, labels3))
db_test = db_test.map(preprocess).batch(batchsz)

#创建网络模型
net = keras.applications.VGG16(weights='imagenet',include_top=False,pooling='max')
net.trainable = False
newnet = keras.Sequential([
    net,
    layers.Dense(128),
    layers.Dense(64),
    layers.Dense(5)
])
newnet.build(input_shape=(4, 224, 224, 3))
newnet.summary()

early_stopping = EarlyStopping(
    monitor='val_accuracy',
    min_delta=0.001,
    patience=5
)

newnet.compile(optimizer=optimizers.Adam(lr=1e-3),
               loss=losses.CategoricalCrossentropy(from_logits=True),
               metrics=['accuracy'])
newnet.fit(db_train, validation_data=db_val, validation_freq=1, epochs=100,
           callbacks=[early_stopping])
newnet.evaluate(db_test)