Siamese-pytorch孿生網(wǎng)絡(luò)實(shí)現(xiàn)評(píng)價(jià)圖像相似度
來(lái)源:機(jī)器學(xué)習(xí)AI算法工程
什么是孿生神經(jīng)網(wǎng)絡(luò)
簡(jiǎn)單來(lái)說(shuō),孿生神經(jīng)網(wǎng)絡(luò)(Siamese network)就是“連體的神經(jīng)網(wǎng)絡(luò)”,神經(jīng)網(wǎng)絡(luò)的“連體”是通過(guò)共享權(quán)值來(lái)實(shí)現(xiàn)的,如下圖所示。
所謂權(quán)值共享就是當(dāng)神經(jīng)網(wǎng)絡(luò)有兩個(gè)輸入的時(shí)候,這兩個(gè)輸入使用的神經(jīng)網(wǎng)絡(luò)的權(quán)值是共享的(可以理解為使用了同一個(gè)神經(jīng)網(wǎng)絡(luò))。
很多時(shí)候,我們需要去評(píng)判兩張圖片的相似性,比如比較兩張人臉的相似性,我們可以很自然的想到去提取這個(gè)圖片的特征再進(jìn)行比較,自然而然的,我們又可以想到利用神經(jīng)網(wǎng)絡(luò)進(jìn)行特征提取。
如果使用兩個(gè)神經(jīng)網(wǎng)絡(luò)分別對(duì)圖片進(jìn)行特征提取,提取到的特征很有可能不在一個(gè)域中,此時(shí)我們可以考慮使用一個(gè)神經(jīng)網(wǎng)絡(luò)進(jìn)行特征提取再進(jìn)行比較。這個(gè)時(shí)候我們就可以理解孿生神經(jīng)網(wǎng)絡(luò)為什么要進(jìn)行權(quán)值共享了。
孿生神經(jīng)網(wǎng)絡(luò)有兩個(gè)輸入(Input1 and Input2),利用神經(jīng)網(wǎng)絡(luò)將輸入映射到新的空間,形成輸入在新的空間中的表示。通過(guò)Loss的計(jì)算,評(píng)價(jià)兩個(gè)輸入的相似度。
孿生神經(jīng)網(wǎng)絡(luò)的實(shí)現(xiàn)思路
一、預(yù)測(cè)部分
1、主干網(wǎng)絡(luò)介紹
孿生神經(jīng)網(wǎng)絡(luò)的主干特征提取網(wǎng)絡(luò)的功能是進(jìn)行特征提取,各種神經(jīng)網(wǎng)絡(luò)都可以適用,本文使用的神經(jīng)網(wǎng)絡(luò)是VGG16
這是一個(gè)VGG被用到爛的圖,但確實(shí)很好的反應(yīng)了VGG的結(jié)構(gòu):
1、一張?jiān)紙D片被resize到指定大小,本文使用105x105。
2、conv1包括兩次[3,3]卷積網(wǎng)絡(luò),一次2X2最大池化,輸出的特征層為64通道。
3、conv2包括兩次[3,3]卷積網(wǎng)絡(luò),一次2X2最大池化,輸出的特征層為128通道。
4、conv3包括三次[3,3]卷積網(wǎng)絡(luò),一次2X2最大池化,輸出的特征層為256通道。
5、conv4包括三次[3,3]卷積網(wǎng)絡(luò),一次2X2最大池化,輸出的特征層為512通道。
6、conv5包括三次[3,3]卷積網(wǎng)絡(luò),一次2X2最大池化,輸出的特征層為512通道。
2、比較網(wǎng)絡(luò)
在獲得主干特征提取網(wǎng)絡(luò)之后,我們可以獲取到一個(gè)多維特征,我們可以使用flatten的方式將其平鋪到一維上,這個(gè)時(shí)候我們就可以獲得兩個(gè)輸入的一維向量了。
將這兩個(gè)一維向量進(jìn)行相減,再進(jìn)行絕對(duì)值求和,相當(dāng)于求取了兩個(gè)特征向量插值的L1范數(shù)。也就相當(dāng)于求取了兩個(gè)一維向量的距離。
然后對(duì)這個(gè)距離再進(jìn)行兩次全連接,第二次全連接到一個(gè)神經(jīng)元上,對(duì)這個(gè)神經(jīng)元的結(jié)果取sigmoid,使其值在0-1之間,代表兩個(gè)輸入圖片的相似程度。
實(shí)現(xiàn)代碼如下:
二、訓(xùn)練部分
1、數(shù)據(jù)集的格式
本文所使用的數(shù)據(jù)集為Omniglot數(shù)據(jù)集。
其包含來(lái)自 50不同字母(語(yǔ)言)的1623 個(gè)不同手寫字符。每一個(gè)字符都是由 20個(gè)不同的人通過(guò)亞馬遜的 Mechanical Turk 在線繪制的。
相當(dāng)于每一個(gè)字符有20張圖片,然后存在1623個(gè)不同的手寫字符,我們需要利用神經(jīng)網(wǎng)絡(luò)進(jìn)行學(xué)習(xí),去區(qū)分這1623個(gè)不同的手寫字符,比較輸入進(jìn)來(lái)的字符的相似性。
最后一級(jí)的文件夾用于分辨不同的字體,同一個(gè)文件夾里面的圖片屬于同一文字。在不同文件夾里面存放的圖片屬于不同文字。
上兩個(gè)圖為
.\images_background\Alphabet_of_the_Magi\character01里的兩幅圖。它們兩個(gè)屬于同一個(gè)字。
上一個(gè)圖為
.\images_background\Alphabet_of_the_Magi\character02里的一幅圖。它和上面另兩幅圖不屬于同一個(gè)字。
2、Loss計(jì)算
對(duì)于孿生神經(jīng)網(wǎng)絡(luò)而言,其具有兩個(gè)輸入。
當(dāng)兩個(gè)輸入指向同一個(gè)類型的圖片時(shí),此時(shí)標(biāo)簽為1。
當(dāng)兩個(gè)輸入指向不同類型的圖片時(shí),此時(shí)標(biāo)簽為0。
然后將網(wǎng)絡(luò)的輸出結(jié)果和真實(shí)標(biāo)簽進(jìn)行交叉熵運(yùn)算,就可以作為最終的loss了。
本文所使用的Loss為binary_crossentropy。
當(dāng)我們輸入如下兩個(gè)字體的時(shí)候,我們希望網(wǎng)絡(luò)的輸出為1。
我們會(huì)將預(yù)測(cè)結(jié)果和1求交叉熵。
當(dāng)我們輸入如下兩個(gè)字體的時(shí)候,我們希望網(wǎng)絡(luò)的輸出為0。
我們會(huì)將預(yù)測(cè)結(jié)果和0求交叉熵。
訓(xùn)練自己的孿生神經(jīng)網(wǎng)絡(luò)
1、訓(xùn)練本文所使用的Omniglot例子
下載數(shù)據(jù)集,放在根目錄下的dataset文件夾下。
運(yùn)行train.py開始訓(xùn)練。
2、訓(xùn)練自己相似性比較的模型
如果大家想要訓(xùn)練自己的數(shù)據(jù)集,可以將數(shù)據(jù)集按照如下格式進(jìn)行擺放。
每一個(gè)chapter里面放同類型的圖片。
之后將train.py當(dāng)中的train_own_data設(shè)置成True,即可開始訓(xùn)練。
*博客內(nèi)容為網(wǎng)友個(gè)人發(fā)布,僅代表博主個(gè)人觀點(diǎn),如有侵權(quán)請(qǐng)聯(lián)系工作人員刪除。