人脸识别ArcFace损失函数(代码)

人脸识别ArcFace损失函数(代码)

# 实现方式1
class ArcLoss1(nn.Module):
	def __init__(self, class_num, feature_num, s=10, m=0.1):
		super().__init__()
		self.class_num = class_num
		self.feature_num = feature_num
		self.s = s
		self.m = torch.tensor(m)
		self.w = nn.Parameter(torch.rand(feature_num, class_num))  # 2*10

	def forward(self, feature):
		feature = F.normalize(feature, dim=1)  # 128*2
		w = F.normalize(self.w, dim=0)  # 2*10
		cos_theat = torch.matmul(feature, w) / 10
		sin_theat = torch.sqrt(1.0 - torch.pow(cos_theat, 2))
		cos_theat_m = cos_theat * torch.cos(self.m) - sin_theat * torch.sin(self.m)
		cos_theat_ = torch.exp(cos_theat * self.s)
		sum_cos_theat = torch.sum(torch.exp(cos_theat * self.s), dim=1, keepdim=True) - cos_theat_
		top = torch.exp(cos_theat_m * self.s)
		divide = (top / (top + sum_cos_theat))
		return divide

# 实现方式2
class ArcLoss2(nn.Module):
	def __init__(self, feature_dim=2, cls_dim=10):
		super().__init__()
		self.W = nn.Parameter(torch.randn(feature_dim, cls_dim))

	def forward(self, feature, m=1, s=10):
		x = F.normalize(feature, dim=1)
		w = F.normalize(self.W, dim=0)
		cos = torch.matmul(x, w)/10
		a = torch.acos(cos)
		top = torch.exp(s*torch.cos(a+m))
		down2 = torch.sum(torch.exp(s*torch.cos(a)), dim=1, keepdim=True)-torch.exp(s*torch.cos(a))
		out = torch.log(top/(top+down2))
		return out
代码先放这儿,后面再来解释


发布了8 篇原创文章 · 获赞 2 · 访问量 1871

猜你喜欢

转载自blog.csdn.net/leiduifan6944/article/details/103652569