孪生网络Siamese Network 简介

版权声明:本文为博主CSDN Rosefun96原创文章。 https://blog.csdn.net/rosefun96/article/details/88320155

1 简介

在这里插入图片描述

架构:
在这里插入图片描述

对比损失函数(Contrastive Loss function)

在这里插入图片描述其中,Dw为以下表达式,网络输出的欧式距离。

在这里插入图片描述

2 实践

class SiameseNetwork(nn.Module):
	    def __init__(self):
	        super(SiameseNetwork, self).__init__()
	        self.cnn1 = nn.Sequential(
	            nn.ReflectionPad2d(1),
	            nn.Conv2d(1, 4, kernel_size=3),
	            nn.ReLU(inplace=True),
	            nn.BatchNorm2d(4),
	            nn.Dropout2d(p=.2),
	            
	            nn.ReflectionPad2d(1),
	            nn.Conv2d(4, 8, kernel_size=3),
	            nn.ReLU(inplace=True),
	            nn.BatchNorm2d(8),
	            nn.Dropout2d(p=.2),
	
	            nn.ReflectionPad2d(1),
	            nn.Conv2d(8, 8, kernel_size=3),
	            nn.ReLU(inplace=True),
	            nn.BatchNorm2d(8),
	            nn.Dropout2d(p=.2),
	        )
	
	        self.fc1 = nn.Sequential(
	            nn.Linear(8*100*100, 500),
	            nn.ReLU(inplace=True),
	
	            nn.Linear(500, 500),
	            nn.ReLU(inplace=True),
	
	            nn.Linear(500, 5)
	        )
	
	    def forward_once(self, x):
	        output = self.cnn1(x)
	        output = output.view(output.size()[0], -1)
	        output = self.fc1(output)
	        return output
	
	    def forward(self, input1, input2):
	        output1 = self.forward_once(input1)
	        output2 = self.forward_once(input2)
	        return output1, output2

对比散度:

class ContrastiveLoss(torch.nn.Module):
	    """
	    Contrastive loss function.
	    Based on: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
	    """
	
	    def __init__(self, margin=2.0):
	        super(ContrastiveLoss, self).__init__()
	        self.margin = margin
	
	    def forward(self, output1, output2, label):
	        euclidean_distance = F.pairwise_distance(output1, output2)
	        loss_contrastive = torch.mean((1-label) * torch.pow(euclidean_distance, 2)  
	                                      (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))
	
	        return loss_contrastive

参考:
1 Pytorch 社区

猜你喜欢

转载自blog.csdn.net/rosefun96/article/details/88320155