ResNet 殘差網絡論文閱讀及示例代碼

論文閱讀

其實論文的思想在今天看來是不難的,不過在當時 ResNet 提出的時候可是橫掃了各大分類任務,這個網絡解決了隨著網絡的加深,分類的準確率不升反降的問題。通過一個名叫“殘差”的網絡結構(如下圖所示),使作者可以只通過簡單的網絡深度堆疊便可達到提升準確率的目的。

ResNet 殘差網絡論文閱讀及示例代碼

殘差結構

殘差結構的處理過程分成兩個部分,左邊的 F(X) 與右邊的 X,最後結果為兩者相加。其中右邊那根線不會對 X 做任何處理,所以沒有可學習的參數;左邊部分 F(X) 為網絡中負責學習特徵的部分,把整個殘差結構看做是 H(X) 函數的話,則負責學習的部分可以表示為 H(X)=F(X)-X,這個結構學習的其實是輸出結果與輸入的差值,這也是殘差名字的由來。完整的 ResNet 網絡由多個上圖中所示的殘差結構組成,每個結構學習的都是輸出與輸入之間的差值,通過步步逼近,達到了比直接學習輸入好得多的效果。

文中殘差結構的具體實現分為兩種,首先介紹 ResNet-18 與 ResNet-34 使用的殘差結構稱為 Basic Block,如下圖所示,圖中的結構包含了兩個卷積操作用於提取特徵。

ResNet 殘差網絡論文閱讀及示例代碼

Basic Block

對應到代碼中,這是 Pytorch 自帶的 ResNet 實現中的一部分,跟上圖對應起來看更加好理解,我個人比較喜歡論文與代碼結合起來看,因為我除了需要知道原理之外,也要知道如何去使用,而代碼更給我一種一目瞭然的感覺:

class BasicBlock(nn.Module):
 expansion = 1
 def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
 base_width=64, dilation=1, norm_layer=None):
 super(BasicBlock, self).__init__()
 if norm_layer is None:
 norm_layer = nn.BatchNorm2d
 if groups != 1 or base_width != 64:
 raise ValueError('BasicBlock only supports groups=1 and base_width=64')
 if dilation > 1:
 raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
 # Both self.conv1 and self.downsample layers downsample the input when stride != 1
 self.conv1 = conv3x3(inplanes, planes, stride)
 self.bn1 = norm_layer(planes)
 self.relu = nn.ReLU(inplace=True)
 self.conv2 = conv3x3(planes, planes)
 self.bn2 = norm_layer(planes)
 self.downsample = downsample
 self.stride = stride
 def forward(self, x):
 identity = x
 out = self.conv1(x)
 out = self.bn1(out)
 out = self.relu(out)
 out = self.conv2(out)
 out = self.bn2(out)
 if self.downsample is not None:
 identity = self.downsample(x)
 out += identity
 out = self.relu(out)
 return out

另一種殘差結構稱為 Bottleneck,就是瓶頸的意思:

ResNet 殘差網絡論文閱讀及示例代碼

作者起名字真的很形象,網絡結構也正如這瓶頸一樣,首先做一個降維,然後做卷積,然後升維,這樣做的好處是可以大大減少計算量,專門用於網絡層數較深的的網絡,ResNet-50 以上的網絡都有這種基礎結構構成(不同層級的輸入輸出維度可能會不一樣,但結構類似):

ResNet 殘差網絡論文閱讀及示例代碼

Pytorch 中的代碼,注意到上圖中為了減少計算量,作者將 256 維的輸入縮小了 4 倍變為 64 進入卷積,在升維時需要升到 256 維,對應代碼中的

expansion 參數:

class Bottleneck(nn.Module):
 expansion = 4
 def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
 base_width=64, dilation=1, norm_layer=None):
 super(Bottleneck, self).__init__()
 if norm_layer is None:
 norm_layer = nn.BatchNorm2d
 width = int(planes * (base_width / 64.)) * groups
 # Both self.conv2 and self.downsample layers downsample the input when stride != 1
 self.conv1 = conv1x1(inplanes, width)
 self.bn1 = norm_layer(width)
 self.conv2 = conv3x3(width, width, stride, groups, dilation)
 self.bn2 = norm_layer(width)
 self.conv3 = conv1x1(width, planes * self.expansion)
 self.bn3 = norm_layer(planes * self.expansion)
 self.relu = nn.ReLU(inplace=True)
 self.downsample = downsample
 self.stride = stride
 def forward(self, x):
 identity = x
 out = self.conv1(x)
 out = self.bn1(out)
 out = self.relu(out)
 out = self.conv2(out)
 out = self.bn2(out)
 out = self.relu(out)
 out = self.conv3(out)
 out = self.bn3(out)
 if self.downsample is not None:
 identity = self.downsample(x)
 out += identity
 out = self.relu(out)
 return out

由上面介紹的基本結構再加上池化以及全連接層,就構成了各種完整的網絡:

ResNet 殘差網絡論文閱讀及示例代碼

圖中的網絡在 Pytorch 中都已經集成進去了,而且都是預訓練好的,我們可以在預訓練好的模型上面訓練自己的分類器,大大減少我們的訓練時間。下面介紹一下如何使用 ResNet。

在 Pytorch 中使用 ResNet

Pytorch 是一個對初學者很友好的深度學習框架,入門的話非常推薦,官方提供了一小時入門教程:https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html

在 Pytorch 中使用 ResNet 只需要 4 行代碼:

from torch import nn
# torchvision 專用於視覺方面
import torchvision 
 
# pretrained :使用在 ImageNet 數據集上預訓練的模型
model = torchvision.models.resnet18(pretrained=True)
# 修改模型的全連接層使其輸出為你需要類型數,這裡是10
# 由於使用了預訓練的模型 而預訓練的模型輸出為1000類,所以要修改全連接層
# 若不使用預訓練的模型可以直接在創建模型時添加參數 num_classes=10 而不需要修改全連接層
model.fc = nn.Linear(model.fc.in_features, 10)


分享到:


相關文章: