【PyTorch】使用容器(Containers)进行网络层管理(Module)

news/2024/9/18 2:35:27 标签: pytorch, 人工智能, python, 深度学习

文章目录

  • 前言
  • 一、Sequential
  • 二、ModuleList
  • 三、ModuleDict
  • 四、ParameterList & ParameterDict
  • 总结


前言

深度学习模型逐渐变得复杂,在编写代码时便会遇到诸多麻烦,此时便需要Containers的帮助。Containers的作用是将一部分网络层模块化,从而更方便地管理和调用。本文介绍PyTorch库常用的nn.Sequential,nn.ModuleList,nn.ModuleDict容器以及nn.ParameterList & ParameterDict参数容器。


一、Sequential

Sequential是最为常用的容器,它的功能也十分简单直接-将多个网络层按照固定的顺序连接,从前往后依次执行。比如在AlexNet中,多次需要conv+relu+maxpool的组合,此时便可以将其放入Sequential容器,便于在forward中调用。
下面来看PyTorch官方代码示例:

python">model = nn.Sequential(
	nn.Conv2d(1,20,5),
	nn.ReLU(),
	nn.Conv2d(20,64,5),
	nn.ReLU()
	)
 # Using Sequential with OrderedDict. This is functionally the
 # same as the above code
 model = nn.Sequential(OrderedDict([
	('conv1', nn.Conv2d(1,20,5)),
	('relu1', nn.ReLU()),
	('conv2', nn.Conv2d(20,64,5)),
	('relu2', nn.ReLU())
	]))

示例中展示了两种Sequential使用方法:1,直接串联各个网络层。2,使用OrderedDict为每个module取名。这两种方法是等效的。


二、ModuleList

"顾名思义"ModuleList的作用如同Python的列表,将各个层存入一个类似于List的结构中,从而可以利用索引来进行调用。
注意这里是类似于list的结构,那为什么我们不直接用list呢?
ModuleList是专门为Pytorch中的神经网络模块(即继承自nn.Module的类)设计的容器。它确保所有添加到其中的模块都会正确地注册到网络中,以便进行参数管理和梯度更新。当模型被保存或加载时,nn.ModuleList中的模块也会相应地被保存或加载。而Python的列表是一个通用的容器,可以存储任意类型的对象。它没有专门为神经网络模块设计,因此不会进行参数的自动注册或管理。
代码示例:

python">class MyModule(nn.Module):
		def __init__(self):
			super(MyModule, self).__init__()
			self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])
			# self.linears = [nn.Linear(10, 10) for i in range(10)]    
		def forward(self, x):
 			for sub_layer in self.linears:
 			x = sub_layer(x)
 			return x

三、ModuleDict

ModuleDict是一个类似python字典的容器,相比于ModuleList,它的优点在于可以利用名字来调用网络层,这就避免了必须记住网络层具体元素才能调用的麻烦。
代码示例:

python"> class MyModule2(nn.Module):
        def __init__(self):
            super(MyModule2, self).__init__()
            self.choices = nn.ModuleDict({
                    'conv': nn.Conv2d(3, 16, 5),
                    'pool': nn.MaxPool2d(3)
            })
            self.activations = nn.ModuleDict({
                    'lrelu': nn.LeakyReLU(),
                    'prelu': nn.PReLU()
            })
        def forward(self, x, choice, act):
            x = self.choices[choice](x)
            x = self.activations[act](x)
            return x

四、ParameterList & ParameterDict

除了Module有容器,Parameter也有容器。与ModuleList和ModuleDict类似的,Paramter也有List和Dict,使用方法一样。

python">class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.params = nn.ParameterDict({
                'left': nn.Parameter(torch.randn(5, 10)),
                'right': nn.Parameter(torch.randn(5, 10))
        })
    def forward(self, x, choice):
        x = self.params[choice].mm(x)
        return x
 # ParameterList
 class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.params = nn.ParameterList([nn.Parameter(torch.randn(10, 10)) for i in range(10)])
    def forward(self, x):
        # ParameterList can act as an iterable, or be indexed using ints
        for i, p in enumerate(self.params):
            x = self.params[i // 2].mm(x) + p.mm(x)
        return x

这是专门为Pytorch中的参数(如权重和偏置)设计的容器。它确保添加到其中的参数会被正确地注册到网络中,以便进行参数管理和梯度更新。与module类似,参数容器中的参数也会被包含在网络的参数列表中,并在模型保存和加载时被正确处理。


总结

容器是pytorch框架对网络进行组织管理的实用工具,合理运用可以极大提高代码的可读性与可维护性。


http://www.niftyadmin.cn/n/5655928.html

相关文章

【Leetcode算法面试题】-1. 两数之和

文章目录 算法练习题目思路参考答案算法1算法2算法3 算法练习 面试经常会遇到算法题目,今天开启算法专栏,常用算法解析 题目 ** 给定一个整数数组 nums 和一个整数目标值 target,请你在该数组中找出 和为目标值 target 的那 两个 整数&…

初识c++:入门基础

打字不易,留个赞再走吧~~ 目录 一.第一个c程序二.命名空间 namespace三.C输⼊&输出四.缺省参数 C兼容C语⾔绝⼤多数的语法,所以C语⾔实现的hello world依旧可以运⾏,C中需要把定义⽂件 代码后缀改为.cpp 一.第一个c程序 做好准备我们来写…

Mysql异常断电InnoDB损坏处理

一、mysql启动报错信息收集 1、截图 [ERROR] InnoDB: Database page corruption on disk or a failed file read of page [page id: space0, page number203]. You may have to recover from a backup. Jun 27 13:30:06 localhost mysqld: 2024-06-27T13:30:06.14747208:00 0 …

【编程底层思考】性能监控和优化:JVM参数调优,诊断工具的使用等。JVM 调优和线上问题排查实战经验总结

JVM性能监控和优化是确保Java应用程序高效运行的关键环节。以下是一些JVM性能监控和优化的方法,以及使用诊断工具和实战经验的总结: 一、JVM参数调优: 堆大小设置 : - Xms:设置JVM启动时的初始堆大小。 - -Xmx:设置J…

如何安全,高效,优雅的提升linux的glibc版本

如何安全,高效,优雅的提升linux的glibc版本 一、发现问题二、升级glibc版本1. 下载对应的软件包2. 解压软件包3. 查看新版本glibc安装要求,并查看自己版本是否符合需求4. 升级python版本4.1 下载软件包4.2 解压4.3 编译4.4 确认更新后的pytho…

物联网之ESP32配网方式、蓝牙、WiFi

MENU 前言SmartConfig(智能配网)AP模式(Access Point模式)蓝牙配网Web Server模式WPS配网(Wi-Fi Protected Setup)Provisioning(配网服务)静态配置(硬编码)总结 前言 ESP32配网(Wi-Fi配置)的方式有多种,每种方式都有各自的优缺点。 根据具体项目需求,可以…

C++入门(part 3)

前言 在前文我们讲解了C的诞生与历史,顺便讲解一些C的小语法,本文会继续讲解C的基础语法知识。 1.inline(内联函数) inline是C新加入的关键字,用inline修饰的函数叫做内联函数,编译时C编译器会在调用的地方将函数展开,这样每次…

【6大设计原则】解锁代码的灵活性:深入解析开闭原则的代码实例与应用

1.引言 在软件开发中,设计模式是解决常见问题的经过验证的解决方案。设计模式不仅提供了一种可复用的设计思路,还有助于提高软件的质量和可维护性。设计模式的六大原则是指导我们进行软件设计的基石,其中开闭原则(Open/Closed Pr…