PSP - 蛋白质结构预测 OpenFold Multimer 训练过程的特征预处理

欢迎关注我的CSDN:https://spike.blog.csdn.net/
本文地址:https://spike.blog.csdn.net/article/details/132561225

Img

Paper: OpenFold: Retraining AlphaFold2 yields new insights into its learning mechanisms and capacity for generalization

  • OpenFold: 重新训练 AlphaFold2 揭示对于学习机制和泛化能力的新见解

OpenFold Multimer 是基于 OpenFold 的开源框架,支持预测蛋白质复合物的结构。蛋白质复合物是由多个蛋白质亚基组成的大分子,在生命过程中发挥着重要的功能。OpenFold Multimer 的开发目的是为了实现类似于 DeepMind 的 AlphaFold-Multimer 的功能,即利用深度学习和结构优化的方法,根据蛋白质序列和亚基间的接触信息,预测出复合物的三维结构。目前还处于开发阶段,是 OpenFold 的一个分支,可以使用 DeepMind 提供的 Multimer v3 模型,也可以使用 OpenFold 训练的模型。

GitHub 源码:aqlaboratory/openfold

核心函数:train_openfold.py,关键类是 data_module,即 OpenFoldMultimerDataModule

# ...
if "multimer" in args.config_preset:
    data_module = OpenFoldMultimerDataModule(
        config=config.data,
        batch_seed=args.seed,
        **vars(args))
# ...
data_module.prepare_data()  # 未实现这个函数
data_module.setup()
# ...
trainer.fit(
    model_module, 
    datamodule=data_module,
    ckpt_path=ckpt_path)
# ...

OpenFoldMultimerDataModule 中,核心逻辑是 setup() 初始化、_gen_dataloader() 生成 DataLoader,其中:

  • setup() 直接调用,重写父类方法。
  • _gen_dataloader() 在父类 OpenFoldDataModule 中调用。
def train_dataloader(self):
    return self._gen_dataloader("train") 
def val_dataloader(self):
    if(self.eval_dataset is not None):
        return self._gen_dataloader("eval")
    return None
def predict_dataloader(self):
    return self._gen_dataloader("predict") 

核心函数是 _gen_dataloader() ,不同数据使用不同的 dataset ,训练集(train)使用 reroll() 重新概率采样,即可:

def _gen_dataloader(self, stage):
    if self.batch_seed is not None:
        generator = torch.Generator()
        generator.manual_seed(self.batch_seed)

    if stage == "train":
        dataset = self.train_dataset
        # Filter the dataset, if necessary
        dataset.reroll()  # 重新概率采样
    elif stage == "eval":
        dataset = self.eval_dataset
    elif stage == "predict":
        dataset = self.predict_dataset
    else:
        raise ValueError("Invalid stage")

    dl = torch.utils.data.DataLoader(
        dataset,
        batch_size=1,
        num_workers=self.config.data_module.data_loaders.num_workers,
    )
    print(f"generated training dataloader")
    return dl

train_dataset 来源于 setup() 函数,初始化 OpenFoldMultimerDataset 类:

  • 在 OpenFoldMultimerDataset 类中,核心 __getitem__,通过 dataset_idxdatapoint_idx 选择样本。

即:

# setup()
self.train_dataset = OpenFoldMultimerDataset(
    datasets=datasets,
    probabilities=probabilities,
    epoch_len=self.train_epoch_len,
    generator=generator,
    _roll_at_init=True,)

# ...
# OpenFoldMultimerDataset
def __getitem__(self, idx):
    dataset_idx, datapoint_idx = self.datapoints[idx]
    return self.datasets[dataset_idx][datapoint_idx]

train_dataset 有两类来源,一类是 PDB (train_dataset),一类是 distillation_dataset,由 train_dataset 直接生成,逻辑位于 setup() 中:

# 训练集
train_dataset = dataset_gen(
    data_dir=self.train_data_dir,
    mmcif_data_cache_path=self.train_mmcif_data_cache_path,
    alignment_dir=self.train_alignment_dir,
    filter_path=self.train_filter_path,
    max_template_hits=self.config.train.max_template_hits,
    shuffle_top_k_prefiltered=
        self.config.train.shuffle_top_k_prefiltered,
    treat_pdb_as_distillation=False,
    mode="train",
    alignment_index=self.alignment_index)
# ...
if(distillation_dataset is not None):
    datasets = [train_dataset, distillation_dataset]
    d_prob = self.config.train.distillation_prob
    probabilities = [1. - d_prob, d_prob]

核心类 OpenFoldSingleMultimerDataset ,即:

  • 核心逻辑:data_pipeline.process_fasta(),提取特征
  • 核心逻辑:feature_pipeline.process_features(),处理特征,Multimer 相关

即:

def __getitem__(self, idx):
    mmcif_id = self.idx_to_mmcif_id(idx)
    chains = self.mmcif_data_cache[mmcif_id]['chain_ids']
    print(f"mmcif_id is: {mmcif_id}, idx: {idx} and has {len(chains)} chains")

    alignment_index = None
    if self.mode == 'train' or self.mode == 'eval':
        path = os.path.join(self.data_dir, f"{mmcif_id}")
        ext = None
        for e in self.supported_exts:
            if os.path.exists(path + e):
                ext = e
                break
        if ext is None:
            raise ValueError("Invalid file type")

        # TODO: Add pdb and core exts to data_pipeline for multimer
        path += ext
        if ext == ".cif":
            data = self._parse_mmcif(
                path, mmcif_id, self.alignment_dir, alignment_index)
        else:
            raise ValueError("Extension branch missing")
    else:
        path = os.path.join(self.data_dir, f"{mmcif_id}.fasta")
        data = self.data_pipeline.process_fasta(
            fasta_path=path,
            alignment_dir=self.alignment_dir)

    if self._output_raw:
        return data

    # process all_chain_features
    data = self.feature_pipeline.process_features(
        data, mode=self.mode, is_multimer=True)

    # if it's inference mode, only need all_chain_features
    data["batch_idx"] = torch.tensor(
        [idx for _ in range(data["aatype"].shape[-1])],
        dtype=torch.int64,
        device=data["aatype"].device)

    return data

初始化逻辑:

data_processor = data_pipeline.DataPipeline(
    template_featurizer=template_featurizer,
)
self.data_pipeline = data_pipeline.DataPipelineMultimer(
    monomer_data_pipeline=data_processor
)
self.feature_pipeline = feature_pipeline.FeaturePipeline(config)

DataPipelineMultimer 类中 process_fasta 标准 AF2 逻辑:

def process_fasta(self,
                  fasta_path: str,
                  alignment_dir: str,
                  ) -> FeatureDict:
    """Creates features."""
    with open(fasta_path) as f:
        input_fasta_str = f.read()

    input_seqs, input_descs = parsers.parse_fasta(input_fasta_str)

    all_chain_features = {
    
    }
    sequence_features = {
    
    }
    is_homomer_or_monomer = len(set(input_seqs)) == 1
    for desc, seq in zip(input_descs, input_seqs):
        if seq in sequence_features:
            all_chain_features[desc] = copy.deepcopy(
                sequence_features[seq]
            )
            continue

        chain_features = self._process_single_chain(
            chain_id=desc,
            sequence=seq,
            description=desc,
            chain_alignment_dir=os.path.join(alignment_dir, desc),
            is_homomer_or_monomer=is_homomer_or_monomer
        )

        chain_features = convert_monomer_features(
            chain_features,
            chain_id=desc
        )
        all_chain_features[desc] = chain_features
        sequence_features[seq] = chain_features

    all_chain_features = add_assembly_features(all_chain_features)

    np_example = feature_processing_multimer.pair_and_merge(
        all_chain_features=all_chain_features,
    )

    # Pad MSA to avoid zero-sized extra_msa.
    np_example = pad_msa(np_example, 512)

    return np_example

FeaturePipeline 类比较简单,主要是调用 np_example_to_features,将特征后处理,其中包括 Multimer 的相关逻辑:

def process_features(
    self,
    raw_features: FeatureDict,
    mode: str = "train",
    is_multimer: bool = False,
) -> FeatureDict:
    # if(is_multimer and mode != "predict"):
    #     raise ValueError("Multimer mode is not currently trainable")

    return np_example_to_features(
        np_example=raw_features,
        config=self.config,
        mode=mode,
        is_multimer=is_multimer,
    )

np_example_to_features 中,将 np_example 转换成 tensor_dict,再转换成 features,同时增加 Multimer 的处理逻辑。

猜你喜欢

转载自blog.csdn.net/u012515223/article/details/132561225