Source code for openprotein.api.align

from typing import Iterator, Optional, List, BinaryIO, Literal, Union
from openprotein.pydantic import BaseModel, Field, validator, root_validator
from enum import Enum
from io import BytesIO
import random
import csv
import codecs
import requests

from openprotein.base import APISession
from openprotein.api.jobs import (
    AsyncJobFuture,
)

from openprotein.jobs import (
    ResultsParser,
    Job,
    register_job_type,
    JobType,
    job_args_get,
)

import openprotein.config as config

from openprotein.errors import (
    InvalidParameterError,
    MissingParameterError,
    APIError,
)
from openprotein.futures import FutureBase, FutureFactory


class PoetInputType(str, Enum):
    INPUT = "RAW"
    MSA = "GENERATED"
    PROMPT = "PROMPT"


class MSASamplingMethod(str, Enum):
    RANDOM = "RANDOM"
    NEIGHBORS = "NEIGHBORS"
    NEIGHBORS_NO_LIMIT = "NEIGHBORS_NO_LIMIT"
    NEIGHBORS_NONGAP_NORM_NO_LIMIT = "NEIGHBORS_NONGAP_NORM_NO_LIMIT"
    TOP = "TOP"


class PromptPostParams(BaseModel):
    msa_id: str
    num_sequences: Optional[int] = Field(None, ge=0, lt=100)
    num_residues: Optional[int] = Field(None, ge=0, lt=24577)
    method: MSASamplingMethod = MSASamplingMethod.NEIGHBORS_NONGAP_NORM_NO_LIMIT
    homology_level: float = Field(0.8, ge=0, le=1)
    max_similarity: float = Field(1.0, ge=0, le=1)
    min_similarity: float = Field(0.0, ge=0, le=1)
    always_include_seed_sequence: bool = False
    num_ensemble_prompts: int = 1
    random_seed: Optional[int] = None


@register_job_type(JobType.align_align)
class MSAJob(Job):
    msa_id: Optional[str] = None
    job_type: Literal[JobType.align_align] = JobType.align_align

    @root_validator
    def set_msa_id(cls, values):
        if not values.get("msa_id"):
            values["msa_id"] = values.get("job_id")
        return values


@register_job_type(JobType.align_prompt)
class PromptJob(MSAJob):
    prompt_id: Optional[str] = None
    job_type: Literal[JobType.align_prompt] = JobType.align_prompt

    @root_validator
    def set_prompt_id(cls, values):
        if not values.get("prompt_id"):
            values["prompt_id"] = values.get("job_id")
        return values


def csv_stream(response: requests.Response) -> csv.reader:
    """
    Returns a CSV reader from a requests.Response object.

    Parameters
    ----------
    response : requests.Response
        The response object to parse.

    Returns
    -------
    csv.reader
        A csv reader object for the response.
    """
    raw_content = response.raw  # the raw bytes stream
    content = codecs.getreader("utf-8")(
        raw_content
    )  # force the response to be encoded as utf-8
    return csv.reader(content)


def get_align_job_inputs(
    session: APISession,
    job_id,
    input_type: PoetInputType,
    prompt_index: Optional[int] = None,
) -> requests.Response:
    """
    Get MSA and related data for an align job.

    Returns either the original user seed (RAW), the generated MSA or the prompt.

    Specify prompt_index to retreive the specific prompt for each replicate when input_type is PROMPT.

    Parameters
    ----------
    session : APISession
        The API session.
    job_id : int or str
        The job identifier.
    input_type : PoetInputType
        The type of MSA data.
    prompt_index : Optional[int]
        The replicate number for the prompt (input_type=-PROMPT only)

    Returns
    -------
    requests.Response
        The response from the server.
    """
    endpoint = "v1/align/inputs"

    params = {"job_id": job_id, "msa_type": input_type}
    if prompt_index is not None:
        params["replicate"] = prompt_index

    response = session.get(endpoint, params=params, stream=True)
    return response


def get_input(
    self: APISession,
    job: Job,
    input_type: PoetInputType,
    prompt_index: Optional[int] = None,
) -> csv.reader:
    """
    Get input data for a given job.

    Parameters
    ----------
    self : APISession
        The API session.
    job : Job
        The job for which to retrieve data.
    input_type : PoetInputType
        The type of MSA data.
    prompt_index : Optional[int]
        The replicate number for the prompt (input_type=-PROMPT only)

    Returns
    -------
    csv.reader
        A CSV reader for the response data.
    """
    job_id = job.job_id
    response = get_align_job_inputs(self, job_id, input_type, prompt_index=prompt_index)
    return csv_stream(response)


def get_prompt(
    self: APISession, job: Job, prompt_index: Optional[int] = None
) -> csv.reader:
    """
    Get the prompt for a given job.

    Parameters
    ----------
    self : APISession
        The API session.
    job : Job
        The job for which to retrieve the prompt.
    prompt_index : Optional[int], default=None
        The index of the prompt. If None, it returns all.

    Returns
    -------
    csv.reader
        A CSV reader for the prompt data.
    """
    return get_input(self, job, PoetInputType.PROMPT, prompt_index=prompt_index)


def get_seed(self: APISession, job: Job) -> csv.reader:
    """
    Get the seed for a given MSA job.

    Parameters
    ----------
    self : APISession
        The API session.
    job : Job
        The job for which to retrieve the seed.

    Returns
    -------
    csv.reader
        A CSV reader for the seed sequence.
    """
    return get_input(self, job, PoetInputType.INPUT)


def get_msa(self: APISession, job: Job) -> csv.reader:
    """
    Get the generated MSA (Multiple Sequence Alignment) for a given job.

    Parameters
    ----------
    self : APISession
        The API session.
    job : Job
        The job for which to retrieve the MSA.

    Returns
    -------
    csv.reader
        A CSV reader for the MSA data.
    """
    return get_input(self, job, PoetInputType.MSA)


def msa_post(session: APISession, msa_file=None, seed=None):
    """
    Create an MSA.

    Either via a seed sequence (which will trigger MSA creation) or a ready-to-use MSA (via msa_file).

    Note that seed and msa_file are mutually exclusive, and one or the other must be set.

    Parameters
    ----------
    session : APISession
        Authorized session.
    msa_file : str, optional
        Ready-made MSA. Defaults to None.
    seed : str, optional
        Seed to trigger MSA job. Defaults to None.

    Raises
    ------
    Exception
        If msa_file and seed are both None.

    Returns
    -------
    MSAJob
        Job details.
    """
    if (msa_file is None and seed is None) or (
        msa_file is not None and seed is not None
    ):
        raise MissingParameterError("seed OR msa_file must be provided.")
    endpoint = "v1/align/msa"

    is_seed = False
    if seed is not None:
        msa_file = BytesIO(b"\n".join([b">seed", seed]))
        is_seed = True

    params = {"is_seed": is_seed}
    files = {"msa_file": msa_file}

    response = session.post(endpoint, files=files, params=params)
    return FutureFactory.create_future(session=session, response=response)


def prompt_post(
    session: APISession,
    msa_id: str,
    num_sequences: Optional[int] = None,
    num_residues: Optional[int] = None,
    method: MSASamplingMethod = MSASamplingMethod.NEIGHBORS_NONGAP_NORM_NO_LIMIT,
    homology_level: float = 0.8,
    max_similarity: float = 1.0,
    min_similarity: float = 0.0,
    always_include_seed_sequence: bool = False,
    num_ensemble_prompts: int = 1,
    random_seed: Optional[int] = None,
) -> PromptJob:
    """
    Create a protein sequence prompt from a linked MSA (Multiple Sequence Alignment) for PoET Jobs.

    The MSA is specified by msa_id and created in msa_post.

    Parameters
    ----------
    session : APISession
        An instance of APISession to manage interactions with the API.
    msa_id : str
        The ID of the Multiple Sequence Alignment to use for the prompt.
    num_sequences : int, optional
        Maximum number of sequences in the prompt. Must be <100.
    num_residues : int, optional
        Maximum number of residues (tokens) in the prompt. Must be less than 24577.
    method : MSASamplingMethod, optional
        Method to use for MSA sampling. Defaults to NEIGHBORS_NONGAP_NORM_NO_LIMIT.
    homology_level : float, optional
        Level of homology for sequences in the MSA (neighbors methods only). Must be between 0 and 1. Defaults to 0.8.
    max_similarity : float, optional
        Maximum similarity between sequences in the MSA and the seed. Must be between 0 and 1. Defaults to 1.0.
    min_similarity : float, optional
        Minimum similarity between sequences in the MSA and the seed. Must be between 0 and 1. Defaults to 0.0.
    always_include_seed_sequence : bool, optional
        Whether to always include the seed sequence in the MSA. Defaults to False.
    num_ensemble_prompts : int, optional
        Number of ensemble jobs to run. Defaults to 1.
    random_seed : int, optional
        Seed for random number generation. Defaults to a random number between 0 and 2**32-1.

    Raises
    ------
    InvalidParameterError
        If provided parameter values are not in the allowed range.
    MissingParameterError
        If both or none of 'num_sequences', 'num_residues' is specified.

    Returns
    -------
    PromptJob
    """
    endpoint = "v1/align/prompt"

    if not (0 <= homology_level <= 1):
        raise InvalidParameterError("The 'homology_level' must be between 0 and 1.")
    if not (0 <= max_similarity <= 1):
        raise InvalidParameterError("The 'max_similarity' must be between 0 and 1.")
    if not (0 <= min_similarity <= 1):
        raise InvalidParameterError("The 'min_similarity' must be between 0 and 1.")

    if num_residues is None and num_sequences is None:
        num_residues = 12288

    if (num_sequences is None and num_residues is None) or (
        num_sequences is not None and num_residues is not None
    ):
        raise MissingParameterError(
            "Either 'num_sequences' or 'num_residues' must be set, but not both."
        )

    if num_sequences is not None and not (0 <= num_sequences < 100):
        raise InvalidParameterError("The 'num_sequences' must be between 0 and 100.")

    if num_residues is not None and not (0 <= num_residues < 24577):
        raise InvalidParameterError("The 'num_residues' must be between 0 and 24577.")

    if random_seed is None:
        random_seed = random.randrange(2**32)

    params = {
        "msa_id": msa_id,
        "msa_method": method,
        "homology_level": homology_level,
        "max_similarity": max_similarity,
        "min_similarity": min_similarity,
        "force_include_first": always_include_seed_sequence,
        "replicates": num_ensemble_prompts,
        "seed": random_seed,
    }
    if num_sequences is not None:
        params["max_msa_sequences"] = num_sequences
    if num_residues is not None:
        params["max_msa_tokens"] = num_residues

    response = session.post(endpoint, params=params)
    return FutureFactory.create_future(session=session, response=response)


def upload_prompt_post(
    session: APISession,
    prompt_file: BinaryIO,
):
    """
    Directly upload a prompt.

    Bypass post_msa and prompt_post steps entirely. In this case PoET will use the prompt as is.
    You can specify multiple prompts (one per replicate) with an `<END_PROMPT>\n` between CSVs.

    Parameters
    ----------
    session : APISession
        An instance of APISession to manage interactions with the API.
    prompt_file : BinaryIO
        Binary I/O object representing the prompt file.

    Raises
    ------
    APIError
        If there is an issue with the API request.

    Returns
    -------
    PromptJob
        An object representing the status and results of the prompt job.
    """

    endpoint = "v1/align/upload_prompt"
    files = {"prompt_file": prompt_file}
    try:
        response = session.post(endpoint, files=files)
        return FutureFactory.create_future(session=session, response=response)
    except Exception as exc:
        raise APIError(f"Failed to upload prompt post: {exc}") from exc


def poet_score_post(session: APISession, prompt_id: str, queries: List[bytes]):
    """
    Submits a job to score sequences based on the given prompt.

    Parameters
    ----------
    session : APISession
        An instance of APISession to manage interactions with the API.
    prompt_id : str
        The ID of the prompt.
    queries : List[str]
        A list of query sequences to be scored.

    Raises
    ------
    APIError
        If there is an issue with the API request.

    Returns
    -------
    PoetScoreJob
        An object representing the status and results of the scoring job.
    """
    endpoint = "v1/poet/score"

    if len(queries) == 0:
        raise MissingParameterError("Must include queries for scoring!")
    if not prompt_id:
        raise MissingParameterError("Must include prompt_id in request!")

    if isinstance(queries[0], str):
        queries = [i.encode() for i in queries]
    try:
        variant_file = BytesIO(b"\n".join(queries))
        params = {"prompt_id": prompt_id}
        response = session.post(
            endpoint, files={"variant_file": variant_file}, params=params
        )
        return FutureFactory.create_future(session=session, response=response)
    except Exception as exc:
        raise APIError(f"Failed to post poet score: {exc}") from exc


def poet_score_get(
    session: APISession, job_id, page_size=config.POET_PAGE_SIZE, page_offset=0
):
    """
    Fetch a page of results from a PoET score job.

    Parameters
    ----------
    session : APISession
        An instance of APISession to manage interactions with the API.
    job_id : str
        The ID of the PoET scoring job to fetch results from.
    page_size : int, optional
        The number of results to fetch in a single page. Defaults to config.POET_PAGE_SIZE.
    page_offset : int, optional
        The offset (number of results) to start fetching results from. Defaults to 0.

    Raises
    ------
    APIError
        If the provided page size is larger than the maximum allowed page size.

    Returns
    -------
    PoetScoreJob
        An object representing the PoET scoring job, including its current status and results (if any).
    """
    endpoint = "v1/poet/score"

    if page_size > config.POET_MAX_PAGE_SIZE:
        raise APIError(
            f"Page size must be less than the max for PoET: {config.POET_MAX_PAGE_SIZE}"
        )

    response = session.get(
        endpoint,
        params={"job_id": job_id, "page_size": page_size, "page_offset": page_offset},
    )

    return FutureFactory.create_future(session=session, response=response)


class AlignFutureMixin:
    session: APISession
    job: Job

    def get_input(self, input_type: PoetInputType):
        """See child function docs."""
        return get_input(self.session, self.job, input_type)

    def get_prompt(self, prompt_index: Optional[int] = None):
        """See child function docs."""
        return get_prompt(self.session, self.job, prompt_index=prompt_index)

    def get_seed(self):
        """See child function docs."""
        return get_seed(self.session, self.job)

    def get_msa(self):
        """See child function docs."""
        return get_msa(self.session, self.job)

    @property
    def id(self):
        return self.job.job_id


[docs] class MSAFuture(AlignFutureMixin, AsyncJobFuture, FutureBase): """ Represents a result of a MSA job. Attributes ---------- session : APISession An instance of APISession for API interactions. job : Job The PoET scoring job. page_size : int The number of results to fetch in a single page. Methods ------- get(verbose=False) Get the final results of the PoET scoring job. Returns ------- List[PoetScoreResult] The list of results from the PoET scoring job. """ job_type = "/align/align"
[docs] def __init__(self, session: APISession, job: Job, page_size=config.POET_PAGE_SIZE): """ init a PoetScoreFuture instance. Parameters ---------- session : APISession An instance of APISession for API interactions. job : Job The PoET scoring job. page_size : int The number of results to fetch in a single page. """ super().__init__(session, job) self.page_size = page_size self._msa_id = None self._prompt_id = None
def __str__(self) -> str: return str(self.job) def __repr__(self) -> str: return repr(self.job) @property def id(self): return self.job.job_id @property def prompt_id(self): if self.job.job_type == "/align/prompt" and self._prompt_id is None: self._prompt_id = self.job.job_id return self._prompt_id @property def msa_id(self): if self.job.job_type == "/align/align" and self._msa_id is None: self._msa_id = self.job.job_id return self._msa_id
[docs] def wait(self, verbose: bool = False): _ = self.job.wait( self.session, interval=config.POLLING_INTERVAL, timeout=config.POLLING_TIMEOUT, verbose=False, ) # no progress to track return self.get()
[docs] def get(self, verbose: bool = False) -> csv.reader: return self.get_msa()
[docs] def sample_prompt( self, num_sequences: Optional[int] = None, num_residues: Optional[int] = None, method: MSASamplingMethod = MSASamplingMethod.NEIGHBORS_NONGAP_NORM_NO_LIMIT, homology_level: float = 0.8, max_similarity: float = 1.0, min_similarity: float = 0.0, always_include_seed_sequence: bool = False, num_ensemble_prompts: int = 1, random_seed: Optional[int] = None, ) -> PromptJob: """ Create a protein sequence prompt from a linked MSA (Multiple Sequence Alignment) for PoET Jobs. Parameters ---------- num_sequences : int, optional Maximum number of sequences in the prompt. Must be <100. num_residues : int, optional Maximum number of residues (tokens) in the prompt. Must be less than 24577. method : MSASamplingMethod, optional Method to use for MSA sampling. Defaults to NEIGHBORS_NONGAP_NORM_NO_LIMIT. homology_level : float, optional Level of homology for sequences in the MSA (neighbors methods only). Must be between 0 and 1. Defaults to 0.8. max_similarity : float, optional Maximum similarity between sequences in the MSA and the seed. Must be between 0 and 1. Defaults to 1.0. min_similarity : float, optional Minimum similarity between sequences in the MSA and the seed. Must be between 0 and 1. Defaults to 0.0. always_include_seed_sequence : bool, optional Whether to always include the seed sequence in the MSA. Defaults to False. num_ensemble_prompts : int, optional Number of ensemble jobs to run. Defaults to 1. random_seed : int, optional Seed for random number generation. Defaults to a random number between 0 and 2**32-1. Raises ------ InvalidParameterError If provided parameter values are not in the allowed range. MissingParameterError If both or none of 'num_sequences', 'num_residues' is specified. Returns ------- PromptJob """ msa_id = self.msa_id return prompt_post( self.session, msa_id, num_sequences=num_sequences, num_residues=num_residues, method=method, homology_level=homology_level, max_similarity=max_similarity, min_similarity=min_similarity, always_include_seed_sequence=always_include_seed_sequence, num_ensemble_prompts=num_ensemble_prompts, random_seed=random_seed, )
[docs] class PromptFuture(MSAFuture, FutureBase): """ Represents a result of a prompt job. Attributes ---------- session : APISession An instance of APISession for API interactions. job : Job The PoET scoring job. page_size : int The number of results to fetch in a single page. Methods ------- get(verbose=False) Get the final results of the PoET scoring job. Returns ------- List[PoetScoreResult] The list of results from the PoET scoring job. """ job_type = "/align/prompt"
[docs] def __init__( self, session: APISession, job: Job, page_size=config.POET_PAGE_SIZE, msa_id: Optional[str] = None, ): """ init a PoetScoreFuture instance. Parameters ---------- session (APISession): An instance of APISession for API interactions. job (Job): The PoET scoring job. page_size (int, optional): The number of results to fetch in a single page. Defaults to config.POET_PAGE_SIZE. """ super().__init__(session, job) self.page_size = page_size if msa_id is None: msa_id = job_args_get(self.session, job.job_id).get("root_msa") self._msa_id = msa_id
[docs] def get(self, verbose: bool = False) -> csv.reader: return self.get_prompt()
Prompt = Union[PromptFuture, str] def validate_prompt(prompt: Prompt): """helper function to validate prompt_id is prompt type""" if not (isinstance(prompt, PromptFuture) or isinstance(prompt, str)): raise ValueError( f"Expect prompt to be either a PromptFuture or str, got {type(prompt)}" ) if isinstance(prompt, str): return prompt return prompt.prompt_id def validate_msa(msa: Union[MSAFuture, str]): """helper function to validate prompt_id is prompt type""" if not (isinstance(msa, MSAFuture) or isinstance(msa, str)): raise ValueError( f"Expect prompt to be either a MSAFuture or str, got {type(msa)}" ) if isinstance(msa, str): return msa return msa.msa_id
[docs] class AlignAPI: """API interface for calling Poet and Align endpoints"""
[docs] def __init__(self, session: APISession): self.session = session
[docs] def upload_msa(self, msa_file) -> MSAFuture: """ Upload an MSA from file. Parameters ---------- msa_file : str, optional Ready-made MSA. If not provided, default value is None. Raises ------ APIError If there is an issue with the API request. Returns ------- MSAJob Job object containing the details of the MSA upload. """ return msa_post(self.session, msa_file=msa_file)
[docs] def create_msa(self, seed: bytes) -> MSAFuture: """ Construct an MSA via homology search with the seed sequence. Parameters ---------- seed : bytes Seed sequence for the MSA construction. Raises ------ APIError If there is an issue with the API request. Returns ------- MSAJob Job object containing the details of the MSA construction. """ return msa_post(self.session, seed=seed)
[docs] def upload_prompt(self, prompt_file: BinaryIO) -> Job: """ Directly upload a prompt. Bypass post_msa and prompt_post steps entirely. In this case PoET will use the prompt as is. You can specify multiple prompts (one per replicate) with an <END_PROMPT> and newline between CSVs. Parameters ---------- prompt_file : BinaryIO Binary I/O object representing the prompt file. Raises ------ APIError If there is an issue with the API request. Returns ------- PromptJob An object representing the status and results of the prompt job. """ return upload_prompt_post(self.session, prompt_file)
[docs] def get_prompt(self, job: Job, prompt_index: Optional[int] = None) -> csv.reader: """ Get prompts for a given job. Parameters ---------- job : Job The job for which to retrieve data. prompt_index : Optional[int] The replicate number for the prompt (input_type=-PROMPT only) Returns ------- csv.reader A CSV reader for the response data. """ return get_input( self.session, job, PoetInputType.PROMPT, prompt_index=prompt_index )
[docs] def get_seed(self, job: Job) -> csv.reader: """ Get input data for a given msa job. Parameters ---------- job : Job The job for which to retrieve data. Returns ------- csv.reader A CSV reader for the response data. """ return get_input(self.session, job, PoetInputType.INPUT)
[docs] def get_msa(self, job: Job) -> csv.reader: """ Get generated MSA for a given job. Parameters ---------- job : Job The job for which to retrieve data. Returns ------- csv.reader A CSV reader for the response data. """ return get_input(self.session, job, PoetInputType.MSA)