"""Fold prediction results represented as futures."""
import copy
import typing
from typing import TYPE_CHECKING, Iterator, Literal
import numpy as np
import pandas as pd
from pydantic.type_adapter import TypeAdapter
from typing_extensions import Self
from openprotein import config
from openprotein.base import APISession
from openprotein.fold.complex import id_generator
from openprotein.jobs import JobsAPI, MappedFuture
from openprotein.molecules import DNA, RNA, Complex, Ligand, Protein, Structure
from openprotein.utils.numpy import readonly_view
from . import api
from .schemas import FoldJob, FoldMetadata
if TYPE_CHECKING:
from .boltz import BoltzAffinity, BoltzConfidence
FoldResult: typing.TypeAlias = (
"Structure | np.ndarray | pd.DataFrame | BoltzAffinity | list[BoltzConfidence]"
)
[docs]
class FoldResultFuture(
MappedFuture[
bytes,
FoldResult,
]
):
"""
Fold results represented as a future.
Attributes
----------
job : FoldJob
The fold job associated with this future.
"""
job: FoldJob
def __init__(
self,
session: APISession,
job: FoldJob | None = None,
metadata: FoldMetadata | None = None,
sequences: list[bytes] | None = None,
complexes: list[Complex] | None = None,
max_workers: int = config.MAX_CONCURRENT_WORKERS,
):
"""
Initialize a FoldResultFuture instance.
Takes in either a fold job, or the fold job metadata.
:meta private:
"""
# initialize the fold job metadata
if metadata is None:
if job is None or job.job_id is None:
raise ValueError("Expected fold metadata or job")
metadata = api.fold_get(session=session, job_id=job.job_id)
self._metadata = metadata
if job is None:
jobs_api = getattr(session, "jobs", None)
assert isinstance(jobs_api, JobsAPI)
job = FoldJob.create(jobs_api.get_job(job_id=metadata.job_id))
if sequences is None:
sequences = api.fold_get_sequences(session=session, job_id=job.job_id)
self._sequences = sequences
self._complexes = complexes
self.reverse_map = {s: i for i, s in enumerate(self._sequences)}
super().__init__(session, job, max_workers)
[docs]
@classmethod
def create(
cls: type[Self],
session: APISession,
job: FoldJob | None = None,
metadata: FoldMetadata | None = None,
**kwargs,
) -> "Self":
"""
Factory method to create a FoldResultFuture.
Parameters
----------
session : APISession
The API session to use for requests.
job : FoldJob
The fold job associated with this future.
Additional keyword arguments.
Returns
-------
FoldResultFuture
An instance of FoldResultFuture.
"""
if job is not None:
job_id = job.job_id
elif metadata is not None:
job_id = metadata.job_id
else:
raise ValueError("Expected fold metadata or job")
# model_id = api.fold_get(session=session, job_id=job_id).model_id
# create different future - not used now
return cls(session=session, job=job, **kwargs)
@property
def sequences(self) -> list[bytes]:
"""
Get the sequences submitted for the fold request.
Returns
-------
list[bytes]
List of sequences.
"""
import warnings
warnings.warn(
"`sequences` for fold jobs of complexes will show ':'-delimited protein sequences but omit the ligands and other chain entities"
)
if self._sequences is None:
self._sequences = api.fold_get_sequences(self.session, self.job.job_id)
return self._sequences
@property
def complexes(self) -> list[Complex]:
"""
Get the molecular complexes submitted for the fold request.
Returns
-------
list[Complex]
List of complexes.
"""
if self._complexes is not None:
return copy.deepcopy(self._complexes)
complexes: list[Complex] = []
if self.metadata.sequences is None:
# make from self.sequences instead
# all proteins
id_gen = id_generator()
for seq in self.sequences:
proteins = {}
for monomer in seq.split(b":"):
chain_id = next(id_gen)
protein = Protein(sequence=monomer)
proteins[chain_id] = protein
model = Complex(chains=proteins)
complexes.append(model)
else:
# collate used ids
used_ids = []
for complex_dicts in self.metadata.sequences:
for complex_dict in complex_dicts:
for entity_dict in complex_dict.values():
if (id := entity_dict.get("id")) is not None:
if isinstance(id, str):
used_ids.append(id)
elif isinstance(id, list):
used_ids.extend(id)
id_gen = id_generator(used_ids)
for complex_dicts in self.metadata.sequences:
chains: dict = {}
for complex_dict in complex_dicts:
for entity_type, entity_dict in complex_dict.items():
if entity_type == "protein":
chain_id = entity_dict.get("id") or next(id_gen)
protein = Protein(sequence=entity_dict["sequence"])
if (msa_id := entity_dict.get("msa_id")) is not None:
protein.msa = msa_id
if isinstance(chain_id, list):
for id in chain_id:
chains[id] = protein
else:
chains[chain_id] = protein
elif entity_type == "dna":
chain_id = entity_dict.get("id") or next(id_gen)
dna = DNA(
sequence=entity_dict["sequence"],
cyclic=entity_dict.get("cyclic"),
)
if isinstance(chain_id, list):
for id in chain_id:
chains[id] = dna
else:
chains[chain_id] = dna
elif entity_type == "rna":
chain_id = entity_dict.get("id") or next(id_gen)
rna = RNA(
sequence=entity_dict["sequence"],
cyclic=entity_dict.get("cyclic"),
)
if isinstance(chain_id, list):
for id in chain_id:
chains[id] = rna
else:
chains[chain_id] = rna
elif entity_type == "ligand":
chain_id = entity_dict.get("id") or next(id_gen)
ligand = Ligand(
smiles=entity_dict.get("smiles"),
ccd=entity_dict.get("ccd"),
)
if isinstance(chain_id, list):
for id in chain_id:
chains[id] = ligand
else:
chains[chain_id] = ligand
complexes.append(Complex(chains=chains))
self._complexes = complexes
return copy.deepcopy(self._complexes)
@property
def id(self):
"""
Get the ID of the fold request.
Returns
-------
str
Fold job ID.
"""
return self.job.job_id
@property
def metadata(self) -> FoldMetadata:
"""The fold metadata."""
return self._metadata
@property
def model_id(self) -> str:
"""The fold model used."""
return self._metadata.model_id
def __keys__(self):
"""
Get the list of sequences submitted for the fold request.
Returns
-------
list of bytes
List of sequences.
"""
return list(range(len(self._sequences)))
@typing.overload
def get_item(
self,
index: int,
key: None = None,
) -> Structure: ...
@typing.overload
def get_item(
self,
index: int,
key: (
Literal[
"pae",
"pde",
"plddt",
"ptm",
]
| None
) = None,
) -> np.ndarray: ...
@typing.overload
def get_item(
self,
index: int,
key: Literal["affinity"],
) -> "BoltzAffinity": ...
@typing.overload
def get_item(
self,
index: int,
key: Literal["confidence"],
) -> "list[BoltzConfidence]": ...
@typing.overload
def get_item(
self,
index: int,
key: (
Literal[
"score",
"metrics",
]
| None
) = None,
) -> pd.DataFrame: ...
[docs]
def get_item(
self,
index: int,
key: (
Literal[
"pae",
"pde",
"plddt",
"ptm",
"confidence",
"affinity",
"score",
"metrics",
]
| None
) = None,
) -> FoldResult:
"""
Get fold results for a specified sequence.
Parameters
----------
sequence : bytes
Sequence to fetch results for.
Returns
-------
Complex
Complex containing the folded structure.
"""
if key is None:
data = api.fold_get_sequence_result(self.session, self.job.job_id, index)
model = Structure.from_string(data.decode(), format="cif")
return model
else:
data = api.fold_get_extra_result(self.session, self.job.job_id, index, key)
if key == "affinity":
from .boltz import BoltzAffinity
data = TypeAdapter(BoltzAffinity).validate_python(data)
elif key == "confidence":
from .boltz import BoltzConfidence
data = TypeAdapter(list[BoltzConfidence]).validate_python(data)
return data # type: ignore - converted by adapter
@typing.overload
def stream(
self,
key: None = None,
) -> Iterator[Structure]: ...
@typing.overload
def stream(
self,
key: (
Literal[
"pae",
"pde",
"plddt",
"ptm",
]
| None
) = None,
) -> Iterator[np.ndarray]: ...
@typing.overload
def stream(
self,
key: Literal["affinity"],
) -> "Iterator[BoltzAffinity]": ...
@typing.overload
def stream(
self,
key: Literal["confidence"],
) -> "Iterator[list[BoltzConfidence]]": ...
@typing.overload
def stream(
self,
key: (
Literal[
"score",
"metrics",
]
| None
) = None,
) -> Iterator[pd.DataFrame]: ...
# NOTE: ensure we only return the complex without the tuple
[docs]
def stream(
self,
key: (
Literal[
"pae",
"pde",
"plddt",
"ptm",
"confidence",
"affinity",
"score",
"metrics",
]
| None
) = None,
) -> "Iterator[Structure] | Iterator[np.ndarray] | Iterator[pd.DataFrame] | Iterator[BoltzAffinity] | Iterator[list[BoltzConfidence]]":
for _, v in super().stream(key=key):
yield v # type: ignore - homogenous
@typing.overload
def get(
self,
verbose: bool = False,
key: None = None,
) -> list[Structure]: ...
@typing.overload
def get(
self,
verbose: bool = False,
key: (
Literal[
"pae",
"pde",
"plddt",
"ptm",
]
| None
) = None,
) -> list[np.ndarray]: ...
@typing.overload
def get(
self,
verbose: bool = False,
key: Literal["affinity"] | None = None,
) -> "list[BoltzAffinity]": ...
@typing.overload
def get(
self,
verbose: bool = False,
key: Literal["confidence"] | None = None,
) -> "list[list[BoltzConfidence]]": ...
@typing.overload
def get(
self,
verbose: bool = False,
key: (
Literal[
"score",
"metrics",
]
| None
) = None,
) -> list[pd.DataFrame]: ...
[docs]
def get(
self,
verbose: bool = False,
key: (
Literal[
"pae",
"pde",
"plddt",
"ptm",
"confidence",
"affinity",
"score",
"metrics",
]
| None
) = None,
) -> "list[Structure] | list[np.ndarray] | list[pd.DataFrame] | list[list[BoltzConfidence]] | list[BoltzAffinity]":
return super().get(verbose, key=key) # type: ignore - homogenous
[docs]
def get_pae(self) -> list[np.ndarray]:
"""
Get the Predicted Aligned Error (PAE) matrix for all outputs.
Returns
-------
list[np.ndarray]
PAE matrix.
Raises
------
AttributeError
If PAE is not supported for the model.
"""
if self.model_id not in {
"boltz-1",
"boltz-1x",
"boltz-2",
"alphafold2",
"esmfold",
}:
raise AttributeError("pae not supported for this model")
if not hasattr(self, "_pae"):
self._pae = None
if self._pae is None:
pae = self.get(key="pae")
self._pae = pae
return [readonly_view(x) for x in self._pae]
[docs]
def get_pde(self) -> list[np.ndarray]:
"""
Get the Predicted Distance Error (PDE) matrix.
Returns
-------
list[np.ndarray]
PDE matrix.
Raises
------
AttributeError
If PDE is not supported for the model.
"""
if self.model_id not in {"boltz-1", "boltz-1x", "boltz-2"}:
raise AttributeError("pde not supported for this model")
if not hasattr(self, "_pde"):
self._pde = None
if self._pde is None:
pde = self.get(key="pde")
self._pde = pde
return [readonly_view(x) for x in self._pde]
[docs]
def get_plddt(self) -> list[np.ndarray]:
"""
Get the Predicted Local Distance Difference Test (pLDDT) scores.
Returns
-------
list[np.ndarray]
pLDDT scores.
Raises
------
AttributeError
If pLDDT is not supported for the model.
"""
if self.model_id not in {"boltz-1", "boltz-1x", "boltz-2", "alphafold2"}:
raise AttributeError("plddt not supported for this model")
if not hasattr(self, "_plddt"):
self._plddt = None
if self._plddt is None:
plddt = self.get(key="plddt")
self._plddt = plddt
return [readonly_view(x) for x in self._plddt]
[docs]
def get_ptm(self) -> list[np.ndarray]:
"""
Get the Predicted TM (pTM) scores.
Returns
-------
list[np.ndarray]
pTM scores.
Raises
------
AttributeError
If pTM is not supported for the model.
"""
if self.model_id not in {"alphafold2"}:
raise AttributeError("ptm not supported for this model")
if not hasattr(self, "_ptm"):
self._ptm = None
if self._ptm is None:
ptm = self.get(key="ptm")
self._ptm = ptm
return [readonly_view(x) for x in self._ptm]
[docs]
def get_score(self) -> list[pd.DataFrame]:
"""
Get the predicted scores.
Returns
-------
list[pd.DataFrame]
Structure prediction scores.
Raises
------
AttributeError
If score is not supported for the model.
"""
if self.model_id not in {"rosettafold-3"}:
raise AttributeError("score not supported for this model")
if not hasattr(self, "_score"):
self._score = None
if self._score is None:
score = self.get(key="score")
self._score = score
return copy.deepcopy(self._score)
[docs]
def get_metrics(self) -> list[pd.DataFrame]:
"""
Get the predicted metrics.
Returns
-------
list[pd.DataFrame]
Structure prediction metrics.
Raises
------
AttributeError
If metrics is not supported for the model.
"""
if self.model_id not in {"rosettafold-3"}:
raise AttributeError("metrics not supported for this model")
if not hasattr(self, "_metrics"):
self._metrics = None
if self._metrics is None:
metrics = self.get(key="metrics")
self._metrics = metrics
return copy.deepcopy(self._metrics)
[docs]
def get_confidence(self) -> list[list["BoltzConfidence"]]:
"""
Retrieve the confidences of the structure prediction.
Note
----
This is only currently supported for Boltz models.
Returns
-------
list[list[BoltzConfidence]]
List of list of BoltzConfidence objects.
Raises
------
AttributeError
If confidence is not supported for the model.
"""
if self.model_id not in {"boltz-1", "boltz-1x", "boltz-2"}:
raise AttributeError("confidence not supported for non-Boltz model")
if not hasattr(self, "_confidence"):
self._confidence = None
if self._confidence is None:
confidence = self.get(key="confidence")
self._confidence = confidence
return copy.deepcopy(self._confidence)
[docs]
def get_affinity(self) -> list["BoltzAffinity"]:
"""
Retrieve the predicted binding affinities.
Note
----
This is only currently supported for Boltz models.
Returns
-------
list[list[BoltzAffinity]]
BoltzAffinity object containing the predicted affinities.
Raises
------
AttributeError
If affinity is not supported for the model.
"""
from .boltz import BoltzAffinity
if self.model_id not in {"boltz-1", "boltz-1x", "boltz-2"}:
raise AttributeError("affinity not supported for non-Boltz model")
if not hasattr(self, "_affinity"):
self._affinity = None
if self._affinity is None:
affinity = self.get(key="affinity")
self._affinity = affinity
return copy.deepcopy(self._affinity)