Source code for openprotein.api.predict

from typing import Optional, List, Union, Any, Dict, Literal
from openprotein.pydantic import BaseModel, root_validator

from openprotein.base import APISession
from openprotein.api.jobs import AsyncJobFuture
from openprotein.jobs import ResultsParser, Job, register_job_type, JobType, JobStatus
from openprotein.errors import InvalidParameterError, APIError
from openprotein.futures import FutureFactory, FutureBase


class SequenceData(BaseModel):
    sequence: str


class SequenceDataset(BaseModel):
    sequences: List[str]


class _Prediction(BaseModel):
    """Prediction details."""

    @root_validator(pre=True)
    def extract_pred(cls, values):
        p = values.pop("properties")
        name = list(p.keys())[0]
        ymu = p[name]["y_mu"]
        yvar = p[name]["y_var"]
        p["name"] = name
        p["y_mu"] = ymu
        p["y_var"] = yvar

        values.update(p)
        return values

    model_id: str
    model_name: str
    y_mu: Optional[float] = None
    y_var: Optional[float] = None
    name: Optional[str]


class Prediction(BaseModel):
    """Prediction details."""

    model_id: str
    model_name: str
    properties: Dict[str, Dict[str, float]]


class PredictJobBase(Job):
    # might be none if just fetching
    job_id: Optional[str] = None
    job_type: str
    status: JobStatus


@register_job_type(JobType.workflow_predict)
class PredictJob(PredictJobBase):
    """Properties about predict job returned via API."""

    @root_validator(pre=True)
    def extract_pred(cls, values):
        # Extracting 'predictions' and 'sequences' from the input values
        v = values.pop("result")
        preds = [i["predictions"] for i in v]
        seqs = [i["sequence"] for i in v]
        values["result"] = [
            {"sequence": i, "predictions": p} for i, p in zip(seqs, preds)
        ]
        return values

    class SequencePrediction(BaseModel):
        """Sequence prediction."""

        sequence: str
        predictions: List[Prediction] = []

    result: Optional[List[SequencePrediction]] = None
    job_type: str


@register_job_type(JobType.worflow_predict_single_site)
class PredictSingleSiteJob(PredictJobBase):
    """Properties about single-site prediction job returned via API."""

    class SequencePrediction(BaseModel):
        """Sequence prediction."""

        position: int
        amino_acid: str
        # sequence: str
        predictions: List[Prediction] = []

    result: Optional[List[SequencePrediction]] = None
    job_type: Literal[JobType.worflow_predict_single_site] = (
        JobType.worflow_predict_single_site
    )


def _create_predict_job(
    session: APISession,
    endpoint: str,
    payload: dict,
    model_ids: Optional[List[str]] = None,
    train_job_id: Optional[str] = None,
) -> FutureBase:
    """
    Creates a Predict request and returns the job object.

    This function makes a post request to the specified endpoint with the payload.
    Either 'model_ids' or 'train_job_id' should be provided but not both.

    Parameters
    ----------
    session : APISession
        APIsession with auth
    endpoint : str
        The endpoint to which the post request is to be made.
        either predict or predict/single_site
    payload : dict
        The payload to be sent in the post request.
    model_ids : List[str], optional
        The list of model ids to be used for Predict. Default is None.
    train_job_id : str, optional
        The id of the train job to be used for Predict. Default is None.

    Returns
    -------
    PredictJob
        The job object representing the Predict job.

    Raises
    ------
    InvalidParameterError
        If neither 'model_ids' nor 'train_job_id' is provided.
        If both 'model_ids' and 'train_job_id' are provided.
    HTTPError
        If the post request does not succeed.
    ValidationError
        If the response cannot be parsed into a 'Job' object.
    """

    if model_ids is None and train_job_id is None:
        raise InvalidParameterError(
            "Either a list of model IDs or a train job ID must be provided"
        )

    if model_ids is not None and train_job_id is not None:
        raise InvalidParameterError(
            "Only a list of model IDs OR a train job ID must be provided, not both"
        )

    if model_ids is not None:
        payload["model_id"] = model_ids
    else:
        payload["train_job_id"] = train_job_id

    response = session.post(endpoint, json=payload)
    return FutureFactory.create_future(session=session, response=response)


def create_predict_job(
    session: APISession,
    sequences: SequenceDataset,
    train_job: Optional[Any] = None,
    model_ids: Optional[List[str]] = None,
) -> FutureBase:
    """
    Creates a predict job with a given set of sequences and a train job.

    This function will use the sequences and train job ID to create a new Predict job.

    Parameters
    ----------
    session : APISession
        APIsession with auth
    sequences : SequenceDataset
        The dataset containing the sequences to predict
    train_job : Any
        The Train job: this model will be used for making Predicts.
    model_ids: List[str]
        specific IDs for models

    Returns
    -------
    PredictJob
        The job object representing the created Predict job.

    Raises
    ------
    InvalidParameterError
        If neither 'model_ids' nor 'train_job' is provided.
    InvalidParameterError
        If BOTH `model_ids` and `train_job` is provided
    HTTPError
        If the post request does not succeed.
    ValidationError
        If the response cannot be parsed into a 'Job' object.
    """
    if isinstance(model_ids, str):
        model_ids = [model_ids]
    endpoint = "v1/workflow/predict"
    payload = {"sequences": sequences.sequences}
    train_job_id = train_job.id if train_job is not None else None
    return _create_predict_job(
        session, endpoint, payload, model_ids=model_ids, train_job_id=train_job_id
    )


def create_predict_single_site(
    session: APISession,
    sequence: SequenceData,
    train_job: Any,
    model_ids: Optional[List[str]] = None,
) -> FutureBase:
    """
    Creates a predict job for single site mutants with a given sequence and a train job.

    Parameters
    ----------
    session : APISession
        APIsession with auth
    sequence : SequenceData
        The sequence for which single site mutants predictions will be made.
    train_job : Any
        The train job whose model will be used for making Predicts.
    model_ids: List[str]
        specific IDs for models

    Returns
    -------
    PredictJob
        The job object representing the created Predict job.

    Raises
    ------
    InvalidParameterError
        If neither 'model_ids' nor 'train_job' is provided.
    InvalidParameterError
        If BOTH `model_ids` and `train_job` is provided
    HTTPError
        If the post request does not succeed.
    ValidationError
        If the response cannot be parsed into a 'Job' object.
    """
    endpoint = "v1/workflow/predict/single_site"
    payload = {"sequence": sequence.sequence}
    return _create_predict_job(
        session, endpoint, payload, model_ids=model_ids, train_job_id=train_job.id
    )


def get_prediction_results(
    session: APISession,
    job_id: str,
    page_size: Optional[int] = None,
    page_offset: Optional[int] = None,
) -> PredictJob:
    """
    Retrieves the results of a Predict job.

    Parameters
    ----------
    session : APISession
        APIsession with auth
    job_id : str
        The ID of the job whose results are to be retrieved.
    page_size : Optional[int], default is None
        The number of results to be returned per page. If None, all results are returned.
    page_offset : Optional[int], default is None
        The number of results to skip. If None, defaults to 0.

    Returns
    -------
    PredictJob
        The job object representing the Predict job.

    Raises
    ------
    HTTPError
        If the GET request does not succeed.
    """
    endpoint = f"v1/workflow/predict/{job_id}"
    params = {}
    if page_size is not None:
        params["page_size"] = page_size
    if page_offset is not None:
        params["page_offset"] = page_offset

    response = session.get(endpoint, params=params)
    # get results to assemble into list
    return ResultsParser.parse_obj(response.json())


def get_single_site_prediction_results(
    session: APISession,
    job_id: str,
    page_size: Optional[int] = None,
    page_offset: Optional[int] = None,
) -> PredictSingleSiteJob:
    """
    Retrieves the results of a single site Predict job.

    Parameters
    ----------
    session : APISession
        APIsession with auth
    job_id : str
        The ID of the job whose results are to be retrieved.
    page_size : Optional[int], default is None
        The number of results to be returned per page. If None, all results are returned.
    page_offset : Optional[int], default is None
        The number of results to skip. If None, defaults to 0.

    Returns
    -------
    PredictSingleSiteJob
        The job object representing the single site Predict job.

    Raises
    ------
    HTTPError
        If the GET request does not succeed.
    """
    endpoint = f"v1/workflow/predict/single_site/{job_id}"
    params = {}
    if page_size is not None:
        params["page_size"] = page_size
    if page_offset is not None:
        params["page_offset"] = page_offset

    response = session.get(endpoint, params=params)
    # get results to assemble into list
    return ResultsParser.parse_obj(response)


class PredictFutureMixin:
    """
    Class to to retrieve results from a Predict job.

    Attributes
    ----------
    session : APISession
        APIsession with auth
    job : PredictJob
        The job object that represents the current Predict job.

    Methods
    -------
    get_results(page_size: Optional[int] = None, page_offset: Optional[int] = None) -> Union[PredictSingleSiteJob, PredictJob]
        Retrieves results from a Predict job.
    """

    session: APISession
    job: PredictJob
    id: Optional[str] = None

    def get_results(
        self, page_size: Optional[int] = None, page_offset: Optional[int] = None
    ) -> Union[PredictSingleSiteJob, PredictJob]:
        """
        Retrieves results from a Predict job.

        it uses the appropriate method to retrieve the results based on job_type.

        Parameters
        ----------
        page_size : Optional[int], default is None
            The number of results to be returned per page. If None, all results are returned.
        page_offset : Optional[int], default is None
            The number of results to skip. If None, defaults to 0.

        Returns
        -------
        Union[PredictSingleSiteJob, PredictJob]
            The job object representing the Predict job. The exact type of job depends on the job type.

        Raises
        ------
        HTTPError
            If the GET request does not succeed.
        """
        assert self.id is not None
        if "single_site" in self.job.job_type:
            return get_single_site_prediction_results(
                self.session, self.id, page_size, page_offset
            )
        else:
            return get_prediction_results(self.session, self.id, page_size, page_offset)


[docs] class PredictFuture(PredictFutureMixin, AsyncJobFuture, FutureBase): # type: ignore """Future Job for manipulating results""" job_type = [JobType.workflow_predict, JobType.worflow_predict_single_site]
[docs] def __init__(self, session: APISession, job: PredictJob, page_size=1000): super().__init__(session, job) self.page_size = page_size
def __str__(self) -> str: return str(self.job) def __repr__(self) -> str: return repr(self.job) @property def id(self): return self.job.job_id def _fmt_results(self, results): properties = set( list(i["properties"].keys())[0] for i in results[0].dict()["predictions"] ) dict_results = {} for p in properties: dict_results[p] = {} for i, r in enumerate(results): s = r.sequence props = [i.properties[p] for i in r.predictions if p in i.properties][0] dict_results[p][s] = {"mean": props["y_mu"], "variance": props["y_var"]} dict_results return dict_results def _fmt_ssp_results(self, results): properties = set( list(i["properties"].keys())[0] for i in results[0].dict()["predictions"] ) dict_results = {} for p in properties: dict_results[p] = {} for i, r in enumerate(results): s = s = f"{r.position+1}{r.amino_acid}" props = [i.properties[p] for i in r.predictions if p in i.properties][0] dict_results[p][s] = {"mean": props["y_mu"], "variance": props["y_var"]} return dict_results
[docs] def get(self, verbose: bool = False) -> Dict: """ Get all the results of the predict job. Args: verbose (bool, optional): If True, print verbose output. Defaults False. Raises: APIError: If there is an issue with the API request. Returns: PredictJob: A list of predict objects representing the results. """ step = self.page_size results: List = [] num_returned = step offset = 0 while num_returned >= step: try: response = self.get_results(page_offset=offset, page_size=step) assert isinstance(response.result, list) results += response.result num_returned = len(response.result) offset += num_returned except APIError as exc: if verbose: print(f"Failed to get results: {exc}") if self.job.job_type == JobType.workflow_predict: return self._fmt_results(results) else: return self._fmt_ssp_results(results)
class PredictService: """interface for calling Predict endpoints""" def __init__(self, session: APISession): """ Initialize a new instance of the PredictService class. Parameters ---------- session : APISession APIsession with auth """ self.session = session def create_predict_job( self, sequences: List, train_job: Optional[Any] = None, model_ids: Optional[List[str]] = None, ) -> PredictFuture: """ Creates a new Predict job for a given list of sequences and a trained model. Parameters ---------- sequences : List The list of sequences to be used for the Predict job. train_job : Any The train job object representing the trained model. model_ids : List[str], optional The list of model ids to be used for Predict. Default is None. Returns ------- PredictFuture The job object representing the Predict job. Raises ------ InvalidParameterError If the sequences are not of the same length as the assay data or if the train job has not completed successfully. InvalidParameterError If BOTH train_job and model_ids are specified InvalidParameterError If NEITHER train_job or model_ids is specified APIError If the backend refuses the job (due to sequence length or invalid inputs) """ if train_job is not None: if train_job.assaymetadata is not None: if train_job.assaymetadata.sequence_length is not None: if any( [ train_job.assaymetadata.sequence_length != len(s) for s in sequences ] ): raise InvalidParameterError( f"Predict sequences length {len(sequences[0])} != training assaydata ({train_job.assaymetadata.sequence_length})" ) if not train_job.done(): print(f"WARNING: training job has status {train_job.status}") # raise InvalidParameterError( # f"train job has status {train_job.status.value}, Predict requires status SUCCESS" # ) sequence_dataset = SequenceDataset(sequences=sequences) return create_predict_job( self.session, sequence_dataset, train_job, model_ids=model_ids # type: ignore ) def create_predict_single_site( self, sequence: str, train_job: Any, model_ids: Optional[List[str]] = None, ) -> PredictFuture: """ Creates a new Predict job for single site mutation analysis with a trained model. Parameters ---------- sequence : str The sequence for single site analysis. train_job : Any The train job object representing the trained model. model_ids : List[str], optional The list of model ids to be used for Predict. Default is None. Returns ------- PredictFuture The job object representing the Predict job. Raises ------ InvalidParameterError If the sequences are not of the same length as the assay data or if the train job has not completed successfully. InvalidParameterError If BOTH train_job and model_ids are specified InvalidParameterError If NEITHER train_job or model_ids is specified APIError If the backend refuses the job (due to sequence length or invalid inputs) """ if train_job.assaymetadata is not None: if train_job.assaymetadata.sequence_length is not None: if any([train_job.assaymetadata.sequence_length != len(sequence)]): raise InvalidParameterError( f"Predict sequences length {len(sequence)} != training assaydata ({train_job.assaymetadata.sequence_length})" ) train_job.refresh() if not train_job.done(): print(f"WARNING: training job has status {train_job.status}") # raise InvalidParameterError( # f"train job has status {train_job.status.value}, Predict requires status SUCCESS" # ) sequence_dataset = SequenceData(sequence=sequence) return create_predict_single_site( self.session, sequence_dataset, train_job, model_ids=model_ids # type: ignore )