[1]:
%matplotlib inline

Getting started with Embeddings

This notebook will briefly cover how to run Embedding workflows.

For more information please read the docs.

[2]:
import numpy as np
import matplotlib.pyplot as plt
import json

Setup

Connect to the OpenProtein backend with your credentials:

[3]:
import openprotein

with open('secrets.config', 'r') as f:
    config = json.load(f)

session = openprotein.connect(username= config['username'], password= config['password'])

Model metadata

You can list the available models, and fetch metadata for more information (inc publications and DOIs where available):

[4]:
session.embedding.list_models()
[4]:
[esm1b_t33_650M_UR50S,
 esm1v_t33_650M_UR90S_1,
 esm1v_t33_650M_UR90S_2,
 esm1v_t33_650M_UR90S_3,
 esm1v_t33_650M_UR90S_4,
 esm1v_t33_650M_UR90S_5,
 esm2_t12_35M_UR50D,
 esm2_t30_150M_UR50D,
 esm2_t33_650M_UR50D,
 esm2_t36_3B_UR50D,
 esm2_t6_8M_UR50D,
 prot-seq,
 rotaprot-large-uniref50w,
 rotaprot-large-uniref90-ft]

You can view more information on each model:

[5]:
esm_model = session.embedding.list_models()[0]
esm_model.metadata.dict()['description']
[5]:
{'citation_title': 'Biological Structure and Function Emerge from Scaling Unsupervised Learning to 250 Million Protein Sequences',
 'doi': '10.1101/622803',
 'summary': 'ESM1b model with 650M parameters'}

There’s data available on supported tokens and outputs too:

[6]:
esm_model.metadata.dict()
[6]:
{'model_id': 'esm1b_t33_650M_UR50S',
 'description': {'citation_title': 'Biological Structure and Function Emerge from Scaling Unsupervised Learning to 250 Million Protein Sequences',
  'doi': '10.1101/622803',
  'summary': 'ESM1b model with 650M parameters'},
 'max_sequence_length': 1022,
 'dimension': 1280,
 'output_types': ['attn', 'embed', 'logits'],
 'input_tokens': ['A',
  'R',
  'N',
  'D',
  'C',
  'Q',
  'E',
  'G',
  'H',
  'I',
  'L',
  'K',
  'M',
  'F',
  'P',
  'S',
  'T',
  'W',
  'Y',
  'V',
  'X',
  'O',
  'U',
  'B',
  'Z'],
 'output_tokens': ['<cls>',
  '<pad>',
  '<eos>',
  '<unk>',
  'L',
  'A',
  'G',
  'V',
  'S',
  'E',
  'R',
  'T',
  'I',
  'D',
  'P',
  'K',
  'Q',
  'N',
  'F',
  'Y',
  'M',
  'H',
  'W',
  'C',
  '<null_0>',
  'B',
  'U',
  'Z',
  'O',
  '.',
  '-',
  '<null_1>',
  'X'],
 'token_descriptions': [[{'id': 0,
    'token': '<cls>',
    'primary': True,
    'description': 'Start token'}],
  [{'id': 1,
    'token': '<pad>',
    'primary': True,
    'description': 'Padding token'}],
  [{'id': 2, 'token': '<eos>', 'primary': True, 'description': 'Stop token'}],
  [{'id': 3,
    'token': '<unk>',
    'primary': True,
    'description': 'Unknown token'}],
  [{'id': 4, 'token': 'L', 'primary': True, 'description': 'Leucine'}],
  [{'id': 5, 'token': 'A', 'primary': True, 'description': 'Alanine'}],
  [{'id': 6, 'token': 'G', 'primary': True, 'description': 'Glycine'}],
  [{'id': 7, 'token': 'V', 'primary': True, 'description': 'Valine'}],
  [{'id': 8, 'token': 'S', 'primary': True, 'description': 'Serine'}],
  [{'id': 9, 'token': 'E', 'primary': True, 'description': 'Glutamic acid'}],
  [{'id': 10, 'token': 'R', 'primary': True, 'description': 'Arginine'}],
  [{'id': 11, 'token': 'T', 'primary': True, 'description': 'Threonine'}],
  [{'id': 12, 'token': 'I', 'primary': True, 'description': 'Isoleucine'}],
  [{'id': 13, 'token': 'D', 'primary': True, 'description': 'Aspartic acid'}],
  [{'id': 14, 'token': 'P', 'primary': True, 'description': 'Proline'}],
  [{'id': 15, 'token': 'K', 'primary': True, 'description': 'Lysine'}],
  [{'id': 16, 'token': 'Q', 'primary': True, 'description': 'Glutamine'}],
  [{'id': 17, 'token': 'N', 'primary': True, 'description': 'Asparagine'}],
  [{'id': 18, 'token': 'F', 'primary': True, 'description': 'Phenylalanine'}],
  [{'id': 19, 'token': 'Y', 'primary': True, 'description': 'Tyrosine'}],
  [{'id': 20, 'token': 'M', 'primary': True, 'description': 'Methionine'}],
  [{'id': 21, 'token': 'H', 'primary': True, 'description': 'Histidine'}],
  [{'id': 22, 'token': 'W', 'primary': True, 'description': 'Tryptophan'}],
  [{'id': 23, 'token': 'C', 'primary': True, 'description': 'Cysteine'}],
  [{'id': 24,
    'token': '<null_0>',
    'primary': True,
    'description': 'Null token, unused'}],
  [{'id': 25,
    'token': 'B',
    'primary': True,
    'description': 'Aspartic acid or Asparagine'}],
  [{'id': 26, 'token': 'U', 'primary': True, 'description': 'Selenocysteine'}],
  [{'id': 27,
    'token': 'Z',
    'primary': True,
    'description': 'Glutamic acid or Glutamine'}],
  [{'id': 28, 'token': 'O', 'primary': True, 'description': 'Pyrrolysine'}],
  [{'id': 29,
    'token': '.',
    'primary': True,
    'description': 'Insertion token, unused'}],
  [{'id': 30,
    'token': '-',
    'primary': True,
    'description': 'Gap token, unused'}],
  [{'id': 31,
    'token': '<null_1>',
    'primary': True,
    'description': 'Null token, unused'}],
  [{'id': 32,
    'token': 'X',
    'primary': True,
    'description': 'Mask token; represents any amino acid'}]]}

Making requests

We can make embedding requests from the model directly or from the API:

[7]:
# dummy data
sequences= ["AAAAPLHLALA".encode()]
[8]:

esm_job = esm_model.embed(sequences=sequences) esm_job.job
[8]:
Job(status=<JobStatus.PENDING: 'PENDING'>, job_id='89089c15-9e76-41fa-af9c-05452efb3014', job_type='/embeddings/embed_reduced', created_date=datetime.datetime(2023, 8, 4, 4, 10, 29, 565648, tzinfo=datetime.timezone.utc), start_date=None, end_date=None, prerequisite_job_id=None, progress_message=None, progress_counter=0, num_records=1)
[9]:
embedjob = session.embedding.embed(model="esm1b_t33_650M_UR50S", sequences= sequences )
embedjob.job
[9]:
Job(status=<JobStatus.PENDING: 'PENDING'>, job_id='a0187ab9-1a72-4e03-b0d6-cd48fdc04d19', job_type='/embeddings/embed_reduced', created_date=datetime.datetime(2023, 8, 4, 4, 10, 29, 617604, tzinfo=datetime.timezone.utc), start_date=None, end_date=None, prerequisite_job_id=None, progress_message=None, progress_counter=0, num_records=1)

Getting results

You can get the results by wait() which will wait for the job to complete:

[10]:
results = embedjob.wait(verbose=True) # wait for results
Waiting:   0%|          | 0/100 [00:00<?, ?it/s, status=RUNNING]Waiting: 100%|██████████| 100/100 [06:43<00:00,  4.04s/it, status=SUCCESS]
Retrieving: 100%|██████████| 1/1 [00:00<00:00, 21.70it/s]
[11]:
results[0][0],results[0][1].shape
[11]:
(b'AAAAPLHLALA', (1280,))
[12]:
results[0][1][0:3]
[12]:
array([ 0.15882437, -0.03162469,  0.11416737], dtype=float32)
[13]:
esm_job.done()
[13]:
False
[14]:
results2 = esm_job.wait(verbose=True) # wait for results
Waiting: 100%|██████████| 100/100 [00:00<00:00, 6243.29it/s, status=SUCCESS]
Retrieving: 100%|██████████| 1/1 [00:00<00:00, 26.15it/s]
[15]:
results2[0][0],results2[0][1].shape
[15]:
(b'AAAAPLHLALA', (1280,))
[16]:
results2[0][1][0:3]
[16]:
array([ 0.15882437, -0.03162469,  0.11416737], dtype=float32)

You can aso fetch results by sequence (useful for when we have many sequence embeddings!):

[17]:
esm_job.get_item(b"AAAAPLHLALA")[0:3]
[17]:
array([ 0.15882437, -0.03162469,  0.11416737], dtype=float32)

Lastly, you can also use the get() method as with other workflows:

[18]:
esm_job.get()
[18]:
[(b'AAAAPLHLALA',
  array([ 0.15882437, -0.03162469,  0.11416737, ..., -0.17913206,
          0.19573624,  0.13490376], dtype=float32))]

Resume workflows

Lastly, it’s possible to resume from where you left off with the job id:

[19]:
esm_job_id = esm_job.job.job_id
[20]:
reloaded_job = session.embedding.load_job(esm_job_id)
reloaded_job.job
[20]:
Job(status=<JobStatus.SUCCESS: 'SUCCESS'>, job_id='89089c15-9e76-41fa-af9c-05452efb3014', job_type='/embeddings/embed_reduced', created_date=datetime.datetime(2023, 8, 4, 4, 10, 29, 565648), start_date=datetime.datetime(2023, 8, 4, 4, 10, 29, 782187), end_date=datetime.datetime(2023, 8, 4, 4, 17, 13, 119073), prerequisite_job_id=None, progress_message=None, progress_counter=100, num_records=None)
[21]:
reloaded_job.sequences
[21]:
[b'AAAAPLHLALA']
[22]:
reloaded_job.get_item(b"AAAAPLHLALA")
[22]:
array([ 0.15882437, -0.03162469,  0.11416737, ..., -0.17913206,
        0.19573624,  0.13490376], dtype=float32)