Source code for openprotein.api.predict

from typing import Optional, List, Union, Any, Dict, Literal
from openprotein.pydantic import BaseModel, root_validator

from openprotein.base import APISession
from import AsyncJobFuture
from import ResultsParser, Job, register_job_type, JobType, JobStatus
from openprotein.errors import InvalidParameterError, APIError
from openprotein.futures import FutureFactory, FutureBase

class SequenceData(BaseModel):
    sequence: str

class SequenceDataset(BaseModel):
    sequences: List[str]

class _Prediction(BaseModel):
    """Prediction details."""

    def extract_pred(cls, values):
        p = values.pop("properties")
        name = list(p.keys())[0]
        ymu = p[name]["y_mu"]
        yvar = p[name]["y_var"]
        p["name"] = name
        p["y_mu"] = ymu
        p["y_var"] = yvar

        return values

    model_id: str
    model_name: str
    y_mu: Optional[float] = None
    y_var: Optional[float] = None
    name: Optional[str]

class Prediction(BaseModel):
    """Prediction details."""

    model_id: str
    model_name: str
    properties: Dict[str, Dict[str, float]]

class PredictJobBase(Job):
    # might be none if just fetching
    job_id: Optional[str] = None
    job_type: str
    status: JobStatus

class PredictJob(PredictJobBase):
    """Properties about predict job returned via API."""

    def extract_pred(cls, values):
        # Extracting 'predictions' and 'sequences' from the input values
        v = values.pop("result")
        preds = [i["predictions"] for i in v]
        seqs = [i["sequence"] for i in v]
        values["result"] = [
            {"sequence": i, "predictions": p} for i, p in zip(seqs, preds)
        return values

    class SequencePrediction(BaseModel):
        """Sequence prediction."""

        sequence: str
        predictions: List[Prediction] = []

    result: Optional[List[SequencePrediction]] = None
    job_type: str

class PredictSingleSiteJob(PredictJobBase):
    """Properties about single-site prediction job returned via API."""

    class SequencePrediction(BaseModel):
        """Sequence prediction."""

        position: int
        amino_acid: str
        # sequence: str
        predictions: List[Prediction] = []

    result: Optional[List[SequencePrediction]] = None
    job_type: Literal[JobType.worflow_predict_single_site] = (

def _create_predict_job(
    session: APISession,
    endpoint: str,
    payload: dict,
    model_ids: Optional[List[str]] = None,
    train_job_id: Optional[str] = None,
) -> FutureBase:
    Creates a Predict request and returns the job object.

    This function makes a post request to the specified endpoint with the payload.
    Either 'model_ids' or 'train_job_id' should be provided but not both.

    session : APISession
        APIsession with auth
    endpoint : str
        The endpoint to which the post request is to be made.
        either predict or predict/single_site
    payload : dict
        The payload to be sent in the post request.
    model_ids : List[str], optional
        The list of model ids to be used for Predict. Default is None.
    train_job_id : str, optional
        The id of the train job to be used for Predict. Default is None.

        The job object representing the Predict job.

        If neither 'model_ids' nor 'train_job_id' is provided.
        If both 'model_ids' and 'train_job_id' are provided.
        If the post request does not succeed.
        If the response cannot be parsed into a 'Job' object.

    if model_ids is None and train_job_id is None:
        raise InvalidParameterError(
            "Either a list of model IDs or a train job ID must be provided"

    if model_ids is not None and train_job_id is not None:
        raise InvalidParameterError(
            "Only a list of model IDs OR a train job ID must be provided, not both"

    if model_ids is not None:
        payload["model_id"] = model_ids
        payload["train_job_id"] = train_job_id

    response =, json=payload)
    return FutureFactory.create_future(session=session, response=response)

def create_predict_job(
    session: APISession,
    sequences: SequenceDataset,
    train_job: Optional[Any] = None,
    model_ids: Optional[List[str]] = None,
) -> FutureBase:
    Creates a predict job with a given set of sequences and a train job.

    This function will use the sequences and train job ID to create a new Predict job.

    session : APISession
        APIsession with auth
    sequences : SequenceDataset
        The dataset containing the sequences to predict
    train_job : Any
        The Train job: this model will be used for making Predicts.
    model_ids: List[str]
        specific IDs for models

        The job object representing the created Predict job.

        If neither 'model_ids' nor 'train_job' is provided.
        If BOTH `model_ids` and `train_job` is provided
        If the post request does not succeed.
        If the response cannot be parsed into a 'Job' object.
    if isinstance(model_ids, str):
        model_ids = [model_ids]
    endpoint = "v1/workflow/predict"
    payload = {"sequences": sequences.sequences}
    train_job_id = if train_job is not None else None
    return _create_predict_job(
        session, endpoint, payload, model_ids=model_ids, train_job_id=train_job_id

def create_predict_single_site(
    session: APISession,
    sequence: SequenceData,
    train_job: Any,
    model_ids: Optional[List[str]] = None,
) -> FutureBase:
    Creates a predict job for single site mutants with a given sequence and a train job.

    session : APISession
        APIsession with auth
    sequence : SequenceData
        The sequence for which single site mutants predictions will be made.
    train_job : Any
        The train job whose model will be used for making Predicts.
    model_ids: List[str]
        specific IDs for models

        The job object representing the created Predict job.

        If neither 'model_ids' nor 'train_job' is provided.
        If BOTH `model_ids` and `train_job` is provided
        If the post request does not succeed.
        If the response cannot be parsed into a 'Job' object.
    endpoint = "v1/workflow/predict/single_site"
    payload = {"sequence": sequence.sequence}
    return _create_predict_job(
        session, endpoint, payload, model_ids=model_ids,

def get_prediction_results(
    session: APISession,
    job_id: str,
    page_size: Optional[int] = None,
    page_offset: Optional[int] = None,
) -> PredictJob:
    Retrieves the results of a Predict job.

    session : APISession
        APIsession with auth
    job_id : str
        The ID of the job whose results are to be retrieved.
    page_size : Optional[int], default is None
        The number of results to be returned per page. If None, all results are returned.
    page_offset : Optional[int], default is None
        The number of results to skip. If None, defaults to 0.

        The job object representing the Predict job.

        If the GET request does not succeed.
    endpoint = f"v1/workflow/predict/{job_id}"
    params = {}
    if page_size is not None:
        params["page_size"] = page_size
    if page_offset is not None:
        params["page_offset"] = page_offset

    response = session.get(endpoint, params=params)
    # get results to assemble into list
    return ResultsParser.parse_obj(response.json())

def get_single_site_prediction_results(
    session: APISession,
    job_id: str,
    page_size: Optional[int] = None,
    page_offset: Optional[int] = None,
) -> PredictSingleSiteJob:
    Retrieves the results of a single site Predict job.

    session : APISession
        APIsession with auth
    job_id : str
        The ID of the job whose results are to be retrieved.
    page_size : Optional[int], default is None
        The number of results to be returned per page. If None, all results are returned.
    page_offset : Optional[int], default is None
        The number of results to skip. If None, defaults to 0.

        The job object representing the single site Predict job.

        If the GET request does not succeed.
    endpoint = f"v1/workflow/predict/single_site/{job_id}"
    params = {}
    if page_size is not None:
        params["page_size"] = page_size
    if page_offset is not None:
        params["page_offset"] = page_offset

    response = session.get(endpoint, params=params)
    # get results to assemble into list
    return ResultsParser.parse_obj(response)

class PredictFutureMixin:
    Class to to retrieve results from a Predict job.

    session : APISession
        APIsession with auth
    job : PredictJob
        The job object that represents the current Predict job.

    get_results(page_size: Optional[int] = None, page_offset: Optional[int] = None) -> Union[PredictSingleSiteJob, PredictJob]
        Retrieves results from a Predict job.

    session: APISession
    job: PredictJob
    id: Optional[str] = None

    def get_results(
        self, page_size: Optional[int] = None, page_offset: Optional[int] = None
    ) -> Union[PredictSingleSiteJob, PredictJob]:
        Retrieves results from a Predict job.

        it uses the appropriate method to retrieve the results based on job_type.

        page_size : Optional[int], default is None
            The number of results to be returned per page. If None, all results are returned.
        page_offset : Optional[int], default is None
            The number of results to skip. If None, defaults to 0.

        Union[PredictSingleSiteJob, PredictJob]
            The job object representing the Predict job. The exact type of job depends on the job type.

            If the GET request does not succeed.
        assert is not None
        if "single_site" in self.job.job_type:
            return get_single_site_prediction_results(
                self.session,, page_size, page_offset
            return get_prediction_results(self.session,, page_size, page_offset)

[docs] class PredictFuture(PredictFutureMixin, AsyncJobFuture, FutureBase): # type: ignore """Future Job for manipulating results""" job_type = [JobType.workflow_predict, JobType.worflow_predict_single_site]
[docs] def __init__(self, session: APISession, job: PredictJob, page_size=1000): super().__init__(session, job) self.page_size = page_size
def __str__(self) -> str: return str(self.job) def __repr__(self) -> str: return repr(self.job) @property def id(self): return self.job.job_id def _fmt_results(self, results): properties = set( list(i["properties"].keys())[0] for i in results[0].dict()["predictions"] ) dict_results = {} for p in properties: dict_results[p] = {} for i, r in enumerate(results): s = r.sequence props = [[p] for i in r.predictions if p in][0] dict_results[p][s] = {"mean": props["y_mu"], "variance": props["y_var"]} dict_results return dict_results def _fmt_ssp_results(self, results): properties = set( list(i["properties"].keys())[0] for i in results[0].dict()["predictions"] ) dict_results = {} for p in properties: dict_results[p] = {} for i, r in enumerate(results): s = s = f"{r.position+1}{r.amino_acid}" props = [[p] for i in r.predictions if p in][0] dict_results[p][s] = {"mean": props["y_mu"], "variance": props["y_var"]} return dict_results
[docs] def get(self, verbose: bool = False) -> Dict: """ Get all the results of the predict job. Args: verbose (bool, optional): If True, print verbose output. Defaults False. Raises: APIError: If there is an issue with the API request. Returns: PredictJob: A list of predict objects representing the results. """ step = self.page_size results: List = [] num_returned = step offset = 0 while num_returned >= step: try: response = self.get_results(page_offset=offset, page_size=step) assert isinstance(response.result, list) results += response.result num_returned = len(response.result) offset += num_returned except APIError as exc: if verbose: print(f"Failed to get results: {exc}") if self.job.job_type == JobType.workflow_predict: return self._fmt_results(results) else: return self._fmt_ssp_results(results)
class PredictService: """interface for calling Predict endpoints""" def __init__(self, session: APISession): """ Initialize a new instance of the PredictService class. Parameters ---------- session : APISession APIsession with auth """ self.session = session def create_predict_job( self, sequences: List, train_job: Optional[Any] = None, model_ids: Optional[List[str]] = None, ) -> PredictFuture: """ Creates a new Predict job for a given list of sequences and a trained model. Parameters ---------- sequences : List The list of sequences to be used for the Predict job. train_job : Any The train job object representing the trained model. model_ids : List[str], optional The list of model ids to be used for Predict. Default is None. Returns ------- PredictFuture The job object representing the Predict job. Raises ------ InvalidParameterError If the sequences are not of the same length as the assay data or if the train job has not completed successfully. InvalidParameterError If BOTH train_job and model_ids are specified InvalidParameterError If NEITHER train_job or model_ids is specified APIError If the backend refuses the job (due to sequence length or invalid inputs) """ if train_job is not None: if train_job.assaymetadata is not None: if train_job.assaymetadata.sequence_length is not None: if any( [ train_job.assaymetadata.sequence_length != len(s) for s in sequences ] ): raise InvalidParameterError( f"Predict sequences length {len(sequences[0])} != training assaydata ({train_job.assaymetadata.sequence_length})" ) if not train_job.done(): print(f"WARNING: training job has status {train_job.status}") # raise InvalidParameterError( # f"train job has status {train_job.status.value}, Predict requires status SUCCESS" # ) sequence_dataset = SequenceDataset(sequences=sequences) return create_predict_job( self.session, sequence_dataset, train_job, model_ids=model_ids # type: ignore ) def create_predict_single_site( self, sequence: str, train_job: Any, model_ids: Optional[List[str]] = None, ) -> PredictFuture: """ Creates a new Predict job for single site mutation analysis with a trained model. Parameters ---------- sequence : str The sequence for single site analysis. train_job : Any The train job object representing the trained model. model_ids : List[str], optional The list of model ids to be used for Predict. Default is None. Returns ------- PredictFuture The job object representing the Predict job. Raises ------ InvalidParameterError If the sequences are not of the same length as the assay data or if the train job has not completed successfully. InvalidParameterError If BOTH train_job and model_ids are specified InvalidParameterError If NEITHER train_job or model_ids is specified APIError If the backend refuses the job (due to sequence length or invalid inputs) """ if train_job.assaymetadata is not None: if train_job.assaymetadata.sequence_length is not None: if any([train_job.assaymetadata.sequence_length != len(sequence)]): raise InvalidParameterError( f"Predict sequences length {len(sequence)} != training assaydata ({train_job.assaymetadata.sequence_length})" ) train_job.refresh() if not train_job.done(): print(f"WARNING: training job has status {train_job.status}") # raise InvalidParameterError( # f"train job has status {train_job.status.value}, Predict requires status SUCCESS" # ) sequence_dataset = SequenceData(sequence=sequence) return create_predict_single_site( self.session, sequence_dataset, train_job, model_ids=model_ids # type: ignore )