Source code for openprotein.app.models.embeddings.poet

import warnings
from typing import TYPE_CHECKING

from openprotein.api import embedding
from openprotein.api.deprecated import poet
from openprotein.base import APISession
from openprotein.schemas import ModelMetadata, ReductionType
from openprotein.schemas.deprecated.poet import (
    PoetGenerateJob,
    PoetScoreJob,
    PoetSSPJob,
)

from ..align import PromptFuture
from ..assaydata import AssayDataset, AssayMetadata
from .base import EmbeddingModel
from .future import (
    EmbeddingsGenerateFuture,
    EmbeddingsResultFuture,
    EmbeddingsScoreFuture,
)

if TYPE_CHECKING:
    from ..deprecated import PoetGenerateFuture, PoetScoreFuture, PoetSingleSiteFuture
    from ..predictor import PredictorModel
    from ..svd import SVDModel
    from ..umap import UMAPModel


[docs] class PoETModel(EmbeddingModel): """ Class for OpenProtein's foundation model PoET - NB. PoET functions are dependent on a prompt supplied via the align endpoints. Examples -------- View specific model details (inc supported tokens) with the `?` operator. .. code-block:: python import openprotein session = openprotein.connect(username="user", password="password") session.embedding.poet.<embeddings_method> """ model_id = "poet" _deprecated: "Deprecated | None" = None # TODO - Add model to explicitly require prompt_id
[docs] def __init__( self, session: APISession, model_id: str, metadata: ModelMetadata | None = None ): self.session = session self.id = model_id self._metadata = metadata
# could add prompt here?
[docs] def embed( self, prompt: str | PromptFuture, sequences: list[bytes], reduction: ReductionType | None = ReductionType.MEAN, ) -> EmbeddingsResultFuture: """ Embed sequences using this model. Parameters ---------- prompt: Union[str, PromptFuture] prompt from an align workflow to condition Poet model sequence : bytes Sequence to embed. reduction: str embeddings reduction to use (e.g. mean) Returns ------- EmbeddingResultFuture A future object that returns the embeddings of the submitted sequences. """ prompt_id = prompt.id if isinstance(prompt, PromptFuture) else prompt return super().embed( sequences=sequences, reduction=reduction, prompt_id=prompt_id )
[docs] def logits( self, prompt: str | PromptFuture, sequences: list[bytes], ) -> EmbeddingsResultFuture: """ logit embeddings for sequences using this model. Parameters ---------- prompt: Union[str, PromptFuture] prompt from an align workflow to condition Poet model sequence : bytes Sequence to analyse. Returns ------- EmbeddingResultFuture A future object that returns the logits of the submitted sequences. """ prompt_id = prompt.id if isinstance(prompt, PromptFuture) else prompt return super().logits( sequences=sequences, prompt_id=prompt_id, )
[docs] def attn(self): """Not Available for Poet.""" raise ValueError("Attn not yet supported for poet")
[docs] def score( self, prompt: str | PromptFuture, sequences: list[bytes] ) -> EmbeddingsScoreFuture: """ Score query sequences using the specified prompt. Parameters ---------- prompt: str | PromptFuture Prompt or prompt_id or prompt from an align workflow to condition Poet model sequence: list[bytes] Sequences to score. Returns ------- EmbeddingsScoreFuture A future object that returns the scores of the submitted sequences. """ prompt_id = prompt.id if isinstance(prompt, PromptFuture) else prompt return EmbeddingsScoreFuture.create( session=self.session, job=embedding.request_score_post( session=self.session, model_id=self.id, prompt_id=prompt_id, sequences=sequences, ), )
[docs] def single_site( self, prompt: str | PromptFuture, sequence: bytes ) -> EmbeddingsScoreFuture: """ Score all single substitutions of the query sequence using the specified prompt. Parameters ---------- prompt: str | PromptFuture Prompt or prompt_id or prompt from an align workflow to condition Poet model sequence: bytes Sequence to analyse. Returns ------- EmbeddingsScoreFuture A future object that returns the scores of the mutated sequence. """ prompt_id = prompt.id if isinstance(prompt, PromptFuture) else prompt return EmbeddingsScoreFuture.create( session=self.session, job=embedding.request_score_single_site_post( session=self.session, model_id=self.id, base_sequence=sequence, prompt_id=prompt_id, ), )
[docs] def generate( self, prompt: str | PromptFuture, num_samples: int = 100, temperature: float = 1.0, topk: float | None = None, topp: float | None = None, max_length: int = 1000, seed: int | None = None, ) -> EmbeddingsScoreFuture: """ Generate protein sequences conditioned on a prompt. Parameters ---------- prompt: Union[str, PromptFuture] prompt from an align workflow to condition Poet model num_samples: int, optional The number of samples to generate, by default 100. temperature: float, optional The temperature for sampling. Higher values produce more random outputs, by default 1.0. topk: int, optional The number of top-k residues to consider during sampling, by default None. topp: float, optional The cumulative probability threshold for top-p sampling, by default None. max_length: int, optional The maximum length of generated proteins, by default 1000. seed: int, optional Seed for random number generation, by default a random number. Returns ------- EmbeddingsGenerateFuture A future object representing the status and information about the generation job. """ prompt_id = prompt.id if isinstance(prompt, PromptFuture) else prompt return EmbeddingsGenerateFuture.create( session=self.session, job=embedding.request_generate_post( session=self.session, model_id=self.id, num_samples=num_samples, temperature=temperature, topk=topk, topp=topp, max_length=max_length, random_seed=seed, prompt_id=prompt_id, ), )
[docs] def fit_svd( self, prompt: str | PromptFuture, sequences: list[bytes] | list[str] | None = None, assay: AssayDataset | None = None, n_components: int = 1024, reduction: ReductionType | None = None, ) -> "SVDModel": """ Fit an SVD on the embedding results of PoET. This function will create an SVDModel based on the embeddings from this model \ as well as the hyperparameters specified in the args. Parameters ---------- prompt: Union[str, PromptFuture] prompt from an align workflow to condition Poet model sequences : List[bytes] sequences to SVD n_components: int number of components in SVD. Will determine output shapes reduction: str embeddings reduction to use (e.g. mean) Returns ------- SVDModel A future that represents the fitted SVD model. """ prompt_id = prompt.id if isinstance(prompt, PromptFuture) else prompt return super().fit_svd( sequences=sequences, assay=assay, n_components=n_components, reduction=reduction, prompt_id=prompt_id, )
[docs] def fit_umap( self, prompt: str | PromptFuture, sequences: list[bytes] | list[str] | None = None, assay: AssayDataset | None = None, n_components: int = 2, reduction: ReductionType | None = ReductionType.MEAN, ) -> "UMAPModel": """ Fit a UMAP on assay using PoET and hyperparameters. This function will create a UMAP based on the embeddings from this PoET model \ as well as the hyperparameters specified in the args. Parameters ---------- prompt: Union[str, PromptFuture] prompt from an align workflow to condition Poet model sequences : list[bytes] | None Optional sequences to fit UMAP with. Either use sequences or assay. sequences is preferred. assay: AssayDataset | None Optional assay containing sequences to fit UMAP with. Either use sequences or assay. Ignored if sequences are provided. n_components: int Number of components in UMAP fit. Will determine output shapes. Defaults to 2. reduction: ReductionType | None Embeddings reduction to use (e.g. mean). Defaults to MEAN. Returns ------- UMAPModel A future that represents the fitted UMAP model. """ prompt_id = prompt.id if isinstance(prompt, PromptFuture) else prompt return super().fit_umap( sequences=sequences, assay=assay, n_components=n_components, reduction=reduction, prompt_id=prompt_id, )
[docs] def fit_gp( self, prompt: str | PromptFuture, assay: AssayMetadata | AssayDataset | str, properties: list[str], **kwargs, ) -> "PredictorModel": """ Fit a GP on assay using this embedding model and hyperparameters. Parameters ---------- assay : AssayMetadata | str Assay to fit GP on. properties: list[str] Properties in the assay to fit the gp on. reduction : str Type of embedding reduction to use for computing features. PLM must use reduction. Returns ------- PredictorModel A future that represents the trained predictor model. """ prompt_id = prompt.id if isinstance(prompt, PromptFuture) else prompt return super().fit_gp( assay=assay, properties=properties, prompt_id=prompt_id, **kwargs )
@property def deprecated(self): if self._deprecated is None: warnings.warn( "The old interface to PoET is deprecated! Support will be dropped in the future. Please migrate your code to use the new interface." ) from ..deprecated import ( PoetGenerateFuture, PoetScoreFuture, PoetSingleSiteFuture, ) self._deprecated = self.Deprecated(session=self.session) return self._deprecated class Deprecated: def __init__(self, session: APISession): self.session = session def score( self, prompt: str | PromptFuture, sequences: list[bytes], ) -> "PoetScoreFuture": """ (Deprecated) Score query sequences using the specified prompt. Parameters ---------- prompt: Union[str, PromptFuture] Prompt or prompt_id of prompt from an align workflow to condition Poet model sequences : list[bytes] Sequences to score. Returns ------- PoetScoreFuture A future object that returns the scores of the submitted sequences. """ from ..deprecated import PoetScoreFuture prompt_id = prompt.id if isinstance(prompt, PromptFuture) else prompt # HACK - manually construct the job and future since job types have been overwritten return PoetScoreFuture( session=self.session, job=PoetScoreJob( **poet.poet_score_post( session=self.session, prompt_id=prompt_id, queries=sequences, ).model_dump() ), ) def single_site( self, prompt: str | PromptFuture, sequence: bytes ) -> "PoetSingleSiteFuture": """ (Deprecated) Score query sequences using the specified prompt. Parameters ---------- prompt: str | PromptFuture Prompt or prompt_id of prompt from an align workflow to condition Poet model sequence: bytes Sequence to analyse. Returns ------- PoetSingleSiteFuture A future object that returns the scores of the mutated sequence. """ from ..deprecated import PoetSingleSiteFuture prompt_id = prompt.id if isinstance(prompt, PromptFuture) else prompt # HACK - manually construct the job and future since job types have been overwritten return PoetSingleSiteFuture( session=self.session, job=PoetSSPJob( **poet.poet_single_site_post( session=self.session, prompt_id=prompt_id, variant=sequence, ).model_dump() ), ) def generate( self, prompt: str | PromptFuture, num_samples: int = 100, temperature: float = 1.0, topk: float | None = None, topp: float | None = None, max_length: int = 1000, seed: int | None = None, ) -> "PoetGenerateFuture": """ (Deprecated) Generate protein sequences conditioned on a prompt. Parameters ---------- prompt: Union[str, PromptFuture] prompt from an align workflow to condition Poet model num_samples : int, optional The number of samples to generate, by default 100. temperature : float, optional The temperature for sampling. Higher values produce more random outputs, by default 1.0. topk : int, optional The number of top-k residues to consider during sampling, by default None. topp : float, optional The cumulative probability threshold for top-p sampling, by default None. max_length : int, optional The maximum length of generated proteins, by default 1000. seed : int, optional Seed for random number generation, by default a random number. Returns ------- PoetGenerateFuture A future object representing the status and information about the generation job. """ from ..deprecated import PoetGenerateFuture prompt_id = prompt.id if isinstance(prompt, PromptFuture) else prompt # HACK - manually construct the job and future since job types have been overwritten return PoetGenerateFuture( session=self.session, job=PoetGenerateJob( **poet.poet_generate_post( session=self.session, prompt_id=prompt_id, num_samples=num_samples, temperature=temperature, topk=topk, topp=topp, max_length=max_length, random_seed=seed, ).model_dump() ), )