前提 本地模式,file文件模式
channel =训练模式/验证模式
_prepare_for_training
本地情况下
保证entry_point能够找到,如果source_dir存在,则保证 source_dir 中存在entry_point这个文件
这一点,也是很正确的,因为我们确实也就在镜像中创建了这个文件夹
保证 entry_point 的存在
start_new()
inputs和model_uri需要是本地地址,并且存在file判定是否正确
job._load_config
加载job的配置文件
对输入进行核对
如果是RecordSet, FileSystemRecordSet这两种类型如何处理
RecordSet类型:
{self.channel : s3_input}
{self.channel: self.file_system_input}
如果就是字符串类型
首先判断是不是s3类型
其次判断是不是file本地类型
如果是的话,转换成对应的数据类型传送回来
如果传入的是字典,那么就可以一次性处理training与vaild.
如果传入的是列表,那么只能是RecordSet, FileSystemRecordSet
经过处理以后,
input_dict
{
'training':s3_input,
}
即config={
"input_config": input_config,
"role": role,
"output_config": output_config,
"resource_config": resource_config,
"stop_condition": stop_condition,
"vpc_config": vpc_config,
}
加载超参数
hyperparameters
生成train_args 这个训练参数
sagemaker_session.train
使用train_args生成traun_request
train_request = {
"AlgorithmSpecification": {"TrainingInputMode": input_mode},
"OutputDataConfig": output_config,
"TrainingJobName": job_name,
"StoppingCondition": stop_condition,
"ResourceConfig": resource_config,
"RoleArn": role,
}
train_request["AlgorithmSpecification"]["TrainingImage"] = image
train_request["AlgorithmSpecification"]["AlgorithmName"] = algorithm_arn
train_request["InputDataConfig"] = input_config
train_request["AlgorithmSpecification"]["MetricDefinitions"] = metric_definitions
train_request["AlgorithmSpecification"][
"EnableSageMakerMetricsTimeSeries"
] = enable_sagemaker_metrics
train_request["HyperParameters"] = hyperparameters
....
sagemaker_client.create_training_job = local_session().create_training_job
sagemaker.local.image._SageMakerContainer
创建容器类
判定docker-compose有没有安装
self.hosts 则存储的容器类的文件夹,一个容器一个文件夹
sagemaker.local.image._SageMakerContainer.train
创建output目录
创建好容器类之后,需要使用这个容器类进行训练
为每个目录创建
创建docker-compose文件
Sagemaker源码解析
猜你喜欢
转载自blog.csdn.net/qq_41861526/article/details/110002917
今日推荐
周排行