Triplet loss最初是谷歌在 FaceNet: A Unified Embedding for Face Recognition and Clustering 论文中提出的,可以学到较好的人脸的embedding
Triplet Loss 是一种用于训练特征嵌入(feature embedding)的损失函数,广泛应用于人脸识别、图像检索等需要度量相似性的任务。其核心思想是通过学习将同类样本的嵌入距离拉近,不同类样本的嵌入距离推远。
- 三元组定义
Triplet Loss 的输入由三部分组成:
Anchor(基准样本)
Positive(与 Anchor 同类的样本)
Negative(与 Anchor 不同类的样本)
- 目标函数
目标是最小化 Anchor 与 Positive 的距离,同时最大化 Anchor 与 Negative 的距离。公式如下:
L = max ( d ( a , p ) − d ( a , n ) + margin , 0 ) L = \max\left( d(a, p) - d(a, n) + \text{margin}, 0 \right) L=max(d(a,p)−d(a,n)+margin,0)
d(a,p):Anchor 和 Positive 的欧氏距离
d(a,n):Anchor 和 Negative 的欧氏距离
margin:超参数,控制正负样本的最小间隔
- 核心思想
当 d(a,p)+margin<d(a,n) 时,损失为 0,无需优化。
否则,通过反向传播调整特征嵌入,使同类更近、异类更远。
优化的目标是让loss越小越好,使得 d ( d(a,p)−d(a,n)+margin 越小越好,直到d(a,p)−d(a,n)+margin小于等于0,就不优化了
核心代码实现:
class TripletLoss(nn.Module):
def __init__(self, margin=1.0):
super().__init__()
self.margin = margin
def forward(self, anchor, positive, negative):
# 计算欧氏距离
distance_ap = torch.sqrt(torch.sum((anchor - positive)**2, dim=1)) # [batch]
distance_an = torch.sqrt(torch.sum((anchor - negative)**2, dim=1)) # [batch]
# 计算损失
losses = torch.relu(distance_ap - distance_an + self.margin)
return torch.mean(losses)
完整代码演示:
(1) 三元组数据生成
class TripletDataset(Dataset):
def __init__(self, dataset):
self.dataset = dataset
self.labels = [label for _, label in dataset]
self.classes = list(set(self.labels))
def __getitem__(self, index):
# 锚点样本
anchor, anchor_label = self.dataset[index]
# 正样本:随机选择同类别样本
positive_indices = [i for i, label in enumerate(self.labels) if label == anchor_label]
positive_index = np.random.choice(positive_indices)
positive = self.dataset[positive_index][0]
# 负样本:随机选择不同类别样本
negative_label = np.random.choice([c for c in self.classes if c != anchor_label])
negative_indices = [i for i, label in enumerate(self.labels) if label == negative_label]
negative_index = np.random.choice(negative_indices)
negative = self.dataset[negative_index][0]
return anchor, positive, negative
def __len__(self):
return len(self.dataset)
(2) 嵌入模型(示例:CNN)
class EmbeddingNet(nn.Module):
def __init__(self, embedding_dim=128):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(1, 32, 5),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, 5),
nn.ReLU(),
nn.MaxPool2d(2)
)
self.fc = nn.Sequential(
nn.Linear(64 * 4 * 4, 256),
nn.ReLU(),
nn.Linear(256, embedding_dim)
)
def forward(self, x):
x = self.conv(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
(3) Triplet Loss实现
class TripletLoss(nn.Module):
def __init__(self, margin=1.0):
super().__init__()
self.margin = margin
def forward(self, anchor, positive, negative):
# 计算欧氏距离
distance_ap = torch.sqrt(torch.sum((anchor - positive)**2, dim=1)) # [batch]
distance_an = torch.sqrt(torch.sum((anchor - negative)**2, dim=1)) # [batch]
# 计算损失
losses = torch.relu(distance_ap - distance_an + self.margin)
return torch.mean(losses)
(4) 训练流程
# 数据加载
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
triplet_dataset = TripletDataset(train_dataset)
train_loader = DataLoader(triplet_dataset, batch_size=32, shuffle=True)
# 模型初始化
model = EmbeddingNet()
criterion = TripletLoss(margin=1.0)
optimizer = optim.Adam(model.parameters(), lr=0.001)
# 训练循环
for epoch in range(10):
for batch_idx, (anchor, pos, neg) in enumerate(train_loader):
optimizer.zero_grad()
# 前向传播
anchor_emb = model(anchor)
pos_emb = model(pos)
neg_emb = model(neg)
# 计算损失
loss = criterion(anchor_emb, pos_emb, neg_emb)
# 反向传播
loss.backward()
optimizer.step()
if batch_idx % 100 == 0:
print(f'Epoch: {
epoch}, Batch: {
batch_idx}, Loss: {
loss.item():.4f}')
关键代码解析
三元组采样:确保正样本与锚点同类,负样本不同类
距离计算:使用欧氏距离(L2距离)
损失计算:通过torch.relu实现max(⋅,0)
梯度更新:Adam优化器更新嵌入网络参数
在 Triplet Loss 中,margin 的设置直接影响模型对样本间距的敏感度,通常需要根据具体任务和数据特性进行调整。
常见值:margin 通常设置在 [0.2, 1.0] 之间。
人脸识别任务常用 margin=0.2(如 FaceNet)。
通用嵌入学习任务可能使用 margin=1.0(如 ResNet 特征提取)。
任务类型 | 推荐 margin | 说明 |
---|---|---|
人脸识别(FaceNet) | 0.2 | 同类样本特征高度相似 |
细粒度图像分类 | 1.0 | 不同类别特征易混淆 |
通用图像检索 | 0.5 | 平衡多样性和紧凑性 |
文本匹配 | 0.3 | 短文本语义差异较小 |
参考文献:
Triplet Loss原理及实现
深度学习从入门到放飞自我:完全解析triplet loss
Triplet-Loss原理及其实现、应用