"""Fold prediction results represented as futures."""
from typing import TYPE_CHECKING, Literal
import numpy as np
from pydantic.type_adapter import TypeAdapter
from typing_extensions import Self
from openprotein import config
from openprotein.base import APISession
from openprotein.chains import DNA, RNA, Ligand
from openprotein.jobs import Future, MappedFuture
from openprotein.protein import Protein
from . import api
from .schemas import FoldJob
if TYPE_CHECKING:
from .boltz import BoltzAffinity, BoltzConfidence
[docs]
class FoldResultFuture(MappedFuture, Future):
"""
Fold results represented as a future.
Attributes
----------
job : FoldJob
The fold job associated with this future.
"""
job: FoldJob
[docs]
def __init__(
self,
session: APISession,
job: FoldJob,
sequences: list[bytes] | None = None,
max_workers: int = config.MAX_CONCURRENT_WORKERS,
):
"""
Initialize a FoldResultFuture instance.
Parameters
----------
session : APISession
The API session to use for requests.
job : FoldJob
The fold job associated with this future.
sequences : list[bytes], optional
List of sequences submitted for the fold request. If None, sequences will be fetched.
max_workers : int, optional
Maximum number of concurrent workers. Default is config.MAX_CONCURRENT_WORKERS.
"""
super().__init__(session, job, max_workers)
if sequences is None:
sequences = api.fold_get_sequences(self.session, job_id=job.job_id)
self._sequences = sequences
[docs]
@classmethod
def create(
cls: type[Self],
session: APISession,
job: FoldJob,
**kwargs,
) -> "Self | FoldComplexResultFuture":
"""
Factory method to create a FoldResultFuture or FoldComplexResultFuture.
Parameters
----------
session : APISession
The API session to use for requests.
job : FoldJob
The fold job associated with this future.
**kwargs
Additional keyword arguments.
Returns
-------
FoldResultFuture or FoldComplexResultFuture
An instance of FoldResultFuture or FoldComplexResultFuture depending on the model.
"""
model_id = api.fold_get(session=session, job_id=job.job_id).model_id
if model_id.startswith("boltz") or model_id.startswith("alphafold"):
return FoldComplexResultFuture(session=session, job=job, **kwargs)
else:
return cls(session=session, job=job, **kwargs)
@property
def sequences(self) -> list[bytes]:
"""
Get the sequences submitted for the fold request.
Returns
-------
list[bytes]
List of sequences.
"""
if self._sequences is None:
self._sequences = api.fold_get_sequences(self.session, self.job.job_id)
return self._sequences
@property
def model_id(self) -> str:
"""
Get the model ID used for the fold request.
Returns
-------
str
Model ID.
"""
if self._model_id is None:
self._model_id = api.fold_get(
session=self.session, job_id=self.job.job_id
).model_id
return self._model_id
@property
def id(self):
"""
Get the ID of the fold request.
Returns
-------
str
Fold job ID.
"""
return self.job.job_id
[docs]
def keys(self):
"""
Get the list of sequences submitted for the fold request.
Returns
-------
list[bytes]
List of sequences.
"""
return self.sequences
[docs]
def get(self, verbose=False) -> list[tuple[str, bytes]]:
"""
Retrieve the fold results as a list of tuples mapping sequence to PDB-encoded string.
Parameters
----------
verbose : bool, optional
If True, print verbose output. Default is False.
Returns
-------
list[tuple[str, str]]
List of tuples mapping sequence to PDB-encoded string.
"""
return super().get(verbose=verbose)
[docs]
def get_item(self, sequence: bytes) -> bytes:
"""
Get fold results for a specified sequence.
Parameters
----------
sequence : bytes
Sequence to fetch results for.
Returns
-------
bytes
Fold result for the specified sequence.
"""
data = api.fold_get_sequence_result(self.session, self.job.job_id, sequence)
return data
[docs]
class FoldComplexResultFuture(Future):
"""
Future for manipulating results of a fold complex request.
Attributes
----------
job : FoldJob
The fold job associated with this future.
"""
job: FoldJob
[docs]
def __init__(
self,
session: APISession,
job: FoldJob,
model_id: str | None = None,
proteins: list[Protein] | None = None,
ligands: list[Ligand] | None = None,
dnas: list[DNA] | None = None,
rnas: list[RNA] | None = None,
):
"""
Initialize a FoldComplexResultFuture instance.
Parameters
----------
session : APISession
The API session to use for requests.
job : FoldJob
The fold job associated with this future.
model_id : str, optional
Model ID used for the fold request.
proteins : list[Protein], optional
List of proteins submitted for fold request.
ligands : list[Ligand], optional
List of ligands submitted for fold request.
dnas : list[DNA], optional
List of DNAs submitted for fold request.
rnas : list[RNA], optional
List of RNAs submitted for fold request.
"""
super().__init__(session, job)
self._model_id = model_id
self._proteins = proteins
self._ligands = ligands
self._dnas = dnas
self._rnas = rnas
self._initialized = not (proteins == ligands == dnas == rnas == None)
self._pae: np.ndarray | None = None
self._pde: np.ndarray | None = None
self._plddt: np.ndarray | None = None
self._confidence: list["BoltzConfidence"] | None = None
self._affinity: "BoltzAffinity | None" = None
@property
def model_id(self) -> str:
"""
Get the model ID used for the fold request.
Returns
-------
str
Model ID.
"""
if self._model_id is None:
self._model_id = api.fold_get(
session=self.session, job_id=self.job.job_id
).model_id
return self._model_id
def __get_chains(self):
"""
Internal method to initialize chain objects (proteins, dnas, rnas, ligands)
from the fold job arguments.
"""
args = api.fold_get(session=self.session, job_id=self.job.job_id).args
assert args is not None and "sequences" in args
for chain in args["sequences"]:
assert isinstance(chain, dict)
for chain_type, chain_info in chain:
if chain_type == "protein":
self._proteins = self._proteins or []
protein = Protein(sequence=chain_info["sequence"])
protein.chain_id = chain_info.get("id")
protein.msa = chain_info.get("msa_id")
self._proteins.append(protein)
elif chain_type == "dna":
self._dnas = self._dnas or []
dna = DNA(sequence=chain_info["sequence"])
dna.chain_id = chain_info.get("id")
self._dnas.append(dna)
elif chain_type == "rna":
self._rnas = self._rnas or []
rna = RNA(sequence=chain_info["sequence"])
rna.chain_id = chain_info.get("id")
self._rnas.append(rna)
elif chain_type == "ligand":
self._ligands = self._ligands or []
ligand = Ligand(
chain_id=chain_info.get("id"),
ccd=chain_info.get("ccd"),
smiles=chain_info.get("smiles"),
)
self._ligands.append(ligand)
else:
pass
self._initialized = True
@property
def proteins(self) -> list[Protein] | None:
"""
Get the proteins submitted for the fold request.
Returns
-------
list[Protein] or None
List of Protein objects or None.
"""
if not self._initialized:
self.__get_chains()
return self._proteins
@property
def dnas(self) -> list[DNA] | None:
"""
Get the DNAs submitted for the fold request.
Returns
-------
list[DNA] or None
List of DNA objects or None.
"""
if not self._initialized:
self.__get_chains()
return self._dnas
@property
def rnas(self) -> list[RNA] | None:
"""
Get the RNAs submitted for the fold request.
Returns
-------
list[RNA] or None
List of RNA objects or None.
"""
if not self._initialized:
self.__get_chains()
return self._rnas
@property
def ligands(self) -> list[Ligand] | None:
"""
Get the ligands submitted for the fold request.
Returns
-------
list[Ligand] or None
List of Ligand objects or None.
"""
if not self._initialized:
self.__get_chains()
return self._ligands
@property
def pae(self) -> np.ndarray:
"""
Get the Predicted Aligned Error (PAE) matrix.
Returns
-------
np.ndarray
PAE matrix.
Raises
------
AttributeError
If PAE is not supported for the model.
"""
if self.model_id not in {"boltz-1", "boltz-1x", "boltz-2"}:
raise AttributeError("pae not supported for non-Boltz model")
if self._pae is None:
pae = api.fold_get_complex_extra_result(
session=self.session, job_id=self.job.job_id, key="pae"
)
assert isinstance(pae, np.ndarray)
self._pae = pae
return self._pae
@property
def pde(self) -> np.ndarray:
"""
Get the Predicted Distance Error (PDE) matrix.
Returns
-------
np.ndarray
PDE matrix.
Raises
------
AttributeError
If PDE is not supported for the model.
"""
if self.model_id not in {"boltz-1", "boltz-1x", "boltz-2"}:
raise AttributeError("pde not supported for non-Boltz model")
if self._pde is None:
pde = api.fold_get_complex_extra_result(
session=self.session, job_id=self.job.job_id, key="pde"
)
assert isinstance(pde, np.ndarray)
self._pde = pde
return self._pde
@property
def plddt(self) -> np.ndarray:
"""
Get the Predicted Local Distance Difference Test (pLDDT) scores.
Returns
-------
np.ndarray
pLDDT scores.
Raises
------
AttributeError
If pLDDT is not supported for the model.
"""
if self.model_id not in {"boltz-1", "boltz-1x", "boltz-2"}:
raise AttributeError("plddt not supported for non-Boltz model")
if self._plddt is None:
plddt = api.fold_get_complex_extra_result(
session=self.session, job_id=self.job.job_id, key="plddt"
)
assert isinstance(plddt, np.ndarray)
self._plddt = plddt
return self._plddt
@property
def confidence(self) -> list["BoltzConfidence"]:
"""
Retrieve the confidences of the structure prediction.
Note
----
This is only currently supported for Boltz models.
Returns
-------
list[BoltzConfidence]
List of BoltzConfidence objects.
Raises
------
AttributeError
If confidence is not supported for the model.
"""
if self.model_id not in {"boltz-1", "boltz-1x", "boltz-2"}:
raise AttributeError("confidence not supported for non-Boltz model")
if self._confidence is None:
confidence = api.fold_get_complex_extra_result(
session=self.session, job_id=self.job.job_id, key="confidence"
)
assert isinstance(confidence, list)
self._confidence = TypeAdapter(list[BoltzConfidence]).validate_python(
confidence
)
return self._confidence
@property
def affinity(self) -> "BoltzAffinity":
"""
Retrieve the predicted binding affinities.
Note
----
This is only currently supported for Boltz models.
Returns
-------
BoltzAffinity
BoltzAffinity object containing the predicted affinities.
Raises
------
AttributeError
If affinity is not supported for the model.
"""
if self.model_id not in {"boltz-1", "boltz-1x", "boltz-2"}:
raise AttributeError("affinity not supported for non-Boltz model")
if self._affinity is None:
affinity = api.fold_get_complex_extra_result(
session=self.session, job_id=self.job.job_id, key="affinity"
)
assert isinstance(affinity, dict)
self._affinity = BoltzAffinity.parse_obj_with_models(affinity)
return self._affinity
@property
def id(self):
"""
Get the ID of the fold request.
Returns
-------
str
Fold job ID.
"""
return self.job.job_id
[docs]
def get(self, format: Literal["pdb", "mmcif"] = "mmcif", verbose=False) -> bytes:
"""
Retrieve the fold results as a single bytestring.
Defaults to mmCIF for complexes. Additional predicted properties like plddt and pae should be accessed from their respective properties, i.e. `.plddt` and `.pae`.
Parameters
----------
format : {'pdb', 'mmcif'}, optional
Output format. Default is 'mmcif'.
verbose : bool, optional
If True, print verbose output. Default is False.
Returns
-------
bytes
Fold result as a bytestring.
"""
return api.fold_get_complex_result(
session=self.session, job_id=self.id, format=format
)