Siamese Networks:算法,應用程式和PyTorch實現

Siamese Networks:算法,應用程序和PyTorch實現

由於暹羅網絡在深度學習研究和應用程序中越來越受歡迎,我將解釋什麼是暹羅網絡,並以PyTorch中一個簡單的暹羅CNN網絡為例進行總結。

什麼是暹羅網絡?

暹羅網絡是包含兩個或多個相同子網組件的神經網絡。暹羅網絡可能如下所示:

Siamese Networks:算法,應用程序和PyTorch實現

暹羅網絡示例

重要的是,不僅子網絡的體系結構是相同的,而且必須在它們之間共享權重,使網絡被稱為“siamese”。siamese網絡背後的主要思想是,他們可以學習有用的數據描述符,這些描述符可以進一步用於比較各個子網絡的輸入。因此,輸入可以是數字數據(在這種情況下,子網絡通常由完全連接的層組成)、圖像數據(以CNN作為子網絡),甚至是連續的數據,如句子或時間信號(以RNN為子網絡)。

通常,暹羅網絡在輸出端執行二元分類。因此,在訓練期間可以使用不同的損失函數。最流行的損失函數之一是二元交叉熵損失。這種損失可以計算為

Siamese Networks:算法,應用程序和PyTorch實現

,其中L是損失函數,y是類標籤(0或1),p是預測。為了訓練網絡區分相似和不同的對象,我們可以一次給它一個正的和一個負的例子,並把損失加起來:

Siamese Networks:算法,應用程序和PyTorch實現

另一種使用triplet loss:

Siamese Networks:算法,應用程序和PyTorch實現

d是距離函數(例如L2損失),a是數據集的樣本,p是隨機正樣本,n是負樣本。m是任意邊界,用於進一步分析正分數和負分數。

暹羅網絡的應用

暹羅網絡具有廣泛的應用。這裡有幾個:

  • One-shot learning。在這個學習場景中,一個新的訓練數據集被提供給訓練過的(分類)網絡,每個類只有一個樣本。然後,在一個單獨的測試數據集上測試這個新數據集的分類性能。當暹羅網絡首先學習大型特定數據集的判別特徵時,它們也可用於將這些知識推廣到全新的類和分佈。在(Koch,Gregory,Richard Zemel和Ruslan Salakhutdinov。“用於一次性圖像識別的連體神經網絡。”ICML Deep Learning Workshop.Vol.2。2015.)中,作者使用此功能進行一次性學習MNIST數據集使用在Omniglot數據集上訓練的網絡(完全不同的圖像數據集)。
  • 用於視頻監控的行人跟蹤。在這項工作中,一個暹羅CNN網絡與圖像塊的大小和位置特徵相結合,通過檢測它們在每個視頻幀中的位置,學習多個幀之間的關聯和計算,來跟蹤攝像機視野中的多個人。軌跡。
  • Cosegmentation(Mukherjee,Prerana,Brejesh Lall和Snehith Lattupally。“使用深暹羅網絡的對象分配。”arXiv preprint arXiv:1803.02555(2018)。)。
  • 匹配簡歷到工作。在這個應用程序中,該網絡試圖為應聘者找到匹配的工作崗位。為了做到這一點,一個訓練有素的暹羅CNN網絡從帖子和簡歷中提取深層上下文信息,並計算它們的語義相似性。假設匹配的簡歷-張貼配對在相似度上比不匹配的排序更高。

示例:在PyTorch中使用Siamese網絡對MNIST圖像進行分類

在解釋了連體網絡的基本原理之後,我們現在將在PyTorch中構建一個網絡,以分類一對MNIST圖像是否具有相同的數字。我們將使用二元交叉熵損失作為我們的訓練損失函數,我們將使用精度測量來評估測試數據集上的網絡。以下是這篇文章的完整Python代碼:

import codecs
import errno
import matplotlib.pyplot as plt
import numpy as np
import os
from PIL import Image
import random
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
import torchvision.datasets.mnist
from torchvision import transforms
from tqdm import tqdm

do_learn = True
save_frequency = 2
batch_size = 16
lr = 0.001

num_epochs = 10
weight_decay = 0.0001

class BalancedMNISTPair(torch.utils.data.Dataset):
"""Dataset that on each iteration provides two random pairs of
MNIST images. One pair is of the same number (positive sample), one
is of two different numbers (negative sample).
"""
urls = [
'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz',
'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz',
'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz',
'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz',
]
raw_folder = 'raw'
processed_folder = 'processed'
training_file = 'training.pt'
test_file = 'test.pt'

def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
self.root = os.path.expanduser(root)
self.transform = transform
self.target_transform = target_transform
self.train = train # training set or test set

if download:
self.download()

if not self._check_exists():
raise RuntimeError('Dataset not found.' + ' You can use download=True to download it')

if self.train:
self.train_data, self.train_labels = torch.load(
os.path.join(self.root, self.processed_folder, self.training_file))

train_labels_class = []
train_data_class = []
for i in range(10):
indices = torch.squeeze((self.train_labels == i).nonzero())
train_labels_class.append(torch.index_select(self.train_labels, 0, indices))
train_data_class.append(torch.index_select(self.train_data, 0, indices))

# generate balanced pairs
self.train_data = []
self.train_labels = []
lengths = [x.shape[0] for x in train_labels_class]
for i in range(10):
for j in range(500): # create 500 pairs
rnd_cls = random.randint(0,8) # choose random class that is not the same class
if rnd_cls >= i:

rnd_cls = rnd_cls + 1

rnd_dist = random.randint(0, 100)

self.train_data.append(torch.stack([train_data_class[i][j], train_data_class[i][j+rnd_dist], train_data_class[rnd_cls][j]]))
self.train_labels.append([1,0])

self.train_data = torch.stack(self.train_data)
self.train_labels = torch.tensor(self.train_labels)

else:
self.test_data, self.test_labels = torch.load(
os.path.join(self.root, self.processed_folder, self.test_file))

test_labels_class = []
test_data_class = []
for i in range(10):
indices = torch.squeeze((self.test_labels == i).nonzero())
test_labels_class.append(torch.index_select(self.test_labels, 0, indices))
test_data_class.append(torch.index_select(self.test_data, 0, indices))

# generate balanced pairs
self.test_data = []
self.test_labels = []
lengths = [x.shape[0] for x in test_labels_class]
for i in range(10):
for j in range(500): # create 500 pairs
rnd_cls = random.randint(0,8) # choose random class that is not the same class
if rnd_cls >= i:
rnd_cls = rnd_cls + 1

rnd_dist = random.randint(0, 100)

self.test_data.append(torch.stack([test_data_class[i][j], test_data_class[i][j+rnd_dist], test_data_class[rnd_cls][j]]))
self.test_labels.append([1,0])

self.test_data = torch.stack(self.test_data)
self.test_labels = torch.tensor(self.test_labels)

def __getitem__(self, index):
if self.train:
imgs, target = self.train_data[index], self.train_labels[index]
else:
imgs, target = self.test_data[index], self.test_labels[index]

img_ar = []
for i in range(len(imgs)):
img = Image.fromarray(imgs[i].numpy(), mode='L')
if self.transform is not None:
img = self.transform(img)

img_ar.append(img)

if self.target_transform is not None:
target = self.target_transform(target)

return img_ar, target

def __len__(self):
if self.train:
return len(self.train_data)
else:
return len(self.test_data)

def _check_exists(self):
return os.path.exists(os.path.join(self.root, self.processed_folder, self.training_file)) and \
os.path.exists(os.path.join(self.root, self.processed_folder, self.test_file))

def download(self):
"""Download the MNIST data if it doesn't exist in processed_folder already."""
from six.moves import urllib
import gzip

if self._check_exists():
return

# download files
try:
os.makedirs(os.path.join(self.root, self.raw_folder))
os.makedirs(os.path.join(self.root, self.processed_folder))
except OSError as e:
if e.errno == errno.EEXIST:
pass
else:
raise

for url in self.urls:
print('Downloading ' + url)
data = urllib.request.urlopen(url)
filename = url.rpartition('/')[2]
file_path = os.path.join(self.root, self.raw_folder, filename)
with open(file_path, 'wb') as f:
f.write(data.read())
with open(file_path.replace('.gz', ''), 'wb') as out_f, \
gzip.GzipFile(file_path) as zip_f:
out_f.write(zip_f.read())
os.unlink(file_path)

# process and save as torch files
print('Processing...')

training_set = (
read_image_file(os.path.join(self.root, self.raw_folder, 'train-images-idx3-ubyte')),
read_label_file(os.path.join(self.root, self.raw_folder, 'train-labels-idx1-ubyte'))
)
test_set = (
read_image_file(os.path.join(self.root, self.raw_folder, 't10k-images-idx3-ubyte')),
read_label_file(os.path.join(self.root, self.raw_folder, 't10k-labels-idx1-ubyte'))
)
with open(os.path.join(self.root, self.processed_folder, self.training_file), 'wb') as f:
torch.save(training_set, f)
with open(os.path.join(self.root, self.processed_folder, self.test_file), 'wb') as f:
torch.save(test_set, f)

print('Done!')

def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
tmp = 'train' if self.train is True else 'test'
fmt_str += ' Split: {}\n'.format(tmp)
fmt_str += ' Root Location: {}\n'.format(self.root)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str

class Net(nn.Module):
def __init__(self):
super().__init__()

self.conv1 = nn.Conv2d(1, 64, 7)
self.pool1 = nn.MaxPool2d(2)
self.conv2 = nn.Conv2d(64, 128, 5)
self.conv3 = nn.Conv2d(128, 256, 5)
self.linear1 = nn.Linear(2304, 512)

self.linear2 = nn.Linear(512, 2)

def forward(self, data):
res = []
for i in range(2): # Siamese nets; sharing weights
x = data[i]
x = self.conv1(x)
x = F.relu(x)
x = self.pool1(x)
x = self.conv2(x)
x = F.relu(x)
x = self.conv3(x)
x = F.relu(x)


x = x.view(x.shape[0], -1)
x = self.linear1(x)
res.append(F.relu(x))

res = torch.abs(res[1] - res[0])
res = self.linear2(res)
return res

def train(model, device, train_loader, epoch, optimizer):
model.train()

for batch_idx, (data, target) in enumerate(train_loader):
for i in range(len(data)):
data[i] = data[i].to(device)

optimizer.zero_grad()
output_positive = model(data[:2])
output_negative = model(data[0:3:2])

target = target.type(torch.LongTensor).to(device)
target_positive = torch.squeeze(target[:,0])
target_negative = torch.squeeze(target[:,1])

loss_positive = F.cross_entropy(output_positive, target_positive)
loss_negative = F.cross_entropy(output_negative, target_negative)

loss = loss_positive + loss_negative
loss.backward()

optimizer.step()
if batch_idx % 10 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx*batch_size, len(train_loader.dataset), 100. * batch_idx*batch_size / len(train_loader.dataset),
loss.item()))

def test(model, device, test_loader):
model.eval()

with torch.no_grad():
accurate_labels = 0
all_labels = 0
loss = 0
for batch_idx, (data, target) in enumerate(test_loader):
for i in range(len(data)):
data[i] = data[i].to(device)

output_positive = model(data[:2])
output_negative = model(data[0:3:2])

target = target.type(torch.LongTensor).to(device)
target_positive = torch.squeeze(target[:,0])
target_negative = torch.squeeze(target[:,1])

loss_positive = F.cross_entropy(output_positive, target_positive)
loss_negative = F.cross_entropy(output_negative, target_negative)

loss = loss + loss_positive + loss_negative

accurate_labels_positive = torch.sum(torch.argmax(output_positive, dim=1) == target_positive).cpu()
accurate_labels_negative = torch.sum(torch.argmax(output_negative, dim=1) == target_negative).cpu()

accurate_labels = accurate_labels + accurate_labels_positive + accurate_labels_negative
all_labels = all_labels + len(target_positive) + len(target_negative)

accuracy = 100. * accurate_labels / all_labels
print('Test accuracy: {}/{} ({:.3f}%)\tLoss: {:.6f}'.format(accurate_labels, all_labels, accuracy, loss))

def oneshot(model, device, data):
model.eval()

with torch.no_grad():
for i in range(len(data)):
data[i] = data[i].to(device)

output = model(data)
return torch.squeeze(torch.argmax(output, dim=1)).cpu().item()

def main():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])

model = Net().to(device)

if do_learn: # training mode
train_loader = torch.utils.data.DataLoader(BalancedMNISTPair('../data', train=True, download=True, transform=trans), batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(BalancedMNISTPair('../data', train=False, download=True, transform=trans), batch_size=batch_size, shuffle=False)

optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
for epoch in range(num_epochs):
train(model, device, train_loader, epoch, optimizer)
test(model, device, test_loader)
if epoch & save_frequency == 0:
torch.save(model, 'siamese_{:03}.pt'.format(epoch))
else: # prediction
prediction_loader = torch.utils.data.DataLoader(BalancedMNISTPair('../data', train=False, download=True, transform=trans), batch_size=1, shuffle=True)
model.load_state_dict(torch.load(load_model_path))
data = []
data.extend(next(iter(prediction_loader))[0][:3:2])
same = oneshot(model, device, data)

if same > 0:
print('These two images are of the same number')
else:
print('These two images are not of the same number')

if __name__ == '__main__':
main()

如您所見,大部分代碼都包括構建一個適當的Dataset類,它為我們提供隨機的圖像樣本。為了訓練網絡,重要的是我們要得到一個平衡的數據集,有正的和負的樣本。因此,在每次迭代中,我們同時提供這兩種方法。數據集的代碼很長,但最終很簡單:對於每個數字(類)0-9,我們必須提供一個正對(另一個相同數字的圖像)和一個負對(隨機不同數字的圖像)。

網絡本身,在Net類中定義,是一個siamese tional neural network,由2個相同的子網絡組成,每個子網絡包含3個tional layer,內核大小分別為7、5和5,中間還有一個pooling層。在經過卷積層之後,我們讓網絡構建每個輸入的一維描述符,方法是將特徵扁平化,並將它們通過帶有512個輸出特徵的線性層傳遞。注意,兩個子網絡中的層共享相同的權重。這允許網絡為每個輸入學習有意義的描述符,並使輸出對稱(輸入的順序應該與我們的目標無關)。

整個過程的關鍵步驟是下一個步驟:計算特徵向量的平方距離。原則上,為了訓練網絡,我們可以使用三重損失和這個平方差的輸出。但是,我使用二元交叉熵損失得到了更好的結果(收斂速度更快)。因此,我們在網絡上附加一個帶有兩個輸出特徵的線性層(數量相同,數量不同)來獲得邏輯。

代碼中有三個主要的相關函數:訓練函數、測試函數和預測函數。

在train函數中,我們向網絡提供一個正樣本和一個負樣本(兩對圖像)。我們計算每個損失,並將它們相加(正樣本的目標是1,負樣本的目標是0)。

測試函數用於測量測試數據集中網絡的準確性。我們在每個訓練階段結束後進行測試,觀察訓練進度,防止過擬合。

給定一對MNIST圖像,該預測函數僅預測它們是否屬於同一類。通過將全局變量do_learn設置為False,可以在培訓結束後使用predict。

使用上面的實現,我能夠在測試MNIST數據集上達到96%的準確率。


分享到:


相關文章: