Source code for openprotein.fold.alphafold2

"""Community-based AlphaFold 2 model running using ColabFold."""

import warnings
from collections import Counter

from openprotein.align import MSAFuture
from openprotein.base import APISession
from openprotein.common import ModelMetadata
from openprotein.protein import Protein

from . import api
from .future import FoldComplexResultFuture
from .models import FoldModel


[docs] class AlphaFold2Model(FoldModel): """ Class providing inference endpoints for AlphaFold2 structure prediction models, based on the implementation by ColabFold. """ model_id: str = "alphafold2"
[docs] def __init__( self, session: APISession, model_id: str, metadata: ModelMetadata | None = None, ): super().__init__(session=session, model_id=model_id, metadata=metadata)
[docs] def fold( self, proteins: list[Protein] | MSAFuture | None = None, num_recycles: int | None = None, num_models: int = 1, num_relax: int = 0, **kwargs, ) -> FoldComplexResultFuture: """ Post sequences to alphafold model. Parameters ---------- proteins : List[Protein] | MSAFuture List of protein sequences to fold. `Protein` objects must be tagged with an `msa`. Alternatively, supply an `MSAFuture` to use all query sequences as a multimer. num_recycles : int number of times to recycle models num_models : int number of models to train - best model will be used max_msa : Union[str, int] maximum number of sequences in the msa to use. relax_max_iterations : int maximum number of iterations Returns ------- job : Job """ if "msa" in kwargs: warnings.warn( "Inputs to AlphaFold 2 have been updated. 'msa' should be supplied as 'proteins' argument. Support will be dropped in the future." ) proteins = kwargs["msa"] if "ligands" in kwargs or "dnas" in kwargs or "rnas" in kwargs: with warnings.catch_warnings(): warnings.simplefilter("always") # Force warning to always show warnings.warn( "Alphafold 2 only supports proteins. All other chains will be ignored" ) if proteins is None: raise TypeError("Expected 'proteins' argument") if isinstance(proteins, list): msa_to_seed: dict[str, Counter] = dict() for protein in proteins: if (msa := protein.msa) is not None: msa_id = msa.id if isinstance(msa, MSAFuture) else msa if msa_id in msa_to_seed: seeds = msa_to_seed[msa_id] else: from openprotein.align import AlignAPI align_api = getattr(self.session, "align", None) assert isinstance(align_api, AlignAPI) seed = align_api.get_seed(job_id=msa_id) # need a counter so we can make sure later that the proteins make up the msa completely seeds = Counter(seed.split(":")) msa_to_seed[msa_id] = seeds # check that this protein is in the seed if protein.sequence.decode() not in seeds: raise ValueError( f"Expected specified msa_id {msa_id} for protein {protein.sequence} to contain the sequence as part of its seed/query" ) else: raise ValueError("Expected msa for protein when using AlphaFold 2") # now make sure we only have one msa if len(msa_to_seed) > 1: raise ValueError("Expected only 1 unique msa when using AlphaFold 2") # now check that the list of proteins completely make up the msa seeds = list(msa_to_seed.values())[0] # should have just 1 for protein in proteins: # make sure to account for multimers seeds[protein.sequence.decode()] -= ( len(protein.chain_id) if isinstance(protein.chain_id, list) else 1 ) # handle when too many of a sequence in the list of proteins if seeds[protein.sequence.decode()] < 0: raise ValueError( "List of proteins does not completely make up the MSA seed" ) if seeds.total() != 0: # handle when overall mismatch - 1 and -1 case is handled above raise ValueError( "List of proteins does not completely make up the MSA seed" ) msa_id = list(msa_to_seed.keys())[0] elif isinstance(proteins, MSAFuture): msa_id = proteins.id else: raise TypeError("Expected either list of Proteins or MSAFuture") return FoldComplexResultFuture.create( session=self.session, job=api.fold_models_post( self.session, model_id=self.model_id, msa_id=msa_id, num_recycles=num_recycles, num_models=num_models, num_relax=num_relax, ), )