当前位置：网站首页>Siamese net and its loss function
Siamese net and its loss function
20201206 17:41:52 【itread01】
Recently in multiple keywords （** Little data set , Unsupervised, semi supervised , Image segmentation ,SOTA Model **） Within the scope of , We all see such a concept ,** Twin Internet **, So I have a look at the related classic papers and blog articles , Then I made a simple case to strengthen the understanding . If you need to communicate, please contact me ,WX：cyx645016617 So this ** Introduction to twin networks , I want to divide it into two parts , The first part is about model theory 、 Basic knowledge and the unique loss function of the twin network ; The next section explains how to use code to duplicate a simple twin network .** ## 1 The origin of the name The nickname of twin network will die Siamese Net, and Siam It's the name of ancient Thailand , therefore Siamese In fact “ Thais ” The ancient address of . Why? Siamese Now in English it is “ Twins ”“ continuum ” What do you mean ？ It comes from an allusion ： > A couple of conjoined babies were born in Thailand in the 19th century , Medical technology at the time couldn't separate the two , So they lived a life of perseverance ,1829 It was discovered by an English businessman , Enter the Circus , Perform all over the world ,1839 They visited North Carolina in the United States and later became “ Lingling Circus ” The pillars of , And finally became an American citizen .1843 year 4 month 13 I married a couple of British sisters on Sunday , Ensheng 10 A child , Chang Sheng 12 One , When sisters quarrel , The brothers will take turns to stay at each wife's house for three days .1874 Nian en died of lung disease , Another died soon , Both of them are in 63 She left the world at the age of . Their livers are still stored in the matt Museum in Philadelphia . Since then “ Siamese twins ”（Siamese twins） It becomes the pronoun of conjoined people , And because the twins have brought this particular disease to the attention of the world . ![](https://p1juejin.byteimg.com/toscnik3u1fbpfcp/16f50e602dbb4dbd8e9035a40c9afa83~tplvk3u1fbpfcpwatermark.image) ## 2 Model structure ![](https://p9juejin.byteimg.com/toscnik3u1fbpfcp/9c6591af381047598bc0016a664244dd~tplvk3u1fbpfcpwatermark.image) This diagram has these points to understand ：  Among them Network1 and Network2 In professional terms, it's ** Sharing right system **, To put it bluntly, these two networks are actually one network , Just build a network in your code ;  General tasks , Each sample goes through the model to get a model of pred, And then this pred and ground truth Calculate the loss function , And then we get the gradient ;** This twin network has changed this structure , Suppose it's the task of image classification , Put the picture A Input into the model yields an output pred1, And then I'll put the picture B Input into the model , You get another output pred2, Then my loss function is from pred1 and pred2 Between the calculation of .** It's usually , The model is executed once , Give a loss, But in siamese net in , The model has to be executed twice to get one loss.  I personally feel , General tasks are like measuring an absolute distance , A distance from sample to label ; But the twin network measures the distance between samples . ### 2.1 The use of the twin network Siamese net It measures the relationship between two inputs , That is, whether the two samples are similar or not . There is such a task , stay NIPS On , stay 1993 Published an article in 《Signature Verification using a ‘Siamese’ Time Delay Neural Network》 For signature verification on U.S. cheques ,** Check whether the signature on the check is consistent with the signature reserved by the bank **. At that time, the convolution network has been used to verify ... I wasn't born then . After that ,2010 year Hinton stay ICML Published on 《Rectified Linear Units Improve Restricted Boltzmann Machines》, For face verification , The effect is very good . Input is two faces , The output is **same or different**. ![](https://p1juejin.byteimg.com/toscnik3u1fbpfcp/d266951700bd4ce6a8ad50e0508c6cdc~tplvk3u1fbpfcpwatermark.image) As one can imagine , The twin network can do classification tasks .** In my opinion , The twin network is not a network structure , No resnet That kind of network structure , It's a network framework , I can put resnet As the backbone of the twin network **. Since the twin network backbone（ Let's call it that , It should be understandable that ） It can be CNN, Then it can be LSTM, So ** It can realize the semantic similarity analysis of words **. Before Kaggle There was a question pair The competition of , It's a competition that measures whether two questions ask the same question ,TOP1 The solution is the structure of this twin network Siamese net. Later, it seems that it was based on Siamese On the Internet ** Visual tracking algorithms **, I don't know that yet , I'll take a look at this paper if I have a chance in the future .《Fullyconvolutional siamese networks for object tracking》. Dig a hole first . ### 2.2 Pseudo twin network Here comes the question , The twin network looks like two networks , In fact, the system of sharing rights is a network , Suppose we do get him two networks , That way, one can be LSTM, One CNN The similarity of different modes is compared ？ That's right , This is called pseudosiamese network Pseudo twin network .** One input is text , One input is a picture , Determine whether the text description is the picture content ; One is the short title , One is a long article , Judge whether the content of the article is the title **.（ Senior high school Chinese composition is the Savior of the contestants who run away from the topic all the year round , Later, I told the teacher that this algorithm did not stray from the topic , You don't have to look at ？ Will the teacher kill me ） However, the code in this article and the next one is based on siamese network As the core ,backbone Also to CNN Convolution networks and image unfolding . ### 2.3 triplets Now that we have the twin network , There are triplets, of course , be called Triplet network《Deep metric learning using Triplet network》. It is said that the effect has been better than Siamese network 了 , I don't know if there are quadruplets and quintuplets . ## 3 Loss function Classification tasks are routinely used softmax Plus the cross entropy , But it was suggested that , The model trained in this way , stay “ Between classes ” The distinction is not good , Using counter sample attacks doesn't work immediately .** Later, I will explain how to fight against the sample attack , Dig another hole **. In short, it's , Suppose it's face recognition , Then everyone is a category , So you let a model do a task of thousands of categories , There is very little information about each category , Think about it and feel the difficulty of this training . In response to such a problem , There are two loss functions in the twin network ：  Contrastive Loss  Triplte Loss ### 3.1 Contrastive Loss  Put forward a paper ：《Dimensionality Reduction by Learning an Invariant Mapping》 Now we know ：  Pictures 1 Through the model obtain pred1  Pictures 2 Through the model obtain pred2  pred1 and pred2 Calculate to get loss This paper gives such a formula ： ![](https://p6juejin.byteimg.com/toscnik3u1fbpfcp/da99a197bead4619ab467260ce1c40f1~tplvk3u1fbpfcpwatermark.image) First of all , This is the model pred1 and pred2 Is a vector , The process is equivalent to the picture passing through CNN Extract features , And then we get an implicit vector , It's a Encoder The feeling of . Then calculate the Euclidean distance between these two vectors , This distance （ If the model is trained correctly ）, It can reflect the correlation between the two input images . We input two pictures at a time , We need to make sure in advance that ** These two images are of the same kind , It's different , This is like a label , In the formula above Y. If it's a class , So Y For 0, If not ,Y=1** It is similar to the binary cross entropy loss function , What we need to pay attention to is ：  Y=0 When , The loss is ：$（1Y）L_S(D_W^i)$  Y=1 When , The loss is ：$YL_D(D_W^i)$.  In the paper $L_D,L_S$ Is a constant , The paper presupposes that 0.5  i Is the meaning of a power , In this paper, we discuss the common use of contrastive loss in , It's all presupposition i=2, That's the square of the Euclidean distance .  For categories it is 1（different Category ）, We naturally hope that pred1 and pred2 The larger the Euclidean distance, the better . So how big is this ？ The loss function moves in a small direction , So what needs to be done ？ Add a margin, As the maximum distance . If pred1 and pred2 The distance is greater than margin, So we think the distance between the two samples is large enough , Consider the loss to be 0. So the way to write is ：$max(margindistance,0)$.  In the picture above W I understand it as neural network weight, And then $\vec X_1$, Represents the original image to be input . So the loss function looks like this ： ![](https://p3juejin.byteimg.com/toscnik3u1fbpfcp/120a6d2230cb47cfa549ef260f0c6dff~tplvk3u1fbpfcpwatermark.image) ** To summarize , What should be paid attention to here should be about different Two pictures of , You need to set a margin, Then less than margin To calculate the loss of , Bigger than margin Loss of 0.** ### 3.2 Contrastive Loss pytorch ```python # Custom Contrastive Loss class ContrastiveLoss(torch.nn.Module): """ Contrastive loss function. Based on: http://yann.lecun.com/exdb/publis/pdf/hadsellchopralecun06.pdf """ def __init__(self, margin=2.0): super(ContrastiveLoss, self).__init__() self.margin = margin def forward(self, output1, output2, label): euclidean_distance = F.pairwise_distance(output1, output2) loss_contrastive = torch.mean((1label) * torch.pow(euclidean_distance, 2) + # calmp Clip off (label) * torch.pow(torch.clamp(self.margin  euclidean_distance, min=0.0), 2)) return loss_contrastive ``` The only thing that needs to be talked about is ```torch.nn.functional.pariwise_distance```, This is to calculate the Euclidean distance of the corresponding element , For example ： ```python import torch import torch.nn.functional as F a = torch.Tensor([[1,2],[3,4]]) b = torch.Tensor([[10,20],[30,40]]) F.pairwise_distance(a,b) ``` The output is ： ![](https://p1juejin.byteimg.com/toscnik3u1fbpfcp/bfd9e1168c9748f69533c285d9d8d29d~tplvk3u1fbpfcpwatermark.image) And see if this number is Euclidean distance ： ![](https://p9juejin.byteimg.com/toscnik3u1fbpfcp/36ef8dfd24ad48c3b767deadc56e1754~tplvk3u1fbpfcpwatermark.image) No problem ### 3.3 Triplte Loss  Put forward a paper ：《FaceNet: A Unified Embedding for Face Recognition and Clustering》 This paper proposes FactNet, And then used Triplte Loss.Triplet Loss The triple loss , Let's introduce in detail .  Triplet Loss Define ： Minimize the distance between the anchor and positive samples with the same identity , Minimize the distance between anchor and negative samples with different identities .** This should be the loss function of the triplet network , Input three samples at the same time , A picture , And then one same A picture of the category and a different Pictures .**  Triplet Loss Goal of ：Triplet Loss The goal is to make the features of the same label as close to , At the same time, the features of different labels should be kept away from , At the same time, in order to avoid the aggregation of sample features into a very small space, two positive examples and one negative example of the same class are required , The negative case should be at least further away from the positive case margin. As shown in the figure below ： ![](https://p6juejin.byteimg.com/toscnik3u1fbpfcp/764a41ada7bd4eff8022c4645474305b~tplvk3u1fbpfcpwatermark.image) In this case, how do we construct the loss function ？ We know what we want ：  Let anchor and positive The smaller the Euclidean distance of the vector, the better ;  Let anchor and negative The larger the Euclidean distance of the vector, the better ; So the following formula is expected to hold ： ![](https://p6juejin.byteimg.com/toscnik3u1fbpfcp/a14af1a4c8be42e0b8557cf2e440f401~tplvk3u1fbpfcpwatermark.image)  In short, it's anchor and positive The distance is more than anchor and negative Small distance , And the gap should be at least greater than $\alpha$.** Personal thinking is , Here T, It's a set of triples . For a data set , You can often build a lot of triples , So I personally feel that this kind of task is usually used in many categories , In tasks with less data , Or the number of triples will explode ** ### 3.4 Triplte Loss keras Here is a keras Of triplte loss Code for ```python def triplet_loss(y_true, y_pred): """ Triplet Loss The loss function of """ anc, pos, neg = y_pred[:, 0:128], y_pred[:, 128:256], y_pred[:, 256:] # European distance pos_dist = K.sum(K.square(anc  pos), axis=1, keepdims=True) neg_dist = K.sum(K.square(anc  neg), axis=1, keepdims=True) basic_loss = pos_dist  neg_dist + TripletModel.MARGIN loss = K.maximum(basic_loss, 0.0) print "[INFO] model  triplet_loss shape: %s" % str(loss.shape) return loss ``` References ： [1] Momentum Contrast for Unsupervised Visual Representation Learning, 2019, Kaiming He Haoqi Fan Yuxin Wu Saining Xie Ross Girshick [2] Dimensionality Reduction by Learning an Invariant Mapping, 2006, Raia Hadsell, Sumit Chopra, Ya
版权声明
本文为[itread01]所创，转载请带上原文链接，感谢
https://chowdera.com/2020/12/20201206173919664h.html
边栏推荐
 C++ 数字、string和char*的转换
 C++学习——centos7上部署C++开发环境
 C++学习——一步步学会写Makefile
 C++学习——临时对象的产生与优化
 C++学习——对象的引用的用法
 C++编程经验（6）：使用C++风格的类型转换
 Won the CKA + CKS certificate with the highest gold content in kubernetes in 31 days!
 C + + number, string and char * conversion
 C + + Learning  capacity() and resize() in C + +
 C + + Learning  about code performance optimization
猜你喜欢

C + + programming experience (6): using C + + style type conversion

Latest party and government work report ppt  Park ppt

在线身份证号码提取生日工具

Online ID number extraction birthday tool

️野指针？悬空指针？️ 一文带你搞懂！

Field pointer? Dangling pointer? This article will help you understand!

HCNA Routing＆Switching之GVRP

GVRP of hcna Routing & Switching

Seq2Seq实现闲聊机器人

【闲聊机器人】seq2seq模型的原理
随机推荐
 LeetCode 91. 解码方法
 Seq2seq implements chat robot
 [chat robot] principle of seq2seq model
 Leetcode 91. Decoding method
 HCNA Routing＆Switching之GVRP
 GVRP of hcna Routing & Switching
 HDU7016 Random Walk 2
 [Code+＃1]Yazid 的新生舞会
 CF1548C The Three Little Pigs
 HDU7033 Typing Contest
 HDU7016 Random Walk 2
 [code + 1] Yazid's freshman ball
 CF1548C The Three Little Pigs
 HDU7033 Typing Contest
 Qt Creator 自动补齐变慢的解决
 HALCON 20.11：如何处理标定助手品质问题
 HALCON 20.11：标定助手使用注意事项
 Solution of QT creator's automatic replenishment slowing down
 Halcon 20.11: how to deal with the quality problem of calibration assistant
 Halcon 20.11: precautions for use of calibration assistant
 “十大科学技术问题”揭晓！青年科学家50²论坛
 "Top ten scientific and technological issues" announced Young scientists 50 ² forum
 求反转链表
 Reverse linked list
 js的数据类型
 JS data type
 记一次文件读写遇到的bug
 Remember the bug encountered in reading and writing a file
 单例模式
 Singleton mode
 在这个 N 多编程语言争霸的世界，C++ 究竟还有没有未来？
 In this world of N programming languages, is there a future for C + +?
 es6模板字符
 js Promise
 js 数组方法 回顾
 ES6 template characters
 js Promise
 JS array method review
 【Golang】️走进 Go 语言️ 第一课 Hello World
 [golang] go into go language lesson 1 Hello World