欢迎关注我的CSDN:https://spike.blog.csdn.net/
本文地址:https://spike.blog.csdn.net/article/details/132561225
- OpenFold: 重新训练 AlphaFold2 揭示对于学习机制和泛化能力的新见解
OpenFold Multimer 是基于 OpenFold 的开源框架,支持预测蛋白质复合物的结构。蛋白质复合物是由多个蛋白质亚基组成的大分子,在生命过程中发挥着重要的功能。OpenFold Multimer 的开发目的是为了实现类似于 DeepMind 的 AlphaFold-Multimer 的功能,即利用深度学习和结构优化的方法,根据蛋白质序列和亚基间的接触信息,预测出复合物的三维结构。目前还处于开发阶段,是 OpenFold 的一个分支,可以使用 DeepMind 提供的 Multimer v3 模型,也可以使用 OpenFold 训练的模型。
GitHub 源码:aqlaboratory/openfold
核心函数:train_openfold.py
,关键类是 data_module
,即 OpenFoldMultimerDataModule
:
# ...
if "multimer" in args.config_preset:
data_module = OpenFoldMultimerDataModule(
config=config.data,
batch_seed=args.seed,
**vars(args))
# ...
data_module.prepare_data() # 未实现这个函数
data_module.setup()
# ...
trainer.fit(
model_module,
datamodule=data_module,
ckpt_path=ckpt_path)
# ...
在 OpenFoldMultimerDataModule
中,核心逻辑是 setup()
初始化、_gen_dataloader()
生成 DataLoader,其中:
setup()
直接调用,重写父类方法。_gen_dataloader()
在父类OpenFoldDataModule
中调用。
def train_dataloader(self):
return self._gen_dataloader("train")
def val_dataloader(self):
if(self.eval_dataset is not None):
return self._gen_dataloader("eval")
return None
def predict_dataloader(self):
return self._gen_dataloader("predict")
核心函数是 _gen_dataloader()
,不同数据使用不同的 dataset ,训练集(train)使用 reroll()
重新概率采样,即可:
def _gen_dataloader(self, stage):
if self.batch_seed is not None:
generator = torch.Generator()
generator.manual_seed(self.batch_seed)
if stage == "train":
dataset = self.train_dataset
# Filter the dataset, if necessary
dataset.reroll() # 重新概率采样
elif stage == "eval":
dataset = self.eval_dataset
elif stage == "predict":
dataset = self.predict_dataset
else:
raise ValueError("Invalid stage")
dl = torch.utils.data.DataLoader(
dataset,
batch_size=1,
num_workers=self.config.data_module.data_loaders.num_workers,
)
print(f"generated training dataloader")
return dl
train_dataset
来源于 setup()
函数,初始化 OpenFoldMultimerDataset
类:
- 在 OpenFoldMultimerDataset 类中,核心
__getitem__
,通过dataset_idx
与datapoint_idx
选择样本。
即:
# setup()
self.train_dataset = OpenFoldMultimerDataset(
datasets=datasets,
probabilities=probabilities,
epoch_len=self.train_epoch_len,
generator=generator,
_roll_at_init=True,)
# ...
# OpenFoldMultimerDataset
def __getitem__(self, idx):
dataset_idx, datapoint_idx = self.datapoints[idx]
return self.datasets[dataset_idx][datapoint_idx]
train_dataset
有两类来源,一类是 PDB (train_dataset)
,一类是 distillation_dataset
,由 train_dataset
直接生成,逻辑位于 setup()
中:
# 训练集
train_dataset = dataset_gen(
data_dir=self.train_data_dir,
mmcif_data_cache_path=self.train_mmcif_data_cache_path,
alignment_dir=self.train_alignment_dir,
filter_path=self.train_filter_path,
max_template_hits=self.config.train.max_template_hits,
shuffle_top_k_prefiltered=
self.config.train.shuffle_top_k_prefiltered,
treat_pdb_as_distillation=False,
mode="train",
alignment_index=self.alignment_index)
# ...
if(distillation_dataset is not None):
datasets = [train_dataset, distillation_dataset]
d_prob = self.config.train.distillation_prob
probabilities = [1. - d_prob, d_prob]
核心类 OpenFoldSingleMultimerDataset
,即:
- 核心逻辑:
data_pipeline.process_fasta()
,提取特征 - 核心逻辑:
feature_pipeline.process_features()
,处理特征,Multimer 相关
即:
def __getitem__(self, idx):
mmcif_id = self.idx_to_mmcif_id(idx)
chains = self.mmcif_data_cache[mmcif_id]['chain_ids']
print(f"mmcif_id is: {mmcif_id}, idx: {idx} and has {len(chains)} chains")
alignment_index = None
if self.mode == 'train' or self.mode == 'eval':
path = os.path.join(self.data_dir, f"{mmcif_id}")
ext = None
for e in self.supported_exts:
if os.path.exists(path + e):
ext = e
break
if ext is None:
raise ValueError("Invalid file type")
# TODO: Add pdb and core exts to data_pipeline for multimer
path += ext
if ext == ".cif":
data = self._parse_mmcif(
path, mmcif_id, self.alignment_dir, alignment_index)
else:
raise ValueError("Extension branch missing")
else:
path = os.path.join(self.data_dir, f"{mmcif_id}.fasta")
data = self.data_pipeline.process_fasta(
fasta_path=path,
alignment_dir=self.alignment_dir)
if self._output_raw:
return data
# process all_chain_features
data = self.feature_pipeline.process_features(
data, mode=self.mode, is_multimer=True)
# if it's inference mode, only need all_chain_features
data["batch_idx"] = torch.tensor(
[idx for _ in range(data["aatype"].shape[-1])],
dtype=torch.int64,
device=data["aatype"].device)
return data
初始化逻辑:
data_processor = data_pipeline.DataPipeline(
template_featurizer=template_featurizer,
)
self.data_pipeline = data_pipeline.DataPipelineMultimer(
monomer_data_pipeline=data_processor
)
self.feature_pipeline = feature_pipeline.FeaturePipeline(config)
DataPipelineMultimer
类中 process_fasta
标准 AF2 逻辑:
def process_fasta(self,
fasta_path: str,
alignment_dir: str,
) -> FeatureDict:
"""Creates features."""
with open(fasta_path) as f:
input_fasta_str = f.read()
input_seqs, input_descs = parsers.parse_fasta(input_fasta_str)
all_chain_features = {
}
sequence_features = {
}
is_homomer_or_monomer = len(set(input_seqs)) == 1
for desc, seq in zip(input_descs, input_seqs):
if seq in sequence_features:
all_chain_features[desc] = copy.deepcopy(
sequence_features[seq]
)
continue
chain_features = self._process_single_chain(
chain_id=desc,
sequence=seq,
description=desc,
chain_alignment_dir=os.path.join(alignment_dir, desc),
is_homomer_or_monomer=is_homomer_or_monomer
)
chain_features = convert_monomer_features(
chain_features,
chain_id=desc
)
all_chain_features[desc] = chain_features
sequence_features[seq] = chain_features
all_chain_features = add_assembly_features(all_chain_features)
np_example = feature_processing_multimer.pair_and_merge(
all_chain_features=all_chain_features,
)
# Pad MSA to avoid zero-sized extra_msa.
np_example = pad_msa(np_example, 512)
return np_example
FeaturePipeline
类比较简单,主要是调用 np_example_to_features
,将特征后处理,其中包括 Multimer
的相关逻辑:
def process_features(
self,
raw_features: FeatureDict,
mode: str = "train",
is_multimer: bool = False,
) -> FeatureDict:
# if(is_multimer and mode != "predict"):
# raise ValueError("Multimer mode is not currently trainable")
return np_example_to_features(
np_example=raw_features,
config=self.config,
mode=mode,
is_multimer=is_multimer,
)
在 np_example_to_features
中,将 np_example
转换成 tensor_dict
,再转换成 features
,同时增加 Multimer 的处理逻辑。