记一次BART模型复现踩坑经历

问题1:数据集 inputs 和 labels 反了

问题原因

加载时数据集未指定 inputs 和 labels 。

    # Get the column names for input/target.
    # 设置 input/target 的逻辑: 1.指定名称(data_args.text_column, data_args.summary_column) 2.指定数据集(自带名称map) 3.默认为(dataset_columns[0], dataset_columns[1])
    dataset_columns = summarization_name_mapping.get(data_args.dataset_name, None)
    if data_args.text_column is None:
        text_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
    else:
        text_column = data_args.text_column
        if text_column not in column_names:
            raise ValueError(
                f"--text_column' value '{
      
      data_args.text_column}' needs to be one of: {
      
      ', '.join(column_names)}"
            )
    if data_args.summary_column is None:
        summary_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
    else:
        summary_column = data_args.summary_column
        if summary_column not in column_names:
            raise ValueError(
                f"--summary_column' value '{
      
      data_args.summary_column}' needs to be one of: {
      
      ', '.join(column_names)}"
            )

设置 input/target 的逻辑:
1.指定名称(data_args.text_column, data_args.summary_column)
2.指定数据集(自带名称map)
3.默认为(dataset_columns[0], dataset_columns[1])

解决方案

这里采用了在读取参数时指定,这里代码写了一种从文件读取参数的方式,需要一个参数文件config.json

run_summarization.py:311 行加载参数文件。

    # main, run_summarization.py:311
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))

参数文件内容

{
    
    
  "model_name_or_path":"fnlp_bart-base-chinese",

  "text_column":"context",
  "summary_column":"response",
  "max_source_length":128,
  "max_target_length":256,

  "dataset_name":"data_douhao",
  "num_train_epochs":160,
  "save_steps":1000,
  "per_device_train_batch_size":32,
  "per_device_eval_batch_size":32,
  "do_train":true,
  "do_eval":true,
  "do_predict":false,
  "include_inputs_for_metrics":true,
  "predict_with_generate":true,
  "output_dir":"checkpoints/model_douhao2",
  "overwrite_output_dir":true
}

这里在参数文件中用 "text_column":"context""summary_column":"response", 指定 inputs 和 labels 签。

其他解决方案

设置 input/target 的逻辑:
1.指定名称 (data_args.text_column, data_args.summary_column)
在加载完数据集后,直接赋值就可以了

  data_args.text_column = "context",
  data_args.summary_column = "response",

2.指定数据集(自带名称map)
run_summarization.py:289 行有个数据集映射表,增加一行数据集:元组映射即可。

summarization_name_mapping = {
    
    
    "amazon_reviews_multi": ("review_body", "review_title"),
    "big_patent": ("description", "abstract"),
    ......
    "wiki_summary": ("article", "highlights"),
    "multi_news": ("document", "summary"),

	# 加一行即可
	# "数据集名称": (input,output)
    "data": ("context", "response"),
}

3.默认为(dataset_columns[0], dataset_columns[1])
(被坑了)

问题2:模型生成长度始终为20。

训练出来的模型生成长度始终为20。

问题原因

transformers 库中加载模型超参数时,有个默认值 max_length = 20,控制生成文本长度,在载入模型config文件时,没设置值,自动加载的默认值。(默认值有点短)

# 下面按照顺序一层一层进入
# --------------------------------
main, run_summarization.py:417  # 这行加载了BART模型config
config = AutoConfig.from_pretrained(
# --------------------------------
from_pretrained, configuration_auto.py:941  # 加载完 config 文件数值去找对应的模型 config 类了
# 这里模型的 config 里写明了 "model_type": "bart"
# 所以载入时 config_dict["model_type"] = "bart"
return config_class.from_dict(config_dict, **unused_kwargs)
# --------------------------------
from_dict, configuration_utils.py:701   # 同样在找找对应的模型 config 类
config = cls(**config_dict)
# --------------------------------
__init__, configuration_bart.py:165 # 找到对应 bart 模型 config 类,进行初始化
super().__init__(
# --------------------------------
__init__, configuration_utils.py:285    # 用的是通用 config 加载
self.max_length = kwargs.pop("max_length", 20)

在最后这个文件configuration_utils.py第285行打个断点就可以看到加载的默认值了。

为什么默认值只给20啊,真的是

解决方案

改的话其实比较好改,加载完 config 后,将变量 config.max_length 改为 256,生成文本长度即可改变。

config.max_length = 256

猜你喜欢

转载自blog.csdn.net/aiaidexiaji/article/details/131063936