Source code for openprotein.fold.esmfold

"""Community-based ESMFold model."""

import warnings
from collections.abc import Sequence
from typing import Sequence

from openprotein.base import APISession
from openprotein.common import ModelMetadata
from openprotein.fold.common import normalize_inputs, serialize_input
from openprotein.molecules import DNA, RNA, Ligand, Protein, Complex

from . import api
from .future import FoldResultFuture
from .models import FoldModel


[docs] class ESMFoldModel(FoldModel): """ Class providing inference endpoints for Facebook's ESMFold structure prediction models. """ model_id: str = "esmfold" def __init__( self, session: APISession, model_id: str, metadata: ModelMetadata | None = None, ): super().__init__(session=session, model_id=model_id, metadata=metadata)
[docs] def fold( self, sequences: Sequence[Complex | Protein | str | bytes], num_recycles: int | None = None, ) -> FoldResultFuture: """ Fold sequences using this model. Parameters ---------- sequences : Sequence[bytes | str] sequences to fold num_recycles : int | None number of times to recycle models Returns ------- FoldResultFuture """ normalized_complexes = normalize_inputs(sequences) for complex in normalized_complexes: for id, chain in complex.get_chains().items(): if ( isinstance(chain, DNA) or isinstance(chain, RNA) or isinstance(chain, Ligand) ): with warnings.catch_warnings(): warnings.simplefilter("always") # Force warning to always show warnings.warn( "ESMFold does not support ligand/DNA/RNA input. These extra chains will be ignored in the output." ) del complex._chains[id] _complexes = serialize_input( self.session, normalized_complexes, needs_msa=False ) result = FoldResultFuture( session=self.session, job=api.fold_models_post( session=self.session, model_id=self.model_id, sequences=_complexes, num_recycles=num_recycles, ), complexes=normalized_complexes, ) return result