Source code for openprotein.api.train

from typing import Optional, List, Union
from openprotein.pydantic import BaseModel

import openprotein.pydantic as pydantic
from openprotein.base import APISession
from openprotein.api.jobs import AsyncJobFuture, Job
from openprotein.futures import FutureFactory, FutureBase

from openprotein.errors import InvalidParameterError, APIError, InvalidJob
from openprotein.api.data import AssayDataset, AssayMetadata
from openprotein.jobs import JobType
from openprotein.api.predict import PredictService, PredictFuture
from datetime import datetime


class CVItem(BaseModel):
    row_index: int
    sequence: str
    measurement_name: str
    y: float
    y_mu: float
    y_var: float


class CVResults(Job):
    num_rows: int
    page_size: int
    page_offset: int
    result: List[CVItem]


class TrainStep(BaseModel):
    step: int
    loss: float
    tag: str
    tags: dict


class TrainGraph(BaseModel):
    traingraph: List[TrainStep]
    created_date: datetime
    job_id: str


def list_models(session: APISession, job_id: str) -> List:
    """
    List models assoicated with job

    Parameters
    ----------
    session : APISession
        Session object for API communication.
    job_id : str
        job ID

    Returns
    -------
    List
        List of models
    """
    endpoint = "v1/models"
    response = session.get(endpoint, params={"job_id": job_id})
    return response.json()


def crossvalidate(session: APISession, train_job_id: str, n_splits: int = 5) -> Job:
    """
    Submit a cross-validation job.

    Args:
        session (APISession): auth session
        job_id (str): job id
        n_splits (int, optional): N of CV splits. Defaults to 5.

    Returns:
        Job:
    """
    endpoint = "v1/workflow/crossvalidate"
    response = session.post(
        endpoint, json={"train_job_id": train_job_id, "n_splits": n_splits}
    )
    return pydantic.parse_obj_as(Job, response.json())


def get_crossvalidation(
    session: APISession,
    job_id: str,
    page_size: Optional[int] = None,
    page_offset: Optional[int] = 0,
) -> CVResults:
    """
    Get CV results

    Args:
        session (APISession): auth'd session
        job_id (str): Job id

    Returns:
        _type_: _description_
    """
    endpoint = f"v1/workflow/crossvalidate/{job_id}"
    params = {"page_size": page_size, "page_offset": page_offset}
    response = session.get(endpoint, params=params)
    if response.status_code == 404:
        raise InvalidJob("No CV job has been submitted for this job!")
    return pydantic.parse_obj_as(CVResults, response.json())


def _train_job(
    session: APISession,
    endpoint: str,
    assaydataset: AssayDataset,
    measurement_name: Union[str, List[str]],
    model_name: str = "",
    force_preprocess: Optional[bool] = False,
) -> Job:
    """
    Create a training job.

    Validate inputs, format  data, sends the job training request to the endpoint,

    Parses the response into a `Job` object.

    Parameters
    ----------
    session : APISession
        The current API session for communication with the server.
    endpoint : str
        The endpoint to which the job training request is to be sent.
    assaydataset : AssayDataset
        An AssayDataset object from which the assay_id is extracted.
    measurement_name : str or List[str]
        The name(s) of the measurement(s) to be used in the training job.
    model_name : str, optional
        The name to give the model.
    force_preprocess : bool, optional
        If set to True, preprocessing is forced even if preprocessed data already exists.

    Returns
    -------
    Job
        A Job

    Raises
    ------
    InvalidParameterError
        If the `assaydataset` is not an AssayDataset object,
        If any measurement name provided does not exist in the AssayDataset,
        or if the AssayDataset has fewer than 3 data points.
    HTTPError
        If the request to the server fails.
    """
    if not isinstance(assaydataset, AssayDataset):
        raise InvalidParameterError("assaydataset should be an assaydata Job result")
    if isinstance(measurement_name, str):
        measurement_name = [measurement_name]

    for measurement in measurement_name:
        if measurement not in assaydataset.measurement_names:
            raise InvalidParameterError(f"No {measurement} in measurement names")
    if assaydataset.shape[0] < 3:
        raise InvalidParameterError("Assaydata must have >=3 data points for training!")
    if model_name is None:
        model_name = ""

    data = {
        "assay_id": assaydataset.id,
        "measurement_name": measurement_name,
        "model_name": model_name,
    }
    params = {"force_preprocess": str(force_preprocess).lower()}

    response = session.post(endpoint, params=params, json=data)
    response.raise_for_status()
    return FutureFactory.create_future(session=session, response=response)


def create_train_job(
    session: APISession,
    assaydataset: AssayDataset,
    measurement_name: Union[str, List[str]],
    model_name: str = "",
    force_preprocess: Optional[bool] = False,
):
    """
    Create a training job.

    Validate inputs, format  data, sends the job training request to the endpoint,

    Parses the response into a `Job` object.

    Parameters
    ----------
    session : APISession
        The current API session for communication with the server.
    endpoint : str
        The endpoint to which the job training request is to be sent.
    assaydataset : AssayDataset
        An AssayDataset object from which the assay_id is extracted.
    measurement_name : str or List[str]
        The name(s) of the measurement(s) to be used in the training job.
    model_name : str, optional
        The name to give the model.
    force_preprocess : bool, optional
        If set to True, preprocessing is forced even if preprocessed data already exists.

    Returns
    -------
    Job
        A Job

    Raises
    ------
    InvalidParameterError
        If the `assaydataset` is not an AssayDataset object,
        If any measurement name provided does not exist in the AssayDataset,
        or if the AssayDataset has fewer than 3 data points.
    HTTPError
        If the request to the server fails.
    """
    endpoint = "v1/workflow/train"
    return _train_job(
        session, endpoint, assaydataset, measurement_name, model_name, force_preprocess
    )


def _create_train_job_br(
    session: APISession,
    assaydataset: AssayDataset,
    measurement_name: Union[str, List[str]],
    model_name: str = "",
    force_preprocess: Optional[bool] = False,
):
    """Alias for create_train_job"""
    endpoint = "v1/workflow/train/br"
    return _train_job(
        session, endpoint, assaydataset, measurement_name, model_name, force_preprocess
    )


def _create_train_job_gp(
    session: APISession,
    assaydataset: AssayDataset,
    measurement_name: Union[str, List[str]],
    model_name: str = "",
    force_preprocess: Optional[bool] = False,
):
    """Alias for create_train_job"""
    endpoint = "v1/workflow/train/gp"
    return _train_job(
        session, endpoint, assaydataset, measurement_name, model_name, force_preprocess
    )


def get_training_results(session: APISession, job_id: str) -> TrainGraph:
    """Get Training results (e.g. loss etc) of job."""
    endpoint = f"v1/workflow/train/{job_id}"
    response = session.get(endpoint)
    return TrainGraph(**response.json())


class CVFutureMixin:
    """
    A mixin class to provide cross-validation job submission and retrieval.

    Attributes
    ----------
    session : APISession
        The session object to use for API communication.
    train_job_id : str
        The id of the training job associated with this cross-validation job.
    job : Job
        The Job object for this cross-validation job.

    Methods
    -------
    crossvalidate():
        Submits a cross-validation job to the server.
    get_crossvalidation(page_size: Optional[int] = None, page_offset: Optional[int] = 0):
        Retrieves the results of the cross-validation job.
    """

    session: APISession
    train_job_id: str
    job: Job

    def crossvalidate(self):
        """
        Submit a cross-validation job to the server.

        Returns
        -------
        Job
            The Job object for this cross-validation job.

        """
        self.job = crossvalidate(self.session, self.train_job_id)
        return self.job

    def get_crossvalidation(
        self, page_size: Optional[int] = None, page_offset: Optional[int] = 0
    ):
        """
        Retrieves the results of the cross-validation job.


        Parameters
        ----------
        page_size : int, optional
            The number of items to retrieve in a single request..
        page_offset : int, optional
            The offset to start retrieving items from. Default is 0.

        Returns
        -------
        dict
            The results of the cross-validation job.

        """
        return get_crossvalidation(
            self.session, self.job.job_id, page_size, page_offset
        )


class CVFuture(CVFutureMixin, AsyncJobFuture, FutureBase):
    """
    This class helps initiating, submitting, and retrieving the
    results of a cross-validation job.

    Attributes
    ----------
    session : APISession
        The session object to use for API communication.
    train_job_id : str
        The id of the training job associated with this cross-validation job.
    job : Job
        The Job object for this cross-validation job.
    page_size : int
        The number of items to retrieve in a single request.

    """

    job_type = [JobType.workflow_crossvalidate]

    def __init__(self, session: APISession, train_job_id: str, job: Job = None):
        """
        Constructs a new CVFuture instance.

        Parameters
        ----------
        session : APISession
            The session object to use for API communication.
        train_job_id : str
            The id of the training job associated with this cross-validation job.
        job : Job, optional
            The Job object for this cross-validation job.
        """
        super().__init__(session, job)
        self.train_job_id = train_job_id
        self.page_size = 1000

    def __str__(self) -> str:
        return str(self.job)

    def __repr__(self) -> str:
        return repr(self.job)

    @property
    def id(self):
        return self.job.job_id

    def _fmt_results(self, results):
        return [i.dict() for i in results]

    def get(self, verbose: bool = False) -> List:
        """
        Get all the results of the CV job.

        Args:
            verbose (bool, optional): If True, print verbose output. Defaults False.

        Raises:
            APIError: If there is an issue with the API request.

        Returns:
            PredictJob: A list of predict objects representing the results.
        """
        step = self.page_size

        results = []
        num_returned = step
        offset = 0

        while num_returned >= step:
            try:
                response = self.get_crossvalidation(page_offset=offset, page_size=step)
                results += response.result
                num_returned = len(response.result)
                offset += num_returned
            except APIError as exc:
                if verbose:
                    print(f"Failed to get results: {exc}")
                return self._fmt_results(results)
        return self._fmt_results(results)


class TrainFutureMixin:
    """
    This class provides functionality for retrieving the
    results of a training job and initiating cross-validation jobs.

    Attributes
    ----------
    session : APISession
        The session object to use for API communication.
    job : Job
        The Job object for this training job.

    Methods
    -------
    get_results() -> TrainGraph:
        Returns the results of the training job.
    crossvalidate():
        Submits a cross-validation job and returns it.
    """

    session: APISession
    job: Job

    def _fmt_results(self, results):
        train_dict = {}
        tags = set([i.tag for i in results.traingraph])
        for tag in tags:
            train_dict[tag] = [
                i.loss for i in results.traingraph if i.dict()["tag"] == tag
            ]
        return train_dict

    def get_results(self) -> TrainGraph:
        """
        Gets the results of the training job.

        Returns
        -------
        TrainGraph
            The results of the training job.
        """
        results = get_training_results(self.session, self.job.job_id)
        return self._fmt_results(results)

    def crossvalidate(self):
        """
        Submits a cross-validation job.

        If a cross-validation job has already been created, it returns that job.
        Otherwise, it creates a new cross-validation job and returns it.

        Returns
        -------
        CVFuture
            The cross-validation job associated with this training job.
        """
        cv = CVFuture(self.session, train_job_id=self.job.job_id)
        job = cv.crossvalidate()  # noqa: F841
        return cv

    def list_models(self):
        """
        List models assoicated with job

        Parameters
        ----------
        session : APISession
            Session object for API communication.
        job_id : str
            job ID

        Returns
        -------
        List
            List of models
        """
        return list_models(self.session, self.job.job_id)


[docs] class TrainFuture(TrainFutureMixin, AsyncJobFuture, FutureBase): """Future Job for manipulating results""" job_type = [JobType.workflow_train]
[docs] def __init__( self, session: APISession, job: Job, assaymetadata: Optional[AssayMetadata] = None, ): super().__init__(session, job) self.assaymetadata = assaymetadata self._predict = PredictService(session)
[docs] def predict( self, sequences: List[str], model_ids: Optional[List[str]] = None ) -> PredictFuture: """ Creates a predict job based on the training job. Parameters ---------- sequences : List[str] The list of sequences to be used for the Predict job. model_ids : List[str], optional The list of model ids to be used for Predict. Default is None. Returns ------- PredictFuture The job object representing the Predict job. """ return self._predict.create_predict_job(sequences, self, model_ids=model_ids)
[docs] def predict_single_site( self, sequence: str, model_ids: Optional[List[str]] = None, ) -> PredictFuture: """ Creates a new Predict job for single site mutation analysis with a trained model. Parameters ---------- sequence : str The sequence for single site analysis. train_job : Any The train job object representing the trained model. model_ids : List[str], optional The list of model ids to be used for Predict. Default is None. Returns ------- PredictFuture The job object representing the Predict job. Creates a predict job based on the training job """ return self._predict.create_predict_single_site( sequence, self, model_ids=model_ids )
def __str__(self) -> str: return str(self.job) def __repr__(self) -> str: return repr(self.job) @property def id(self): return self.job.job_id def get(self, verbose: bool = False) -> TrainGraph: try: results = self.get_results() except APIError as exc: if verbose: print(f"Failed to get results: {exc}") raise exc return results
[docs] class TrainingAPI: """API interface for calling Train endpoints"""
[docs] def __init__( self, session: APISession, ): self.session = session self.assay = None
[docs] def create_training_job( self, assaydataset: AssayDataset, measurement_name: Union[str, List[str]], model_name: str = "", force_preprocess: Optional[bool] = False, ) -> TrainFuture: """ Create a training job on your data. This function validates the inputs, formats the data, and sends the job. Parameters ---------- assaydataset : AssayDataset An AssayDataset object from which the assay_id is extracted. measurement_name : str or List[str] The name(s) of the measurement(s) to be used in the training job. model_name : str, optional The name to give the model. force_preprocess : bool, optional If set to True, preprocessing is forced even if data already exists. Returns ------- TrainFuture A TrainFuture Job Raises ------ InvalidParameterError If the `assaydataset` is not an AssayDataset object, If any measurement name provided does not exist in the AssayDataset, or if the AssayDataset has fewer than 3 data points. HTTPError If the request to the server fails. """ if isinstance(measurement_name, str): measurement_name = [measurement_name] return create_train_job( self.session, assaydataset, measurement_name, model_name, force_preprocess )
def _create_training_job_br( self, assaydataset: AssayDataset, measurement_name: Union[str, List[str]], model_name: str = "", force_preprocess: Optional[bool] = False, ) -> TrainFuture: """Same as create_training_job.""" return _create_train_job_br( self.session, assaydataset, measurement_name, model_name, force_preprocess ) def _create_training_job_gp( self, assaydataset: AssayDataset, measurement_name: Union[str, List[str]], model_name: str = "", force_preprocess: Optional[bool] = False, ) -> TrainFuture: """Same as create_training_job.""" return _create_train_job_gp( self.session, assaydataset, measurement_name, model_name, force_preprocess )
[docs] def get_training_results(self, job_id: str) -> TrainFuture: """ Get training results (e.g. loss etc). Parameters ---------- job_id : str job_id to get Returns ------- TrainFuture A TrainFuture Job """ return get_training_results(self.session, job_id)