Source code for openprotein.svd.svd
"""SVD API providing the interface for creating and using SVD models."""
from openprotein.base import APISession
from openprotein.common import ReductionType
from openprotein.data import AssayDataset
from openprotein.embeddings import EmbeddingsAPI
from . import api
from .models import SVDModel
[docs]
class SVDAPI:
"""SVD API providing the interface for creating and using SVD models."""
[docs]
def __init__(
self,
session: APISession,
):
self.session = session
[docs]
def fit_svd(
self,
model_id: str,
sequences: list[bytes] | list[str] | None = None,
assay: AssayDataset | None = None,
n_components: int = 1024,
reduction: ReductionType | None = None,
**kwargs,
) -> SVDModel:
"""
Fit an SVD on the sequences with the specified model_id and hyperparameters (n_components).
Parameters
----------
model_id : str
The ID of the model to fit the SVD on.
sequences : list[bytes]
The list of sequences to use for the SVD fitting.
n_components : int, optional
The number of components for the SVD, by default 1024.
reduction : str, optional
The reduction method to apply to the embeddings, by default None.
Returns
-------
SVDModel
The model with the SVD fit.
"""
embeddings_api = getattr(self.session, "embedding", None)
assert isinstance(embeddings_api, EmbeddingsAPI)
model = embeddings_api.get_model(model_id)
return model.fit_svd(
sequences=sequences,
assay=assay,
n_components=n_components,
reduction=reduction,
**kwargs,
)
[docs]
def get_svd(self, svd_id: str) -> SVDModel:
"""
Get SVD job results. Including SVD dimension and sequence lengths.
Requires a successful SVD job from fit_svd
Parameters
----------
svd_id : str
The ID of the SVD job.
Returns
-------
SVDModel
The model with the SVD fit.
"""
metadata = api.svd_get(self.session, svd_id)
return SVDModel(
session=self.session,
metadata=metadata,
)
def __delete_svd(self, svd_id: str) -> bool:
"""
Delete SVD model.
Parameters
----------
svd_id : str
The ID of the SVD job.
Returns
-------
bool
True: successful deletion
"""
return api.svd_delete(self.session, svd_id)
[docs]
def list_svd(self) -> list[SVDModel]:
"""
List SVD models made by user.
Takes no args.
Returns
-------
list[SVDModel]
SVDModels
"""
return [
SVDModel(
session=self.session,
metadata=metadata,
)
for metadata in api.svd_list_get(self.session)
]