Source code for openprotein.molecules.complex

import gzip
import operator
from collections.abc import Mapping, Sequence
from functools import reduce
from pathlib import Path
from types import MappingProxyType
from typing import Literal, overload

import numpy as np
import numpy.typing as npt

import gemmi

import openprotein.utils.chain_id as _chain_id_utils
import openprotein.utils.cif as _cif_utils

from .chains import DNA, RNA, Ligand
from .protein import Protein


# TODO: deserialization note about plddt parsed per residue
[docs] class Complex: def __init__( self, chains: Mapping[str, Protein | DNA | RNA | Ligand] | None = None, name: bytes | str | None = None, ): self._chains = dict(sorted(chains.items())) if chains is not None else {} self.name = name @property def name(self) -> str | None: return self._name @name.setter def name(self, x: bytes | str | None) -> None: self._name = x.decode() if isinstance(x, bytes) else x def get_name(self) -> str | None: return self._name def set_name(self, x: bytes | str | None) -> "Complex": self.name = x return self def get_chains(self) -> Mapping[str, Protein | DNA | RNA | Ligand]: return MappingProxyType(self._chains) def get_proteins(self) -> Mapping[str, Protein]: return MappingProxyType( {k: v for k, v in self._chains.items() if isinstance(v, Protein)} ) def get_protein(self, chain_id: str) -> Protein: chain = self._chains[chain_id] assert isinstance(chain, Protein) return chain def get_dnas(self) -> Mapping[str, DNA]: return MappingProxyType( {k: v for k, v in self._chains.items() if isinstance(v, DNA)} ) def get_dna(self, chain_id: str) -> DNA: chain = self._chains[chain_id] assert isinstance(chain, DNA) return chain def get_rnas(self) -> Mapping[str, RNA]: return MappingProxyType( {k: v for k, v in self._chains.items() if isinstance(v, RNA)} ) def get_rna(self, chain_id: str) -> RNA: chain = self._chains[chain_id] assert isinstance(chain, RNA) return chain def get_ligands(self) -> Mapping[str, Ligand]: return MappingProxyType( {k: v for k, v in self._chains.items() if isinstance(v, Ligand)} ) def get_ligand(self, chain_id: str) -> Ligand: chain = self._chains[chain_id] assert isinstance(chain, Ligand) return chain def set_chain( self, chain_id: str, value: Protein | DNA | RNA | Ligand ) -> "Complex": self._chains[chain_id] = value self._chains = dict(sorted(self._chains.items())) return self def __rand__(self, left: "Complex | Protein | str") -> "Complex": if isinstance(left, str): left = Protein.from_expr(expr=left) return left & self def __and__(self, right: "Complex | Protein | str") -> "Complex": """Combine multiple objects into a single Complex.""" assert ( isinstance(right, Complex) or isinstance(right, Protein) or isinstance(right, str) ) id_gen = _chain_id_utils.id_generator(list(self._chains.keys())) if isinstance(right, str): right = Protein.from_expr(right) if isinstance(right, Protein): self.set_chain(chain_id=next(id_gen), value=right) else: if ( len(overlapping_chain_ids := self._chains.keys() & right._chains.keys()) > 0 ): raise ValueError( f"Trying to combine two sets of chains with overlapping chain ids: {overlapping_chain_ids}" ) self._chains = dict(sorted((self._chains | right._chains).items())) return self @overload def rmsd( self, tgt: "Complex", backbone_only: bool | str | Sequence[str] = False, return_transform: Literal[False] = False, ) -> float: ... @overload def rmsd( self, tgt: "Complex", backbone_only: bool | str | Sequence[str] = False, return_transform: Literal[True] = True, ) -> tuple[float, npt.NDArray[np.floating], npt.NDArray[np.floating]]: ... def rmsd( self, tgt: "Complex", backbone_only: bool | str | Sequence[str] = False, return_transform: bool = False, ) -> float | tuple[float, npt.NDArray[np.floating], npt.NDArray[np.floating]]: assert all( isinstance(v, Protein) for v in self._chains.values() ), "rmsd supported only for Protein chains, not supported for non-protein chains" assert all( isinstance(v, Protein) for v in tgt._chains.values() ), "rmsd supported only for Protein chains, not supported for non-protein chains" src_proteins, tgt_proteins = self.get_proteins(), tgt.get_proteins() assert tgt_proteins.keys() == src_proteins.keys() assert [len(x) for x in src_proteins.values()] == [ len(x) for x in tgt_proteins.values() ] src_protein: Protein = reduce(operator.add, src_proteins.values()) tgt_protein: Protein = reduce(operator.add, tgt_proteins.values()) return src_protein.rmsd( tgt_protein, backbone_only=backbone_only, return_transform=return_transform, ) def transform( self, R: npt.NDArray[np.floating] | None = None, t: npt.NDArray[np.floating] | None = None, ) -> "Complex": assert all( isinstance(v, Protein) for v in self._chains.values() ), "transform supported only for Protein chains, not supported for non-protein chains" for protein in self.get_proteins().values(): protein.transform(R=R, t=t) return self def superimpose_onto( self, tgt: "Complex", backbone_only: bool | str | Sequence[str] = False ) -> "Complex": _, R, t = tgt.rmsd(self, backbone_only=backbone_only, return_transform=True) return self.transform(R=R, t=t)
[docs] def to_string(self, format: Literal["cif", "pdb"] = "cif") -> str: """ Serialize this Complex to a string. Note that format="pdb" may not serialize all aspects of this object, so format="cif", the default, is preferred. """ if format == "cif": return self._make_cif_string() elif format == "pdb": return self._make_pdb_string() else: raise ValueError(format)
@staticmethod def from_filepath( path: Path | str, use_bfactor_as_plddt: bool | None = None, model_idx: int = 0, verbose: bool = True, ) -> "Complex": path = Path(path) if path.suffix == ".gz": if path.name.endswith(".cif.gz"): ext, format = ".cif.gz", "cif" elif path.name.endswith(".pdb.gz"): ext, format = ".pdb.gz", "pdb" else: raise ValueError(f"unsupported format: {path}") with gzip.open(path, "rb") as f: data = f.read() else: ext = path.suffix format = ext.removeprefix(".") assert format == "cif" or format == "pdb" data = path.read_bytes() return Complex.from_string( filestring=data, format=format, use_bfactor_as_plddt=use_bfactor_as_plddt, model_idx=model_idx, verbose=verbose, ).set_name(path.name.removesuffix(ext)) @staticmethod def from_string( filestring: bytes | str, format: Literal["pdb", "cif"], use_bfactor_as_plddt: bool | None = None, model_idx: int = 0, verbose: bool = True, ) -> "Complex": structure_block = _cif_utils.StructureCIFBlock( filestring=filestring, format=format ) return Complex._from_structure_block( structure_block=structure_block, use_bfactor_as_plddt=use_bfactor_as_plddt, model_idx=model_idx, verbose=verbose, ) def copy(self) -> "Complex": return Complex( chains={k: v.copy() for k, v in self._chains.items()}, name=self._name ) @staticmethod def _from_structure_block( structure_block: _cif_utils.StructureCIFBlock, use_bfactor_as_plddt: bool | None = None, model_idx: int = 0, verbose: bool = True, ) -> "Complex": block, structure = structure_block.block, structure_block.structure model = structure[model_idx] if len(structure) > 0 else None # Use block info directly so that we can get chains with empty struct info subchain_ids = [x for x in block.find_loop("_struct_asym.id")] if len(subchain_ids) == 0 and model is not None: # Try to get actual chain IDs from the structure subchain_ids = [subchain.subchain_id() for subchain in model.subchains()] # collect chains chains = {} for subchain_id in sorted(subchain_ids): subchain = model.get_subchain(subchain_id) if model is not None else None # Get the entity for this chain to determine its type if subchain is not None and len(subchain) > 0: entity = structure.get_entity_of(subchain) if entity is None: raise ValueError(f"Could not find entity for chain {subchain_id}") else: matching_entities = [ e for e in structure.entities if subchain_id in e.subchains ] assert len(matching_entities) == 1, ( f"expected only one entity to match {chain_id=}, but found " f"{len(matching_entities)}: {matching_entities}" ) entity = matching_entities[0] del matching_entities # Determine chain type based on entity type and polymer type if (entity_type := entity.entity_type) == gemmi.EntityType.Polymer: if structure.input_format == gemmi.CoorFormat.Pdb: assert subchain_id.endswith("xp") chain_id = subchain_id.removesuffix("xp") assert chain_id not in chains else: chain_id = subchain_id if (polymer_type := entity.polymer_type) in ( gemmi.PolymerType.PeptideL, gemmi.PolymerType.PeptideD, ): chains[chain_id] = Protein._from_structure_block( structure_block=structure_block, chain_id=subchain_id, use_bfactor_as_plddt=use_bfactor_as_plddt, model_idx=model_idx, verbose=verbose, ) elif polymer_type == gemmi.PolymerType.Dna: chains[chain_id] = DNA._from_structure_block( structure_block=structure_block, chain_id=subchain_id, model_idx=model_idx, ) elif polymer_type == gemmi.PolymerType.Rna: chains[chain_id] = RNA._from_structure_block( structure_block=structure_block, chain_id=subchain_id, model_idx=model_idx, ) else: # if verbose: # print( # f"Warning: Skipping unsupported polymer type {polymer_type} for chain {subchain_id}" # ) continue elif entity_type == gemmi.EntityType.NonPolymer: if structure.input_format == gemmi.CoorFormat.Pdb: raise ValueError("ligands from pdb files not supported yet") chain_id = subchain_id assert ( structure.input_format != gemmi.CoorFormat.Pdb ), "ligands from pdb files not supported yet" chains[chain_id] = Ligand._from_structure_block( structure_block=structure_block, chain_id=subchain_id, model_idx=model_idx, ) elif entity_type == gemmi.EntityType.Water: continue else: # if verbose: # print( # f"Warning: Skipping unsupported entity type {entity_type} for chain {subchain_id}" # ) continue return Complex(chains=chains, name=structure.name) def _make_cif_string(self) -> str: structure = self._make_structure() block = structure.make_mmcif_block( groups=gemmi.MmcifOutputGroups(True, chem_comp=False) ) sequence_loop, atom_loop = _cif_utils.init_loops(block=block) for chain_id, chain in self._chains.items(): chain._append_loop_data( chain_id=chain_id, sequence_loop=sequence_loop, atom_loop=atom_loop ) return block.as_string() def _make_pdb_string(self) -> str: structure = self._make_structure() return structure.make_pdb_string(gemmi.PdbWriteOptions(minimal=True)) def _make_structure(self) -> gemmi.Structure: assert ( len(set(x._structure_block for x in self.get_ligands().values())) <= 1 ), "can only serialize ligands if they all originate from the same structure file" structure = gemmi.Structure() for chain_id, chain in self._chains.items(): structure = chain._make_structure( structure=structure, model_idx=0, chain_id=chain_id, entity_name=str(len(structure.entities) + 1), ) structure.setup_entities() # this should deduplicate polymer entities for entity_idx, entity in enumerate(structure.entities): entity.name = str(entity_idx + 1) if self._name is not None: structure.name = self._name return structure