ChatGLM Efficient Tuning源码解析src/utils/seq2seq.py (二)

class Seq2SeqTrainerForChatGLM(PeftTrainer):
    r"""
    Inherits PeftTrainer to compute generative metrics such as BLEU and ROUGE.
    """

    def save_predictions(
            self,
            predict_results: PredictionOutput,
            tokenizer: PreTrainedTokenizer
    ) -> None:
        r"""
        Saves model predictions to `output_dir`.

        A custom behavior that not contained in Seq2SeqTrainer.
        """
        if not self.is_world_process_zero():
            return

        preds = np.where(predict_results.predictions != IGNORE_INDEX, predict_results.predictions, self.tokenizer.pad_token_id)
        labels = np.where(predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.tokenizer.pad_token_id)

        preds = preds[:, labels.shape[1]:] # remove prompts
        preds = [tokenizer.decode(pred, skip_special_tokens=True).strip() for pred in preds]
        labels = [tokenizer.decode(label, skip_special_tokens=True).strip() for label in labels]

        output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
        logger.info(f"Saving prediction results to {output_prediction_file}")
        with open(output_prediction_file, "w", encoding="utf-8") as writer:
            res: List[str] = []
            for pred, label in zip(preds, labels):
                res.append(json.dumps({"label": label, "predict": pred}, ensure_ascii=False))
            writer.write("\n".join(res))

这是一个Python代码段,定义了一个名为Seq2SeqTrainerForChatGLM的类,该类继承自PeftTrainer。这个类特定用于处理序列到序列模型的训练,以及生成指标(如BLEU和ROUGE)的计算。

接下来,我们将逐行解释这段代码:

  1. class Seq2SeqTrainerForChatGLM(PeftTrainer): - 这行定义了一个新的类,Seq2SeqTrainerForChatGLM,它继承自PeftTrainer类。

  2. def save_predictions(self, predict_results: PredictionOutput, tokenizer: PreTrainedTokenizer) -> None: - 这是一个save_predictions方法,它接受两个参数,一个是predict_results预测结果,类型为PredictionOutput),另一个是tokenizer(预训练的分词器,类型为PreTrainedTokenizer)。

  3. if not self.is_world_process_zero(): return - 这行检查当前进程是否为主进程(进程编号为0)。如果不是,这个方法就结束运行。

  4. preds = np.where(predict_results.predictions != IGNORE_INDEX, predict_results.predictions, self.tokenizer.pad_token_id) - 这行创建了一个新的预测值数组,其中非忽略索引的预测值被保留,忽略索引的预测值被替换为填充符号的ID。

  5. labels = np.where(predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.tokenizer.pad_token_id) - 这行创建了一个新的标签数组,其中非忽略索引的标签值被保留,忽略索引的标签值被替换为填充符号的ID。

  6. preds = preds[:, labels.shape[1]:] - 这行移除了预测值中的提示,提示是在预测数组的开始部分

  7. preds = [tokenizer.decode(pred, skip_special_tokens=True).strip() for pred in preds] - 这行使用分词器解码预测值,移除特殊符号,并删除前后的空白。

  8. labels = [tokenizer.decode(label, skip_special_tokens=True).strip() for label in labels] - 这行使用分词器解码标签,移除特殊符号,并删除前后的空白。

  9. output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl") - 这行定义了预测结果的输出文件路径。

  10. logger.info(f"Saving prediction results to {output_prediction_file}") - 这行向日志发送一条信息,表明预测结果将被保存到哪个文件。

  11. with open(output_prediction_file, "w", encoding="utf-8") as writer: - 这行打开预测结果输出文件,以写入模式,准备写入预测结果。

  12. res: List[str] = [] - 这行初始化一个空的列表,准备收集每一行的预测结果。

  13. for pred, label in zip(preds, labels): - 这行开始一个循环,遍历所有的预测值和标签。

  14. res.append(json.dumps({"label": label, "predict": pred}, ensure_ascii=False)) - 在循环中,这行将每对预测和标签以JSON格式转换为字符串,然后添加到结果列表中。

  15. writer.write("\n".join(res)) - 这行将结果列表连接成一个字符串,每个结果之间用换行符隔开,然后写入到文件中。

猜你喜欢

转载自blog.csdn.net/sinat_37574187/article/details/131459455