分享web开发知识

注册/登录|最近发布|今日推荐

主页 IT知识网页技术软件开发前端开发代码编程运营维护技术分享教程案例
当前位置:首页 > 教程案例

pytorch实现squeezenet

发布时间:2023-09-06 02:16责任编辑:顾先生关键词:暂无标签

squeezenet是16年发布的一款轻量级网络模型,模型很小,只有4.8M,可用于移动设备,嵌入式设备。

关于squeezenet的原理可自行阅读论文或查找博客,这里主要解读下pytorch对squeezenet的官方实现。

地址:https://github.com/pytorch/vision/blob/master/torchvision/models/squeezenet.py

首先定义fire模块,这是squeezenet的核心所在,降低3X3卷积的数量。

class Fire(nn.Module): ???def __init__(self, inplanes, squeeze_planes, ????????????????expand1x1_planes, expand3x3_planes): ???????super(Fire, self).__init__() ???????self.inplanes = inplanes ???????self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1)#定义压缩层,1X1卷积 ???????self.squeeze_activation = nn.ReLU(inplace=True) ???????self.expand1x1 = nn.Conv2d(squeeze_planes, expand1x1_planes,#定义扩展层,1X1卷积 ??????????????????????????????????kernel_size=1) ???????self.expand1x1_activation = nn.ReLU(inplace=True) ???????self.expand3x3 = nn.Conv2d(squeeze_planes, expand3x3_planes,#定义扩展层,3X3卷积 ??????????????????????????????????kernel_size=3, padding=1) ???????self.expand3x3_activation = nn.ReLU(inplace=True) ???def forward(self, x): ???????x = self.squeeze_activation(self.squeeze(x)) ???????return torch.cat([ ???????????self.expand1x1_activation(self.expand1x1(x)), ???????????self.expand3x3_activation(self.expand3x3(x)) ???????], 1)

可以看到首先定义压缩层与两个扩展层,压缩层用的是1X1卷积,扩展层是1X1卷积和3X3卷积的混合使用,网络inference的脉络是先经过压缩层,然后并行经过两个扩展层,最后将扩展层串联。

定义完核心模块,来看网络整体。

class SqueezeNet(nn.Module): ???def __init__(self, version=1.0, num_classes=1000): ???????super(SqueezeNet, self).__init__() ???????if version not in [1.0, 1.1]: ???????????raise ValueError("Unsupported SqueezeNet version {version}:" ????????????????????????????"1.0 or 1.1 expected".format(version=version)) ???????self.num_classes = num_classes ???????if version == 1.0: ???????????self.features = nn.Sequential( ???????????????nn.Conv2d(3, 96, kernel_size=7, stride=2), ???????????????nn.ReLU(inplace=True), ???????????????nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), ???????????????Fire(96, 16, 64, 64), ???????????????Fire(128, 16, 64, 64), ???????????????Fire(128, 32, 128, 128), ???????????????nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), ???????????????Fire(256, 32, 128, 128), ???????????????Fire(256, 48, 192, 192), ???????????????Fire(384, 48, 192, 192), ???????????????Fire(384, 64, 256, 256), ???????????????nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), ???????????????Fire(512, 64, 256, 256), ???????????) ???????else: ???????????self.features = nn.Sequential( ???????????????nn.Conv2d(3, 64, kernel_size=3, stride=2), ???????????????nn.ReLU(inplace=True), ???????????????nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), ???????????????Fire(64, 16, 64, 64), ???????????????Fire(128, 16, 64, 64), ???????????????nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), ???????????????Fire(128, 32, 128, 128), ???????????????Fire(256, 32, 128, 128), ???????????????nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), ???????????????Fire(256, 48, 192, 192), ???????????????Fire(384, 48, 192, 192), ???????????????Fire(384, 64, 256, 256), ???????????????Fire(512, 64, 256, 256), ???????????) ???????# Final convolution is initialized differently form the rest ???????final_conv = nn.Conv2d(512, self.num_classes, kernel_size=1) ???????self.classifier = nn.Sequential( ???????????nn.Dropout(p=0.5), ???????????final_conv, ???????????nn.ReLU(inplace=True), ???????????nn.AvgPool2d(13, stride=1) ???????) ???????for m in self.modules(): ???????????if isinstance(m, nn.Conv2d): ???????????????if m is final_conv: ???????????????????init.normal_(m.weight, mean=0.0, std=0.01) ???????????????else: ???????????????????init.kaiming_uniform_(m.weight) ???????????????if m.bias is not None: ???????????????????init.constant_(m.bias, 0) ???def forward(self, x): ???????x = self.features(x) ???????x = self.classifier(x) ???????return x.view(x.size(0), self.num_classes)

首先依然是定义网络层,在这里有两个版本,差别不大,都是fire模块的堆积,最后经过全局平均池化输出1000类。这里对卷积层采用了不同的初始化策略,我还没仔细研究过,就不说了。

pytorch实现squeezenet

原文地址:https://www.cnblogs.com/wzyuan/p/9710565.html

知识推荐

我的编程学习网——分享web前端后端开发技术知识。 垃圾信息处理邮箱 tousu563@163.com 网站地图
icp备案号 闽ICP备2023006418号-8 不良信息举报平台 互联网安全管理备案 Copyright 2023 www.wodecom.cn All Rights Reserved