from typing import Optional, List, Union, Any, Dict, Literal
from openprotein.pydantic import BaseModel, root_validator
from openprotein.base import APISession
from openprotein.api.jobs import AsyncJobFuture
from openprotein.jobs import ResultsParser, Job, register_job_type, JobType, JobStatus
from openprotein.errors import InvalidParameterError, APIError
from openprotein.futures import FutureFactory, FutureBase
class SequenceData(BaseModel):
    sequence: str
class SequenceDataset(BaseModel):
    sequences: List[str]
class _Prediction(BaseModel):
    """Prediction details."""
    @root_validator(pre=True)
    def extract_pred(cls, values):
        p = values.pop("properties")
        name = list(p.keys())[0]
        ymu = p[name]["y_mu"]
        yvar = p[name]["y_var"]
        p["name"] = name
        p["y_mu"] = ymu
        p["y_var"] = yvar
        values.update(p)
        return values
    model_id: str
    model_name: str
    y_mu: Optional[float] = None
    y_var: Optional[float] = None
    name: Optional[str]
class Prediction(BaseModel):
    """Prediction details."""
    model_id: str
    model_name: str
    properties: Dict[str, Dict[str, float]]
class PredictJobBase(Job):
    # might be none if just fetching
    job_id: Optional[str] = None
    job_type: str
    status: JobStatus
@register_job_type(JobType.workflow_predict)
class PredictJob(PredictJobBase):
    """Properties about predict job returned via API."""
    @root_validator(pre=True)
    def extract_pred(cls, values):
        # Extracting 'predictions' and 'sequences' from the input values
        v = values.pop("result")
        preds = [i["predictions"] for i in v]
        seqs = [i["sequence"] for i in v]
        values["result"] = [
            {"sequence": i, "predictions": p} for i, p in zip(seqs, preds)
        ]
        return values
    class SequencePrediction(BaseModel):
        """Sequence prediction."""
        sequence: str
        predictions: List[Prediction] = []
    result: Optional[List[SequencePrediction]] = None
    job_type: str
@register_job_type(JobType.worflow_predict_single_site)
class PredictSingleSiteJob(PredictJobBase):
    """Properties about single-site prediction job returned via API."""
    class SequencePrediction(BaseModel):
        """Sequence prediction."""
        position: int
        amino_acid: str
        # sequence: str
        predictions: List[Prediction] = []
    result: Optional[List[SequencePrediction]] = None
    job_type: Literal[JobType.worflow_predict_single_site] = (
        JobType.worflow_predict_single_site
    )
def _create_predict_job(
    session: APISession,
    endpoint: str,
    payload: dict,
    model_ids: Optional[List[str]] = None,
    train_job_id: Optional[str] = None,
) -> FutureBase:
    """
    Creates a Predict request and returns the job object.
    This function makes a post request to the specified endpoint with the payload.
    Either 'model_ids' or 'train_job_id' should be provided but not both.
    Parameters
    ----------
    session : APISession
        APIsession with auth
    endpoint : str
        The endpoint to which the post request is to be made.
        either predict or predict/single_site
    payload : dict
        The payload to be sent in the post request.
    model_ids : List[str], optional
        The list of model ids to be used for Predict. Default is None.
    train_job_id : str, optional
        The id of the train job to be used for Predict. Default is None.
    Returns
    -------
    PredictJob
        The job object representing the Predict job.
    Raises
    ------
    InvalidParameterError
        If neither 'model_ids' nor 'train_job_id' is provided.
        If both 'model_ids' and 'train_job_id' are provided.
    HTTPError
        If the post request does not succeed.
    ValidationError
        If the response cannot be parsed into a 'Job' object.
    """
    if model_ids is None and train_job_id is None:
        raise InvalidParameterError(
            "Either a list of model IDs or a train job ID must be provided"
        )
    if model_ids is not None and train_job_id is not None:
        raise InvalidParameterError(
            "Only a list of model IDs OR a train job ID must be provided, not both"
        )
    if model_ids is not None:
        payload["model_id"] = model_ids
    else:
        payload["train_job_id"] = train_job_id
    response = session.post(endpoint, json=payload)
    return FutureFactory.create_future(session=session, response=response)
def create_predict_job(
    session: APISession,
    sequences: SequenceDataset,
    train_job: Optional[Any] = None,
    model_ids: Optional[List[str]] = None,
) -> FutureBase:
    """
    Creates a predict job with a given set of sequences and a train job.
    This function will use the sequences and train job ID to create a new Predict job.
    Parameters
    ----------
    session : APISession
        APIsession with auth
    sequences : SequenceDataset
        The dataset containing the sequences to predict
    train_job : Any
        The Train job: this model will be used for making Predicts.
    model_ids: List[str]
        specific IDs for models
    Returns
    -------
    PredictJob
        The job object representing the created Predict job.
    Raises
    ------
    InvalidParameterError
        If neither 'model_ids' nor 'train_job' is provided.
    InvalidParameterError
        If BOTH `model_ids` and `train_job` is provided
    HTTPError
        If the post request does not succeed.
    ValidationError
        If the response cannot be parsed into a 'Job' object.
    """
    if isinstance(model_ids, str):
        model_ids = [model_ids]
    endpoint = "v1/workflow/predict"
    payload = {"sequences": sequences.sequences}
    train_job_id = train_job.id if train_job is not None else None
    return _create_predict_job(
        session, endpoint, payload, model_ids=model_ids, train_job_id=train_job_id
    )
def create_predict_single_site(
    session: APISession,
    sequence: SequenceData,
    train_job: Any,
    model_ids: Optional[List[str]] = None,
) -> FutureBase:
    """
    Creates a predict job for single site mutants with a given sequence and a train job.
    Parameters
    ----------
    session : APISession
        APIsession with auth
    sequence : SequenceData
        The sequence for which single site mutants predictions will be made.
    train_job : Any
        The train job whose model will be used for making Predicts.
    model_ids: List[str]
        specific IDs for models
    Returns
    -------
    PredictJob
        The job object representing the created Predict job.
    Raises
    ------
    InvalidParameterError
        If neither 'model_ids' nor 'train_job' is provided.
    InvalidParameterError
        If BOTH `model_ids` and `train_job` is provided
    HTTPError
        If the post request does not succeed.
    ValidationError
        If the response cannot be parsed into a 'Job' object.
    """
    endpoint = "v1/workflow/predict/single_site"
    payload = {"sequence": sequence.sequence}
    return _create_predict_job(
        session, endpoint, payload, model_ids=model_ids, train_job_id=train_job.id
    )
def get_prediction_results(
    session: APISession,
    job_id: str,
    page_size: Optional[int] = None,
    page_offset: Optional[int] = None,
) -> PredictJob:
    """
    Retrieves the results of a Predict job.
    Parameters
    ----------
    session : APISession
        APIsession with auth
    job_id : str
        The ID of the job whose results are to be retrieved.
    page_size : Optional[int], default is None
        The number of results to be returned per page. If None, all results are returned.
    page_offset : Optional[int], default is None
        The number of results to skip. If None, defaults to 0.
    Returns
    -------
    PredictJob
        The job object representing the Predict job.
    Raises
    ------
    HTTPError
        If the GET request does not succeed.
    """
    endpoint = f"v1/workflow/predict/{job_id}"
    params = {}
    if page_size is not None:
        params["page_size"] = page_size
    if page_offset is not None:
        params["page_offset"] = page_offset
    response = session.get(endpoint, params=params)
    # get results to assemble into list
    return ResultsParser.parse_obj(response.json())
def get_single_site_prediction_results(
    session: APISession,
    job_id: str,
    page_size: Optional[int] = None,
    page_offset: Optional[int] = None,
) -> PredictSingleSiteJob:
    """
    Retrieves the results of a single site Predict job.
    Parameters
    ----------
    session : APISession
        APIsession with auth
    job_id : str
        The ID of the job whose results are to be retrieved.
    page_size : Optional[int], default is None
        The number of results to be returned per page. If None, all results are returned.
    page_offset : Optional[int], default is None
        The number of results to skip. If None, defaults to 0.
    Returns
    -------
    PredictSingleSiteJob
        The job object representing the single site Predict job.
    Raises
    ------
    HTTPError
        If the GET request does not succeed.
    """
    endpoint = f"v1/workflow/predict/single_site/{job_id}"
    params = {}
    if page_size is not None:
        params["page_size"] = page_size
    if page_offset is not None:
        params["page_offset"] = page_offset
    response = session.get(endpoint, params=params)
    # get results to assemble into list
    return ResultsParser.parse_obj(response)
class PredictFutureMixin:
    """
    Class to to retrieve results from a Predict job.
    Attributes
    ----------
    session : APISession
        APIsession with auth
    job : PredictJob
        The job object that represents the current Predict job.
    Methods
    -------
    get_results(page_size: Optional[int] = None, page_offset: Optional[int] = None) -> Union[PredictSingleSiteJob, PredictJob]
        Retrieves results from a Predict job.
    """
    session: APISession
    job: PredictJob
    id: Optional[str] = None
    def get_results(
        self, page_size: Optional[int] = None, page_offset: Optional[int] = None
    ) -> Union[PredictSingleSiteJob, PredictJob]:
        """
        Retrieves results from a Predict job.
        it uses the appropriate method to retrieve the results based on job_type.
        Parameters
        ----------
        page_size : Optional[int], default is None
            The number of results to be returned per page. If None, all results are returned.
        page_offset : Optional[int], default is None
            The number of results to skip. If None, defaults to 0.
        Returns
        -------
        Union[PredictSingleSiteJob, PredictJob]
            The job object representing the Predict job. The exact type of job depends on the job type.
        Raises
        ------
        HTTPError
            If the GET request does not succeed.
        """
        assert self.id is not None
        if "single_site" in self.job.job_type:
            return get_single_site_prediction_results(
                self.session, self.id, page_size, page_offset
            )
        else:
            return get_prediction_results(self.session, self.id, page_size, page_offset)
[docs]
class PredictFuture(PredictFutureMixin, AsyncJobFuture, FutureBase):  # type: ignore
    """Future Job for manipulating results"""
    job_type = [JobType.workflow_predict, JobType.worflow_predict_single_site]
[docs]
    def __init__(self, session: APISession, job: PredictJob, page_size=1000):
        super().__init__(session, job)
        self.page_size = page_size 
    def __str__(self) -> str:
        return str(self.job)
    def __repr__(self) -> str:
        return repr(self.job)
    @property
    def id(self):
        return self.job.job_id
    def _fmt_results(self, results):
        properties = set(
            list(i["properties"].keys())[0] for i in results[0].dict()["predictions"]
        )
        dict_results = {}
        for p in properties:
            dict_results[p] = {}
            for i, r in enumerate(results):
                s = r.sequence
                props = [i.properties[p] for i in r.predictions if p in i.properties][0]
                dict_results[p][s] = {"mean": props["y_mu"], "variance": props["y_var"]}
        dict_results
        return dict_results
    def _fmt_ssp_results(self, results):
        properties = set(
            list(i["properties"].keys())[0] for i in results[0].dict()["predictions"]
        )
        dict_results = {}
        for p in properties:
            dict_results[p] = {}
            for i, r in enumerate(results):
                s = s = f"{r.position+1}{r.amino_acid}"
                props = [i.properties[p] for i in r.predictions if p in i.properties][0]
                dict_results[p][s] = {"mean": props["y_mu"], "variance": props["y_var"]}
        return dict_results
[docs]
    def get(self, verbose: bool = False) -> Dict:
        """
        Get all the results of the predict job.
        Args:
            verbose (bool, optional): If True, print verbose output. Defaults False.
        Raises:
            APIError: If there is an issue with the API request.
        Returns:
            PredictJob: A list of predict objects representing the results.
        """
        step = self.page_size
        results: List = []
        num_returned = step
        offset = 0
        while num_returned >= step:
            try:
                response = self.get_results(page_offset=offset, page_size=step)
                assert isinstance(response.result, list)
                results += response.result
                num_returned = len(response.result)
                offset += num_returned
            except APIError as exc:
                if verbose:
                    print(f"Failed to get results: {exc}")
        if self.job.job_type == JobType.workflow_predict:
            return self._fmt_results(results)
        else:
            return self._fmt_ssp_results(results) 
 
class PredictService:
    """interface for calling Predict endpoints"""
    def __init__(self, session: APISession):
        """
        Initialize a new instance of the PredictService class.
        Parameters
        ----------
        session : APISession
            APIsession with auth
        """
        self.session = session
    def create_predict_job(
        self,
        sequences: List,
        train_job: Optional[Any] = None,
        model_ids: Optional[List[str]] = None,
    ) -> PredictFuture:
        """
        Creates a new Predict job for a given list of sequences and a trained model.
        Parameters
        ----------
        sequences : List
            The list of sequences to be used for the Predict job.
        train_job : Any
            The train job object representing the trained model.
        model_ids : List[str], optional
            The list of model ids to be used for Predict. Default is None.
        Returns
        -------
        PredictFuture
            The job object representing the Predict job.
        Raises
        ------
        InvalidParameterError
            If the sequences are not of the same length as the assay data or if the train job has not completed successfully.
        InvalidParameterError
            If BOTH train_job and model_ids are specified
        InvalidParameterError
            If NEITHER train_job or model_ids is specified
        APIError
            If the backend refuses the job (due to sequence length or invalid inputs)
        """
        if train_job is not None:
            if train_job.assaymetadata is not None:
                if train_job.assaymetadata.sequence_length is not None:
                    if any(
                        [
                            train_job.assaymetadata.sequence_length != len(s)
                            for s in sequences
                        ]
                    ):
                        raise InvalidParameterError(
                            f"Predict sequences length {len(sequences[0])}  != training assaydata ({train_job.assaymetadata.sequence_length})"
                        )
        if not train_job.done():
            print(f"WARNING: training job has status {train_job.status}")
            # raise InvalidParameterError(
            #    f"train job has status {train_job.status.value}, Predict requires status SUCCESS"
            # )
        sequence_dataset = SequenceDataset(sequences=sequences)
        return create_predict_job(
            self.session, sequence_dataset, train_job, model_ids=model_ids  # type: ignore
        )
    def create_predict_single_site(
        self,
        sequence: str,
        train_job: Any,
        model_ids: Optional[List[str]] = None,
    ) -> PredictFuture:
        """
        Creates a new Predict job for single site mutation analysis with a trained model.
        Parameters
        ----------
        sequence : str
            The sequence for single site analysis.
        train_job : Any
            The train job object representing the trained model.
        model_ids : List[str], optional
            The list of model ids to be used for Predict. Default is None.
        Returns
        -------
        PredictFuture
            The job object representing the Predict job.
        Raises
        ------
        InvalidParameterError
            If the sequences are not of the same length as the assay data or if the train job has not completed successfully.
        InvalidParameterError
            If BOTH train_job and model_ids are specified
        InvalidParameterError
            If NEITHER train_job or model_ids is specified
        APIError
            If the backend refuses the job (due to sequence length or invalid inputs)
        """
        if train_job.assaymetadata is not None:
            if train_job.assaymetadata.sequence_length is not None:
                if any([train_job.assaymetadata.sequence_length != len(sequence)]):
                    raise InvalidParameterError(
                        f"Predict sequences length {len(sequence)}  != training assaydata ({train_job.assaymetadata.sequence_length})"
                    )
        train_job.refresh()
        if not train_job.done():
            print(f"WARNING: training job has status {train_job.status}")
            # raise InvalidParameterError(
            #    f"train job has status {train_job.status.value}, Predict requires status SUCCESS"
            # )
        sequence_dataset = SequenceData(sequence=sequence)
        return create_predict_single_site(
            self.session, sequence_dataset, train_job, model_ids=model_ids  # type: ignore
        )