import gzip
import operator
from collections.abc import Mapping, Sequence
from functools import reduce
from pathlib import Path
from types import MappingProxyType
from typing import TYPE_CHECKING, Literal, overload
import gemmi
import numpy as np
import numpy.typing as npt
import openprotein.utils.chain_id as _chain_id_utils
import openprotein.utils.cif as _cif_utils
from .chains import DNA, RNA, Ligand
from .protein import Protein
if TYPE_CHECKING:
from .template import Template
# TODO: deserialization note about plddt parsed per residue
[docs]
class Complex:
def __init__(
self,
chains: Mapping[str, Protein | DNA | RNA | Ligand] | None = None,
name: bytes | str | None = None,
):
self._chains = dict(sorted(chains.items())) if chains is not None else {}
self._templates: "Sequence[Protein | Complex | Template]" = ()
self.name = name
@property
def name(self) -> str | None:
return self._name
@name.setter
def name(self, x: bytes | str | None) -> None:
self._name = x.decode() if isinstance(x, bytes) else x
def get_name(self) -> str | None:
return self._name
def set_name(self, x: bytes | str | None) -> "Complex":
self.name = x
return self
@property
def templates(self) -> "Sequence[Protein | Complex | Template]":
"""A list of templates for guiding the structure prediction of this molecular complex."""
return self._templates
@templates.setter
def templates(self, templates: "Sequence[Protein | Complex | Template]") -> None:
self._templates = tuple(templates)
def get_templates(self) -> "Sequence[Protein | Complex | Template]":
return self.templates
def set_templates(
self, templates: "Sequence[Protein | Complex | Template]"
) -> "Complex":
self.templates = templates
return self
def get_chains(self) -> Mapping[str, Protein | DNA | RNA | Ligand]:
return MappingProxyType(self._chains)
def get_proteins(self) -> Mapping[str, Protein]:
return MappingProxyType(
{k: v for k, v in self._chains.items() if isinstance(v, Protein)}
)
def get_protein(self, chain_id: str) -> Protein:
chain = self._chains[chain_id]
assert isinstance(chain, Protein)
return chain
def get_dnas(self) -> Mapping[str, DNA]:
return MappingProxyType(
{k: v for k, v in self._chains.items() if isinstance(v, DNA)}
)
def get_dna(self, chain_id: str) -> DNA:
chain = self._chains[chain_id]
assert isinstance(chain, DNA)
return chain
def get_rnas(self) -> Mapping[str, RNA]:
return MappingProxyType(
{k: v for k, v in self._chains.items() if isinstance(v, RNA)}
)
def get_rna(self, chain_id: str) -> RNA:
chain = self._chains[chain_id]
assert isinstance(chain, RNA)
return chain
def get_ligands(self) -> Mapping[str, Ligand]:
return MappingProxyType(
{k: v for k, v in self._chains.items() if isinstance(v, Ligand)}
)
def get_ligand(self, chain_id: str) -> Ligand:
chain = self._chains[chain_id]
assert isinstance(chain, Ligand)
return chain
def set_chain(
self, chain_id: str, value: Protein | DNA | RNA | Ligand
) -> "Complex":
self._chains[chain_id] = value
self._chains = dict(sorted(self._chains.items()))
return self
def __rand__(self, left: "Complex | Protein | str") -> "Complex":
if isinstance(left, str):
left = Protein.from_expr(expr=left)
return left & self
def __and__(self, right: "Complex | Protein | str") -> "Complex":
"""Combine multiple objects into a single Complex."""
assert (
isinstance(right, Complex)
or isinstance(right, Protein)
or isinstance(right, str)
)
id_gen = _chain_id_utils.id_generator(list(self._chains.keys()))
if isinstance(right, str):
right = Protein.from_expr(right)
if isinstance(right, Protein):
self.set_chain(chain_id=next(id_gen), value=right)
else:
if (
len(overlapping_chain_ids := self._chains.keys() & right._chains.keys())
> 0
):
raise ValueError(
f"Trying to combine two sets of chains with overlapping chain ids: {overlapping_chain_ids}"
)
self._chains = dict(sorted((self._chains | right._chains).items()))
return self
@overload
def rmsd(
self,
tgt: "Complex",
backbone_only: bool | str | Sequence[str] = False,
return_transform: Literal[False] = False,
) -> float: ...
@overload
def rmsd(
self,
tgt: "Complex",
backbone_only: bool | str | Sequence[str] = False,
return_transform: Literal[True] = True,
) -> tuple[float, npt.NDArray[np.floating], npt.NDArray[np.floating]]: ...
def rmsd(
self,
tgt: "Complex",
backbone_only: bool | str | Sequence[str] = False,
return_transform: bool = False,
) -> float | tuple[float, npt.NDArray[np.floating], npt.NDArray[np.floating]]:
assert all(
isinstance(v, Protein) for v in self._chains.values()
), "rmsd supported only for Protein chains, not supported for non-protein chains"
assert all(
isinstance(v, Protein) for v in tgt._chains.values()
), "rmsd supported only for Protein chains, not supported for non-protein chains"
src_proteins, tgt_proteins = self.get_proteins(), tgt.get_proteins()
assert tgt_proteins.keys() == src_proteins.keys()
assert [len(x) for x in src_proteins.values()] == [
len(x) for x in tgt_proteins.values()
]
src_protein: Protein = reduce(operator.add, src_proteins.values())
tgt_protein: Protein = reduce(operator.add, tgt_proteins.values())
return src_protein.rmsd(
tgt_protein,
backbone_only=backbone_only,
return_transform=return_transform,
)
def transform(
self,
R: npt.NDArray[np.floating] | None = None,
t: npt.NDArray[np.floating] | None = None,
) -> "Complex":
assert all(
isinstance(v, Protein) for v in self._chains.values()
), "transform supported only for Protein chains, not supported for non-protein chains"
for protein in self.get_proteins().values():
protein.transform(R=R, t=t)
return self
def superimpose_onto(
self, tgt: "Complex", backbone_only: bool | str | Sequence[str] = False
) -> "Complex":
_, R, t = tgt.rmsd(self, backbone_only=backbone_only, return_transform=True)
return self.transform(R=R, t=t)
[docs]
def to_string(self, format: Literal["cif", "pdb"] = "cif") -> str:
"""
Serialize this Complex to a string. Note that format="pdb" may not serialize all
aspects of this object, so format="cif", the default, is preferred.
"""
if format == "cif":
return self._make_cif_string()
elif format == "pdb":
return self._make_pdb_string()
else:
raise ValueError(format)
@staticmethod
def from_filepath(
path: Path | str,
use_bfactor_as_plddt: bool | None = None,
model_idx: int = 0,
verbose: bool = True,
) -> "Complex":
path = Path(path)
if path.suffix == ".gz":
if path.name.endswith(".cif.gz"):
ext, format = ".cif.gz", "cif"
elif path.name.endswith(".pdb.gz"):
ext, format = ".pdb.gz", "pdb"
else:
raise ValueError(f"unsupported format: {path}")
with gzip.open(path, "rb") as f:
data = f.read()
else:
ext = path.suffix
format = ext.removeprefix(".")
assert format == "cif" or format == "pdb"
data = path.read_bytes()
return Complex.from_string(
filestring=data,
format=format,
use_bfactor_as_plddt=use_bfactor_as_plddt,
model_idx=model_idx,
verbose=verbose,
).set_name(path.name.removesuffix(ext))
@staticmethod
def from_string(
filestring: bytes | str,
format: Literal["pdb", "cif"],
use_bfactor_as_plddt: bool | None = None,
model_idx: int = 0,
verbose: bool = True,
) -> "Complex":
structure_block = _cif_utils.StructureCIFBlock(
filestring=filestring, format=format
)
return Complex._from_structure_block(
structure_block=structure_block,
use_bfactor_as_plddt=use_bfactor_as_plddt,
model_idx=model_idx,
verbose=verbose,
)
def copy(self) -> "Complex":
return Complex(
chains={k: v.copy() for k, v in self._chains.items()}, name=self._name
)
def _assert_valid_templates(self):
from .template import Template
for template in self.templates:
(
template if isinstance(template, Template) else Template(template)
).validate_for_target(self)
for chain_id, protein in self.get_proteins().items():
for template in protein.templates:
(
template
if isinstance(template, Template)
else Template(template, mapping=chain_id)
).validate_for_target(Complex({chain_id: protein}))
@staticmethod
def _from_structure_block(
structure_block: _cif_utils.StructureCIFBlock,
use_bfactor_as_plddt: bool | None = None,
model_idx: int = 0,
verbose: bool = True,
) -> "Complex":
block, structure = structure_block.block, structure_block.structure
model = structure[model_idx] if len(structure) > 0 else None
# Use block info directly so that we can get chains with empty struct info
subchain_ids = [x for x in block.find_loop("_struct_asym.id")]
if len(subchain_ids) == 0 and model is not None:
# Try to get actual chain IDs from the structure
subchain_ids = [subchain.subchain_id() for subchain in model.subchains()]
# collect chains
chains = {}
for subchain_id in sorted(subchain_ids):
subchain = model.get_subchain(subchain_id) if model is not None else None
# Get the entity for this chain to determine its type
if subchain is not None and len(subchain) > 0:
entity = structure.get_entity_of(subchain)
if entity is None:
raise ValueError(f"Could not find entity for chain {subchain_id}")
else:
matching_entities = [
e for e in structure.entities if subchain_id in e.subchains
]
assert len(matching_entities) == 1, (
f"expected only one entity to match {chain_id=}, but found "
f"{len(matching_entities)}: {matching_entities}"
)
entity = matching_entities[0]
del matching_entities
# Determine chain type based on entity type and polymer type
if (entity_type := entity.entity_type) == gemmi.EntityType.Polymer:
if structure.input_format == gemmi.CoorFormat.Pdb:
assert subchain_id.endswith("xp")
chain_id = subchain_id.removesuffix("xp")
assert chain_id not in chains
else:
chain_id = subchain_id
if (polymer_type := entity.polymer_type) in (
gemmi.PolymerType.PeptideL,
gemmi.PolymerType.PeptideD,
):
chains[chain_id] = Protein._from_structure_block(
structure_block=structure_block,
chain_id=subchain_id,
use_bfactor_as_plddt=use_bfactor_as_plddt,
model_idx=model_idx,
verbose=verbose,
)
elif polymer_type == gemmi.PolymerType.Dna:
chains[chain_id] = DNA._from_structure_block(
structure_block=structure_block,
chain_id=subchain_id,
model_idx=model_idx,
)
elif polymer_type == gemmi.PolymerType.Rna:
chains[chain_id] = RNA._from_structure_block(
structure_block=structure_block,
chain_id=subchain_id,
model_idx=model_idx,
)
else:
# if verbose:
# print(
# f"Warning: Skipping unsupported polymer type {polymer_type} for chain {subchain_id}"
# )
continue
elif entity_type == gemmi.EntityType.NonPolymer:
if structure.input_format == gemmi.CoorFormat.Pdb:
raise ValueError("ligands from pdb files not supported yet")
chain_id = subchain_id
assert (
structure.input_format != gemmi.CoorFormat.Pdb
), "ligands from pdb files not supported yet"
chains[chain_id] = Ligand._from_structure_block(
structure_block=structure_block,
chain_id=subchain_id,
model_idx=model_idx,
)
elif entity_type == gemmi.EntityType.Water:
continue
else:
# if verbose:
# print(
# f"Warning: Skipping unsupported entity type {entity_type} for chain {subchain_id}"
# )
continue
return Complex(chains=chains, name=structure.name)
def _make_cif_string(self) -> str:
structure = self._make_structure()
block = structure.make_mmcif_block(
groups=gemmi.MmcifOutputGroups(True, chem_comp=False)
)
sequence_loop, atom_loop = _cif_utils.init_loops(block=block)
for chain_id, chain in self._chains.items():
chain._append_loop_data(
chain_id=chain_id, sequence_loop=sequence_loop, atom_loop=atom_loop
)
return block.as_string()
def _make_pdb_string(self) -> str:
structure = self._make_structure()
return structure.make_pdb_string(gemmi.PdbWriteOptions(minimal=True))
def _make_structure(self) -> gemmi.Structure:
assert (
len(set(x._structure_block for x in self.get_ligands().values())) <= 1
), "can only serialize ligands if they all originate from the same structure file"
structure = gemmi.Structure()
for chain_id, chain in self._chains.items():
structure = chain._make_structure(
structure=structure,
model_idx=0,
chain_id=chain_id,
entity_name=str(len(structure.entities) + 1),
)
structure.setup_entities() # this should deduplicate polymer entities
for entity_idx, entity in enumerate(structure.entities):
entity.name = str(entity_idx + 1)
if self._name is not None:
structure.name = self._name
return structure