Source code for openprotein.fold.boltz

"""Community-based Boltz models for complex structure prediction with ligands/dna/rna."""

import warnings
from typing import Sequence

from pydantic import BaseModel, Field, TypeAdapter, model_validator

from openprotein.align import AlignAPI, MSAFuture
from openprotein.base import APISession
from openprotein.common import ModelMetadata
from openprotein.fold.common import normalize_inputs, serialize_input
from openprotein.molecules import Complex, Ligand, Protein

from . import api
from .complex import id_generator
from .future import FoldResultFuture
from .models import FoldModel


class BoltzModel(FoldModel):
    """
    Class providing inference endpoints for Boltz structure prediction models.
    """

    model_id: str = "boltz"

    def __init__(
        self,
        session: APISession,
        model_id: str,
        metadata: ModelMetadata | None = None,
    ):
        super().__init__(session, model_id, metadata)

    def fold(
        self,
        sequences: Sequence[Complex | Protein | str | bytes] | MSAFuture,
        diffusion_samples: int = 1,
        num_recycles: int = 3,
        num_steps: int = 200,
        step_scale: float = 1.638,
        use_potentials: bool = False,
        constraints: list[dict] | None = None,
        **kwargs,
    ) -> FoldResultFuture:
        """
        Request structure prediction with boltz model.

        Parameters
        ----------
        sequences : Sequence[Complex | Protein | str | bytes] | MSAFuture
            List of protein sequences to include in folded output. `Protein` objects must be tagged with an `msa`, which can be a `Protein.single_sequence_mode` for single sequence mode. Alternatively, supply an `MSAFuture` to use all query sequences as a multimer.
        diffusion_samples: int
            Number of diffusion samples to use
        num_recycles : int
            Number of recycling steps to use
        num_steps : int
            Number of sampling steps to use
        step_scale : float
            Scaling factor for diffusion steps.
        constraints : Optional[List[dict]]
            List of constraints.

        Returns
        -------
        FoldResultFuture
            Future for the folding complex result.
        """
        # migrate old parameter
        if (recycling_steps := kwargs.get("recycling_steps")) is not None:
            num_recycles = recycling_steps
            warnings.warn(
                "`recycling_steps` has been updated to `num_recycles`. The parameter will be auto-corrected for now but raise an exception in the future."
            )
        if (sampling_steps := kwargs.get("sampling_steps")) is not None:
            num_steps = sampling_steps
            warnings.warn(
                "`sampling_steps` has been updated to `num_steps`. The parameter will be auto-corrected for now but raise an exception in the future."
            )
        # validate constraints
        if constraints is not None:
            TypeAdapter(list[BoltzConstraint]).validate_python(constraints)

        # build the normalized_models from msa
        if isinstance(sequences, MSAFuture):
            id_gen = id_generator()
            align_api = getattr(self.session, "align", None)
            assert isinstance(align_api, AlignAPI)
            msa = sequences  # rename
            seed = align_api.get_seed(job_id=msa.job.job_id)
            _proteins: dict[str, Protein] = {}
            for seq in seed.split(":"):
                protein = Protein(sequence=seq)
                id = next(id_gen)
                protein.msa = msa.id
                _proteins[id] = protein
            normalized_complexes = [Complex(chains=_proteins)]

        else:
            normalized_complexes = normalize_inputs(sequences)

        _complexes = serialize_input(self.session, normalized_complexes, needs_msa=True)

        if len(_complexes) == 0:
            raise TypeError(
                "Expected either non-empty list of proteins/models/sequences or MSAFuture"
            )

        return FoldResultFuture(
            session=self.session,
            job=api.fold_models_post(
                session=self.session,
                model_id=self.model_id,
                sequences=_complexes,
                diffusion_samples=diffusion_samples,
                num_recycles=num_recycles,
                num_steps=num_steps,
                step_scale=step_scale,
                constraints=constraints,
                use_potentials=use_potentials,
                **kwargs,
            ),
            complexes=normalized_complexes,
        )


[docs] class Boltz2Model(BoltzModel, FoldModel): """ Class providing inference endpoints for Boltz-2 structure prediction model which jointly models complex structures and binding affinities. """ model_id = "boltz-2"
[docs] def fold( self, sequences: Sequence[Complex | Protein | str | bytes] | MSAFuture, diffusion_samples: int = 1, num_recycles: int = 3, num_steps: int = 200, step_scale: float = 1.638, use_potentials: bool = False, constraints: list[dict] | None = None, templates: list[dict] | None = None, properties: list[dict] | None = None, method: str | None = None, ) -> FoldResultFuture: """ Request structure prediction with Boltz-2 model. Parameters ---------- sequences : Sequence[Complex | Protein | str | bytes] | MSAFuture List of protein sequences to include in folded output. `Protein` objects must be tagged with an `msa`, which can be a `Protein.single_sequence_mode` for single sequence mode. Alternatively, supply an `MSAFuture` to use all query sequences as a multimer. diffusion_samples: int Number of diffusion samples to use num_recycles : int Number of recycling steps to use num_steps : int Number of sampling steps to use step_scale : float Scaling factor for diffusion steps. use_potentials: bool = False. Whether or not to use potentials. constraints : list[dict] | None = None List of constraints. templates: list[dict] | None = None List of templates to use for structure prediction. properties: list[dict] | None = None List of additional properties to predict. Should match the `BoltzProperties` method: str | None The experimental method or supervision source used for the prediction. Defults to None. Supported values (case-insensitive) include: 'MD', 'X-RAY DIFFRACTION', 'ELECTRON MICROSCOPY', 'SOLUTION NMR', 'SOLID-STATE NMR', 'NEUTRON DIFFRACTION', 'ELECTRON CRYSTALLOGRAPHY', 'FIBER DIFFRACTION', 'POWDER DIFFRACTION', 'INFRARED SPECTROSCOPY', 'FLUORESCENCE TRANSFER', 'EPR', 'THEORETICAL MODEL', 'SOLUTION SCATTERING', 'OTHER', 'AFDB', 'BOLTZ-1'. View the documentation on Boltz for upstream details. Returns ------- FoldResultFuture Future for the folding result. """ if templates is not None: raise ValueError("`templates` not yet supported!") # validate properties if properties is not None: props = TypeAdapter(list[BoltzProperty]).validate_python(properties) # Only allow affinity for ligands, and check binder refers to a ligand chain_id (str, not list) ligand_chain_ids = set() if isinstance(sequences, list): for protein in sequences: if isinstance(protein, Complex): complex = protein for id, chain in complex.get_chains().items(): if isinstance(chain, Ligand): ligand_chain_ids.add(id) for prop in props: if hasattr(prop, "affinity") and prop.affinity is not None: binder_id = prop.affinity.binder if binder_id not in ligand_chain_ids: raise ValueError( f"Affinity property binder '{binder_id}' does not match any ligand chain_id (must be a ligand with a single chain_id)." ) return super().fold( sequences=sequences, diffusion_samples=diffusion_samples, num_recycles=num_recycles, num_steps=num_steps, step_scale=step_scale, use_potentials=use_potentials, constraints=constraints, templates=templates, properties=properties, method=method, )
[docs] class Boltz1Model(BoltzModel, FoldModel): """ Class providing inference endpoints for Boltz-1 open-source structure prediction model. """ model_id = "boltz-1"
[docs] def fold( self, sequences: Sequence[Complex | Protein | str | bytes] | MSAFuture, diffusion_samples: int = 1, num_recycles: int = 3, num_steps: int = 200, step_scale: float = 1.638, use_potentials: bool = False, constraints: list[dict] | None = None, ) -> FoldResultFuture: """ Request structure prediction with Boltz-1 model. Parameters ---------- sequences : Sequence[Complex | Protein | str | bytes] | MSAFuture List of protein sequences to include in folded output. `Protein` objects must be tagged with an `msa`, which can be a `Protein.single_sequence_mode` for single sequence mode. Alternatively, supply an `MSAFuture` to use all query sequences as a multimer. diffusion_samples: int Number of diffusion samples to use num_recycles : int Number of recycling steps to use num_steps : int Number of sampling steps to use step_scale : float Scaling factor for diffusion steps. use_potentials: bool = False. Whether or not to use potentials. constraints : Optional[List[dict]] List of constraints. Returns ------- FoldResultFuture Future for the folding complex result. """ if constraints is not None: pocket_constraints = [] for constraint in constraints: if "contact" in constraint: raise ValueError("Boltz-1(x) doesn't support contact constraints") if "pocket" in constraint: pocket_constraint = constraint["pocket"] if len(pocket_constraints) > 0: msg = f"Only one pocket binders is supported in Boltz-1!" raise ValueError(msg) max_distance = constraint["pocket"].get("max_distance", 6.0) if max_distance != 6.0: msg = f"Max distance != 6.0 is not supported in Boltz-1!" raise ValueError(msg) pocket_constraints.append(pocket_constraint) return super().fold( sequences=sequences, diffusion_samples=diffusion_samples, num_recycles=num_recycles, num_steps=num_steps, step_scale=step_scale, use_potentials=use_potentials, constraints=constraints, )
[docs] class Boltz1xModel(Boltz1Model, BoltzModel, FoldModel): """ Class providing inference endpoints for Boltz-1x open-source structure prediction model, which adds the use of inference potentials to improve performance. """ model_id = "boltz-1x"
[docs] def fold( self, sequences: Sequence[Complex | Protein | str | bytes] | MSAFuture, diffusion_samples: int = 1, num_recycles: int = 3, num_steps: int = 200, step_scale: float = 1.638, constraints: list[dict] | None = None, ) -> FoldResultFuture: """ Request structure prediction with Boltz-1x model. Uses potentials with Boltz-1 model. Parameters ---------- sequences : Sequence[Complex | Protein | str | bytes] | MSAFuture List of protein sequences to include in folded output. `Protein` objects must be tagged with an `msa`, which can be a `Protein.single_sequence_mode` for single sequence mode. Alternatively, supply an `MSAFuture` to use all query sequences as a multimer. diffusion_samples: int Number of diffusion samples to use num_recycles : int Number of recycling steps to use num_steps : int Number of sampling steps to use step_scale : float Scaling factor for diffusion steps. constraints : Optional[List[dict]] List of constraints. Returns ------- FoldResultFuture Future for the folding complex result. """ return super().fold( sequences=sequences, diffusion_samples=diffusion_samples, num_recycles=num_recycles, num_steps=num_steps, step_scale=step_scale, use_potentials=True, constraints=constraints, )
class BondConstraint(BaseModel): """ Constraint specifying a covalent bond between two atoms. Attributes ---------- atom1 : list of (str or int) The first atom, specified as [CHAIN_ID, RES_IDX, ATOM_NAME]. atom2 : list of (str or int) The second atom, specified as [CHAIN_ID, RES_IDX, ATOM_NAME]. """ atom1: list[str | int] atom2: list[str | int] class PocketConstraint(BaseModel): """ Constraint specifying a ligand pocket. Attributes ---------- binder : str The chain ID of the binder. contacts : list of list of (str or int) List of contacts, each specified as [CHAIN_ID, RES_IDX/ATOM_NAME]. max_distance : float Maximum distance in angstroms for the pocket constraint. """ binder: str contacts: list[list[str | int]] max_distance: float class ContactConstraint(BaseModel): """ Constraint specifying a contact between two tokens. Attributes ---------- token1 : list of (str or int) The first token, specified as [CHAIN_ID, RES_IDX/ATOM_NAME]. token2 : list of (str or int) The second token, specified as [CHAIN_ID, RES_IDX/ATOM_NAME]. max_distance : float Maximum distance in angstroms for the contact constraint. """ token1: list[str | int] token2: list[str | int] max_distance: float class BoltzConstraint(BaseModel): """ Possible constraints for Boltz. Attributes ---------- bond : BondConstraint or None, optional Covalent bond constraint. pocket : PocketConstraint or None, optional Pocket constraint. contact : ContactConstraint or None, optional Contact constraint. """ bond: BondConstraint | None = None pocket: PocketConstraint | None = None contact: ContactConstraint | None = None @model_validator(mode="after") def check_exactly_one(cls, self): fields = [self.bond, self.pocket, self.contact] if sum(x is not None for x in fields) != 1: raise ValueError( "Exactly one of 'bond', 'pocket', or 'contact' must be set." ) return self class AffinityProperty(BaseModel): """ Property specifying affinity computation. Attributes ---------- binder : str The chain ID of the ligand for which to compute affinity. """ binder: str class BoltzProperty(BaseModel): """ Properties (additionally) requested for computation. Attributes ---------- affinity : AffinityProperty Affinity property specification. """ # TODO handle more than more property affinity: AffinityProperty class BoltzConfidence(BaseModel): """ Model representing the aggregated confidence scores for a prediction sample. Attributes ---------- confidence_score : float Aggregated score used to sort the predictions, corresponds to 0.8 * complex_plddt + 0.2 * iptm (ptm for single chains). ptm : float Predicted TM score for the complex. iptm : float Predicted TM score when aggregating at the interfaces. ligand_iptm : float ipTM but only aggregating at protein-ligand interfaces. protein_iptm : float ipTM but only aggregating at protein-protein interfaces. complex_plddt : float Average pLDDT score for the complex. complex_iplddt : float Average pLDDT score when upweighting interface tokens. complex_pde : float Average PDE score for the complex. complex_ipde : float Average PDE score when aggregating at interfaces. chains_ptm : dict[str, float] Predicted TM score within each chain, keyed by chain index as a string. pair_chains_iptm : dict[str, dict[str, float]] Predicted (interface) TM score between each pair of chains, keyed by chain indices as strings. """ confidence_score: float ptm: float iptm: float ligand_iptm: float protein_iptm: float complex_plddt: float complex_iplddt: float complex_pde: float complex_ipde: float chains_ptm: dict[str, float] pair_chains_iptm: dict[str, dict[str, float]] class BoltzAffinity(BaseModel): """ Output schema for Boltz affinity ensemble predictions. Attributes ---------- affinity_pred_value : float Predicted binding affinity from the ensemble model. affinity_probability_binary : float Predicted binding likelihood from the ensemble model. **kwargs: Extra keys of the form 'affinity_pred_valueN' and 'affinity_probability_binaryN', where N is the model index (e.g., 1, 2, 3, ...). """ affinity_pred_value: float affinity_probability_binary: float class Config: extra = "allow" # Allow extra fields