动手学深度学习(pytorch)学习记录28-使用块的网络(VGG)[学习记录]

news/2024/9/18 23:14:17 标签: 深度学习, pytorch, 学习

目录

  • VGG块
  • VGG网络
  • 训练模型

VGG块

定义了一个名为vgg_block的函数来实现一个VGG块

import torch
from torch import nn
from d2l import torch as d2l
def vgg_block(num_convs, in_channels, out_channels):
    layers = []
    for _ in range(num_convs):
        layers.append(nn.Conv2d(in_channels, out_channels,
                                kernel_size=3, padding=1))
        layers.append(nn.ReLU())
        in_channels = out_channels
    layers.append(nn.MaxPool2d(kernel_size=2,stride=2))
    return nn.Sequential(*layers)

VGG网络

与AlexNet、LeNet一样,VGG网络可以分为两部分:第一部分主要由卷积层和汇聚层组成,第二部分由全连接层组成。
VGG神经网络连接的几个VGG块(在vgg_block函数中定义)。其中有超参数变量conv_arch。该变量指定了每个VGG块里卷积层个数和输出通道数。全连接模块则与AlexNet中的相同。

原始VGG网络有5个卷积块,其中前两个块各有一个卷积层,后三个块各包含两个卷积层。 第一个模块有64个输出通道,每个后续模块将输出通道数量翻倍,直到该数字达到512。由于该网络使用8个卷积层和3个全连接层,因此它通常被称为VGG-11。

conv_arch = ((1, 64), (1, 128), (2, 256), (2, 512), (2, 512))

实现VGG11

def vgg(conv_arch):
    conv_blks = []
    in_channels = 1
    # 卷积层部分
    for (num_convs, out_channels) in conv_arch:
        conv_blks.append(vgg_block(num_convs, in_channels, out_channels))
        in_channels = out_channels

    return nn.Sequential(
        *conv_blks, nn.Flatten(),
        # 全连接层部分
        nn.Linear(out_channels * 7 * 7, 4096), nn.ReLU(), nn.Dropout(0.5),
        nn.Linear(4096, 4096), nn.ReLU(), nn.Dropout(0.5),
        nn.Linear(4096, 10))

net = vgg(conv_arch)

构建一个高度和宽度为224的单通道数据样本,以观察每个层输出的形状。

X = torch.randn(size=(1, 1, 224, 224))
for blk in net:
    X = blk(X)
    print(blk.__class__.__name__,'output shape:\t',X.shape)
Sequential output shape:	 torch.Size([1, 64, 112, 112])
Sequential output shape:	 torch.Size([1, 128, 56, 56])
Sequential output shape:	 torch.Size([1, 256, 28, 28])
Sequential output shape:	 torch.Size([1, 512, 14, 14])
Sequential output shape:	 torch.Size([1, 512, 7, 7])
Flatten output shape:	 torch.Size([1, 25088])
Linear output shape:	 torch.Size([1, 4096])
ReLU output shape:	 torch.Size([1, 4096])
Dropout output shape:	 torch.Size([1, 4096])
Linear output shape:	 torch.Size([1, 4096])
ReLU output shape:	 torch.Size([1, 4096])
Dropout output shape:	 torch.Size([1, 4096])
Linear output shape:	 torch.Size([1, 10])

训练模型

# 由于Fashion-MNIST数据集比较简单,缩放前面几个块的通道数,足够训练这个数据集
ratio = 4 # 设置比例
small_conv_arch = [(pair[0], pair[1] // ratio) for pair in conv_arch]
net = vgg(small_conv_arch)
X = torch.randn(size=(1, 1, 224, 224))
for blk in net:
    X = blk(X)
    print(blk.__class__.__name__,'output shape:\t',X.shape)
Sequential output shape:	 torch.Size([1, 16, 112, 112])
Sequential output shape:	 torch.Size([1, 32, 56, 56])
Sequential output shape:	 torch.Size([1, 64, 28, 28])
Sequential output shape:	 torch.Size([1, 128, 14, 14])
Sequential output shape:	 torch.Size([1, 128, 7, 7])
Flatten output shape:	 torch.Size([1, 6272])
Linear output shape:	 torch.Size([1, 4096])
ReLU output shape:	 torch.Size([1, 4096])
Dropout output shape:	 torch.Size([1, 4096])
Linear output shape:	 torch.Size([1, 4096])
ReLU output shape:	 torch.Size([1, 4096])
Dropout output shape:	 torch.Size([1, 4096])
Linear output shape:	 torch.Size([1, 10])
lr, num_epochs, batch_size = 0.05, 10, 128
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=224)
d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())
loss 0.181, train acc 0.933, test acc 0.919
378.9 examples/sec on cuda:0

在这里插入图片描述

· 本文使用了大量d2l包,这极大地减少了代码编辑量,需要安装d2l包才能运行本文代码
封面图片来源
欢迎点击我的主页查看更多文章。
本人学习地址https://zh-v2.d2l.ai/
恳请大佬批评指正。


http://www.niftyadmin.cn/n/5664678.html

相关文章

Go 1.19.4 路径和目录-Day 15

1. 路径介绍 存储设备保存着数据,但是得有一种方便的模式让用户可以定位资源位置,操作系统采用一种路径字符 串的表达方式,这是一棵倒置的层级目录树,从根开始。 相对路径:不是以根目录开始的路径,例如 a/b…

ant-design表格自动合并相同内容的单元格

表格自动合并相同内容的单元格 合并hooks import { TableColumnProps } from antdexport const useAutoMergeTableCell <T extends object>(dataSource: Array<T>,columns: Array<TableColumnProps> | Array<keyof T> ): Map<keyof T, Array<…

【运维方案】软件运维服务方案(word)

1.项目情况 2.服务简述 2.1服务内容 2.2服务方式 2.3服务要求 2.4服务流程 2.5工作流程 2.6业务关系 2.7培训 3.资源提供 3.1项目组成员 3.2服务保障 进主页学习更多获取更多资料&#xff5e;

字节飞书-测开日常实习-部分手撕代码题

之前的文章提到了一道高频题&#xff1a;最长不重复的字串&#xff0c;用到动态窗口。解法就在之前的文章。这篇文章从牛客上找了一些手撕题&#xff0c;在这里记录分享一下。 1.将给定的字符串中的每个单词的首字母转化为大小字母【简单】 首字母大写__牛客网 不难 就是考察…

虚拟DOM介绍

工作流程 虚拟 DOM 并不直接发生在用户界面构建之前&#xff0c;而是作为构建用户界面过程中的一个重要部分。具体来说&#xff0c;虚拟 DOM 的工作流程如下&#xff1a; 初始化阶段&#xff1a; 组件定义&#xff1a;在应用程序开发过程中&#xff0c;开发者首先定义组件和它…

CMakeLists.txt的学习了解

CMakeLists.txt 是 CMake 构建系统中的配置文件&#xff0c;用于定义项目的编译规则和依赖关系。CMake 是一种跨平台的构建系统&#xff0c;支持从源代码生成编译脚本&#xff08;如 Makefile 或 Visual Studio 工程文件&#xff09;。CMakeLists.txt 通过指定项目信息、源文件…

【AI学习笔记】初学机器学习西瓜书概要记录(二)常用的机器学习方法篇

初学机器学习西瓜书的概要记录&#xff08;一&#xff09;机器学习基础知识篇(已完结) 初学机器学习西瓜书的概要记录&#xff08;二&#xff09;常用的机器学习方法篇(持续更新) 初学机器学习西瓜书的概要记录&#xff08;三&#xff09;进阶知识篇(待更) 文字公式撰写不易&am…

上汽集团社招入职SHL测评:语言理解及数字推理高分攻略、真题题库

上汽集团社招待遇 上汽集团作为国内领先的汽车制造企业&#xff0c;其社招待遇和面试问题一直是求职者关注的焦点。以下是根据最新信息整理的上汽集团社招待遇及面试问题概览&#xff1a; 工资待遇&#xff1a;上汽集团的工资待遇在国内汽车行业中属于较高水平。根据不同职位和…