『OCR_recognition』CTC loss几种解码方式



前言

预测新的样本输入对应的输出字符串,这涉及到解码。按照最大似然准则,最优的解码结果为:
在这里插入图片描述
示例:
在这里插入图片描述
如上图的例子,按照时间序列展开得到栅格网络,解码的过程相当于空间搜索, 求取穷举的所有可能字符串序列中概率最大的那个。我们可以选择暴力的解码策略:穷举搜索,但时间复杂度是指数级的 N T N^{T} NT,显然不可行。

然而,上式不存在已知的高效解法。下面介绍几种实用的近似破解码方法。

一、贪心搜索 (greedy search)

1.1 原理解释

虽然 p(l|x) 难以有效的计算,但是由于 CTC 的独立性假设,对于某个具体的字符串 π(去 blank 前),却容易计算:
在这里插入图片描述
因此,我们放弃寻找使 p(l|x) 最大的字符串,退而寻找一个使 p(π|x) 最大的字符串,即:
在这里插入图片描述
其中,
在这里插入图片描述
简化后,解码过程(构造 π ⋆ π^⋆ π)变得非常简单(基于独立性假设): 在每个时刻输出概率最大的字符
在这里插入图片描述
Greedy search 是在每一步选择概率最大的输出值,这样就可以得到最终解码的输出序列(如上图例子,最终解码的输出序列 l=blank)。然而,CTC 网络的输出序列只对应了搜索空间的一条路径,一个最终标签可对应搜索空间的 N 条路径,所以概率最大的路径并不等于最终标签的概率最大,即不是最优解(如上图例子,最优解是 p(l=b) 而不是 p(l=blank))。

1.2 图示说明

1.3 代码实现

def remove_blank(labels, blank=0):
import numpy as np

def softmax(logits):
	""" 求每一列(即每个时刻)中最大值对应的 softmax 值 """"
	# 注意这里求 e 的次方时,次方数减去 max_value 其实不影响结果,
	# 因为最后可以化简成教科书上softmax的定义次方数,
	# 加入减 max_value 是因为 e 的 x 次方与 x 的极限(x 趋于无穷)为无穷,很容易溢出,
	# 所以为了计算时不溢出,就加入减 max_value 项
	# 次方数减去 max_value 后,e 的该次方数总是在 0 到 1 范围内。
	max_value = np.max(logits, axis=1, keepdims=True)
	exp = np.exp(logits - max_value)
	exp_sum = np.sum(exp, axis=1, keepdims=True)
	dist = exp / exp_sum
	return dist

def remove_blank(labels, blank=0):
	new_labels = []
	# 合并相同的标签
	previous = None
	for l in labels:
		if l != previous:
			new_labels.append(l)
			previous = l
	# 删除 blank
	new_labels = [l for l in new_labels if l != blank]
	return new_labels

def insert_blank(labels, blank=0):
	new_labels = [blank]
	for l in labels:
		new_labels += [l, blank]
	return new_labels

def greedy_decode(y, blank=0):
	# 按列取最大值,即每个时刻 t 上最大值对应的下标
	raw_rs = np.argmax(y, axis=1)
	# 移除 blank,值为 0 的位置表示这个位置是 blank
	rs = remove_blank(raw_rs, blank)
	return raw_rs, rs

np.random.seed(1111)
y_test = softmax(np.random.random([20, 6]))
label_have_blank, label_no_blank = greedy_decode(y_test)
print(label_have_blank)
print(label_no_blank)

二、束搜索(Beam Search)

贪心搜索的性能非常受限, 这种方法忽略了一个输出可能对应多个对齐结果。很多时候,如果我们能拿到 nearbest 的路径,后续可以利用其他信息来进一步优化搜索的结果。束搜索能近似找出 top 最优的若干条路径。

2.1 原理解释

基本原理是通过 t i − 1 t_{i−1} ti1beamsize 个序列,每个序列分别连接 t i t_{i} tibeamsize 个节点,得到 beamsize 个新序列及对应的 score,然后按照 score 从大到小的顺序选出前 beamSize 个序列,依次推进。

2.2 图示说明

假设 beamsize=2t=1 时:

这个时候只会将两个概率最大的节点放进路径集合中,即有两条路径。

t=2 时:
上面的两个路径每个路径都会和下一个时间点的每一项组成新的路径,因此一共有 b e a m s i z e × V = 2 ∗ 3 = 6 beamsize\times V=2*3=6 beamsize×V=23=6 个新路径。

然后我们还是只保留概率最大的两条路径(次大的两个路径相等,这里舍弃掉一个)。

t=3 时:

t=2 时类似,又组成了新的 6 条路径。我们还是取概率最大的两条路径。

实际使用该算法时,往往取前 20,这里前 2 只是为了方便举例。

2.3 代码实现

import numpy as np

def softmax(logits):
	""" 求每一列(即每个时刻)中最大值对应的 softmax 值 """"
	# 注意这里求 e 的次方时,次方数减去 max_value 其实不影响结果,
	# 因为最后可以化简成教科书上softmax的定义次方数,
	# 加入减 max_value 是因为 e 的 x 次方与 x 的极限(x 趋于无穷)为无穷,很容易溢出,
	# 所以为了计算时不溢出,就加入减 max_value 项
	# 次方数减去 max_value 后,e 的该次方数总是在 0 到 1 范围内。
	max_value = np.max(logits, axis=1, keepdims=True)
	exp = np.exp(logits - max_value)
	exp_sum = np.sum(exp, axis=1, keepdims=True)
	dist = exp / exp_sum
	return dist

def remove_blank(labels, blank=0):
	new_labels = []
	# 合并相同的标签
	previous = None
	for l in labels:
		if l != previous:
			new_labels.append(l)
			previous = l
	# 删除 blank
	new_labels = [l for l in new_labels if l != blank]
	return new_labels

def insert_blank(labels, blank=0):
	new_labels = [blank]
	for l in labels:
		new_labels += [l, blank]
	return new_labels

def beam_decode(y, beam_size=10):
	T, V = y.shape	# y 是个二维数组,记录了所有时刻的所有项的概率
	# 将所有的 y 中值改为 log 是为了防止溢出,因为最后得到的 p 是 y1..yn 连乘,
	# 且 yi 都在 0 到 1 之间,可能会导致下溢出,
	# 改成 log(y) 以后就变成连加了,这样就防止了下溢出
	log_y = np.log(y)
	beam = [([], 0)]	# 初始的beam
	for t in range(T):	# 遍历所有时刻t
		new_beam = []	# 每个时刻先初始化一个new_beam
		for prefix, score in beam:	# 遍历beam
			# 对于一个时刻中的每一项(一共V项)
			for i in range(V):
				# 记录添加的新项是这个时刻的第几项,对应的概率(log形式的)加上新的这项log形式的概率(本来是乘的,改成log就是加)
				new_prefix = prefix + [i]
				new_score = score + log_y[t, i]
				# new_beam 记录了对于 beam 中某一项,将这个项分别加上新的时刻中的每一项后的概率
				new_beam.append((new_prefix, new_score))
		new_beam.sort(key=lambda x: x[1], reverse=True)	# 给 new_beam 按 score 排序
		beam = new_beam[:beam_size]	# beam 即为 new_beam 中概率最大的 beam_size 个路径
	return beam

np.random.seed(1111)
y_test = softmax(np.random.random([20, 6]))
beam_chosen = beam_decode(y_test, beam_size=100)
for beam_string, beam_score in beam_chosen[:20]:
	print(remove_blank(beam_string), beam_score)

三、前缀束搜索(Prefix Beam Search)

3.1 原理解释

待理解后补全。。。

3.2 图示说明

3.3 代码实现

import numpy as np
from collections import defaultdict
ninf = float("-inf")

def softmax(logits):
	max_value = np.max(logits, axis=1, keepdims=True)
	exp = np.exp(logits - max_value)
	exp_sum = np.sum(exp, axis=1, keepdims=True)
	dist = exp / exp_sum
	return dist

def remove_blank(labels, blank=0):
	new_labels = []
	previous = None
	for l in labels:
		if l != previous:
			new_labels.append(l)
			previous = l
	new_labels = [l for l in new_labels if l != blank]
	return new_labels

def insert_blank(labels, blank=0):
	new_labels = [blank]
	for l in labels:
		new_labels += [l, blank]
	return new_labels

def _logsumexp(a, b):
	''' np.log(np.exp(a) + np.exp(b)) '''
	if a < b:
		a, b = b, a
	if b == ninf:
		return a
	else:
		return a + np.log(1 + np.exp(b - a))

def logsumexp(*args):
	'''
	from scipy.special import logsumexp
	logsumexp(args)
	'''
	res = args[0]
	for e in args[1:]:
		res = _logsumexp(res, e)
	return res

def prefix_beam_decode(y, beam_size=10, blank=0):
	T, V = y.shape
	log_y = np.log(y)
	# 最后一个字符是 blank 与最后一个字符为 non-blank 两种情况
	beam = [(tuple(), (0, ninf))]
	# 对于每一个时刻t
	for t in range(T):
		# 当我使用普通的字典时,用法一般是 dict={},添加元素的只需要 dict[element]=value 即可,调用的时候也是如此
		# dict[element]=xxx,但前提是 element 字典里,如果不在字典里就会报错
		# defaultdict 的作用是在于,当字典里的 key 不存在但被查找时,返回的不是 keyError 而是一个默认值
		# dict=defaultdict(factory_function)
		# 这个 factory_function 可以是 list、set、str 等等,作用是当 key 不存在时,返回的是函数默认值
		# 这里就是 (ninf, ninf) 是默认值
		new_beam = defaultdict(lambda: (ninf, ninf))
		# 对于 beam 中的每一项
		for prefix, (p_b, p_nb) in beam:
			for i in range(V):
				# beam 的每一项都加上时刻t中的每一项
				p = log_y[t, i]
				# 如果 i 中的这项是 blank
				if i == blank:
					# 将这项直接加入路径中
					new_p_b, new_p_nb = new_beam[prefix]
					new_p_b = logsumexp(new_p_b, p_b + p, p_nb + p)
					new_beam[prefix] = (new_p_b, new_p_nb)
					continue
				# 如果 i 中的这一项不是 blank
				else:
					end_t = prefix[-1] if prefix else None
					# 判断之前 beam 项中的最后一个元素和 i 的元素是不是一样
					new_prefix = prefix + (i,)
					new_p_b, new_p_nb = new_beam[new_prefix]
					# 如果不一样,则将 i 这项加入路径中
					if i != end_t:
						new_p_nb = logsumexp(new_p_nb, p_b + p, p_nb + p)
					else:
						new_p_nb = logsumexp(new_p_nb, p_b + p)
					new_beam[new_prefix] = (new_p_b, new_p_nb)
					# 如果一样,保留现有的路径,但是概率上要加上新的这个 i 项的概率
					if i == end_t:
						new_p_b, new_p_nb = new_beam[prefix]
						new_p_nb = logsumexp(new_p_nb, p_nb + p)
						new_beam[prefix] = (new_p_b, new_p_nb)

		# 给新的 beam 排序并取前 beam_size 个
		beam = sorted(new_beam.items(), key=lambda x: logsumexp(*x[1]), reverse=True)
		beam = beam[:beam_size]
	return beam

np.random.seed(1111)
y_test = softmax(np.random.random([20, 6]))
beam_test = prefix_beam_decode(y_test, beam_size=100)
for beam_string, beam_score in beam_test[:20]:
	print(remove_blank(beam_string), beam_score)

参考链接

  1. https://blog.csdn.net/weixin_42615068/article/details/93767781
  2. https://zhuanlan.zhihu.com/p/39266552

猜你喜欢

转载自blog.csdn.net/libo1004/article/details/111717067
CTC