"""Application futures for waiting for results from jobs."""
import concurrent.futures
import logging
import time
from abc import ABC, abstractmethod
from datetime import datetime
from types import UnionType
from typing import Collection, Generator
import tqdm
from requests import Response
from typing_extensions import Self
from openprotein import config
from openprotein.base import APISession
from openprotein.errors import TimeoutException
from openprotein.jobs.schemas import Job, JobStatus, JobType
from . import api
logger = logging.getLogger(__name__)
[docs]
class Future(ABC):
    """
    Base class for all Futures returning results from a job.
    """
    # NOTE: This base class should be directly inherited for class discovery by the factory `create` method.
    session: APISession
    job: Job
    def __init__(self, session: APISession, job: Job):
        self.session = session
        self.job = job
    @classmethod
    def create(
        cls: type[Self],
        session: APISession,
        job_id: str | None = None,
        job: Job | None = None,
        response: Response | dict | None = None,
        **kwargs,
    ) -> Self:
        """Create an instance of the appropriate Future class based on the job type.
        Parameters
        ----------
        session : APISession
            Session for API interactions.
        job_id : str | None, optional
            The ID of the Job to initialize this future with.
        job : Job | None, optional
            The Job object to initialize this future with.
        response : Response | dict | None, optional
            The response from a job request returning a job-like object.
        **kwargs
            Additional keyword arguments to pass to the Future class constructor.
        Returns
        -------
        Self
            An instance of the appropriate Future class.
        Raises
        ------
        ValueError
            If `job_id`, `job`, and `response` are all None.
        ValueError
            If an appropriate Future subclass cannot be found for the job type.
        :meta private:
        """
        # parse job
        # default to use job_id first
        if job_id is not None:
            # get job
            job = api.job_get(session=session, job_id=job_id)
        # set obj to parse using job or response
        obj = job or response
        if obj is None:
            raise ValueError("Expected job_id, job or response")
        # parse specific job
        job = Job.create(obj, **kwargs)
        # Dynamically discover all subclasses of FutureBase
        future_classes = Future.__subclasses__()
        # Find the Future class that matches the job
        for future_class in future_classes:
            if (
                type(job) == (future_type := future_class.__annotations__.get("job"))
                or isinstance(future_type, UnionType)
                and type(job) in future_type.__args__
            ):
                if isinstance(future_class.__dict__.get("create"), classmethod):
                    future = future_class.create(session=session, job=job, **kwargs)
                else:
                    future = future_class(session=session, job=job, **kwargs)
                return future  # type: ignore - needed since type checker doesnt know subclass
        raise ValueError(f"Unsupported job type: {job.job_type}")
    def __str__(self) -> str:
        return str(self.job)
    def __repr__(self):
        return repr(self.job)
    @property
    def id(self) -> str:
        """The unique identifier of the job."""
        return self.job.job_id
    job_id = id
    @property
    def job_type(self) -> str:
        """The type of the job."""
        return self.job.job_type
    @property
    def status(self) -> JobStatus:
        """The current status of the job."""
        return self.job.status
    @property
    def created_date(self) -> datetime:
        """The creation timestamp of the job."""
        return self.job.created_date
    @property
    def start_date(self) -> datetime | None:
        """The start timestamp of the job."""
        return self.job.start_date
    @property
    def end_date(self) -> datetime | None:
        """The end timestamp of the job."""
        return self.job.end_date
    @property
    def progress_counter(self) -> int:
        """The progress counter of the job."""
        return self.job.progress_counter or 0
    def done(self) -> bool:
        """Check if the job has completed.
        Returns
        -------
        bool
            True if the job is done, False otherwise.
        """
        return self.status.done()
    def cancelled(self) -> bool:
        """Check if the job has been cancelled.
        Returns
        -------
        bool
            True if the job is cancelled, False otherwise.
        """
        return self.status.cancelled()
    def _update_progress(self, job: Job) -> int:
        """Update progress for jobs that may not have explicit counters.
        Parameters
        ----------
        job : Job
            The job object to update progress from.
        Returns
        -------
        int
            The calculated progress value (0-100).
        """
        progress = job.progress_counter
        # if progress is not None:  # Check None before comparison
        if progress is None:
            if job.status == JobStatus.PENDING:
                progress = 5
            if job.status == JobStatus.RUNNING:
                progress = 25
        if job.status in [JobStatus.SUCCESS, JobStatus.FAILURE]:
            progress = 100
        return progress or 0  # never None
    def _refresh_job(self) -> Job:
        """Refresh and return the internal job object.
        Returns
        -------
        Job
            The refreshed job object.
        """
        # dump extra kwargs to keep on refresh
        kwargs = {
            k: v for k, v in self.job.model_dump().items() if k not in Job.model_fields
        }
        job = Job.create(
            api.job_get(session=self.session, job_id=self.job_id), **kwargs
        )
        return job
    def refresh(self):
        """Refresh the job status and internal job object."""
        self.job = self._refresh_job()
[docs]
    @abstractmethod
    def get(self, verbose: bool = False, **kwargs):
        """
        Return the results from this job.
        Parameters
        ----------
        verbose : bool, optional
            Flag to enable verbose output, by default False.
        **kwargs
            Additional keyword arguments.
        """
        raise NotImplementedError() 
    def _wait_job(
        self,
        interval: float = config.POLLING_INTERVAL,
        timeout: int | None = None,
        verbose: bool = False,
    ) -> Job:
        """Wait for a job to finish and return the final job object.
        Parameters
        ----------
        interval : float, optional
            Time in seconds to wait between polls.
            Defaults to `config.POLLING_INTERVAL`.
        timeout : int | None, optional
            Maximum time in seconds to wait before raising an error.
            Defaults to None (unlimited).
        verbose : bool, optional
            If True, print status updates. Defaults to False.
        Returns
        -------
        Job
            The completed job object.
        Raises
        ------
        TimeoutException
            If the wait time exceeds the specified timeout.
        """
        start_time = time.time()
        def is_done(job: Job):
            if timeout is not None:
                elapsed_time = time.time() - start_time
                if elapsed_time >= timeout:
                    raise TimeoutException(
                        f"Wait time exceeded timeout {timeout}, waited {elapsed_time}"
                    )
            return job.status.done()
        pbar = None
        if verbose:
            pbar = tqdm.tqdm(total=100, desc="Waiting", position=0)
        job = self._refresh_job()
        while not is_done(job):
            if pbar is not None:
                # pbar.update(1)
                # pbar.set_postfix({"status": job.status})
                progress = self._update_progress(job)
                pbar.n = progress
                pbar.set_postfix({"status": job.status})
                # pbar.refresh()
                # print(f'Retry {retries}, status={self.job.status}, time elapsed {time.time() - start_time:.2f}')
            time.sleep(interval)
            job = self._refresh_job()
        if pbar is not None:
            # pbar.update(1)
            # pbar.set_postfix({"status": job.status})
            progress = self._update_progress(job)
            pbar.n = progress
            pbar.set_postfix({"status": job.status})
            # pbar.refresh()
        return job
[docs]
    def wait_until_done(
        self,
        interval: float = config.POLLING_INTERVAL,
        timeout: int | None = None,
        verbose: bool = False,
    ):
        """Wait for the job to complete.
        Parameters
        ----------
        interval : float, optional
            Time in seconds between polling. Defaults to `config.POLLING_INTERVAL`.
        timeout : int, optional
            Maximum time in seconds to wait. Defaults to None.
        verbose : bool, optional
            Verbosity flag. Defaults to False.
        Returns
        -------
        bool
            True if the job completed successfully.
        Notes
        -----
        This method does not fetch the job results, unlike `wait()`.
        """
        job = self._wait_job(interval=interval, timeout=timeout, verbose=verbose)
        self.job = job
        return self.done() 
[docs]
    def wait(
        self,
        interval: int = config.POLLING_INTERVAL,
        timeout: int | None = None,
        verbose: bool = False,
    ):
        """Wait for the job to complete, then fetch results.
        Parameters
        ----------
        interval : int, optional
            Time in seconds between polling. Defaults to `config.POLLING_INTERVAL`.
        timeout : int | None, optional
            Maximum time in seconds to wait. Defaults to None.
        verbose : bool, optional
            Verbosity flag. Defaults to False.
        Returns
        -------
        Any
            The results of the job.
        """
        time.sleep(1)  # buffer for BE to register job
        job = self._wait_job(interval=interval, timeout=timeout, verbose=verbose)
        self.job = job
        return self.get() 
 
class StreamingFuture(ABC):
    """Abstract base class for Futures that support streaming results."""
    @abstractmethod
    def stream(self, **kwargs) -> Generator:
        """Return the results from this job as a generator.
        Parameters
        ----------
        **kwargs
            Keyword arguments passed to the streaming implementation.
        Returns
        -------
        Generator
            A generator that yields job results.
        Raises
        ------
        NotImplementedError
            This is an abstract method and must be implemented by a subclass.
        """
        raise NotImplementedError()
    def get(self, verbose: bool = False, **kwargs) -> list:
        """Return all results from the job by consuming the stream.
        Parameters
        ----------
        verbose : bool, optional
            If True, display a progress bar. Defaults to False.
        **kwargs
            Keyword arguments passed to the `stream` method.
        Returns
        -------
        list
            A list containing all results from the job.
        """
        generator = self.stream(**kwargs)
        if verbose:
            total = None
            if hasattr(self, "__len__"):
                total = len(self)  # type: ignore - static type checker doesnt know
            generator = tqdm.tqdm(
                generator, desc="Retrieving", total=total, position=0, mininterval=1.0
            )
        return [entry for entry in generator]
class MappedFuture(StreamingFuture, ABC):
    """Base future for jobs with a key-to-result mapping.
    This class provides methods to retrieve results from jobs where each result
    is associated with a unique key (e.g., sequence to embedding).
    """
    def __init__(
        self,
        session: APISession,
        job: Job,
        max_workers: int = config.MAX_CONCURRENT_WORKERS,
    ):
        """Initialize the MappedFuture.
        Parameters
        ----------
        session : APISession
            The session for API interactions.
        job : Job
            The job to retrieve results from.
        max_workers : int, optional
            The number of workers for concurrent result retrieval.
            Defaults to `config.MAX_CONCURRENT_WORKERS`.
        Notes
        -----
        Use `max_workers` > 0 to enable concurrent retrieval.
        """
        self.session = session
        self.job = job
        self.max_workers = max_workers
        self._cache = {}
    @abstractmethod
    def __keys__(self):
        """Return the keys for the mapped results.
        Raises
        ------
        NotImplementedError
            This is an abstract method and must be implemented by a subclass.
        """
        raise NotImplementedError()
    @abstractmethod
    def get_item(self, k):
        """Retrieve a single item by its key.
        Parameters
        ----------
        k
            The key of the item to retrieve.
        Raises
        ------
        NotImplementedError
            This is an abstract method and must be implemented by a subclass.
        """
        raise NotImplementedError()
    def stream_sync(self):
        """Stream the results synchronously.
        Yields
        ------
        tuple
            A tuple of (key, value) for each result.
        :meta private:
        """
        for k in self.__keys__():
            v = self[k]
            yield k, v
    def stream_parallel(self):
        """Stream the results in parallel using a thread pool.
        Yields
        ------
        tuple
            A tuple of (key, value) for each result.
        :meta private:
        """
        num_workers = self.max_workers
        def process(k):
            v = self[k]
            return k, v
        with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
            futures = []
            for k in self.__keys__():
                if k in self._cache:
                    yield k, self._cache[k]
                else:
                    f = executor.submit(process, k)
                    futures.append(f)
            for f in futures:
                yield f.result()
    def stream(self):
        """Retrieve results for this job as a stream.
        Returns
        -------
        Generator
            A generator that yields (key, value) tuples.
        """
        if self.max_workers > 0:
            return self.stream_parallel()
        return self.stream_sync()
    def __getitem__(self, k):
        """Get an item by key, using the cache if available.
        Parameters
        ----------
        k
            The key of the item to retrieve.
        Returns
        -------
        Any
            The value associated with the key.
        """
        if k in self._cache:
            return self._cache[k]
        v = self.get_item(k)
        self._cache[k] = v
        return v
    def __len__(self):
        """Return the total number of items."""
        return len(self.__keys__())
    def __iter__(self):
        """Return an iterator over the results."""
        return self.stream()
class PagedFuture(StreamingFuture, ABC):
    """Base future class for jobs which have paged results."""
    DEFAULT_PAGE_SIZE = 1024
    def __init__(
        self,
        session: APISession,
        job: Job,
        page_size: int | None = None,
        num_records: int | None = None,
        max_workers: int = config.MAX_CONCURRENT_WORKERS,
    ):
        """Initialize the PagedFuture.
        Parameters
        ----------
        session : APISession
            The session for API interactions.
        job : Job
            The job to retrieve results from.
        page_size : int | None, optional
            The number of records per page. Defaults to `DEFAULT_PAGE_SIZE`.
        num_records : int | None, optional
            The total number of records expected.
        max_workers : int, optional
            Number of workers for concurrent page retrieval.
            Defaults to `config.MAX_CONCURRENT_WORKERS`.
        Notes
        -----
        Use `max_workers` > 0 to enable concurrent retrieval of multiple pages.
        """
        if page_size is None:
            page_size = self.DEFAULT_PAGE_SIZE
        self.session = session
        self.job = job
        self.page_size = page_size
        self.max_workers = max_workers
        self._num_records = num_records
    @abstractmethod
    def get_slice(self, start: int, end: int, **kwargs) -> Collection:
        """Retrieve a slice of results.
        Parameters
        ----------
        start : int
            The starting index of the slice.
        end : int
            The ending index of the slice.
        **kwargs
            Additional keyword arguments.
        Returns
        -------
        Collection
            A collection of results for the specified slice.
        Raises
        ------
        NotImplementedError
            This is an abstract method and must be implemented by a subclass.
        """
        raise NotImplementedError()
    def stream_sync(self):
        """Stream results by fetching pages synchronously.
        Yields
        ------
        Any
            Individual results from the paged endpoint.
        :meta private:
        """
        step = self.page_size
        num_returned = step
        offset = 0
        while num_returned >= step:
            result_page = self.get_slice(start=offset, end=offset + step)
            for result in result_page:
                yield result
            num_returned = len(result_page)
            offset += num_returned
    def stream_parallel(self):
        """Stream results by fetching pages in parallel.
        Yields
        ------
        Any
            Individual results from the paged endpoint.
        Notes
        -----
        The number of results should be checked, or stored somehow, so that
        we don't need to check the number of returned entries to see if we're
        finished (very awkward when using concurrency).
        :meta private:
        """
        step = self.page_size
        offset = 0
        num_workers = self.max_workers
        with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
            # submit the paged requests
            futures: dict[concurrent.futures.Future, int] = {}
            index: int = 0
            for _ in range(num_workers * 2):
                f = executor.submit(self.get_slice, offset, offset + step)
                futures[f] = index
                index += 1
                offset += step
            # until we've retrieved all pages (known by retrieving a page with less than the requested number of records)
            done = False
            while not done:
                results: list[list | None] = [None] * len(futures)
                futures_next: dict[concurrent.futures.Future, int] = {}
                index_next: int = 0
                next_result_index = 0
                # iterate the futures and submit new requests as needed
                for f in concurrent.futures.as_completed(futures):
                    index = futures[f]
                    result_page = f.result()
                    results[index] = result_page
                    # check if we're done, meaning the result page is not full
                    done = done or len(result_page) < step
                    # if we aren't done, submit another request
                    if not done:
                        f = executor.submit(self.get_slice, offset, offset + step)
                        futures_next[f] = index_next
                        index_next += 1
                        offset += step
                    # yield the results from this page
                    while (
                        next_result_index < len(results)
                        and results[next_result_index] is not None
                    ):
                        result_page = results[next_result_index]
                        assert result_page is not None  # checked above
                        for result in result_page:
                            yield result
                        next_result_index += 1
                # update the list of futures and wait on them again
                futures = futures_next
    def stream(self):
        """Retrieve results for this job as a stream.
        Returns
        -------
        Generator
            A generator that yields job results.
        """
        if self.max_workers > 0:
            return self.stream_parallel()
        return self.stream_sync()
class InvalidFutureError(Exception):
    """Error for when an unexpected future is created from a job."""
    def __init__(self, future: Future, expected: type[Future]):
        """Initialize the InvalidFutureError.
        Parameters
        ----------
        future : Future
            The future instance that was created.
        expected : type[Future]
            The type of future that was expected.
        """
        self.future = future
        self.expected = future
        self.message = f"Expected future of type {expected}, got {type(future)}"
        super().__init__(self.message)