【学习】使用PyTorch训练与评估自己的ResNet网络教程

参考:保姆级使用PyTorch训练与评估自己的ResNet网络教程_训练自己的图像分类网络resnet101 pytorch-CSDN博客

项目地址:GitHub - Fafa-DL/Awesome-Backbones: Integrate deep learning models for image classification | Backbone learning/comparison/magic modification project

视频手把手教程:我将维护一个集成各主干网络的图像分类项目_哔哩哔哩_bilibili

主要是复现和训练测试自己的数据集

复现部分

0.环境问题

pytorch官网里面找个合适的CUDA11.0安装一下,然后把requirements.txt安装一下

pip install -r requirements.txt

 参考版本:

pip list
Package                Version
---------------------- ---------------
certifi                2021.5.30
cycler                 0.11.0
dataclasses            0.8
importlib-resources    5.4.0
joblib                 1.1.1
kiwisolver             1.3.1
matplotlib             3.3.4
mkl-fft                1.3.0
mkl-random             1.1.1
mkl-service            2.3.0
numpy                  1.19.2
olefile                0.46
opencv-contrib-python  4.0.1.24
opencv-python          4.0.1.24
opencv-python-headless 4.0.1.24
packaging              21.3
Pillow                 8.4.0
pip                    21.3.1
pyparsing              3.0.7
python-dateutil        2.9.0.post0
scikit-learn           0.24.2
scipy                  1.5.4
setuptools             36.4.0
six                    1.16.0
terminaltables         3.1.10
threadpoolctl          3.1.0
torch                  1.7.1
torchaudio             0.7.0a0+a853dff
torchvision            0.8.2
tqdm                   4.64.1
typing_extensions      4.1.1
wheel                  0.37.1
zipp                   3.6.0

  • 下载MobileNetV3-Small权重至datas
  • 利用项目里的猫狗图片检验一下安装情况
    python tools/single_test.py datas/cat-dog.png models/mobilenet/mobilenet_v3_small.py --classes-map datas/imageNet1kAnnotation.txt
    

    成功的话大概这样:

 1.数据集问题

 先下载花卉数据集(0zat):flower_photos.zip_免费高速下载|百度网盘-分享无限制 (baidu.com)

 原始地址在项目的资料部分:GitHub - Fafa-DL/Awesome-Backbones: Integrate deep learning models for image classification | Backbone learning/comparison/magic modification project

 目录结构,按照花卉类型存放

├─flower_photos
│  ├─daisy
│  │      100080576_f52e8ee070_n.jpg
│  │      10140303196_b88d3d6cec.jpg
│  │      ...
│  ├─dandelion
│  │      10043234166_e6dd915111_n.jpg
│  │      10200780773_c6051a7d71_n.jpg
│  │      ...
│  ├─roses
│  │      10090824183_d02c613f10_m.jpg
│  │      102501987_3cdb8e5394_n.jpg
│  │      ...
│  ├─sunflowers
│  │      1008566138_6927679c8a.jpg
│  │      1022552002_2b93faf9e7_n.jpg
│  │      ...
│  └─tulips
│  │      100930342_92e8746431_n.jpg
│  │      10094729603_eeca3f2cb6.jpg
│  │      ...
  • datas/中创建标签文件annotations.txt,按行将类别名的索引写入文件(应该已经写好了);即
    daisy 0
    dandelion 1
    roses 2
    sunflowers 3
    tulips 4
    

    之后进行数据集划分,随机分为训练和测试集。

  • 在tools/split_data.py中修改原始数据集地址和划分后的数据集地址。(new_datasets最好别更改)

    init_dataset = './flower_photos'
    new_dataset = './Awesome-Backbones/datasets'
    

    终端使用命令:

    python tools/split_data.py
    

    划分后的数据集格式大概为:

    ├─...
    ├─datasets
    │  ├─test
    │  │  ├─daisy
    │  │  ├─dandelion
    │  │  ├─roses
    │  │  ├─sunflowers
    │  │  └─tulips
    │  └─train
    │      ├─daisy
    │      ├─dandelion
    │      ├─roses
    │      ├─sunflowers
    │      └─tulips
    ├─...
    

    查看tools/get_annotation.py,看看路径要不要更改:

  • datasets_path   = '你的数据集路径'
    

 终端使用命令:

python tools/get_annotation.py

 该命令应该会在datas/下形成train.txt和test.txt,里面是具体照片的位置

2.修改配置文件

/models下有许多的模型配置文件

 以resnet为例

 挑一个顺眼的改改

以resnet101为例

# model settings

model_cfg = dict(
    backbone=dict(
        type='ResNet',
        depth=101,
        num_stages=4,
        out_indices=(3, ),
        style='pytorch'),
    neck=dict(type='GlobalAveragePooling'),
    head=dict(
        type='LinearClsHead',
        num_classes=5,
        in_channels=2048,
        loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
        topk=(1, 5),))

# dataloader pipeline
img_lighting_cfg = dict(
    eigval=[55.4625, 4.7940, 1.1475],
    eigvec=[[-0.5675, 0.7192, 0.4009], [-0.5808, -0.0045, -0.8140],
            [-0.5836, -0.6948, 0.4203]],
    alphastd=0.1,
    to_rgb=True)
policies = [
    dict(type='AutoContrast', prob=0.5),
    dict(type='Equalize', prob=0.5),
    dict(type='Invert', prob=0.5),
    dict(
        type='Rotate',
        magnitude_key='angle',
        magnitude_range=(0, 30),
        pad_val=0,
        prob=0.5,
        random_negative_prob=0.5),
    dict(
        type='Posterize',
        magnitude_key='bits',
        magnitude_range=(0, 4),
        prob=0.5),
    dict(
        type='Solarize',
        magnitude_key='thr',
        magnitude_range=(0, 256),
        prob=0.5),
    dict(
        type='SolarizeAdd',
        magnitude_key='magnitude',
        magnitude_range=(0, 110),
        thr=128,
        prob=0.5),
    dict(
        type='ColorTransform',
        magnitude_key='magnitude',
        magnitude_range=(-0.9, 0.9),
        prob=0.5,
        random_negative_prob=0.),
    dict(
        type='Contrast',
        magnitude_key='magnitude',
        magnitude_range=(-0.9, 0.9),
        prob=0.5,
        random_negative_prob=0.),
    dict(
        type='Brightness',
        magnitude_key='magnitude',
        magnitude_range=(-0.9, 0.9),
        prob=0.5,
        random_negative_prob=0.),
    dict(
        type='Sharpness',
        magnitude_key='magnitude',
        magnitude_range=(-0.9, 0.9),
        prob=0.5,
        random_negative_prob=0.),
    dict(
        type='Shear',
        magnitude_key='magnitude',
        magnitude_range=(0, 0.3),
        pad_val=0,
        prob=0.5,
        direction='horizontal',
        random_negative_prob=0.5),
    dict(
        type='Shear',
        magnitude_key='magnitude',
        magnitude_range=(0, 0.3),
        pad_val=0,
        prob=0.5,
        direction='vertical',
        random_negative_prob=0.5),
    dict(
        type='Cutout',
        magnitude_key='shape',
        magnitude_range=(1, 41),
        pad_val=0,
        prob=0.5),
    dict(
        type='Translate',
        magnitude_key='magnitude',
        magnitude_range=(0, 0.3),
        pad_val=0,
        prob=0.5,
        direction='horizontal',
        random_negative_prob=0.5,
        interpolation='bicubic'),
    dict(
        type='Translate',
        magnitude_key='magnitude',
        magnitude_range=(0, 0.3),
        pad_val=0,
        prob=0.5,
        direction='vertical',
        random_negative_prob=0.5,
        interpolation='bicubic')
]
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(
        type='RandAugment',
        policies=policies,
        num_policies=2,
        magnitude_level=12),
    dict(
        type='RandomResizedCrop',
        size=224,
        efficientnet_style=True,
        interpolation='bicubic',
        backend='pillow'),
    dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
    dict(type='ColorJitter', brightness=0.4, contrast=0.4, saturation=0.4),
    dict(type='Lighting', **img_lighting_cfg),
    dict(
        type='Normalize',
        mean=[123.675, 116.28, 103.53],
        std=[58.395, 57.12, 57.375],
        to_rgb=False),
    dict(type='ImageToTensor', keys=['img']),
    dict(type='ToTensor', keys=['gt_label']),
    dict(type='Collect', keys=['img', 'gt_label'])
]
val_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(
        type='CenterCrop',
        crop_size=224,
        efficientnet_style=True,
        interpolation='bicubic',
        backend='pillow'),
    dict(
        type='Normalize',
        mean=[123.675, 116.28, 103.53],
        std=[58.395, 57.12, 57.375],
        to_rgb=True),
    dict(type='ImageToTensor', keys=['img']),
    dict(type='Collect', keys=['img'])
]

# train
data_cfg = dict(
    batch_size = 32,
    num_workers = 0,
    train = dict(
        pretrained_flag = False,
        pretrained_weights = '',
        freeze_flag = False,
        freeze_layers = ('backbone',),
        epoches = 150,
    ),
    test=dict(
        ckpt = './logs/ResNet/2024-06-26-10-37-00/Last_Epoch150.pth',
        metrics = ['accuracy', 'precision', 'recall', 'f1_score', 'confusion'],
        metric_options = dict(
            topk = (1,5),
            thrs = None,
            average_mode='none'
    )
    )
)

# optimizer
optimizer_cfg = dict(
    type='SGD',
    lr=0.001,
    momentum=0.9,
    weight_decay=1e-4)

# learning 
lr_config = dict(type='StepLrUpdater', step=[30, 60, 90])

主要改model_cfg里面的num_classes,data_cfg里的batch_size与num_workers

若有预训练权重则可以将pretrained_weights设置为True并将预训练的路径赋值给pretrained_weights

optimizer_cfg中修改初始学习率,根据batch_size调试

3.训练

终端运行

python tools/train.py models/resnet/resnet101.py

 运行结果

4.评估

在实际使用的配置文件中将ckpt修改

ckpt = '你的训练权重路径'

终端运行

python tools/evaluation.py models/resnet/resnet101.py

 运行结果

 我跑出来的准确率不高哈

5.测试

单张测试

python tools/single_test.py datasets/test/dandelion/14283011_3e7452c5b2_n.jpg models/resnet/resnet101.py

多张测试

使用batch_test.py,路径使用文件夹路径。

----------------------------------------------------------------------------------------------

使用自己的数据集

1.数据集准备

2.配置文件

3.训练

4.评估

5.测试

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/758344.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

HBase Shell命令详解

HBase Shell命令 一、 命名空间 命名空间是 HBase 中用于组织表的一种逻辑容器,类似于文件系统中的文件夹。 Namespace允许用户在 HBase 中更好地管理和组织表,以及提供了隔离和命名约定。 1. 创建命名空间 命令: create_namespace name…

【scrapy】1.scrapy爬虫入门

一、scrapy爬虫框架 Scrapy 框架是一个基于Twisted的一个异步处理爬虫框架,应用范围非常的广泛,常用于数据采集、网络监测,以及自动化测试等。 scrapy框架包括5个主要的组件: Scheduler:事件调度器,它负…

机器学习引领教育革命:智能教育的新时代

📝个人主页🌹:Eternity._ 🌹🌹期待您的关注 🌹🌹 ❀目录 📒1. 引言📙2. 机器学习在教育中的应用🌞个性化学习🌙评估与反馈的智能化⭐教学资源的优…

Lua: 轻量级多用途脚本语言

Lua 是一种高效而轻量级的脚本语言,具备强大的扩展性和灵活性,广泛应用于游戏开发、嵌入式系统、Web 应用等多个领域。本文将深入探讨 Lua 的特性、应用场景以及如何使用 Lua 进行开发。 1. Lua 的起源与发展 Lua 的发展始于上世纪90年代初,…

Java单体架构项目_云霄外卖-特殊点

项目介绍: 定位: 专门为餐饮企业(餐厅、饭店)定制的一款软件商品 分为: 管理端:外卖商家使用 用户端(微信小程序):点餐用户使用。 功能架构: &#xff08…

哎呀呀 又迟到了,还被抓住了,面面相觑 害怕

网络编程 我应该迟点来,唠嗑到35嘿嘿 心疼自己早起呜呜呜,幸运的是35开讲 计算机网络 分4层 应用层(Application Layer): 应用层是用户接口和网络应用程序的接口。它允许用户访问网络服务,并支持各种应用程…

Windows系统下文件夹权限详解

文章目录 问题描述文件夹属性 问题描述 今天在Win10系统下,实现文件夹设置权限,具体的方案的涉及到我们公司内部的一款加密软件,不太方便透漏,借此机会,我也重新的回顾下windows系统下的文件夹权限 文件夹属性 打开…

[C++][设计模式][中介者模式]详细讲解

目录 1.动机2.模式定义3.要点总结 1.动机 在软件构建过程中,经常会出现多个对象相互关联的情况,对象之间常常会维持一种复杂的引用关系,如果遇到一些需求的更改,这种直接的引用关系将面临不断的变化在这种情况下,可以…

【小沐学AI】Python实现语音识别(whisper+HuggingFace)

文章目录 1、简介1.1 whisper 2、HuggingFace2.1 安装transformers2.2 Pipeline 简介2.3 Tasks 简介2.3.1 sentiment-analysis2.3.2 zero-shot-classification2.3.3 text-generation2.3.4 fill-mask2.3.5 ner2.3.6 question-answering2.3.7 summarization2.3.8 translation 3、…

PyTorch Tensor进阶操作指南(二):深度学习中的关键技巧

本文主要讲tensor的裁剪、索引、降维和增维 Tensor与numpy互转、Tensor运算等,请看这篇文章 目录 9.1、首先看torch.squeeze()函数: 示例9.1:(基本的使用) 小技巧1:如何看维数 示例9.2:&a…

ISO15765-2 道路车辆——通过控制器局域网(CAN)进行诊断通信 (翻译版)(万字长文)

ISO15765-2 道路车辆——通过控制器局域网(CAN)进行诊断通信 (翻译版)(万字长文) 文章目录 ISO15765-2 道路车辆——通过控制器局域网(CAN)进行诊断通信 (翻译版)(万字长文)第二部分:传输协议和网络层服务前言Foreword…

在navicat对mysql声明无符号字段

1.无符号设置 在 MySQL 中,我们可以使用 UNSIGNED 属性来设置列的无符号属性,这意味着该列只能存储非负整数值。对于一些需要存储正整数的列,比如年龄、数量等,使用 UNSIGNED 属性可以提高数据存储和查询的效率,并且能…

浅谈一下VScode如何配置C环境

1.这几天突然发现在VScode写C程序比在DevC效果更好,因为在VScode中写代码有代码补全功能。所以我突然对在VScode中配置C环境变量产生了兴趣。 2.不过在VScode中配置C环境要是按照官方的来配置有点麻烦。 3.我这里有一个直接配置VScode中C环境变量的应用。 前提是…

原来“山水博客“的分类也是可以拖动排序的

这二天一直用“山水博客”写文章,发现一个问题,好象它的分类不能调整位置,这可是个大bug。首先,界面上没发现拖动相关按钮;如果按住分类拖动,会成这样: 后来仔细看了它的文档,发现它…

智能社区服务小程序的设计

管理员账户功能包括:系统首页,个人中心,用户管理,房屋信息管理,住户信息管理,家政服务管理,家政预约管理,报修信息管理 微信端账号功能包括:系统首页,房屋信…

水果品牌网站开展如何拓宽渠道

对大多数人来说,零售买水果只在乎是买什么水果、哪个产地、价格等因此,对品牌的依赖度相对较低。但对于水果品牌公司来说,货好仅是基本,还需要将品牌发展出去、能获取准属性客户和转化路径。 与零售不同,批发生意或是…

在vs上远程连接Linux写服务器项目并启动后,可以看到服务启动了,但是通过浏览器访问该服务提示找不到页面

应该是被防火墙挡住了,查看这个如何检查linux服务器被防火墙挡住 • Worktile社区 和这个关于Linux下Nginx服务启动,通过浏览器无法访问的问题_linux无法访问nginx-CSDN博客 的提示之后,知道防火墙开了,想着可能是我写的服务器的…

大数据面试题之Spark(1)

目录 Spark的任务执行流程 Spark的运行流程 Spark的作业运行流程是怎么样的? Spark的特点 Spark源码中的任务调度 Spark作业调度 Spark的架构 Spark的使用场景 Spark on standalone模型、YARN架构模型(画架构图) Spark的yarn-cluster涉及的参数有哪些? Spark提交jo…

c++类成员指针用法

1)C入门级小知识,分享给将要学习或者正在学习C开发的同学。 2)内容属于原创,若转载,请说明出处。 3)提供相关问题有偿答疑和支持。 c中新增类成员指针操作,为了访问方便,他是指…

Spring Boot项目如何配置跨域

1、通过SpringSecurity进行配置 2、前端跨域配置:proxy配置项用于设置代理规则,用于前端开发中与后端API交互时使用。