ResNet——CNN經典網絡模型詳解(pytorch實現)

建議大家可以實踐下,代碼都很詳細,有不清楚的地方評論區見~

1、前言

ResNet(Residual Neural Network)由微軟研究院的Kaiming He等四名華人提出,通過使用ResNet Unit成功訓練出了152層的神經網絡,並在ILSVRC2015比賽中取得冠軍,在top5上的錯誤率為3.57%,同時參數量比VGGNet低,效果非常突出。ResNet的結構可以極快的加速神經網絡的訓練,模型的準確率也有比較大的提升。同時ResNet的推廣性非常好,甚至可以直接用到InceptionNet網絡中。

下圖是ResNet34層模型的結構簡圖。

ResNet——CNN經典網絡模型詳解(pytorch實現)

2、ResNet詳解

在ResNet網絡中有如下幾個亮點:

  • 提出residual結構(殘差結構),並搭建超深的網絡結構(突破1000層)
  • 使用Batch Normalization加速訓練(丟棄dropout)

在ResNet網絡提出之前,傳統的卷積神經網絡都是通過將一系列卷積層與下采樣層進行堆疊得到的。但是當堆疊到一定網絡深度時,就會出現兩個問題。

  1. 梯度消失或梯度爆炸。
  2. 退化問題(degradation problem)。

在ResNet論文中說通過數據的預處理以及在網絡中使用BN(Batch Normalization)層能夠解決梯度消失或者梯度爆炸問題。如果不瞭解BN層可參考這個鏈接 。但是對於退化問題(隨著網絡層數的加深,效果還會變差,如下圖所示)並沒有很好的解決辦法。

ResNet——CNN經典網絡模型詳解(pytorch實現)

所以ResNet論文提出了residual結構(殘差結構)來減輕退化問題。下圖是使用residual結構的卷積網絡,可以看到隨著網絡的不斷加深,效果並沒有變差,反而變的更好了。

ResNet——CNN經典網絡模型詳解(pytorch實現)

殘差結構(residual)

殘差指的是什麼? 其中ResNet提出了兩種mapping:一種是identity mapping,指的就是下圖中”彎彎的曲線”,另一種residual mapping,指的就是除了”彎彎的曲線“那部分,所以最後的輸出是 y=F(x)+x

  • identity mapping

顧名思義,就是指本身,也就是公式中的x,而residual mapping指的是“差”,也就是y−x,所以殘差指的就是F(x)部分。

下圖是論文中給出的兩種殘差結構。左邊的殘差結構是針對層數較少網絡,例如ResNet18層和ResNet34層網絡。右邊是針對網絡層數較多的網絡,例如ResNet101,ResNet152等。為什麼深層網絡要使用右側的殘差結構呢。因為,右側的殘差結構能夠減少網絡參數與運算量。同樣輸入一個channel為256的特徵矩陣,如果使用左側的殘差結構需要大約1170648個參數,但如果使用右側的殘差結構只需要69632個參數。明顯搭建深層網絡時,使用右側的殘差結構更合適。

ResNet——CNN經典網絡模型詳解(pytorch實現)

我們先對左側的殘差結構(針對ResNet18/34)進行一個分析。

如下圖所示,該殘差結構的主分支是由兩層3x3的卷積層組成,而殘差結構右側的連接線是shortcut分支也稱捷徑分支(注意為了讓主分支上的輸出矩陣能夠與我們捷徑分支上的輸出矩陣進行相加,必須保證這兩個輸出特徵矩陣有相同的shape)。如果剛剛仔細觀察了ResNet34網絡結構圖的同學,應該能夠發現圖中會有一些虛線的殘差結構。在原論文中作者只是簡單說了這些虛線殘差結構有降維的作用,並在捷徑分支上通過1x1的卷積核進行降維處理。而下圖右側給出了詳細的虛線殘差結構,注意下每個卷積層的步距stride,以及捷徑分支上的卷積核的個數(與主分支上的卷積核個數相同)。

ResNet——CNN經典網絡模型詳解(pytorch實現)

接著我們再來分析下針對ResNet50/101/152的殘差結構,如下圖所示。在該殘差結構當中,主分支使用了三個卷積層,第一個是1x1的卷積層用來壓縮channel維度,第二個是3x3的卷積層,第三個是1x1的卷積層用來還原channel維度(注意主分支上第一層卷積層和第二次卷積層所使用的卷積核個數是相同的,第三次是第一層的4倍)。該殘差結構所對應的虛線殘差結構如下圖右側所示,同樣在捷徑分支上有一層1x1的卷積層,它的卷積核個數與主分支上的第三層卷積層卷積核個數相同,注意每個卷積層的步距。

ResNet——CNN經典網絡模型詳解(pytorch實現)

為什麼殘差學習相對更容易,從直觀上看殘差學習需要學習的內容少,因為殘差一般會比較小,學習難度小點。不過我們可以從數學的角度來分析這個問題,首先殘差單元可以表示為:

ResNet——CNN經典網絡模型詳解(pytorch實現)

其中 XL和 XL+1分別表示的是第L個殘差單元的輸入和輸出,注意每個殘差單元一般包含多層結構。 F是殘差函數,表示學習到的殘差,而 h(XL)=XL表示恆等映射, F是ReLU激活函數。基於上式,我們求得從淺層 l到深層 L 的學習特徵為:

ResNet——CNN經典網絡模型詳解(pytorch實現)

式子的第一個因子表示的損失函數到達L的梯度,小括號中的1表明短路機制可以無損地傳播梯度,而另外一項殘差梯度則需要經過帶有weights的層,梯度不是直接傳遞過來的。殘差梯度不會那麼巧全為-1,而且就算其比較小,有1的存在也不會導致梯度消失。所以殘差學習會更容易。要注意上面的推導並不是嚴格的證明。

下面這幅圖是原論文給出的不同深度的ResNet網絡結構配置,注意表中的殘差結構給出了主分支上卷積核的大小與卷積核個數,表中的xN表示將該殘差結構重複N次。那到底哪些殘差結構是虛線殘差結構呢。

ResNet——CNN經典網絡模型詳解(pytorch實現)

對於我們ResNet18/34/50/101/152,表中conv3_x, conv4_x, conv5_x所對應的一系列殘差結構的第一層殘差結構都是虛線殘差結構。因為這一系列殘差結構的第一層都有調整輸入特徵矩陣shape的使命(將特徵矩陣的高和寬縮減為原來的一半,將深度channel調整成下一層殘差結構所需要的channel)。為了方便理解,下面給出了ResNet34的網絡結構圖,圖中簡單標註了一些信息。

ResNet——CNN經典網絡模型詳解(pytorch實現)

對於我們ResNet50/101/152,其實在conv2_x所對應的一系列殘差結構的第一層也是虛線殘差結構。因為它需要調整輸入特徵矩陣的channel,根據表格可知通過3x3的max pool之後輸出的特徵矩陣shape應該是[56, 56, 64],但我們conv2_x所對應的一系列殘差結構中的實線殘差結構它們期望的輸入特徵矩陣shape是[56, 56, 256](因為這樣才能保證輸入輸出特徵矩陣shape相同,才能將捷徑分支的輸出與主分支的輸出進行相加)。所以第一層殘差結構需要將shape從[56, 56, 64] --> [56, 56, 256]。注意,這裡只調整channel維度,高和寬不變(而conv3_x, conv4_x, conv5_x所對應的一系列殘差結構的第一層虛線殘差結構不僅要調整channel還要將高和寬縮減為原來的一半)。

代碼

注:

  1. 本次訓練集下載在AlexNet博客有詳細解說:https://blog.csdn.net/weixin_44023658/article/details/105798326
  2. 使用遷移學習方法實現收錄在我的這篇blog中: 遷移學習 TransferLearning—通俗易懂地介紹(pytorch實例)
<code>#model.py

import torch.nn as nn
import torch

#18/34
class BasicBlock(nn.Module):
    expansion = 1 #每一個conv的卷積核個數的倍數

    def __init__(self, in_channel, out_channel, stride=1, downsample=None):#downsample對應虛線殘差結構
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
                               kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channel)#BN處理
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,
                               kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channel)
        self.downsample = downsample

    def forward(self, x):
        identity = x #捷徑上的輸出值
        if self.downsample is not None:
            identity = self.downsample(x)

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out += identity
        out = self.relu(out)

        return out

#50,101,152
class Bottleneck(nn.Module):
    expansion = 4#4倍

    def __init__(self, in_channel, out_channel, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
                               kernel_size=1, stride=1, bias=False)  # squeeze channels
        self.bn1 = nn.BatchNorm2d(out_channel)
        self.relu = nn.ReLU(inplace=True)
        # -----------------------------------------
        self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,
                               kernel_size=3, stride=stride, bias=False, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channel)
        self.relu = nn.ReLU(inplace=True)
        # -----------------------------------------
        self.conv3 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel*self.expansion,#輸出*4
                               kernel_size=1, stride=1, bias=False)  # unsqueeze channels
        self.bn3 = nn.BatchNorm2d(out_channel*self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        out += identity
        out = self.relu(out)

        return out


class ResNet(nn.Module):

    def __init__(self, block, blocks_num, num_classes=1000, include_top=True):#block殘差結構 include_top為了之後搭建更加複雜的網絡
        super(ResNet, self).__init__()
        self.include_top = include_top
        self.in_channel = 64

        self.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2,
                               padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(self.in_channel)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, blocks_num[0])
        self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2)
        self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2)
        self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2)
        if self.include_top:
            self.avgpool = nn.AdaptiveAvgPool2d((1, 1))  # output size = (1, 1)自適應
            self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

    def _make_layer(self, block, channel, block_num, stride=1):
        downsample = None
        if stride != 1 or self.in_channel != channel * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(channel * block.expansion))

        layers = []
        layers.append(block(self.in_channel, channel, downsample=downsample, stride=stride))
        self.in_channel = channel * block.expansion

        for _ in range(1, block_num):
            layers.append(block(self.in_channel, channel))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        if self.include_top:
            x = self.avgpool(x)
            x = torch.flatten(x, 1)
            x = self.fc(x)

        return x


def resnet34(num_classes=1000, include_top=True):
    return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)


def resnet101(num_classes=1000, include_top=True):
    return ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, include_top=include_top)/<code> 
<code>#train.py

import torch
import torch.nn as nn
from torchvision import transforms, datasets
import json
import matplotlib.pyplot as plt
import os
import torch.optim as optim
from model import resnet34, resnet101
import torchvision.models.resnet


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

data_transform = {
    "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                 transforms.RandomHorizontalFlip(),
                                 transforms.ToTensor(),
                                 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),#來自官網參數
    "val": transforms.Compose([transforms.Resize(256),#將最小邊長縮放到256
                               transforms.CenterCrop(224),
                               transforms.ToTensor(),
                               transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}


data_root = os.getcwd()
image_path = data_root + "/flower_data/"  # flower data set path

train_dataset = datasets.ImageFolder(root=image_path + "train",
                                     transform=data_transform["train"])
train_num = len(train_dataset)

# {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
flower_list = train_dataset.class_to_idx
cla_dict = dict((val, key) for key, val in flower_list.items())
# write dict into json file
json_str = json.dumps(cla_dict, indent=4)
with open('class_indices.json', 'w') as json_file:
    json_file.write(json_str)

batch_size = 16
train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=batch_size, shuffle=True,
                                           num_workers=0)

validate_dataset = datasets.ImageFolder(root=image_path + "/val",
                                        transform=data_transform["val"])
val_num = len(validate_dataset)
validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                              batch_size=batch_size, shuffle=False,
                                              num_workers=0)
#net = resnet34()
net = resnet34(num_classes=5)
# load pretrain weights

# model_weight_path = "./resnet34-pre.pth"
# missing_keys, unexpected_keys = net.load_state_dict(torch.load(model_weight_path), strict=False)#載入模型參數

# for param in net.parameters():
#     param.requires_grad = False
# change fc layer structure

# inchannel = net.fc.in_features
# net.fc = nn.Linear(inchannel, 5)


net.to(device)

loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.0001)

best_acc = 0.0
save_path = './resNet34.pth'
for epoch in range(3):
    # train
    net.train()
    running_loss = 0.0
    for step, data in enumerate(train_loader, start=0):
        images, labels = data
        optimizer.zero_grad()
        logits = net(images.to(device))
        loss = loss_function(logits, labels.to(device))
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        # print train process
        rate = (step+1)/len(train_loader)
        a = "*" * int(rate * 50)
        b = "." * int((1 - rate) * 50)
        print("\rtrain loss: {:^3.0f}%[{}->{}]{:.4f}".format(int(rate*100), a, b, loss), end="")
    print()

    # validate
    net.eval()
    acc = 0.0  # accumulate accurate number / epoch
    with torch.no_grad():
        for val_data in validate_loader:
            val_images, val_labels = val_data
            outputs = net(val_images.to(device))  # eval model only have last output layer
            # loss = loss_function(outputs, test_labels)
            predict_y = torch.max(outputs, dim=1)[1]
            acc += (predict_y == val_labels.to(device)).sum().item()
        val_accurate = acc / val_num
        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path)
        print('[epoch %d] train_loss: %.3f  test_accuracy: %.3f' %
              (epoch + 1, running_loss / step, val_accurate))

print('Finished Training')/<code>
ResNet——CNN經典網絡模型詳解(pytorch實現)

在這裡插入圖片描述

<code>#predict.py

import torch
from model import resnet34
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
import json

data_transform = transforms.Compose(
    [transforms.Resize(256),
     transforms.CenterCrop(224),
     transforms.ToTensor(),
     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

# load image
img = Image.open("./roses.jpg")
plt.imshow(img)
# [N, C, H, W]
img = data_transform(img)
# expand batch dimension
img = torch.unsqueeze(img, dim=0)

# read class_indict
try:
    json_file = open('./class_indices.json', 'r')
    class_indict = json.load(json_file)
except Exception as e:
    print(e)
    exit(-1)

# create model
model = resnet34(num_classes=5)
# load model weights
model_weight_path = "./resNet34.pth"
model.load_state_dict(torch.load(model_weight_path))
model.eval()
with torch.no_grad():
    # predict class
    output = torch.squeeze(model(img))
    predict = torch.softmax(output, dim=0)
    predict_cla = torch.argmax(predict).numpy()
print(class_indict[str(predict_cla)], predict[predict_cla].numpy())
plt.show()/<code>
ResNet——CNN經典網絡模型詳解(pytorch實現)

在這裡插入圖片描述


分享到:


相關文章: