AllenNLP源码拓展——自动完成一批训练任务

前几天突然想到,AllenNLP的train命令是根据一个配置进行一项训练任务,如果可以把train命令打包到一个for循环里,那不是可以自动完成多个训练任务嘛。如果完成一个训练任务需要一个小时,那么一个晚上可以设置不同的参数,或者对于不同的模型,完成十几次的训练任务,充分地利用了晚上的时间。

于是就在前几天,趁着实验室的机器还没有修好,尝试实现了一下我的想法,参考train命令的代码,自己加了一个group-train命令,完成了几次测试,并上传到了我fork下来的github代码仓库里。group-train地址(注意,在__init__.py里,也加了几行代码)

等到实验室机器配置好,就可以用它迅速完成大量实验喽,开心。

昨天我在allennlp的Issues说了我做的这个东西以及具体的实现,早晨收到负责人的回复。看到他的建议,我才明白,这个方法的缺点是适用范围不广泛,我使用的机器是单个GPU(租的服务器也是单GPU),于是我只是根据我现有的条件去实现我的想法,没有考虑到多GPU,多CPU,云计算等情况(贫穷限制了我的想象力)。所以,如果读者有更好的实验条件,可以忽略我的代码(很简陋),根据需要实现自己的功能。

使用指南

很简单,很直接。

1.把若干配置文件放到一个文件夹里
在这里插入图片描述
2.命令行输入

allennlp group-train /home/wxy/PycharmProjects/allennlp-as-a-library/experiments 
-s ./group_save --include-package my_library

(第一个需要绝对路径,因为我用的listdir…)

3.开始之前,在这个文件夹创建一个training_progress.json文件,用来记录工作进度。一开始是这样的:

{"venue_classifier_boe_adam": false, 
"venue_classifier.json": false, 
"venue_classifier_boe.json": false}

4.一项训练结束,修改这个文件

{"venue_classifier_boe_adam": true, 
"venue_classifier.json": false, 
"venue_classifier_boe.json": false}

这个训练任务中保存的所有文件,在-s ./group_save的子目录下,文件夹名字和配置文件的一致。

5.接下来,程序在完成一个配置文件的任务后,自动进行下一个,直到所有训练任务结束,在输出路径下,得到三个文件夹。
在这里插入图片描述
6.如果训练从某处中断,后面的不会继续(如果可以跳过这一项会更好,但是我不知道怎么实现,如果读者知道的话,麻烦给点提示)。
但是,可以输入同样的命令,恢复训练,程序读取training_progress.json文件,就知道上次从哪里中断的,于是这一项设置recover,后面的正常进行。

自定义命令

1.在allennlp.commands.__init__的main函数里的subcommands字典中,加一个"group-train": Group_Train()。加一个from allennlp.commands.group_train import Group_Train。完事儿~

2.新建一个group_train.py文件。参考train.py,设置命令行的命令,这里我只设置了三个属性:
param_path(配置文件所在的文件夹,绝对路径)
–serialization-dir(输出路径)
–file-friendly-logging(字面意思说降低进度条刷新速度,但是这一项我没有实际设置过)

3.从命令行获得参数,定义我们自己的功能。核心就只有下面这一个函数,在循环中使用train命令的训练函数train_model_from_file

def train_model_from_files(param_path: str,
                          serialization_dir: str,
                          file_friendly_logging: bool = False) -> None:
    """
    load the params from a file and train a model, then the next.

    Parameters
    ----------
    param_path : ``str``
        A dir contains json parameter files specifying a group of AllenNLP experiment.
        'training_progress.json', which record training progress, will be created here.
    serialization_dir : ``str``
        The directory in which to save results and logs. A parameter file corresponds
        to the sub_serial_path with the same name.
    file_friendly_logging : ``bool``, optional (default=False)
        If ``True``, we make our output more friendly to saved model files.  We just pass this
        along to :func:`train_model`.
    """
    if os.path.isabs(param_path) == False:
        logger.warn(f"param_path must be a absolute path")
        return

    if os.path.isfile(serialization_dir):
        logger.warn(f"serialization_dir must be a path, but you input a file")
        return

    param_files = os.listdir(param_path)  #listdir must use absolute path
    if len(param_files) == 0:
        logger.warn(f"At least one parameter file in the param_path directory")
        return

    for file in param_files:
        if os.path.isdir(file):
            param_files.remove(file) # exclude dir name
        if os.path.splitext(file)[1] not in ['.json', '.jsonnet']:
            param_files.remove(file)  # only include json and jsonnet
        if file == 'training_progress.json':
            param_files.remove(file)  # if training_progress.json exists, remove it.

    start_position  = check_progress_file(param_path)

    if start_position == 'All Finished':
        logger.info(f"param_path must be a absolute path")
        return

    if start_position is not None:
        idx = param_files.index(start_position)
        param_files = param_files[idx:]

    for file in param_files:
        if check_file_for_train(param_path, file):
            continue
        param_file_name = param_path + os.sep + file
        sub_serial_path = serialization_dir + os.sep + os.path.splitext(file)[0]  # XXX.json/.jsonnet
        if start_position is None:
            train_model_from_file(parameter_filename = param_file_name,
                                  serialization_dir = sub_serial_path,
                                  file_friendly_logging = file_friendly_logging)
        else:
            train_model_from_file(parameter_filename=param_file_name,
                                  serialization_dir=sub_serial_path,
                                  file_friendly_logging=file_friendly_logging,
                                  recover=True)
            start_position = None  # Only the first False needs recover, others train as usual

        update_progress_file(param_path) # update training_progress.json to record training state.

还有就是一些创建training_progress.json,更新它,检查是不是所有训练都已经完成的函数。

def check_progress_file(param_path) -> None:
    """
    check training_progress.json. It contains a Dict, the key
    is the name of the parameter file, and the value is whether
    (true/false) the training has been completed.
    If it exists, determine which file to recover training.
    If it doesn't exist, create it.

    Parameters
    ----------
    param_path : ``str``
        Where to check training_progress.json.
    """
    param_file_name = param_path + os.sep + 'training_progress.json'
    if os.path.exists(param_file_name) is not True:   # create training_progress.json at first time
        logger.info(f"Creating training_progress.json at {param_path}.")

        files = os.listdir(param_path)
        progress = {str(file):False for file in files}
        with open(param_file_name, 'w') as f:
            json.dump(progress, f)
        return None
    else:           # continue from which file
        logger.info(f"Recover group train.")
        with open(param_file_name, 'r') as f:
            progress = json.load(f)
            for file, bFinished in progress.items():
                if bFinished == False:
                    return file
        return 'All Finished'


def update_progress_file(param_path) -> None:
    """
    Change a training_progress.json item after a train.

    Parameters
    ----------
    param_path : ``str``
        Where to update training_progress.json.
    """
    progress = {}
    param_file_name = param_path + os.sep + 'training_progress.json'
    with open(param_file_name, 'r') as f:
        progress = json.load(f)
    for file, bFinished in progress.items():
        if bFinished == False:
            progress[file] = True
            break
    with open(param_file_name, 'w') as f:
        json.dump(progress, f)

def check_file_for_train(param_path, file) -> bool:
    """
    Check if the training task corresponding to this file has been completed.

    Parameters
    ----------
    param_path : ``str``
        Where to check training_progress.json.
    file : ``str``
        File name to check
    """
    param_file_name = param_path + os.sep + 'training_progress.json'
    with open(param_file_name, 'r') as f:
        progress = json.load(f)
    return progress[file]

最后,欢迎收藏或fork我的仓库,使用完整的代码(不过上面这些差不多就是全部了…),如果有bug或者其他想法,会更新上去。

猜你喜欢

转载自blog.csdn.net/m0_38133212/article/details/88365758