Source code for openprotein.fold.boltz

"""Community-based Boltz models for complex structure prediction with ligands/dna/rna."""

import re
import string
from typing import Any

from pydantic import BaseModel, Field, TypeAdapter, model_validator

from openprotein.align import AlignAPI, MSAFuture
from openprotein.base import APISession
from openprotein.chains import DNA, RNA, Ligand
from openprotein.common import ModelMetadata
from openprotein.protein import Protein

from . import api
from .future import FoldComplexResultFuture
from .models import FoldModel

valid_id_pattern = re.compile(r"^[A-Z]{1,5}$|^\d{1,5}$")


def is_valid_id(id_str: str) -> bool:
    """
    Check if the id_str matches the valid pattern for IDs (1-5 uppercase or 1-5 digits).
    """
    if not id_str or len(id_str) > 5:
        return False
    return bool(valid_id_pattern.fullmatch(id_str))


def id_generator(used_ids: list[str] | None = None, max_alpha_len=5, max_numeric=99999):
    """
    Yields new chain IDs, skipping any in 'used_ids'.
    First A..Z, AA..ZZ, … up to max_alpha_len, then '1','2',… up to max_numeric.
    """
    used = set(tuple(used_ids or []))
    letters = list(string.ascii_uppercase)

    # --- Alphabetic IDs ---
    curr_len = 1
    curr_indices = [0] * curr_len  # start at 'A'

    def bump_indices():
        # lexicographically increment curr_indices; return False on overflow
        for i in reversed(range(len(curr_indices))):
            if curr_indices[i] < len(letters) - 1:
                curr_indices[i] += 1
                for j in range(i + 1, len(curr_indices)):
                    curr_indices[j] = 0
                return True
        return False

    while curr_len <= max_alpha_len:
        candidate = "".join(letters[i] for i in curr_indices)
        if candidate not in used:
            used.add(candidate)
            yield candidate
        # bump
        if not bump_indices():
            curr_len += 1
            if curr_len > max_alpha_len:
                break
            curr_indices = [0] * curr_len

    # --- Numeric IDs ---
    num = 1
    while num <= max_numeric:
        candidate = str(num)
        num += 1
        if candidate not in used:
            used.add(candidate)
            yield candidate

    # exhausted
    raise RuntimeError("exhausted all possible IDs")


class BoltzModel(FoldModel):
    """
    Class providing inference endpoints for Boltz structure prediction models.
    """

    model_id: str = "boltz"

    def __init__(
        self,
        session: APISession,
        model_id: str,
        metadata: ModelMetadata | None = None,
    ):
        super().__init__(session, model_id, metadata)

    def fold(
        self,
        proteins: list[Protein] | MSAFuture | None = None,
        dnas: list[DNA] | None = None,
        rnas: list[RNA] | None = None,
        ligands: list[Ligand] | None = None,
        diffusion_samples: int = 1,
        recycling_steps: int = 3,
        sampling_steps: int = 200,
        step_scale: float = 1.638,
        use_potentials: bool = False,
        constraints: list[dict] | None = None,
        force_single_sequence_mode: bool = False,
        **kwargs,
    ) -> FoldComplexResultFuture:
        """
        Post sequences to boltz model.

        Parameters
        ----------
        proteins : List[Protein] | MSAFuture | None
            List of protein sequences to include in folded output. `Protein` objects must be tagged with an `msa`, which can be a `Protein.single_sequence_mode` for single sequence mode. Alternatively, supply an `MSAFuture` to use all query sequences as a multimer.
        dna : List[DNA] | None
            List of DNA sequences to include in folded output.
        rna : List[RNA] | None
            List of RNA sequences to include in folded output.
        ligands : List[Ligand] | None
            List of ligands to include in folded output.
        diffusion_samples: int
            Number of diffusion samples to use
        recycling_steps : int
            Number of recycling steps to use
        sampling_steps : int
            Number of sampling steps to use
        step_scale : float
            Scaling factor for diffusion steps.
        constraints : Optional[List[dict]]
            List of constraints.

        Returns
        -------
        FoldComplexResultFuture
            Future for the folding complex result.
        """
        # validate constraints
        if constraints is not None:
            TypeAdapter(list[BoltzConstraint]).validate_python(constraints)
        # collate the id's used
        used_ids = []
        if isinstance(proteins, list):
            for protein in proteins:
                if isinstance(protein, Protein) and protein.chain_id is not None:
                    if isinstance(protein.chain_id, str):
                        used_ids.append(protein.chain_id)
                    elif isinstance(protein.chain_id, list):
                        used_ids.extend(protein.chain_id)
        for dna in dnas or []:
            if isinstance(dna.chain_id, str):
                used_ids.append(dna.chain_id)
            elif isinstance(dna.chain_id, list):
                used_ids.extend(dna.chain_id)
        for rna in rnas or []:
            if isinstance(rna.chain_id, str):
                used_ids.append(rna.chain_id)
            elif isinstance(rna.chain_id, list):
                used_ids.extend(rna.chain_id)
        for ligand in ligands or []:
            if isinstance(ligand.chain_id, str):
                used_ids.append(ligand.chain_id)
            elif isinstance(ligand.chain_id, list):
                used_ids.extend(ligand.chain_id)
        id_gen = id_generator(used_ids)
        # build the proteins from msa
        if isinstance(proteins, MSAFuture):
            align_api = getattr(self.session, "align", None)
            assert isinstance(align_api, AlignAPI)
            msa = proteins  # rename
            proteins = []  # convert back to list of proteins
            seed = align_api.get_seed(job_id=msa.job.job_id)
            query_seqs_cardinality: dict[str, int] = dict()
            for seq in seed.split(":"):
                query_seqs_cardinality[seq] = query_seqs_cardinality.get(seq, 0) + 1
            for seq, card in query_seqs_cardinality.items():
                protein = Protein(sequence=seq)
                if card == 1:
                    id = next(id_gen)
                else:
                    id = [next(id_gen) for _ in range(card)]
                protein.chain_id = id
                protein.msa = msa
                proteins.append(protein)

        # build the sequences input
        sequences: list[dict[str, Any]] = []
        for protein in proteins or []:
            # check the msa
            msa = protein.msa
            if msa is None:
                raise ValueError(
                    "Expected all protein sequences to have `.msa` set with an `MSAFuture` or `Protein.single_sequence_mode` for single sequence mode."
                )
            # convert to msa id or null for single sequence mode
            msa_id = (
                msa
                if isinstance(msa, str)
                else msa.id if isinstance(msa, MSAFuture) else None
            )
            # add the protein in the expected boltz format
            p = {
                "id": protein.chain_id or next(id_gen),
                "msa_id": msa_id,
                "sequence": protein.sequence.decode(),
            }
            if protein.cyclic:
                p["cyclic"] = protein.cyclic
            sequences.append({"protein": p})
        for dna in dnas or []:
            d = {
                "id": dna.chain_id or next(id_gen),
                "sequence": dna.sequence,
            }
            if dna.cyclic:
                d["cyclic"] = dna.cyclic
            sequences.append(
                {
                    "dna": d,
                }
            )
        for rna in rnas or []:
            r = {
                "id": rna.chain_id or next(id_gen),
                "sequence": rna.sequence,
            }
            if rna.cyclic:
                r["cyclic"] = rna.cyclic
            sequences.append(
                {
                    "rna": r,
                }
            )
        for ligand in ligands or []:
            ligand_: dict = {"id": ligand.chain_id or next(id_gen)}
            if ligand.ccd:
                ligand_["ccd"] = ligand.ccd
            if ligand.smiles:
                ligand_["smiles"] = ligand.smiles
            sequences.append({"ligand": ligand_})

        if len(sequences) == 0:
            raise ValueError("Expected proteins, dna, rna or ligands")

        return FoldComplexResultFuture.create(
            session=self.session,
            job=api.fold_models_post(
                session=self.session,
                model_id=self.model_id,
                sequences=sequences,
                diffusion_samples=diffusion_samples,
                recycling_steps=recycling_steps,
                sampling_steps=sampling_steps,
                step_scale=step_scale,
                constraints=constraints,
                use_potentials=use_potentials,
                **kwargs,
            ),
            model_id=self.model_id,
            proteins=proteins,
            dnas=dnas,
            rnas=rnas,
            ligands=ligands,
        )


[docs] class Boltz2Model(BoltzModel, FoldModel): """ Class providing inference endpoints for Boltz-2 structure prediction model which jointly models complex structures and binding affinities. """ model_id = "boltz-2"
[docs] def fold( self, proteins: list[Protein] | MSAFuture | None = None, dnas: list[DNA] | None = None, rnas: list[RNA] | None = None, ligands: list[Ligand] | None = None, diffusion_samples: int = 1, recycling_steps: int = 3, sampling_steps: int = 200, step_scale: float = 1.638, use_potentials: bool = False, constraints: list[dict] | None = None, templates: list[dict] | None = None, properties: list[dict] | None = None, method: str | None = None, ) -> FoldComplexResultFuture: """ Post sequences to Boltz-2 model. Parameters ---------- proteins : List[Protein] | MSAFuture | None List of protein sequences to include in folded output. `Protein` objects must be tagged with an `msa`, which can be a `Protein.single_sequence_mode` for single sequence mode. Alternatively, supply an `MSAFuture` to use all query sequences as a multimer. dna : List[DNA] | None List of DNA sequences to include in folded output. rna : List[RNA] | None List of RNA sequences to include in folded output. ligands : List[Ligand] | None List of ligands to include in folded output. diffusion_samples: int Number of diffusion samples to use recycling_steps : int Number of recycling steps to use sampling_steps : int Number of sampling steps to use step_scale : float Scaling factor for diffusion steps. use_potentials: bool = False. Whether or not to use potentials. constraints : list[dict] | None = None List of constraints. templates: list[dict] | None = None List of templates to use for structure prediction. properties: list[dict] | None = None List of additional properties to predict. Should match the `BoltzProperties` method: str | None The experimental method or supervision source used for the prediction. Defults to None. Supported values (case-insensitive) include: 'MD', 'X-RAY DIFFRACTION', 'ELECTRON MICROSCOPY', 'SOLUTION NMR', 'SOLID-STATE NMR', 'NEUTRON DIFFRACTION', 'ELECTRON CRYSTALLOGRAPHY', 'FIBER DIFFRACTION', 'POWDER DIFFRACTION', 'INFRARED SPECTROSCOPY', 'FLUORESCENCE TRANSFER', 'EPR', 'THEORETICAL MODEL', 'SOLUTION SCATTERING', 'OTHER', 'AFDB', 'BOLTZ-1'. View the documentation on Boltz for upstream details. Returns ------- FoldComplexResultFuture Future for the folding result. """ if templates is not None: raise ValueError("`templates` not yet supported!") # validate properties if properties is not None: props = TypeAdapter(list[BoltzProperty]).validate_python(properties) # Only allow affinity for ligands, and check binder refers to a ligand chain_id (str, not list) ligand_chain_ids = set() if ligands: for ligand in ligands: if isinstance(ligand.chain_id, str): ligand_chain_ids.add(ligand.chain_id) elif isinstance(ligand.chain_id, list): raise ValueError( f"Ligand {ligand} has multiple chain_ids ({ligand.chain_id}); only single (str) chain_id allowed for affinity." ) for prop in props: if hasattr(prop, "affinity") and prop.affinity is not None: binder_id = prop.affinity.binder if binder_id not in ligand_chain_ids: raise ValueError( f"Affinity property binder '{binder_id}' does not match any ligand chain_id (must be a ligand with a single chain_id)." ) return super().fold( proteins=proteins, dnas=dnas, rnas=rnas, ligands=ligands, diffusion_samples=diffusion_samples, recycling_steps=recycling_steps, sampling_steps=sampling_steps, step_scale=step_scale, use_potentials=use_potentials, constraints=constraints, templates=templates, properties=properties, method=method, )
[docs] class Boltz1xModel(BoltzModel, FoldModel): """ Class providing inference endpoints for Boltz-1x open-source structure prediction model, which adds the use of inference potentials to improve performance. """ model_id = "boltz-1x"
[docs] def fold( self, proteins: list[Protein] | MSAFuture | None = None, dnas: list[DNA] | None = None, rnas: list[RNA] | None = None, ligands: list[Ligand] | None = None, diffusion_samples: int = 1, recycling_steps: int = 3, sampling_steps: int = 200, step_scale: float = 1.638, constraints: list[dict] | None = None, ) -> FoldComplexResultFuture: """ Post sequences to Boltz-1x model. Uses potentials with Boltz-1 model. Parameters ---------- proteins : List[Protein] | MSAFuture | None List of protein sequences to include in folded output. `Protein` objects must be tagged with an `msa`, which can be a `Protein.single_sequence_mode` for single sequence mode. Alternatively, supply an `MSAFuture` to use all query sequences as a multimer. dna : List[DNA] | None List of DNA sequences to include in folded output. rna : List[RNA] | None List of RNA sequences to include in folded output. ligands : List[Ligand] | None List of ligands to include in folded output. diffusion_samples: int Number of diffusion samples to use recycling_steps : int Number of recycling steps to use sampling_steps : int Number of sampling steps to use step_scale : float Scaling factor for diffusion steps. constraints : Optional[List[dict]] List of constraints. Returns ------- FoldComplexResultFuture Future for the folding complex result. """ return super().fold( proteins=proteins, dnas=dnas, rnas=rnas, ligands=ligands, diffusion_samples=diffusion_samples, recycling_steps=recycling_steps, sampling_steps=sampling_steps, step_scale=step_scale, use_potentials=True, constraints=constraints, )
[docs] class Boltz1Model(BoltzModel, FoldModel): """ Class providing inference endpoints for Boltz-1 open-source structure prediction model. """ model_id = "boltz-1"
[docs] def fold( self, proteins: list[Protein] | MSAFuture | None = None, dnas: list[DNA] | None = None, rnas: list[RNA] | None = None, ligands: list[Ligand] | None = None, diffusion_samples: int = 1, recycling_steps: int = 3, sampling_steps: int = 200, step_scale: float = 1.638, use_potentials: bool = False, constraints: list[dict] | None = None, ) -> FoldComplexResultFuture: """ Post sequences to Boltz-1 model. Parameters ---------- proteins : List[Protein] | MSAFuture | None List of protein sequences to include in folded output. `Protein` objects must be tagged with an `msa`, which can be a `Protein.single_sequence_mode` for single sequence mode. Alternatively, supply an `MSAFuture` to use all query sequences as a multimer. dna : List[DNA] | None List of DNA sequences to include in folded output. rna : List[RNA] | None List of RNA sequences to include in folded output. ligands : List[Ligand] | None List of ligands to include in folded output. diffusion_samples: int Number of diffusion samples to use recycling_steps : int Number of recycling steps to use sampling_steps : int Number of sampling steps to use step_scale : float Scaling factor for diffusion steps. use_potentials: bool = False. Whether or not to use potentials. constraints : Optional[List[dict]] List of constraints. Returns ------- FoldComplexResultFuture Future for the folding complex result. """ return super().fold( proteins=proteins, dnas=dnas, rnas=rnas, ligands=ligands, diffusion_samples=diffusion_samples, recycling_steps=recycling_steps, sampling_steps=sampling_steps, step_scale=step_scale, use_potentials=use_potentials, constraints=constraints, )
class BondConstraint(BaseModel): """ Constraint specifying a covalent bond between two atoms. Attributes ---------- atom1 : list of (str or int) The first atom, specified as [CHAIN_ID, RES_IDX, ATOM_NAME]. atom2 : list of (str or int) The second atom, specified as [CHAIN_ID, RES_IDX, ATOM_NAME]. """ atom1: list[str | int] atom2: list[str | int] class PocketConstraint(BaseModel): """ Constraint specifying a ligand pocket. Attributes ---------- binder : str The chain ID of the binder. contacts : list of list of (str or int) List of contacts, each specified as [CHAIN_ID, RES_IDX/ATOM_NAME]. max_distance : float Maximum distance in angstroms for the pocket constraint. """ binder: str contacts: list[list[str | int]] max_distance: float class ContactConstraint(BaseModel): """ Constraint specifying a contact between two tokens. Attributes ---------- token1 : list of (str or int) The first token, specified as [CHAIN_ID, RES_IDX/ATOM_NAME]. token2 : list of (str or int) The second token, specified as [CHAIN_ID, RES_IDX/ATOM_NAME]. max_distance : float Maximum distance in angstroms for the contact constraint. """ token1: list[str | int] token2: list[str | int] max_distance: float class BoltzConstraint(BaseModel): """ Possible constraints for Boltz. Attributes ---------- bond : BondConstraint or None, optional Covalent bond constraint. pocket : PocketConstraint or None, optional Pocket constraint. contact : ContactConstraint or None, optional Contact constraint. """ bond: BondConstraint | None = None pocket: PocketConstraint | None = None contact: ContactConstraint | None = None @model_validator(mode="after") def check_exactly_one(cls, self): fields = [self.bond, self.pocket, self.contact] if sum(x is not None for x in fields) != 1: raise ValueError( "Exactly one of 'bond', 'pocket', or 'contact' must be set." ) return self class AffinityProperty(BaseModel): """ Property specifying affinity computation. Attributes ---------- binder : str The chain ID of the ligand for which to compute affinity. """ binder: str class BoltzProperty(BaseModel): """ Properties (additionally) requested for computation. Attributes ---------- affinity : AffinityProperty Affinity property specification. """ # TODO handle more than more property affinity: AffinityProperty class BoltzConfidence(BaseModel): """ Model representing the aggregated confidence scores for a prediction sample. Attributes ---------- confidence_score : float Aggregated score used to sort the predictions, corresponds to 0.8 * complex_plddt + 0.2 * iptm (ptm for single chains). ptm : float Predicted TM score for the complex. iptm : float Predicted TM score when aggregating at the interfaces. ligand_iptm : float ipTM but only aggregating at protein-ligand interfaces. protein_iptm : float ipTM but only aggregating at protein-protein interfaces. complex_plddt : float Average pLDDT score for the complex. complex_iplddt : float Average pLDDT score when upweighting interface tokens. complex_pde : float Average PDE score for the complex. complex_ipde : float Average PDE score when aggregating at interfaces. chains_ptm : dict[str, float] Predicted TM score within each chain, keyed by chain index as a string. pair_chains_iptm : dict[str, dict[str, float]] Predicted (interface) TM score between each pair of chains, keyed by chain indices as strings. """ confidence_score: float ptm: float iptm: float ligand_iptm: float protein_iptm: float complex_plddt: float complex_iplddt: float complex_pde: float complex_ipde: float chains_ptm: dict[str, float] pair_chains_iptm: dict[str, dict[str, float]] class BoltzAffinity(BaseModel): """ Output schema for Boltz affinity ensemble predictions. Attributes ---------- affinity_pred_value : float Predicted binding affinity from the ensemble model. affinity_probability_binary : float Predicted binding likelihood from the ensemble model. per_model : dict of str to float Dictionary containing predictions from each individual model in the ensemble. Keys are of the form 'affinity_pred_valueN' and 'affinity_probability_binaryN', where N is the model index (e.g., 1, 2, 3, ...). Notes ----- Use the `parse_obj_with_models` class method to construct this object from a raw output dictionary, which will automatically separate ensemble-level and per-model predictions. """ affinity_pred_value: float affinity_probability_binary: float # Catch all other per-model fields per_model: dict[str, float] = Field(default_factory=dict) @classmethod def parse_obj_with_models(cls, obj: dict): # Extract fixed fields fixed = { "affinity_pred_value": obj.pop("affinity_pred_value"), "affinity_probability_binary": obj.pop("affinity_probability_binary"), } # Everything else goes into per_model return cls(**fixed, per_model=obj)