Source code for openprotein.app.models.embeddings.future
"""Future for embeddings-related jobs."""
from collections import namedtuple
from typing import Generator
import numpy as np
from openprotein import config
from openprotein.api import embedding
from openprotein.base import APISession
from openprotein.schemas import (
AttnJob,
EmbeddingsJob,
GenerateJob,
JobType,
LogitsJob,
ScoreJob,
ScoreSingleSiteJob,
)
from ..futures import Future, MappedFuture, StreamingFuture
[docs]
class EmbeddingsResultFuture(MappedFuture, Future):
"""Future for manipulating results for embeddings-related requests."""
job: EmbeddingsJob | AttnJob | LogitsJob
[docs]
def __init__(
self,
session: APISession,
job: EmbeddingsJob | AttnJob | LogitsJob,
sequences: list[bytes] | list[str] | None = None,
max_workers: int = config.MAX_CONCURRENT_WORKERS,
):
super().__init__(session=session, job=job, max_workers=max_workers)
# ensure all list[bytes]
self._sequences = (
[seq.encode() if isinstance(seq, str) else seq for seq in sequences]
if sequences is not None
else sequences
)
[docs]
def get(self, verbose=False) -> list:
return super().get(verbose=verbose)
@property
def sequences(self) -> list[bytes] | list[str]:
if self._sequences is None:
self._sequences = embedding.get_request_sequences(
session=self.session, job_id=self.job.job_id, job_type=self.job.job_type
)
return self._sequences
@property
def id(self):
return self.job.job_id
def keys(self):
return self.sequences
[docs]
def get_item(self, sequence: bytes) -> np.ndarray:
"""
Get embedding results for specified sequence.
Args:
sequence (bytes): sequence to fetch results for
Returns:
np.ndarray: embeddings
"""
data = embedding.request_get_sequence_result(
session=self.session,
job_id=self.job.job_id,
sequence=sequence,
job_type=self.job.job_type,
)
return embedding.result_decode(data)
[docs]
class EmbeddingsScoreFuture(StreamingFuture, Future):
"""Future for manipulating results for embeddings score-related requests."""
job: ScoreJob | ScoreSingleSiteJob
[docs]
def __init__(
self,
session: APISession,
job: ScoreJob | ScoreSingleSiteJob | GenerateJob,
sequences: list[bytes] | list[str] | None = None,
):
super().__init__(session=session, job=job)
self._sequences = sequences
@property
def sequences(self) -> list[bytes] | list[str]:
if self._sequences is None:
self._sequences = embedding.get_request_sequences(
self.session, self.job.job_id
)
return self._sequences
[docs]
def stream(self) -> Generator:
if self.job_type == JobType.poet_generate:
stream = embedding.request_get_generate_result(
session=self.session, job_id=self.id
)
else:
stream = embedding.request_get_score_result(
session=self.session, job_id=self.id
)
# mut_code, ... for ssp
# name, sequence, ... for score
header = next(stream)
score_start_index = 0
for i, col_name in enumerate(header):
if col_name.startswith("score"):
score_start_index = i
break
Score = namedtuple("Score", header[:score_start_index] + ["score"])
for line in stream:
# combine scores into numpy array
scores = np.array([float(s) for s in line[score_start_index:]])
output = Score(*line[:score_start_index], scores)
yield output
[docs]
class EmbeddingsGenerateFuture(EmbeddingsScoreFuture):
"""Future for manipulating results for embeddings generate-related requests."""
job: GenerateJob
@property
def sequences(self):
raise Exception("generate job does not support listing sequences")