Source code for openprotein.fold.alphafold2
"""Community-based AlphaFold 2 model running using ColabFold."""
import io
import warnings
from typing import Any, Sequence
from openprotein.align import AlignAPI, MSAFuture
from openprotein.base import APISession
from openprotein.common import ModelMetadata
from openprotein.fold.common import normalize_inputs, serialize_input
from openprotein.fold.complex import id_generator
from openprotein.molecules import Protein, DNA, RNA, Ligand, Complex
from . import api
from .future import FoldResultFuture
from .models import FoldModel
[docs]
class AlphaFold2Model(FoldModel):
"""
Class providing inference endpoints for AlphaFold2 structure prediction models, based on the implementation by ColabFold.
"""
model_id: str = "alphafold2"
def __init__(
self,
session: APISession,
model_id: str,
metadata: ModelMetadata | None = None,
):
super().__init__(session=session, model_id=model_id, metadata=metadata)
[docs]
def fold(
self,
sequences: Sequence[Complex | Protein | str] | MSAFuture | None = None,
num_recycles: int | None = None,
num_models: int = 1,
num_relax: int = 0,
**kwargs,
) -> FoldResultFuture:
"""
Post sequences to alphafold model.
Parameters
----------
sequences : List[Complex | Protein | str] | MSAFuture
List of protein sequences to include in folded output. `Protein` objects must be tagged with an `msa`, which can be a `Protein.single_sequence_mode` for single sequence mode. Alternatively, supply an `MSAFuture` to use all query sequences as a multimer.
num_recycles : int
number of times to recycle models
num_models : int
number of models to train - best model will be used
num_relax : int
maximum number of iterations for relax
Returns
-------
job : Job
"""
from openprotein.align import AlignAPI
if "msa" in kwargs:
warnings.warn(
"Inputs to AlphaFold 2 have been updated. 'msa' should be supplied as 'proteins' argument. Support will be dropped in the future."
)
sequences = kwargs["msa"]
assert isinstance(sequences, MSAFuture), "Expected msa to be an MSAFuture"
if sequences is None:
raise TypeError("Expected 'proteins' argument")
# build the normalized_models from msa
if isinstance(sequences, MSAFuture):
id_gen = id_generator()
align_api = getattr(self.session, "align", None)
assert isinstance(align_api, AlignAPI)
msa = sequences # rename
seed = align_api.get_seed(job_id=msa.job.job_id)
_proteins: dict[str, Protein] = {}
for seq in seed.split(":"):
protein = Protein(sequence=seq)
id = next(id_gen)
protein.msa = msa.id
_proteins[id] = protein
normalized_complexes = [Complex(chains=_proteins)]
else:
normalized_complexes = normalize_inputs(sequences)
for complex in normalized_complexes:
for id, chain in complex.get_chains().items():
if (
isinstance(chain, DNA)
or isinstance(chain, RNA)
or isinstance(chain, Ligand)
):
with warnings.catch_warnings():
warnings.simplefilter("always") # Force warning to always show
warnings.warn(
"AlphaFold-2 does not support ligand/DNA/RNA input. These extra chains will be ignored in the output."
)
del complex._chains[id]
_complexes = serialize_input(self.session, normalized_complexes, needs_msa=True)
if len(_complexes) == 0:
raise TypeError(
"Expected either non-empty list of proteins/models/sequences or MSAFuture"
)
result = FoldResultFuture(
session=self.session,
job=api.fold_models_post(
session=self.session,
model_id=self.model_id,
sequences=_complexes,
num_recycles=num_recycles,
num_models=num_models,
num_relax=num_relax,
),
complexes=normalized_complexes,
)
return result