import enum
import gzip
import io
import warnings
from collections.abc import Sequence
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal, Type, TypeVar, cast, overload
import gemmi
import numpy as np
import numpy.typing as npt
import openprotein.utils.chain_id as _chain_id_utils
import openprotein.utils.cif as _cif_utils
import openprotein.utils.numpy as _numpy_utils
import openprotein.utils.sequence as _sequence_utils
from .. import fasta
if TYPE_CHECKING:
from ..align.msa import MSAFuture
from .complex import Complex
V = TypeVar("V")
class StrEnum(str, enum.Enum): ...
@enum.unique
class Binding(StrEnum):
# TODO: should we use any/X/?/* or something else instead of unknown?
UNKNOWN = "U"
BINDING = "B"
NOT_BINDING = "N"
# TODO: deserialization note about plddt parsed per residue
[docs]
class Protein:
"""
Represents a protein with an optional name.
This class supports partial or complete information: users may create a Protein
with only a sequence, only a structure, or both. The class ensures that all
provided fields have consistent residue-level lengths and provides convenient
methods for indexing, masking, and structural comparisons.
Conventions:
- Missing or unknown residues in the sequence are denoted by b"X".
- Missing structural data (coordinates or pLDDT) are represented by NaN.
- Residue indices are 1-indexed for user-facing methods suffixed with `at` E.g.
`.at`, `mask_sequence_at`
Examples:
Create a Protein from sequence only:
Protein(sequence="ACDEFGHIK")
Create a Protein from sequence and name:
Protein(sequence="ACDEFGHIK", name="my_protein")
"""
def __init__(self, sequence: bytes | str, name: bytes | str | None = None):
assert set(
sequence if isinstance(sequence, str) else sequence.decode()
).issubset(
set(_sequence_utils.AMINO_ACIDS + _sequence_utils.EXTRA_TOKENS)
), "Expected only amino acids or the mask token 'X' or the variable length token '?'\nHint: Use Protein.from_expr if using a sequence expression"
self._sequence = sequence.encode() if isinstance(sequence, str) else sequence
self.name = name
# sequence-level properties
self._cyclic: bool = False
self._msa: "str | MSAFuture | None | Type[Protein.NullMSA]" = None
# per-residue arrays
self._data: dict[str, npt.NDArray] = {}
@property
def name(self) -> str | None:
return self._name
@name.setter
def name(self, x: bytes | str | None) -> None:
self._name = x.decode() if isinstance(x, bytes) else x
def set_name(self, x: bytes | str | None) -> "Protein":
self.name = x
return self
@property
def sequence(self) -> bytes:
return self._sequence
@sequence.setter
def sequence(self, x: bytes | str) -> None:
assert set(x if isinstance(x, str) else x.decode()).issubset(
set(_sequence_utils.AMINO_ACIDS + _sequence_utils.EXTRA_TOKENS)
), "Expected only amino acids or the mask token 'X' or the variable length token '?'\nHint: Use Protein.from_expr if using a sequence expression"
assert len(x) == len(self)
self._sequence = x.encode() if isinstance(x, str) else x
def set_sequence(self, x: bytes | str) -> "Protein":
self.sequence = x
return self
@property
def coordinates(self) -> npt.NDArray[np.float32]:
return _numpy_utils.readonly_view(self._coordinates)
@property
def plddt(self) -> npt.NDArray[np.float32]:
return _numpy_utils.readonly_view(self._plddt)
@property
def cyclic(self) -> bool:
return self._cyclic
@cyclic.setter
def cyclic(self, cyclic: bool) -> None:
self._cyclic = cyclic
def get_cyclic(self) -> bool:
return self._cyclic
def set_cyclic(self, x: bool) -> "Protein":
self._cyclic = x
return self
class NullMSA: ...
single_sequence_mode = NullMSA
@property
def msa(self) -> "str | MSAFuture | None | Type[NullMSA]":
"""A reference identifier to the MSA associated to this protein."""
return self._msa
@msa.setter
def msa(self, msa: "str | MSAFuture | None | Type[NullMSA]") -> None:
# NB: no defensive copy of msa b/c we don't want to copy things like session
# objects, but msa should really be immutable anyways...
self._msa = msa
def get_msa(self) -> "str | MSAFuture | None | Type[NullMSA]":
return self._msa
def set_msa(self, x: "str | MSAFuture | None | Type[NullMSA]") -> "Protein":
self._msa = x
return self
def __len__(self):
return len(self.sequence)
def __getitem__(
self, idx: int | slice | Sequence[int] | npt.NDArray[np.integer]
) -> "Protein":
"""Return a new Protein object indexing into residues by `idx`."""
if isinstance(idx, int):
idx = np.array([idx], dtype=int)
elif isinstance(idx, slice):
idx = np.arange(idx.start or 0, idx.stop or len(self), idx.step or 1)
elif not isinstance(idx, np.ndarray):
idx = np.fromiter(idx, dtype=int)
new = Protein(
sequence=np.frombuffer(self.sequence, dtype=np.uint8)[idx].tobytes(),
name=self.name,
)
# TODO: check msa compatible?
new = new.set_msa(self._msa).set_cyclic(self._cyclic)
new._data = {k: v[idx].copy() for k, v in self._data.items()}
return new
def __radd__(self, left: "Protein | str") -> "Protein":
assert isinstance(left, Protein) or isinstance(left, str)
if isinstance(left, str):
left = self.from_expr(expr=left)
return left + self
def __add__(self, right: "Protein | str") -> "Protein":
"""Return a new Protein object by concatenating with another Protein."""
assert isinstance(right, Protein) or isinstance(right, str)
if isinstance(right, str):
right = self.from_expr(right)
# TODO: if either cyclic, should we actually disable adding?
assert right._msa == self._msa and right._cyclic == self._cyclic
new = Protein(
sequence=self.sequence + right.sequence,
name=( # set name if equal, or if only one of the two have a name
self.name
if self.name == right.name or right.name is None
else right.name if self.name is None else None
),
)
new = new.set_msa(self._msa).set_cyclic(self._cyclic)
new._data = {
k: np.concatenate((getattr(self, f"_{k}"), getattr(right, f"_{k}")))
for k in self._data.keys() | right._data.keys()
}
return new
def __rand__(self, left: "Complex | Protein | str") -> "Complex":
if isinstance(left, str):
left = self.from_expr(expr=left)
return left & self
def __and__(self, right: "Complex | Protein | str") -> "Complex":
"""Combine multiple objects into a single Complex."""
from .complex import Complex
assert (
isinstance(right, Complex)
or isinstance(right, Protein)
or isinstance(right, str)
)
if isinstance(right, str):
right = self.from_expr(right)
if isinstance(right, Protein):
id_gen = _chain_id_utils.id_generator()
return Complex({next(id_gen): self, next(id_gen): right})
return right & self
[docs]
def at(self, positions: Sequence[int] | npt.NDArray[np.integer]) -> "Protein":
"""
Return a new Protein object containing residues at given 1-indexed positions.
"""
if not isinstance(positions, np.ndarray):
positions = np.fromiter(positions, dtype=int)
return self[positions - 1]
[docs]
def mask_sequence(self) -> "Protein":
"""Mask entire sequence."""
return self.mask_sequence_except_at([])
[docs]
def mask_sequence_at(
self, positions: Sequence[int] | npt.NDArray[np.integer]
) -> "Protein":
"""Mask sequence at given 1-indexed positions."""
if not isinstance(positions, np.ndarray):
positions = np.fromiter(positions, dtype=int)
sequence = np.frombuffer(self.sequence, dtype=np.uint8).copy()
sequence[positions - 1] = ord(b"X")
return self.set_sequence(sequence.tobytes())
[docs]
def mask_sequence_except_at(
self, positions: Sequence[int] | npt.NDArray[np.integer]
) -> "Protein":
"""Mask sequence at all positions except the given 1-indexed positions."""
if not isinstance(positions, np.ndarray):
positions = np.fromiter(positions, dtype=int)
sequence = np.frombuffer(self.sequence, dtype=np.uint8).copy()
mask = np.ones_like(sequence, dtype=bool)
mask[positions - 1] = False
sequence[mask] = ord(b"X")
return self.set_sequence(sequence.tobytes())
[docs]
def mask_structure(self, side_chain_only: bool = False) -> "Protein":
"""Mask entire structure."""
return self.mask_structure_except_at([], side_chain_only=side_chain_only)
[docs]
def mask_structure_at(
self,
positions: Sequence[int] | npt.NDArray[np.integer],
side_chain_only: bool = False,
) -> "Protein":
"""Mask structure at given 1-indexed positions."""
if not isinstance(positions, np.ndarray):
positions = np.fromiter(positions, dtype=int)
idxs = positions - 1
atom_idxs = (
np.arange(len(_ATOM_TYPES))
if not side_chain_only
else _SIDE_CHAIN_ATOM_IDXS
)
self._coordinates[np.ix_(idxs, atom_idxs)] = np.nan
if not side_chain_only:
self._plddt[idxs] = np.nan
return self
[docs]
def mask_structure_except_at(
self,
positions: Sequence[int] | npt.NDArray[np.integer],
side_chain_only: bool = False,
) -> "Protein":
"""Mask structure at all positions except the given 1-indexed positions."""
if not isinstance(positions, np.ndarray):
positions = np.fromiter(positions, dtype=int)
mask = np.ones(len(self), dtype=bool)
mask[positions - 1] = False
return self.mask_structure_at(
positions=np.where(mask)[0] + 1, side_chain_only=side_chain_only
)
[docs]
def get_structure_mask(self) -> npt.NDArray[np.bool_]:
"""
Computes the structure mask of the protein. The structure mask is a boolean
array indicating, at each position, whether the structure is undefined at that
position.
"""
return np.all(np.all(np.isnan(self._coordinates), axis=2), axis=1)
@property
def has_structure(self) -> bool:
"""Whether or not the structure is known at any position in the protein."""
return (not np.isnan(self._coordinates).all()) or (
not np.isnan(self._plddt).all()
)
def get_group_at(
self, positions: Sequence[int] | npt.NDArray[np.integer]
) -> npt.NDArray[np.int_]:
if not isinstance(positions, np.ndarray):
positions = np.fromiter(positions, dtype=int)
return _numpy_utils.readonly_view(self._group[positions - 1])
def set_group_at(
self,
positions: Sequence[int] | npt.NDArray[np.integer],
value: int | Sequence[int],
) -> "Protein":
if not isinstance(positions, np.ndarray):
positions = np.fromiter(positions, dtype=int)
self._group[positions - 1] = value
return self
@property
def group(self) -> npt.NDArray[np.integer]:
return _numpy_utils.readonly_view(self._group)
def get_group(self) -> npt.NDArray[np.integer]:
return _numpy_utils.readonly_view(self._group)
def set_group(self, value: int) -> "Protein":
self._group[:] = value
return self
def get_binding_at(
self, positions: Sequence[int] | npt.NDArray[np.integer]
) -> npt.NDArray[np.str_]:
if not isinstance(positions, np.ndarray):
positions = np.fromiter(positions, dtype=int)
return _numpy_utils.readonly_view(self._binding[positions - 1])
def set_binding_at(
self,
positions: Sequence[int] | npt.NDArray[np.integer],
value: Binding | str | Sequence[Binding | str],
) -> "Protein":
if not isinstance(positions, np.ndarray):
positions = np.fromiter(positions, dtype=int)
self._binding[positions - 1] = _enum_to_str(value=value, enum_type=Binding)
return self
@property
def binding(self) -> npt.NDArray[np.str_]:
return _numpy_utils.readonly_view(self._binding)
def get_binding(self) -> npt.NDArray[np.str_]:
return _numpy_utils.readonly_view(self._binding)
@overload
def rmsd(
self,
tgt: "Protein",
backbone_only: bool | str | Sequence[str] = False,
return_transform: Literal[False] = False,
) -> float: ...
@overload
def rmsd(
self,
tgt: "Protein",
backbone_only: bool | str | Sequence[str] = False,
return_transform: Literal[True] = True,
) -> tuple[float, npt.NDArray[np.floating], npt.NDArray[np.floating]]: ...
[docs]
def rmsd(
self,
tgt: "Protein",
backbone_only: bool | str | Sequence[str] = False,
return_transform: bool = False,
) -> float | tuple[float, npt.NDArray[np.floating], npt.NDArray[np.floating]]:
"""
Compute the root-mean-square deviation (RMSD) between this Protein and a target
Protein.
Only atoms that are present (i.e., not NaN) in both structures are included in
the calculation.
Args:
tgt: The target Protein to compare against.
backbone_only: Specifies which atoms to include in the RMSD calculation.
- If False (default), all atom types are included.
- If True, only backbone atoms ("N", "CA", "C") are included.
- If a string, it must be a single atom type (e.g., "CA").
- If a sequence of strings, it must be a non-empty list of atom types
(e.g., ["CA", "CB", "O"]). All specified atom types must be valid.
return_transform: If True, returns both the rmsd and the transformation that
should be applied to `tgt` to superimpose it onto this Protein. If False
(default), returns only the rmsd value.
Returns:
If `return_transform` is False (default):
The RMSD value (float).
If `return_transform` is True:
A tuple `(float, np.ndarray, np.ndarray)` containing the RMSD value,
the rotation matrix, and the translation vector.
Notes:
This method assumes that sequences of `self` and `tgt` are already aligned.
"""
if backbone_only is False:
atom_idxs = np.arange(len(_ATOM_TYPES))
elif backbone_only is True:
atom_idxs = np.arange(3)
elif isinstance(backbone_only, str):
atom_idxs = [_ATOM_TYPE_TO_IDX[backbone_only]]
elif isinstance(backbone_only, Sequence):
assert len(backbone_only) > 0 and isinstance(next(iter(backbone_only)), str)
atom_idxs = [_ATOM_TYPE_TO_IDX[x] for x in backbone_only]
else:
raise ValueError(backbone_only)
src_coords = self._coordinates[:, atom_idxs]
tgt_coords = tgt._coordinates[:, atom_idxs]
src_known_atoms = ~np.isnan(src_coords).any(axis=-1)
tgt_known_atoms = ~np.isnan(tgt_coords).any(axis=-1)
overlapping_known_atoms = src_known_atoms & tgt_known_atoms
src_coords = src_coords[overlapping_known_atoms]
tgt_coords = tgt_coords[overlapping_known_atoms]
rmsd, R, t = _calc_rmsd_and_transform(src_coords, tgt_coords)
if return_transform:
return rmsd, R, t
return rmsd
def transform(
self,
R: npt.NDArray[np.floating] | None = None,
t: npt.NDArray[np.floating] | None = None,
) -> "Protein":
if R is None:
R = np.eye(3, dtype=np.float32)
if t is None:
t = np.zeros(3, dtype=np.float32)
return self._set_coordinates(
self._coordinates @ R.T.astype(np.float32) + t.astype(np.float32)
)
def superimpose_onto(
self, tgt: "Protein", backbone_only: bool | str | Sequence[str] = False
) -> "Protein":
_, R, t = tgt.rmsd(self, backbone_only=backbone_only, return_transform=True)
return self.transform(R=R, t=t)
[docs]
def to_string(self, format: Literal["cif", "pdb"] = "cif") -> str:
"""
Serialize this Protein to a string. Note that format="pdb" may not serialize all
aspects of this object, so format="cif", the default, is preferred.
"""
if format == "cif":
return self._make_cif_string()
elif format == "pdb":
return self._make_pdb_string()
else:
raise ValueError(format)
def make_cif_string(self) -> str:
warnings.warn(
"`make_cif_string()` is deprecated and will be removed in v0.11. "
"Use `to_string()` instead.",
FutureWarning,
stacklevel=2,
)
return self._make_cif_string()
def make_pdb_string(self) -> str:
warnings.warn(
"`make_pdb_string()` is deprecated and will be removed in v0.11. "
'Use `to_string(format="pdb")` instead.',
FutureWarning,
stacklevel=2,
)
return self._make_pdb_string()
def make_fasta_bytes(self) -> bytes:
assert self.name is not None
data = io.BytesIO()
data.write(b">")
data.write(self.name.encode())
data.write(b"\n")
data.write(self.sequence)
data.write(b"\n")
return data.getvalue()
[docs]
@staticmethod
def from_expr(expr: str | int, name: str | None = None) -> "Protein":
"""
Create a Protein from a sequence expression.
A sequence expression allows you to define protein sequences using a concise
notation that mixes fixed sequences, design regions, and length ranges.
Useful for creating a design :py:class:`~openprotein.prompt.Query`.
Args:
expr: Sequence expression string or integer
- Fixed sequences: "ACGT" (literal amino acids)
- Design regions: "6" or 6 (any 6 amino acids)
- Length ranges: "3..5" (between 3-5 amino acids)
- Combined: "AAAA6C3..5" (AAAA + 6 design + C + 3-5 design)
name: Optional name for the protein
Returns:
Protein object with the parsed sequence
Examples:
>>> # Fixed sequence with 6 flexible positions and fixed end
>>> Protein.from_expr("MKLL6VVAA").sequence
>>> b'MKLLXXXXXXVVAA'
>>> # Design region of any 15 amino acids
>>> Protein.from_expr(15).sequence
>>> b'XXXXXXXXXXXXXXX'
>>> # Variable length region between 10-20 residues
>>> Protein.from_expr("10..20").sequence
>>> b'XXXXXXXXXX??????????'
"""
if isinstance(expr, int):
expr = str(expr)
sequence = _sequence_utils.SequenceExpr.parse(expr).to_protein_sequence()
return Protein(sequence=sequence, name=name)
[docs]
@staticmethod
def from_filepath(
path: Path | str,
chain_id: str,
use_bfactor_as_plddt: bool | None = None,
model_idx: int = 0,
verbose: bool = True,
) -> "Protein":
"""
Create a Protein from a structure file.
If the structure file has multiple conformers, the first conformer is always
used.
Args:
path: path to structure file (e.g. pdb or cif file)
chain_id: id of the chain in the structure file to use
use_bfactor_as_plddt: whether or not to use bfactors as pLDDTs. If None,
this parameter will be determined based on heuristics. These heuristics
may change over time.
model_idx: index of the model in the structure file to use
verbose: whether or not to print debugging information such as oddities in
the structure e.g. missing atoms
"""
path = Path(path)
if path.suffix == ".gz":
if path.name.endswith(".cif.gz"):
ext, format = ".cif.gz", "cif"
elif path.name.endswith(".pdb.gz"):
ext, format = ".pdb.gz", "pdb"
else:
raise ValueError(f"unsupported format: {path}")
with gzip.open(path, "rb") as f:
data = f.read()
else:
ext = path.suffix
format = ext.removeprefix(".")
assert format == "cif" or format == "pdb"
data = path.read_bytes()
return Protein.from_string(
filestring=data,
format=format,
chain_id=chain_id,
use_bfactor_as_plddt=use_bfactor_as_plddt,
model_idx=model_idx,
verbose=verbose,
).set_name(path.name.removesuffix(ext))
@staticmethod
def from_string(
filestring: bytes | str,
format: Literal["pdb", "cif"],
chain_id: str,
use_bfactor_as_plddt: bool | None = None,
model_idx: int = 0,
verbose: bool = True,
) -> "Protein":
structure_block = _cif_utils.StructureCIFBlock(
filestring=filestring, format=format
)
return Protein._from_structure_block(
structure_block=structure_block,
chain_id=chain_id,
use_bfactor_as_plddt=use_bfactor_as_plddt,
model_idx=model_idx,
verbose=verbose,
)
def __str__(self):
return self.formatted(include=("sequence",))
def copy(self) -> "Protein":
return self[:]
@staticmethod
def _from_structure_block(
structure_block: _cif_utils.StructureCIFBlock,
chain_id: str,
use_bfactor_as_plddt: bool | None = None,
model_idx: int = 0,
verbose: bool = True,
) -> "Protein":
structure = structure_block.structure
maybe_use_bfactor_as_plddt = use_bfactor_as_plddt is None
if use_bfactor_as_plddt is None:
use_bfactor_as_plddt = _use_bfactor_as_plddt(structure=structure)
model = structure[model_idx] if len(structure) > 0 else None
subchain_id = chain_id
if (
structure.input_format == gemmi.CoorFormat.Pdb
and not subchain_id.endswith("xp")
and model is not None
):
subchain_id = model.find_chain(chain_id).get_polymer().subchain_id()
if model is None or len(model.get_subchain(subchain_id)) == 0:
matching_entities = [
e for e in structure.entities if subchain_id in e.subchains
]
assert len(matching_entities) == 1, (
f"expected only one entity to match {chain_id=}, but found "
f"{len(matching_entities)}: {matching_entities}"
)
entity = matching_entities[0]
assert (
entity.entity_type == gemmi.EntityType.Polymer
), f"expected entity type polymer, got {entity.entity_type}"
assert (
entity.polymer_type == gemmi.PolymerType.PeptideL
), f"expected polymer type PeptideL, got {entity.polymer_type}"
if len(entity.full_sequence) > 0:
chain_seq = entity.full_sequence
else:
chain_seq, _ = structure_block.full_sequences[entity.name]
chain_seq = _extract_one_letter_from_full_sequence(chain_seq)
protein = Protein(
sequence="".join(chain_seq),
name=structure.name if structure.name != "" else None,
)
protein._set_loop_data(
structure_block=structure_block, chain_id=subchain_id
)
return protein
model = structure[model_idx]
polymer = model.get_subchain(subchain_id)
assert len(polymer) > 0
# extract sequence
entity = structure.get_entity_of(polymer)
residues = list(polymer.first_conformer())
# TODO: consider utilizing polymer.make_one_letter_sequence() here or elsewhere
del polymer
if len(entity.full_sequence) > 0:
chain_seq, label_seq_offset = entity.full_sequence, 1
elif entity.name in structure_block.full_sequences:
chain_seq, label_seq_offset = structure_block.full_sequences[entity.name]
else:
chain_seq, label_seq_offset = _extract_full_sequence_from_residues(
residues=residues
)
chain_seq = _extract_one_letter_from_full_sequence(full_sequence=chain_seq)
# extract coordinates and plddt
coordinates = np.full((len(chain_seq), _N_ATOM, 3), np.nan, dtype=np.float32)
plddt = np.full(len(chain_seq), np.nan, dtype=np.float32)
for residue_idx, residue in enumerate(residues):
i = (
residue.label_seq - label_seq_offset
if residue.label_seq is not None
else residue_idx
)
code = gemmi.find_tabulated_residue(residue.name).one_letter_code
code = code.upper() if code != " " else "X"
if code != chain_seq[i]:
if verbose:
# TODO: can this ever happen...? probably want to have this regardless i guess
# TODO: improve this message?
print(
f"Amino acid mismatch at position {i + 1}: SEQRES {chain_seq[i]} Structure {code}"
)
chain_seq[i] = code
if verbose and code == "X" and residue.name != "UNK":
print(f"Unknown amino acid at position {i + 1}: {residue.name}")
if verbose:
for j, atom_name in enumerate(_BACKBONE_ATOM_TYPES):
if atom_name not in residue:
print(
f"Residue at position {i + 1} missing backbone atom={atom_name}"
)
for atom in residue.first_conformer():
atom_name = atom.name
if residue.name == "MSE" and atom_name == "SE":
atom_name = "SD"
if (j := _ATOM_TYPE_TO_IDX.get(atom.name)) is None:
continue
coordinates[i, j] = atom.pos.tolist()
if use_bfactor_as_plddt and atom_name == "CA":
plddt[i] = (
atom.b_iso if atom.b_iso != _NAN_BFACTOR_VALUE else np.nan
)
# TODO: we should experiment and see if this is the behavior we want
if (
not use_bfactor_as_plddt
and np.isfinite(coordinates[i, _ATOM_TYPE_TO_IDX["CA"]]).all()
):
plddt[i] = 100.0
if (
maybe_use_bfactor_as_plddt
and not np.isnan(plddt).all()
and np.nanmax(plddt) <= 10
):
plddt[~np.isnan(plddt)] = 100.0 # these were almost surely not plddts
assert np.isnan(plddt).all() or (
(np.nanmin(plddt) >= 0) and (np.nanmax(plddt) <= 100)
)
protein = Protein(
sequence="".join(chain_seq),
name=structure.name if structure.name != "" else None,
)
protein._coordinates, protein._plddt = coordinates, plddt
protein._set_loop_data(structure_block=structure_block, chain_id=subchain_id)
return protein
def _set_loop_data(
self, structure_block: _cif_utils.StructureCIFBlock, chain_id: str
):
# TODO: chain id overload, extract this into its own method
# TODO: for all tables, support optional columns being missing
columns = ["label_asym_id", "?cyclic", "?msa_id"]
table = structure_block.block.find("_openprotein_sequence.", columns)
if len(table) > 0:
assert all(table.has_column(i) for i in range(len(columns)))
for _chain_id, cyclic, msa_id in table:
if _chain_id != chain_id:
continue
self._cyclic = cyclic == "1"
if msa_id == ".":
self._msa = Protein.single_sequence_mode
elif msa_id != "?":
self._msa = msa_id
columns = [
"label_asym_id",
"label_seq_id",
"label_atom_id",
"?group",
"?binding",
]
table = structure_block.block.find("_openprotein_atom.", columns)
if len(table) > 0:
assert all(table.has_column(i) for i in range(len(columns)))
for _chain_id, seq_id, atom_id, group, binding in table:
if _chain_id != chain_id:
continue
idx = int(seq_id) - 1
assert idx < len(self)
assert atom_id == ".", "atom level not supported yet"
self._group[idx] = int(group)
self._binding[idx] = binding
@property
def _coordinates(self) -> npt.NDArray[np.float32]:
if "coordinates" not in self._data:
self._data["coordinates"] = np.full(
(len(self), _N_ATOM, 3), np.nan, dtype=np.float32
)
return self._data["coordinates"]
@_coordinates.setter
def _coordinates(self, x: npt.NDArray[np.float32]) -> None:
assert x.dtype == np.float32 and x.shape == (len(self), _N_ATOM, 3)
self._data["coordinates"] = x
@property
def _plddt(self) -> npt.NDArray[np.float32]:
if "plddt" not in self._data:
self._data["plddt"] = np.full(len(self), np.nan, dtype=np.float32)
return self._data["plddt"]
@_plddt.setter
def _plddt(self, x: npt.NDArray[np.float32]) -> None:
assert x.dtype == np.float32 and x.shape == (len(self),)
self._data["plddt"] = x
@property
def _group(self) -> npt.NDArray[np.int_]:
if "group" not in self._data:
self._data["group"] = np.zeros(len(self), dtype=int)
return self._data["group"]
@_group.setter
def _group(self, x: npt.NDArray[np.int_]) -> None:
assert np.issubdtype(x.dtype, np.integer) and x.shape == (len(self),)
self._data["group"] = x.astype(int, copy=False)
@property
def _binding(self) -> npt.NDArray[np.str_]:
if "binding" not in self._data:
self._data["binding"] = np.full(
len(self), Binding.UNKNOWN.value, dtype="<U1"
)
return self._data["binding"]
@_binding.setter
def _binding(self, x: npt.NDArray[np.str_]) -> None:
assert x.dtype == "<U1" and x.shape == (len(self),)
assert set(x).issubset({e.value for e in Binding})
self._data["binding"] = x
def _make_structure(
self,
structure: gemmi.Structure | None = None,
model_idx: int = 1,
chain_id: str = "A",
entity_name: str = "1",
) -> gemmi.Structure:
# TODO: add note about _NAN_BFACTOR_VALUE
# Create an empty structure and add a model with a default chain.
if structure is None:
structure = gemmi.Structure()
if self.name is not None:
structure.name = self.name
# Get existing model or create new one
if len(structure) > 0:
model = structure[model_idx]
else:
model = structure.add_model(gemmi.Model(str(model_idx))) # type: ignore - gemmi 0.6 needs str
# Process the sequence.
# TODO: handle optional token...?
resnames = gemmi.expand_one_letter_sequence(
self.sequence.decode(), gemmi.ResidueKind.AA
)
entity = gemmi.Entity(entity_name)
entity.full_sequence = resnames
entity.entity_type = gemmi.EntityType.Polymer
entity.polymer_type = gemmi.PolymerType.PeptideL
entity.subchains = [chain_id]
structure.entities.append(entity)
# Process the coordinates.
n_nan_coords = np.isnan(self._coordinates).sum(axis=2)
assert (
(n_nan_coords == 0) | (n_nan_coords == 3)
).all(), "either all coords of an atom must be nan, or none are"
# Process the plddt.
assert (
np.isnan(self._plddt) | (~np.isnan(self._plddt) & (n_nan_coords[:, 1] == 0))
).all(), "if plddt is known, coord of CA must be known"
# Write the chain
chain = model.add_chain(gemmi.Chain(chain_id))
for i in range(len(self)):
# Add a residue to the chain; note that residue numbering starts at 1.
residue = gemmi.Residue()
residue.entity_id = entity_name
residue.entity_type = gemmi.EntityType.Polymer
residue.subchain = chain_id
residue.name = resnames[i]
residue.label_seq = i + 1
residue.seqid = gemmi.SeqId(str(i + 1))
residue = chain.add_residue(residue, i + 1)
# For each residue, add the atoms.
for j, atom_name in enumerate(_ATOM_TYPES):
if np.isnan(self._coordinates[i, j]).any():
continue
atom = gemmi.Atom()
atom.name = atom_name
atom.element = gemmi.Element(atom_name[0])
atom.pos = gemmi.Position(*self._coordinates[i, j])
if not np.isnan(self._plddt[i]):
atom.b_iso = self._plddt[i]
else:
atom.b_iso = _NAN_BFACTOR_VALUE
atom = residue.add_atom(atom)
return structure
def _append_loop_data(
self, chain_id: str, sequence_loop: gemmi.cif.Loop, atom_loop: gemmi.cif.Loop
):
if self._cyclic or self._msa is not None:
if self._msa is None:
msa_id = "?" # cif convention for unknown
elif isinstance(self._msa, type):
msa_id = "." # cif convention for not applicable
elif isinstance(self._msa, str):
msa_id = self._msa
else:
msa_id = self._msa.id
sequence_loop.add_row([chain_id, "1" if self._cyclic else "0", msa_id])
for idx, (binding, group) in enumerate(zip(self._binding, self._group)):
if binding == Binding.UNKNOWN and group == 0:
continue # don't write default
atom_loop.add_row(
# "."" for atom id indicates residue level annotation
[chain_id, str(idx + 1), ".", str(group), binding]
)
def _make_cif_string(self) -> str:
# TODO: make gemmi take into account chain_id
structure = self._make_structure()
# NB: gemmi doesn't seem to write the _chem_comp category properly... it says
# the type is `.`, but is should be something like `L-PEPTIDE LINKING`...
# see also: https://github.com/project-gemmi/gemmi/discussions/362
block = structure.make_mmcif_block(
groups=gemmi.MmcifOutputGroups(True, chem_comp=False)
)
sequence_loop, atom_loop = _cif_utils.init_loops(block=block)
self._append_loop_data(
chain_id="A", sequence_loop=sequence_loop, atom_loop=atom_loop
)
return block.as_string()
def _make_pdb_string(self) -> str:
# TODO: make gemmi take into account chain_id
structure = self._make_structure()
return structure.make_pdb_string(gemmi.PdbWriteOptions(minimal=True))
def _set_coordinates(self, x: npt.NDArray[np.float32]) -> "Protein":
self._coordinates = x.copy()
return self
def parse_fasta_as_proteins(path: str | Path) -> list[Protein]:
proteins = []
with open(path, "rb") as fp:
for name, sequence in fasta.parse_stream(fp):
proteins.append(Protein(name=name, sequence=sequence))
return proteins
T = TypeVar("T", bound=StrEnum)
def _enum_to_str(
value: T | str | Sequence[T | str], enum_type: Type[T]
) -> str | Sequence[str]:
if isinstance(value, enum_type):
value = value.value
elif isinstance(value, str):
value = enum_type(value)
value = value.value
else:
value = [enum_type(v) if isinstance(v, str) else v for v in value]
value = [v.value for v in value]
return value
# fmt: off
_ATOM_TYPES = (
'N', 'CA', 'C', 'CB', 'O', 'CG', 'CG1', 'CG2', 'OG', 'OG1', 'SG', 'CD',
'CD1', 'CD2', 'ND1', 'ND2', 'OD1', 'OD2', 'SD', 'CE', 'CE1', 'CE2', 'CE3',
'NE', 'NE1', 'NE2', 'OE1', 'OE2', 'CH2', 'NH1', 'NH2', 'OH', 'CZ', 'CZ2',
'CZ3', 'NZ', 'OXT'
)
# fmt: on
_N_ATOM = len(_ATOM_TYPES)
_ATOM_TYPE_TO_IDX = {atom_type: i for i, atom_type in enumerate(_ATOM_TYPES)}
_BACKBONE_ATOM_TYPES = ("N", "CA", "C")
_SIDE_CHAIN_ATOM_IDXS = np.array(
[
i
for i, atom_type in enumerate(_ATOM_TYPES)
if atom_type not in _BACKBONE_ATOM_TYPES + ("O",)
]
)
_EXPERIMENTAL_METHODS = {
"X-RAY DIFFRACTION",
"ELECTRON MICROSCOPY",
"SOLUTION NMR",
"SOLID-STATE NMR",
"NEUTRON DIFFRACTION",
"ELECTRON CRYSTALLOGRAPHY",
"FIBER DIFFRACTION",
"POWDER DIFFRACTION",
"INFRARED SPECTROSCOPY",
"FLUORESCENCE TRANSFER",
"EPR",
"SOLUTION SCATTERING",
}
_NAN_BFACTOR_VALUE = 9999.75 # can't/hard to use 9999.99 due to precision issues
def _calc_rmsd_and_transform(
xyz1: npt.NDArray[np.floating], xyz2: npt.NDArray[np.floating], eps: float = 1e-6
) -> tuple[float, npt.NDArray[np.floating], npt.NDArray[np.floating]]:
"""
Calculates RMSD and the rigid transformation (R, t) to superimpose xyz2 onto xyz1.
Adapted from https://github.com/RosettaCommons/RFdiffusion/blob/b44206a2a79f219bb1a649ea50603a284c225050/rfdiffusion/util.py#L719
Returns:
rmsd: Root Mean Square Deviation.
R: Rotation matrix (3, 3) such that xyz2 @ R.T + t aligns with xyz1.
t: Translation vector (3,) such that xyz2 @ R.T + t aligns with xyz1.
"""
# 1. Compute means to center the coordinates
mu1 = xyz1.mean(axis=0)
mu2 = xyz2.mean(axis=0)
xyz1_c = xyz1 - mu1
xyz2_c = xyz2 - mu2
# 2. Computation of the covariance matrix
C = xyz2_c.T @ xyz1_c
# 3. Compute optimal rotation matrix using SVD
# Note: numpy.linalg.svd returns U, S, Vh (where Vh is V.T)
# The variable names V, W below follow the original code's notation logic
V, S, W = np.linalg.svd(C)
# 4. Get sign to ensure right-handedness (correct for reflections)
d = np.ones((3, 3))
d[:, -1] = np.sign(np.linalg.det(V) * np.linalg.det(W))
# 5. Rotation matrix U (applied on the right: xyz_new = xyz_old @ U)
# This U corresponds to R.T in the formula: x_new = R @ x_old
U = (d * V) @ W
# 6. Rotate xyz2 (centered) to calculate RMSD
xyz2_aligned_c = xyz2_c @ U
L = xyz2_aligned_c.shape[0]
rmsd = np.sqrt(np.sum((xyz2_aligned_c - xyz1_c) ** 2, axis=(0, 1)) / L + eps)
# 7. Compute R and t
# We want: xyz2 @ R.T + t
# We have: xyz2_aligned = (xyz2 - mu2) @ U + mu1
# Expand: xyz2_aligned = xyz2 @ U - mu2 @ U + mu1
# Therefore: R.T = U => R = U.T
# t = mu1 - mu2 @ U
R = U.T
t = mu1 - mu2 @ U
return rmsd, R.astype(xyz2.dtype), t.astype(xyz2.dtype)
def _is_experimental_structure(structure: gemmi.Structure) -> bool:
"""
This heuristic decides whether the structure is an experimental structure.
This heuristic may be changed in the future.
"""
if structure.resolution > 0:
return True
else:
return ("_exptl.method" in structure.info) and (
structure.info["_exptl.method"] in _EXPERIMENTAL_METHODS
)
def _use_bfactor_as_plddt(structure: gemmi.Structure) -> bool:
"""
This heuristic decides whether to use B-factor as pLDDT.
This heuristic may be changed in the future.
"""
return not _is_experimental_structure(structure=structure)
def _extract_full_sequence_from_residues(
residues: list[gemmi.Residue],
) -> tuple[list[str], int]:
if all(residue.label_seq is not None for residue in residues):
label_seqs = [cast(int, residue.label_seq) for residue in residues]
first_label_seq, last_label_seq = min(label_seqs), max(label_seqs)
chain_seq = ["UNK"] * (last_label_seq - first_label_seq + 1)
for residue in residues:
chain_seq[cast(int, residue.label_seq) - first_label_seq] = residue.name
else:
assert all(residue.label_seq is None for residue in residues), (
"if entity.full_sequence is blank, then either all residues must "
"have label_seq or all residues must not have label_seq"
)
chain_seq, first_label_seq = [residue.name for residue in residues], 0
return chain_seq, first_label_seq
def _extract_one_letter_from_full_sequence(full_sequence: Sequence[str]) -> list[str]:
chain_seq = [
gemmi.find_tabulated_residue(
# gemmi.Entity.first_mon extracts the first conformer
gemmi.Entity.first_mon(residue_name)
).one_letter_code
for residue_name in full_sequence
]
# for find_tabulated_residue: lowercase means nonstandard, " " means unknown
return [c.upper() if c != " " else "X" for c in chain_seq]