一、相关工作
LSTM是一种特殊的RNN,用于解决序列的处理问题。
对于每一个时刻,lstm单元接收上一个时刻的cell_state (即c(t-1))和h_state( 即h(t-1)),由公式可见,使用h(t-1)和当前的输入x(t)产生当前时刻的输出o(t),使用上一时刻的c(t-1)生成c(t),c(t)和o(t)一起用来生成h(t)
二、基本思想
CNN+RNN CNN用的是VGG6 RNN用的是LSTM.
三、模型结构
四、代码分析:
首先是训练的部分
(1)准备数据
对COCO数据集中的caption和image根据长度(最大不能超过20个单词,包括句号)和词库(caption中的所有单词必须在词库内)筛选两次caption 准备好数据集。
(2)建立模型
分为两个部分
1)CNN: vgg16模型,最后的隐藏层输出4096维特征,最后的输出层改为1000,在Imagenet上做微调,防止过拟合。4096维特征reshape成 [1,8,512],输入后续的LSTM.
2)LSTM
LSTM的隐含层向量是512维。每个词汇库里的单词可以使用预训练好的word embedding,也可以联合训练。对于一个句子,每个时刻输入一个单词的word embedding。最初的输入的word embedding (x(0)是image feature 的平均值 [1,512]
对于每个时刻的lstm,有:
output, state = lstm(current_input, last_state)
current对应word embedding,last_state包括两部分:c(t-1)和h(t-1),输出output和c(t)、h(t)。
对于每一个时刻512维度的输出,接一个5000维的全连接层加sotmax分类器。从而得到预测结果。每预测一个词,就根据logits和真实的label计算一次交叉熵。预测正确,返回一个1,否则返回一个0.
对于一个句子,最大不超过20个,因此会循环20次上述过程,返回的交叉熵和1或0均被收集到一个长度为20的列表。
for idx in range(num_steps): #20
# Embed the last word
## for 1st LSTM the input is the image
if idx == 0:
word_embed = image_emb #第一个时刻输入图像的特征,后面不再输入图像
else:
with tf.variable_scope("word_embedding"):
word_embed = tf.nn.embedding_lookup(embedding_matrix,
last_word) #如果不是第一个时刻,返回上一个单词对应的word_embedding
#tf.nn.embedding_lookup函数的用法主要是选取一个张量里面索引对应的元素。
#tf.nn.embedding_lookup(tensor, id):tensor就是输入张量,id就是张量对应的索引
# Apply the LSTM
with tf.variable_scope("lstm"):
current_input = word_embed
output, state = lstm(current_input, last_state) #lstm每一个时刻接受当前的输入和上一个时刻的隐藏态
#output 代表最后时刻的输出 [batchsize,512] output并不等于当前的隐藏态h_state,事实上,某一时刻的h_state要用该时刻的output和cell_state来求
#lstm中的state有两个:cell_state 和h_state
#state:LSTMStateTuple(c :shape=(1, 512) dtype=float32>, h:shape=(1, 512) dtype=float32>)
memory, _ = state
# Decode the expanded output of LSTM into a word
with tf.variable_scope("decode"):
expanded_output = output
## Logits is of size vocab
logits = self.decode(expanded_output) #通过一个全连接层 返回output是5000个单词的概率[batchsize,5000]
probs = tf.nn.softmax(logits)
## Prediction is the index of the word the predicted in the vocab
prediction = tf.argmax(logits, 1)
predictionsArr.append(prediction) #用softmax分类器做预测
self.probs = probs
if self.is_train:
# Compute the loss for this step, if necessary
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
labels = sentences[:, idx], #sentence [1,20] labels 是一个单独的数字
logits = logits)
masked_cross_entropy = cross_entropy * masks[:, idx] #句子中第idx个词预测结果的 交叉熵
cross_entropies.append(masked_cross_entropy)
ground_truth = tf.cast(sentences[:, idx], tf.int64)
prediction_correct = tf.where(
tf.equal(prediction, ground_truth),
tf.cast(masks[:, idx], tf.float32),
tf.cast(tf.zeros_like(prediction), tf.float32))
predictions_correct.append(prediction_correct) #如果预测正确 prediction_correct返回mask[:,idx],否则返回和prediction一样大小的zero tensor
训练过程中,得到c(t) h(t) 下一次的x(t)是句子中下一个单词的word embedding。然后进行下一个时刻,直到20次循环完,一个batch结束。
一个batch结束后,把交叉熵的列表元素相加,除以句子的长度,加上正则化的损失,得到总损失。并计算预测精度。
total_loss = cross_entropy_loss + reg_loss
predictions_correct = tf.stack(predictions_correct, axis=1)
accuracy = tf.reduce_sum(predictions_correct) \
/ tf.reduce_sum(masks)
评价部分:
原本以为会有对句子长度的控制,但是看了源码,并没有。训练的时候,对于每一句,循环20次,然后把预测的每个单词放在列表里最后返回。测试的时候也是直接得到这个列表。然后把列表里对应的单词拼接起来,得到完整的句子。
for k in tqdm(list(range(eval_data.num_batches)), desc='batch'):
#for k in range(1):
batch = eval_data.next_batch()
#caption_data = self.beam_search(sess, batch, vocabulary)
images = self.image_loader.load_images(batch)
caption_data, scores = sess.run([self.predictions, self.probs], feed_dict={self.images: images})
fake_cnt = 0 if k<eval_data.num_batches-1 \
else eval_data.fake_count
for l in range(eval_data.batch_size-fake_cnt):
## self.predictions will return the indexes of words, we need to find the corresponding word from it.
word_idxs = caption_data[l]
## get_sentence will return a sentence till there is a end delimiter which is '.'
caption = str(vocabulary.get_sentence(word_idxs))
results.append({'image_id': int(eval_data.image_ids[idx]),
'caption': caption})
#print(results)
idx += 1
在所有要评价的数据都生成了caption以后,计算评价指标。
fp = open(config.eval_result_file, 'w')
json.dump(results, fp)
fp.close()
# Evaluate these captions
eval_result_coco = eval_gt_coco.loadRes(config.eval_result_file)
scorer = COCOEvalCap(eval_gt_coco, eval_result_coco)
scorer.evaluate()
print("Evaluation complete.")