T10:数据增强

T10周:数据增强

      • **一、前期工作**
        • 1.设置GPU,导入库
        • 2.加载数据
      • **二、数据增强**
      • **三、增强方式**
        • 方法一:将其嵌入model中
        • 方法二:在Dataset数据集中进行数据增强
      • **四、训练模型**
      • **五、自定义增强函数**
      • **六、总结**

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

🍺 要求:

  1. 学会在代码中使用数据增强手段来提高acc
  2. 请探索更多的数据增强手段并记录

在本教程中,你将学会如何进行数据增强,并通过数据增强用少量数据达到非常非常棒的识别准确率。我将展示两种数据增强方式,以及如何自定义数据增强方式并将其放到我们代码当中,两种数据增强方式如下:

  • 将数据增强模块嵌入model中
  • 在Dataset数据集中进行数据增强

⌛ 我的环境:

  • 语言环境:Python3.6.5
  • 编译器:Google Colab
  • 深度学习环境:TensorFlow2.17.0

一、前期工作

1.设置GPU,导入库
import matplotlib.pyplot as plt
import numpy as np
#隐藏警告
import warnings
warnings.filterwarnings('ignore')

from tensorflow.keras import layers
import tensorflow as tf
gpus = tf.config.list_physical_devices("GPU")

if gpus:
    tf.config.experimental.set_memory_growth(gpus[0], True)  #设置GPU显存用量按需使用
    tf.config.set_visible_devices([gpus[0]],"GPU")

# 打印显卡信息,确认GPU可用
print(gpus)
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
2.加载数据
from google.colab import drive
drive.mount("/content/drive/")
%cd "/content/drive/Othercomputers/My laptop/jupyter notebook/data/"
Mounted at /content/drive/
/content/drive/Othercomputers/My laptop/jupyter notebook/data
data_dir   = "./10/"
img_height = 224
img_width  = 224
batch_size = 32

train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.3,
    subset="training",
    seed=12,
    image_size=(img_height, img_width),
    batch_size=batch_size)
Found 600 files belonging to 2 classes.
Using 420 files for training.
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.3,
    subset="validation",
    seed=12,
    image_size=(img_height, img_width),
    batch_size=batch_size)
Found 600 files belonging to 2 classes.
Using 180 files for validation.

由于原始数据集不包含测试集,因此需要创建一个。使用 tf.data.experimental.cardinality 确定验证集中有多少批次的数据,然后将其中的 20% 移至测试集。

val_batches = tf.data.experimental.cardinality(val_ds)
test_ds     = val_ds.take(val_batches // 5)
val_ds      = val_ds.skip(val_batches // 5)

print('Number of validation batches: %d' % tf.data.experimental.cardinality(val_ds))
print('Number of test batches: %d' % tf.data.experimental.cardinality(test_ds))
Number of validation batches: 5
Number of test batches: 1
class_names = train_ds.class_names
print(class_names)
['cat', 'dog']
AUTOTUNE = tf.data.AUTOTUNE

def preprocess_image(image,label):
    return (image/255.0,label)

# 归一化处理
train_ds = train_ds.map(preprocess_image, num_parallel_calls=AUTOTUNE)
val_ds   = val_ds.map(preprocess_image, num_parallel_calls=AUTOTUNE)
test_ds  = test_ds.map(preprocess_image, num_parallel_calls=AUTOTUNE)

train_ds = train_ds.cache().prefetch(buffer_size=AUTOTUNE)
val_ds   = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
plt.figure(figsize=(15, 10))  # 图形的宽为15高为10

for images, labels in train_ds.take(1):
    for i in range(8):
        ax = plt.subplot(2, 4, i + 1)
        plt.imshow(images[i])
        plt.title(class_names[labels[i]])

        plt.axis("off")

在这里插入图片描述

二、数据增强

我们可以使用 tf.keras.layers.experimental.preprocessing.RandomFliptf.keras.layers.experimental.preprocessing.RandomRotation 进行数据增强

  • tf.keras.layers.experimental.preprocessing.RandomFlip:水平和垂直随机翻转每个图像。
  • tf.keras.layers.experimental.preprocessing.RandomRotation:随机旋转每个图像
data_augmentation = tf.keras.Sequential([
  tf.keras.layers.RandomFlip("horizontal_and_vertical"), #进行随机水平和垂直翻转
  tf.keras.layers.RandomRotation(0.2), #按照0.2的弧度值进行随机旋转
  tf.keras.layers.RandomBrightness(factor=0.5, value_range=(0.0, 1.0), seed=123), #按照0.1比例调整亮度
  tf.keras.layers.RandomContrast(factor=0.1)
])

# 第一个层表示进行随机的水平和垂直翻转,而第二个层表示按照 0.2 的弧度值进行随机旋转。
#其余增强方式:https://www.tensorflow.org/api_docs/python/tf/keras/layers/RandomRotation
# Add the image to a batch.
image = tf.expand_dims(images[i], 0)

plt.figure(figsize=(8, 8))
for i in range(9):
    augmented_image = data_augmentation(image)
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(augmented_image[0])
    plt.axis("off")

在这里插入图片描述

三、增强方式

方法一:将其嵌入model中
model_base = tf.keras.Sequential([
  data_augmentation,
  layers.Conv2D(16, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),
])

这样做的好处是:

  • 数据增强这块的工作可以得到GPU的加速(如果你使用了GPU训练的话)

    注意:只有在模型训练时(Model.fit)才会进行增强,在模型评估(Model.evaluate)以及预测(Model.predict)时并不会进行增强操作。
方法二:在Dataset数据集中进行数据增强
batch_size = 32
AUTOTUNE = tf.data.AUTOTUNE
#定义处理函数,只有training数据集会处理
def prepare(ds):
    ds = ds.map(lambda x, y: (data_augmentation(x, training=True), y), num_parallel_calls=AUTOTUNE)
    return ds
train_ds = prepare(train_ds)

四、训练模型

model = tf.keras.Sequential([
  model_base,
  layers.Conv2D(16, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),
  layers.Conv2D(32, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),
  layers.Conv2D(64, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),
  layers.Flatten(),
  layers.Dense(128, activation='relu'),
  layers.Dense(len(class_names))
])

在准备对模型进行训练之前,还需要再对其进行一些设置。以下内容是在模型的编译步骤中添加的:

  • 损失函数(loss):用于衡量模型在训练期间的准确率。
  • 优化器(optimizer):决定模型如何根据其看到的数据和自身的损失函数进行更新。
  • 评价函数(metrics):用于监控训练和测试步骤。以下示例使用了准确率,即被正确分类的图像的比率。
model.compile(optimizer='adam',
      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
      metrics=['accuracy'])
epochs=20
history = model.fit(
  train_ds,
  validation_data=val_ds,
  epochs=epochs
)
Epoch 1/20
[1m14/14[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 116ms/step - accuracy: 0.5252 - loss: 0.6915 - val_accuracy: 0.4257 - val_loss: 0.7044
Epoch 2/20
[1m14/14[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30ms/step - accuracy: 0.5461 - loss: 0.6897 - val_accuracy: 0.5068 - val_loss: 0.6871
Epoch 3/20
[1m14/14[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 35ms/step - accuracy: 0.6517 - loss: 0.6707 - val_accuracy: 0.7500 - val_loss: 0.6056
Epoch 4/20
[1m14/14[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 45ms/step - accuracy: 0.6466 - loss: 0.6242 - val_accuracy: 0.7568 - val_loss: 0.5237
Epoch 5/20
[1m14/14[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 41ms/step - accuracy: 0.6958 - loss: 0.5998 - val_accuracy: 0.7365 - val_loss: 0.6128
Epoch 6/20
[1m14/14[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 37ms/step - accuracy: 0.7187 - loss: 0.6327 - val_accuracy: 0.7230 - val_loss: 0.5495
Epoch 7/20
[1m14/14[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 37ms/step - accuracy: 0.7182 - loss: 0.5746 - val_accuracy: 0.7770 - val_loss: 0.5527
Epoch 8/20
[1m14/14[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 34ms/step - accuracy: 0.7288 - loss: 0.5921 - val_accuracy: 0.8176 - val_loss: 0.4373
Epoch 9/20
[1m14/14[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 31ms/step - accuracy: 0.7544 - loss: 0.4827 - val_accuracy: 0.8243 - val_loss: 0.3807
Epoch 10/20
[1m14/14[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 29ms/step - accuracy: 0.7535 - loss: 0.5227 - val_accuracy: 0.8446 - val_loss: 0.4185
Epoch 11/20
[1m14/14[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step - accuracy: 0.7958 - loss: 0.4514 - val_accuracy: 0.8649 - val_loss: 0.3466
Epoch 12/20
[1m14/14[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step - accuracy: 0.8288 - loss: 0.3707 - val_accuracy: 0.8581 - val_loss: 0.3243
Epoch 13/20
[1m14/14[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step - accuracy: 0.8466 - loss: 0.3749 - val_accuracy: 0.7973 - val_loss: 0.4575
Epoch 14/20
[1m14/14[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 29ms/step - accuracy: 0.8277 - loss: 0.4484 - val_accuracy: 0.8784 - val_loss: 0.3096
Epoch 15/20
[1m14/14[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step - accuracy: 0.8629 - loss: 0.2800 - val_accuracy: 0.8716 - val_loss: 0.2454
Epoch 16/20
[1m14/14[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step - accuracy: 0.8496 - loss: 0.3098 - val_accuracy: 0.9054 - val_loss: 0.1984
Epoch 17/20
[1m14/14[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 29ms/step - accuracy: 0.8630 - loss: 0.2889 - val_accuracy: 0.9257 - val_loss: 0.2030
Epoch 18/20
[1m14/14[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step - accuracy: 0.8820 - loss: 0.2658 - val_accuracy: 0.9257 - val_loss: 0.2008
Epoch 19/20
[1m14/14[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step - accuracy: 0.8940 - loss: 0.2895 - val_accuracy: 0.9324 - val_loss: 0.1662
Epoch 20/20
[1m14/14[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 31ms/step - accuracy: 0.8923 - loss: 0.2681 - val_accuracy: 0.8919 - val_loss: 0.2506
loss, acc = model.evaluate(test_ds)
print("Accuracy", acc)
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 409ms/step - accuracy: 0.8438 - loss: 0.2309
Accuracy 0.84375

五、自定义增强函数

import random
# 这是大家可以自由发挥的一个地方
def aug_img(image):
    seed = (random.randint(0,9), 0)
    # 随机改变图像对比度
    stateless_random_brightness = tf.image.stateless_random_contrast(image, lower=0.1, upper=1.0, seed=seed)
    return stateless_random_brightness
image = tf.expand_dims(images[3]*255, 0)
print("Min and max pixel values:", image.numpy().min(), image.numpy().max())
Min and max pixel values: 14.000048 253.28577
plt.figure(figsize=(8, 8))
for i in range(9):
    augmented_image = aug_img(image)
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(augmented_image[0].numpy().astype("uint8"))

    plt.axis("off")

在这里插入图片描述

参考上文的 preprocess_image 函数,将 aug_img 函数嵌入到 preprocess_image 函数中,在数据预处理时完成数据增强就OK啦。

六、总结

学习了不同的数据增强方法并且了解如何用增强方法预处理数据和引入模型训练

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/886611.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

[ RK3566-Android11 ] 关于移植 RK628F 驱动以及后HDMI-IN图像延迟/无声等问题

问题描述 由前一篇文章https://blog.csdn.net/jay547063443/article/details/142059700?fromshareblogdetail&sharetypeblogdetail&sharerId142059700&sharereferPC&sharesourcejay547063443&sharefromfrom_link,移植HDMI-IN部分驱动后出现&a…

硬件-开关电源-结构组成及元件作用

文章目录 一:开关电源组成1.1 开关电源是什么?1.2 开关电源六个组成部分 二:六个组成部分的作用2.1 EMC区域2.2 输入整流滤波区域2.3 控制区域2.4 变压器2.5 输出整流滤波区域2.6 反馈电路区域道友:勿以小恶弃人大美,勿以小怨忘人…

【C++】——list的介绍和模拟实现

P. S.:以下代码均在VS2019环境下测试,不代表所有编译器均可通过。 P. S.:测试代码均未展示头文件stdio.h的声明,使用时请自行添加。 博主主页:Yan. yan.                        …

ARM 架构、cpu

一、ARM的架构 ARM是一种基于精简指令集(RISC)的处理器架构. 1、ARM芯片特点 ARM芯片的主要特点有以下几点: 精简指令集:ARM芯片使用精简指令集,即每条指令只完成一项简单的操作,从而提高指令的执行效率…

EasyCVR视频汇聚平台:解锁视频监控核心功能,打造高效安全监管体系

随着科技的飞速发展,视频监控技术已成为现代社会安全、企业管理、智慧城市构建等领域不可或缺的一部分。EasyCVR视频汇聚平台作为一款高性能的视频综合管理平台,凭借其强大的视频处理、汇聚与融合能力,在构建智慧安防/视频监控系统中展现出了…

Qt Quick 3D 入门:QML 3D场景详解

随着 Qt 6 的发布,QtQuick3D 模块带来了新的 3D 渲染和交互能力,使得在 Qt 中创建 3D 场景变得更加简单和直观。本文将带您从一个简单的 QML 3D 应用开始,详细讲解各个相关领域的概念、代码实现以及功能特点。 什么是 Qt Quick 3D&#xff1…

关于 JVM 个人 NOTE

目录 1、JVM 的体系结构 2、双亲委派机制 3、堆内存调优 4、关于GC垃圾回收机制 4.1 GC中的复制算法 4.2 GC中的标记清除算法 1、JVM 的体系结构 "堆"中存在垃圾而"栈"中不存在垃圾的原因: 堆(Heap) 用途:堆主要用于存储对象实例和数组。在Java中…

Linux --入门学习笔记

文章目录 Linux概述基础篇Linux 的安装教程 ⇒ 太简单了,百度一搜一大堆。此处略……Linux 的目录结构常用的连接 linux 的开源软件vi 和 vim 编辑器Linux 的关机、开机、重启用户登录和注销用户管理添加用户 ⇒ ( useradd 用户名 ) ( useradd -d 制定目…

【Unity踩坑】Unity更新Google Play结算库

一、问题描述: 在Google Play上提交了app bundle后,提示如下错误。 我使用的是Unity 2022.01.20f1,看来用的Play结算库版本是4.0 查了一下文档,Google Play结算库的维护周期是两年。现在需要更新到至少6.0。 二、更新过程 1. 下…

Python | Leetcode Python题解之第454题四数相加II

题目: 题解: class Solution:def fourSumCount(self, A: List[int], B: List[int], C: List[int], D: List[int]) -> int:countAB collections.Counter(u v for u in A for v in B)ans 0for u in C:for v in D:if -u - v in countAB:ans countAB…

C++ | Leetcode C++题解之第454题四数相加II

题目&#xff1a; 题解&#xff1a; class Solution { public:int fourSumCount(vector<int>& A, vector<int>& B, vector<int>& C, vector<int>& D) {unordered_map<int, int> countAB;for (int u: A) {for (int v: B) {count…

Python并发编程(1)——Python并发编程的几种实现方式

更多精彩内容&#xff0c;请关注同名公众&#xff1a;一点sir&#xff08;alittle-sir&#xff09; Python 并发编程是指在 Python 中编写能够同时执行多个任务的程序。并发编程在任何一门语言当中都是比较难的&#xff0c;因为会涉及各种各样的问题&#xff0c;在Python当中也…

C0010.Qt5.15.2下载及安装方法

1. 下载及安装 Qt 添加链接描述下载地址&#xff1a;http://download.qt.io/ 选择 archive 目录 安装Qt **注意&#xff1a;**本人使用的是Qt5.15.2版本&#xff0c;可以按如下方法找到该版本&#xff1b;

Android Studio 新版本 Logcat 的使用详解

点击进入官方Logcat介绍 一个好的Android程序员要会使用AndroidStudio自带的Logcat查看日志&#xff0c;会Log定位也是查找程序bug的第一关键。同时Logcat是一个查看和处理日志消息的工具&#xff0c;它可以更快的帮助开发者调试应用程序。 步入正题&#xff0c;看图说话。 点…

msys2+gdb-multiarch+jlinkGDBServer的nrf52调试环境搭建

前言 刚拿到一块nrf52840的板子&#xff0c;为了方便以后的开发&#xff0c;先搭建一个调试环境&#xff0c;为方便以后回忆记录一下过程。 提示&#xff1a;以下是本篇文章正文内容&#xff0c;下面案例可供参考 1.msys2命令行调用jlink工具 将jlink工具路径加入msys2的PAT…

华为云LTS日志上报至观测云最佳实践

华为云LTS简介 华为云云日志服务&#xff08;Log Tank Service&#xff0c;简称 LTS&#xff09;&#xff0c;用于收集来自主机和云服务的日志数据&#xff0c;通过海量日志数据的分析与处理&#xff0c;可以将云服务和应用程序的可用性和性能最大化&#xff0c;为您提供实时、…

【51单片机】点亮LED之经典流水灯

开发环境 开发板&#xff1a;普中51-单核-A2单片机&#xff1a;STC89C52RC&#xff08;双列直插40引脚 DIP40&#xff09;Keil uVision5 v9.61 最新版破解方法自行百度&#xff0c;相关文档和视频资料很多&#xff0c;我自己将这一操作记录下来当做博客发布&#xff0c;CSDN以…

【数据结构强化】应用题打卡

应用题打卡 数组的应用 对称矩阵的压缩存储 注意&#xff1a; 1. 2.上三角的行优先存储及下三角的列优先存储与数组的下表对应 上/下三角矩阵的压缩存储 注意&#xff1a; 上/下三角压缩存储是将0元素统一压缩存储&#xff0c;而不是将对角线元素统一压缩存储 三对角矩阵的…

King3399 SDK(ubuntu文件系统)编译简明教程

该文章仅供参考&#xff0c;编写人不对任务实验设备、人员及测量结果负责&#xff01;&#xff01;&#xff01; 0 引言 文章主要介绍King3399&#xff08;瑞芯微rk3399开发板&#xff0c;荣品&#xff09;官方SDK&#xff08;Ubuntu文件系统&#xff09;编译过程&#xff0c…

GaussDB关键技术原理:高弹性(六)

书接上文GaussDB关键技术原理&#xff1a;高弹性&#xff08;五&#xff09;从日志多流和事务相关方面对hashbucket扩容技术进行了解读&#xff0c;本篇将从扩容实践方面继续介绍GaussDB高弹性技术。 5 扩容实践 5.1 工具介绍 5.1.1 TPC-C TPC-C(全称Transaction Proces…