class Seq2SeqTrainerForChatGLM(PeftTrainer):
r"""
Inherits PeftTrainer to compute generative metrics such as BLEU and ROUGE.
"""
def save_predictions(
self,
predict_results: PredictionOutput,
tokenizer: PreTrainedTokenizer
) -> None:
r"""
Saves model predictions to `output_dir`.
A custom behavior that not contained in Seq2SeqTrainer.
"""
if not self.is_world_process_zero():
return
preds = np.where(predict_results.predictions != IGNORE_INDEX, predict_results.predictions, self.tokenizer.pad_token_id)
labels = np.where(predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.tokenizer.pad_token_id)
preds = preds[:, labels.shape[1]:] # remove prompts
preds = [tokenizer.decode(pred, skip_special_tokens=True).strip() for pred in preds]
labels = [tokenizer.decode(label, skip_special_tokens=True).strip() for label in labels]
output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
logger.info(f"Saving prediction results to {output_prediction_file}")
with open(output_prediction_file, "w", encoding="utf-8") as writer:
res: List[str] = []
for pred, label in zip(preds, labels):
res.append(json.dumps({"label": label, "predict": pred}, ensure_ascii=False))
writer.write("\n".join(res))
这是一个Python代码段,定义了一个名为Seq2SeqTrainerForChatGLM
的类,该类继承自PeftTrainer
。这个类特定用于处理序列到序列模型的训练,以及生成指标(如BLEU和ROUGE)的计算。
接下来,我们将逐行解释这段代码:
-
class Seq2SeqTrainerForChatGLM(PeftTrainer):
- 这行定义了一个新的类,Seq2SeqTrainerForChatGLM
,它继承自PeftTrainer
类。 -
def save_predictions(self, predict_results: PredictionOutput, tokenizer: PreTrainedTokenizer) -> None:
- 这是一个save_predictions
方法,它接受两个参数,一个是predict_results
(预测结果,类型为PredictionOutput
),另一个是tokenizer
(预训练的分词器,类型为PreTrainedTokenizer
)。 -
if not self.is_world_process_zero(): return
- 这行检查当前进程是否为主进程(进程编号为0)。如果不是,这个方法就结束运行。 -
preds = np.where(predict_results.predictions != IGNORE_INDEX, predict_results.predictions, self.tokenizer.pad_token_id)
- 这行创建了一个新的预测值数组,其中非忽略索引的预测值被保留,忽略索引的预测值被替换为填充符号的ID。 -
labels = np.where(predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.tokenizer.pad_token_id)
- 这行创建了一个新的标签数组,其中非忽略索引的标签值被保留,忽略索引的标签值被替换为填充符号的ID。 -
preds = preds[:, labels.shape[1]:]
- 这行移除了预测值中的提示,提示是在预测数组的开始部分。 -
preds = [tokenizer.decode(pred, skip_special_tokens=True).strip() for pred in preds]
- 这行使用分词器解码预测值,移除特殊符号,并删除前后的空白。 -
labels = [tokenizer.decode(label, skip_special_tokens=True).strip() for label in labels]
- 这行使用分词器解码标签,移除特殊符号,并删除前后的空白。 -
output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
- 这行定义了预测结果的输出文件路径。 -
logger.info(f"Saving prediction results to {output_prediction_file}")
- 这行向日志发送一条信息,表明预测结果将被保存到哪个文件。 -
with open(output_prediction_file, "w", encoding="utf-8") as writer:
- 这行打开预测结果输出文件,以写入模式,准备写入预测结果。 -
res: List[str] = []
- 这行初始化一个空的列表,准备收集每一行的预测结果。 -
for pred, label in zip(preds, labels):
- 这行开始一个循环,遍历所有的预测值和标签。 -
res.append(json.dumps({"label": label, "predict": pred}, ensure_ascii=False))
- 在循环中,这行将每对预测和标签以JSON格式转换为字符串,然后添加到结果列表中。 -
writer.write("\n".join(res))
- 这行将结果列表连接成一个字符串,每个结果之间用换行符隔开,然后写入到文件中。