diff --git a/06_gpu_and_ml/protein_fold/protein_fold.py b/06_gpu_and_ml/protein_fold/protein_fold.py
index 664ca4a25..ea478ba3c 100644
--- a/06_gpu_and_ml/protein_fold/protein_fold.py
+++ b/06_gpu_and_ml/protein_fold/protein_fold.py
@@ -2,13 +2,14 @@
# output-directory: "/tmp/protein-fold"
# ---
-# # Protein fold anything with ESM3, plus some powerful visualizations
+# # Protein Fold sequences with ESM3
-# If you haven't heard of AlphaFold or ESM3, you've likely been living under a
-# rock, it's time to come out! Protein Folding is a hot topic in the world of
-# bioinformatics and machine learning. In this example, we'll show you how to
+# If you haven't heard of [AlphaFold](https://deepmind.google/technologies/alphafold/)
+# or [ESM3](https://github.com/facebookresearch/esm), you've likely been living under a
+# rock, it's time to come out! Protein folding is all the rage in the world of
+# bioinformatics and ML. In this example, we'll show you how to
# use ESM3 to predict the three-dimensional structure of a protein from its raw
-# amino acid sequence. We'll also show you how to visualize the results using
+# amino acid sequence. As a bonus, we'll also show you how to visualize the results using
# [py3DMol](https://3dmol.org/).
# If you're new to protein folding check out the next section for a brief
@@ -16,16 +17,25 @@
# ## What is Protein Folding?
-# A protein's three-dimensional shape determines its function in the body, from
-# catalyzing biochemical reactions to building cellular structures. To develop
-# new drugs and treatments for diseases, researchers need to understand how
-# potential drugs bind to target proteins in the human body and thus need to predict
-# their threre-dimensional structure from their raw amino acid sequence.
+# A protein's three-dimensional shape determines it's function in the body,
+# i.e. everything from catalyzing biochemical reactions to building cellular
+# structures. To develop new drugs and treatments for diseases, researchers
+# need to understand how potential drugs bind to target proteins in the human
+# body and thus need to predict their three-dimensional structure from their
+# raw amino acid sequence.
# Historically determining the protein structure of a single sequence could
# take years of wet-lab experiments and millions of dollars, but with the advent
-# of deep protein folding models like ESM3, this can be done in seconds for a
-# few dollars.
+# of deep protein folding models like Meta's ESM3, a quality approximation can be
+# done in seconds for a few cents on Modal.
+
+# The bioinformatics community has created a number of resources for managing
+# protein data, including the [Research Collaboratory for Structural Bioinformatics](https://www.rcsb.org/) (RCSB) which stores the known structure of over 200,000 proteins.
+# These sets of proteins structures are known as Protein Data Banks (PDB).
+# We'll also be using the open source [Biotite](https://www.biotite-python.org/)
+# library to help us extract sequences and residues from PDBs. Biotite was
+# created by Patrick Kunzmann and members of the Protein Bioinformatics lab
+# at Rub University Bochum in Germany.
# ## Basic Setup
@@ -44,37 +54,43 @@
gpu = "A10G"
-# ### Create a Volume to store weights, data, and HTML
+# ### Create a Volume to store weights and pdb data
-# On the first launch we'll be downloading the model weights from Hugging
-# Face(?), to speed up future cold starts we'll store the weights on a
-# [Volume](https://modal.com/docs/guide/volumes) for quick reads. We'll also
-# use this volume to cache PDB data from the web and store HTML files for
-# visualization.
-volume = modal.Volume.from_name(
- "example-protein-fold", create_if_missing=True
-)
+# To minimize cold start times we'll store the model weights on a modal
+# [Volume](https://modal.com/docs/guide/volumes) so they can be quickly loaded
+# onto the GPU. We'll also store pdb data on the volume to minimize hitting
+# the RCSB server for the same data multiple times.
+
+
+volume = modal.Volume.from_name("example-protein-fold", create_if_missing=True)
VOLUME_PATH = Path("/vol/data")
PDBS_PATH = VOLUME_PATH / "pdbs"
HTMLS_PATH = VOLUME_PATH / "htmls"
+
# ### Define dependencies in container images
# The container image for inference is based on Modal's default slim Debian
-# Linux image with `esm` for loading and running or model
-# for the
+# Linux image with `esm` for loading and running the model, and `hf_transfer`
+# for a higher bandwidth download of the model weights from hugging face.
+
+
esm3_image = (
- modal.Image.debian_slim(python_version="3.12")
- .pip_install(
+ modal.Image.debian_slim(python_version="3.12").pip_install(
"esm==3.0.5",
- "torch==2.4.1", # Needed? FIXME
- )
+ "torch==2.4.1",
+ "huggingface_hub[hf_transfer]==0.26.2",
+ ).env({"HF_HUB_ENABLE_HF_TRANSFER": "1"})
)
-# We'll also define a web app image for extracting PDB data. FIXME
+
+# We'll also define a web app image as a frontend for the entire system that
+# includes `gradio` for building a UI, `biotite` for extracting sequences and
+# residues from PDBs, and `py3Dmol` for visualizing the 3D structures.
+
+
web_app_image = (
- modal.Image.debian_slim(python_version="3.11")
- .pip_install(
+ modal.Image.debian_slim(python_version="3.12").pip_install(
"esm==3.0.5",
"gradio~=4.44.0",
"biotite==0.41.2",
@@ -86,27 +102,37 @@
)
+# Here we "pre-import" libraries that will be used by the functions we run
+# on Modal in a given image using the `with image.imports` context manager.
+
-# We can also "pre-import" libraries that will be used by the functions we run on Modal in a given image
-# using the `with image.imports` context manager.
with esm3_image.imports():
+ from esm.models.esm3 import ESM3
from esm.sdk.api import ESM3InferenceClient, ESMProtein, GenerationConfig
import torch
+
with web_app_image.imports():
from py3DmolWrapper import py3DMolViewWrapper
+ import os
+
+ import biotite.database.rcsb as rcsb
+ import biotite.structure as b_structure
+ import biotite.structure.io.pdbx as pdbx
+ from biotite.structure.io.pdbx import CIFFile
+ from esm.sdk.api import ESMProtein
+
# ## Defining a `Model` inference class for ESM3
# Next, we map the model's setup and inference code onto Modal.
-# 1. We run any setup that can be persisted to disk in methods decorated with `@build`.
-# In this example, that includes downloading the model weights.
-# 2. We run any additional setup, like moving the model to the GPU, in methods decorated with `@enter`.
-# We do our model optimizations in this step. For details, see the section on `torch.compile` below.
-# 3. We run the actual inference in methods decorated with `@method`.
+# 1. To ensure the cached image includes the model weights, we download them in a
+# method decorated with `@build`, this reduces cold boot times.
+# 2. For any additional setup code that involves CPU or GPU RAM we put in a
+# method deocrated with `@enter` which runs on container start.
+# 3. To run the actual inference, we put it in a method decorated with `@method`
-# Here we define the
@app.cls(
image=esm3_image,
@@ -114,45 +140,69 @@
gpu=gpu,
timeout=20 * MINUTES,
)
-class ModelInference:
+class Model:
+ def setup_model(self):
+ self.model: ESM3InferenceClient = ESM3.from_pretrained(
+ "esm3_sm_open_v1"
+ # How to store on volume FIXME
+ )
+
@modal.build()
- @modal.enter()
def build(self):
- from esm.models.esm3 import ESM3
- self.model: ESM3InferenceClient = ESM3.from_pretrained("esm3_sm_open_v1").to("cuda")
- # XXX Remove:
- # https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/protein_folding.ipynb#scrollTo=cc0f0186
+ self.setup_model()
+
+ @modal.enter()
+ def enter(self):
+ self.setup_model()
+ self.model.to("cuda")
+
+ # Enable half precision for faster inference
self.model = self.model.half()
torch.backends.cuda.matmul.allow_tf32 = True
self.max_steps = 250
+ L.info(f"Setting max ESM steps to: {self.max_steps}")
@modal.method()
- def generate(self, sequence: str) -> bool:
- start_time = time.monotonic() # TODO Remove
+ def inference(self, sequence: str) -> bool:
+ start_time = time.monotonic() # TODO Remove
num_steps = min(len(sequence), self.max_steps)
structure_generation_config = GenerationConfig(
track="structure",
num_steps=num_steps,
)
+ L.info("Running ESM3 inference with num_steps={num_steps}")
esm_protein = self.model.generate(
- ESMProtein(sequence=sequence),
- structure_generation_config
+ ESMProtein(sequence=sequence), structure_generation_config
)
latency_s = time.monotonic() - start_time
- print (f"Inference latency: {latency_s:.2f} seconds.")
+ L.info(f"Inference latency: {latency_s:.2f} seconds.")
- # Check that esm_protein did not error.
- print ("Checking for errors...")
+ L.info("Checking for errors...")
if hasattr(esm_protein, "error_msg"):
raise ValueError(esm_protein.error_msg)
- print ("Moving all data off GPU before returning to caller...")
- esm_protein.ptm = esm_protein.ptm.to('cpu')
+ L.info("Moving all data off GPU before returning to caller...")
+ esm_protein.ptm = esm_protein.ptm.to("cpu")
return esm_protein
+
+# ### Serving a Gradio UI `web_endpoint` with an `asgi_app`
+
+# The `ModelInference` class above is available for use
+# from any other Python environment with the right Modal credentials
+# and the `modal` package installed -- just use [`lookup`](https://modal.com/docs/reference/modal.Cls#lookup).
+
+# But we can also expose it via a web app using the `@asgi_app` decorator. Here
+# we will specifically create a Gradio web app that allows us to visualize
+# the 3D structure output of the ESM3 model as well as any known structures from
+# the RCSB Protein Data Bank.
+
+# You should see the URL for this UI in the output of `modal deploy`
+# or on your [Modal app dashboard](https://modal.com/apps) for this app.
+
assets_path = Path(__file__).parent / "assets"
@app.function(
@@ -170,172 +220,53 @@ def protein_fold_fastapi_app():
import gradio as gr
from gradio.routes import mount_gradio_app
- import biotite.database.rcsb as rcsb
- import biotite.structure as b_structure
- import biotite.structure.io.pdbx as pdbx
- from biotite.structure.io.pdbx import CIFFile
- import biotite.application.blast as blast
- from esm.sdk.api import ESMProtein
-
- import os
-
width, height = 400, 400
- def fetch_pdb_if_necessary(pdb_id, pdb_type):
- # Only fetch from server if it's not on the volume already.
- assert pdb_type in ("pdb", "pdbx")
- file_path = PDBS_PATH / f"{pdb_id}.{pdb_type}"
- if not os.path.exists(file_path):
- print (f"Loading PDB {file_path} from server...")
- rcsb.fetch(pdb_id, pdb_type, str(PDBS_PATH))
- return file_path
-
- def get_sequence(pdb_id):
- try:
- pdb_id = pdb_id.strip() # Remove whitespace
- pdbx_file_path = fetch_pdb_if_necessary(pdb_id, "pdbx")
-
- structure = pdbx.get_structure(CIFFile.read(pdbx_file_path), model=1)
-
- # TODO Excluding atoms that are not in a Amino Acid?
- (amino_sequences, _) = (
- b_structure.to_sequence(
- structure[b_structure.filter_amino_acids(structure)]))
- # Note: len(amino_sequences) == # of chains
-
- # TODO Should I put something between chains?
- sequence = "".join([str(s) for s in amino_sequences])
- return sequence
-
- except Exception as e:
- return f"Error: {e}"
-
- def find_pdb_id_from_sequence(query_sequence, maybe_stale_pdb_id):
- pdb_id_sequence = get_sequence(maybe_stale_pdb_id)
-
- if pdb_id_sequence == query_sequence:
- print ("Skipped BLAST run, PDB ID already known.")
- not_stale_pdb_id = maybe_stale_pdb_id
- return not_stale_pdb_id
- else:
- # TODO Use Blast or some kind of Sequence Database
- # Skipping Blast which is rate limited for now.
- return None
-
- print ("Running blast to find PDB ID of sequence...")
- blast_app = blast.BlastWebApp(
- "blastp", query_sequence, database="pdb")
- blast_app.start()
- blast_app.join()
-
- alignments = blast_app.get_alignments()
- if len(alignments) == 0:
- return None
- return alignments[0].hit_id
-
- def build_database_html(sequence, maybe_stale_pdb_id):
- pdb_id = find_pdb_id_from_sequence(sequence, maybe_stale_pdb_id)
- if pdb_id is None:
- return "
Folding Structure of Sequence not found.
"
-
- # Remove chain information if present.
- if pdb_id.find("_") != -1: # "1CRN_A" -> "1CRN"
- pdb_id = pdb_id[:pdb_id.index("_")]
-
- # Extract secondary structure from PDBX file.
- pdbx_file_path = fetch_pdb_if_necessary(pdb_id, "pdbx")
- structure = pdbx.get_structure(CIFFile.read(pdbx_file_path), model=1)
- # TODO Should I only use AA atoms?
- atoms = structure[
- b_structure.filter_amino_acids(structure)]
- residue_secondary_structures = b_structure.annotate_sse(atoms)
- residue_ids = b_structure.get_residues(atoms)[0]
- residue_id_to_sse = {}
- for (residue_id, sse) in zip(residue_ids, residue_secondary_structures):
- residue_id_to_sse[int(residue_id)] = sse
-
- # Extract PDB string from PDB file.
- pdb_file_path = fetch_pdb_if_necessary(pdb_id, "pdb")
- pdb_string = Path(pdb_file_path).read_text()
-
- return py3DMolViewWrapper().build_html_with_secondary_structure(
- width, height, pdb_string, residue_id_to_sse)
-
- def remove_nested_quotes(html):
- """Remove nested quotes in HEADER"""
- if html.find("HEADER") == -1:
- return html
-
- # Need this because srcdoc adds another quote and triple is tricky.
- i = html.index("HEADER")
- # Getting the HEADER double quote end by matching: \n","pdb");
- j = html[i:].index('"pdb");') + i - len('",')
- header_html = html[i:j]
- for delete_me in ("\\'", '\\"', "'", '"'):
- header_html = header_html.replace(delete_me, "")
- html = html[:i] + header_html + html[j:]
- return html
-
- def postprocess_html(html):
- html = remove_nested_quotes(html)
- html = html.replace("'", '"')
-
- html_wrapped = f"""{html}"""
- iframe_html = (f"""""")
- return iframe_html
- # return html_wrapped
-
- def run_esm_and_graph_debug(sequence, maybe_stale_pdb_id):
- # Data cleaning
- sequence = sequence.strip() # Remove whitespace
- maybe_stale_pdb_id = maybe_stale_pdb_id.strip() # Remove whitespace
-
- database_sse_html = build_database_html(sequence, maybe_stale_pdb_id)
-
- # Save HTMLs to volume for debugging issues.
- HTMLS_PATH.mkdir(parents=True, exist_ok=True)
- with open(str(HTMLS_PATH / "db.html"), "w") as f:
- f.write(postprocess_html(database_sse_html))
-
- return [postprocess_html(h)
- # for h in (esm_sse_html, esm_pLDDT_html, database_sse_html)]
- for h in [database_sse_html] * 3]
-
- def run_esm_and_graph(sequence, maybe_stale_pdb_id):
- # Data cleaning
- sequence = sequence.strip() # Remove whitespace
- maybe_stale_pdb_id = maybe_stale_pdb_id.strip() # Remove whitespace
-
- # Run ESM
- esm_protein: ESMProtein = ModelInference().generate.remote(sequence)
+ def extract_data(esm_protein):
residue_pLDDTs = (100 * esm_protein.plddt).tolist()
atoms = esm_protein.to_protein_chain().atom_array
- residue_secondary_structures = b_structure.annotate_sse(atoms)
- residue_ids = b_structure.get_residues(atoms)[0]
- residue_id_to_sse = {}
- for (residue_id, sse) in zip(residue_ids, residue_secondary_structures):
- residue_id_to_sse[int(residue_id)] = sse
+ residue_id_to_sse = extract_residues(atoms)
+ return residue_pLDDTs, residue_id_to_sse
- esm_sse_html = (
- py3DMolViewWrapper().build_html_with_secondary_structure(
- width, height, esm_protein.to_pdb_string(),
- residue_id_to_sse))
+ def run_esm(sequence, maybe_stale_pdb_id):
+ L.info("Removing whitespace from text input")
+ sequence = sequence.strip()
+ maybe_stale_pdb_id = maybe_stale_pdb_id.strip()
- esm_pLDDT_html = (
- py3DMolViewWrapper().build_html_with_pLDDTs(
- width, height, esm_protein.to_pdb_string(), residue_pLDDTs))
+ L.info("Running ESM")
+ esm_protein: ESMProtein = Model().inference.remote(sequence)
+ residue_pLDDTs, residue_id_to_sse = extract_data(esm_protein)
- database_sse_html = build_database_html(sequence, maybe_stale_pdb_id)
+ L.info("Constructing HTML for ESM3 prediction with residues.")
+ esm_sse_html = py3DMolViewWrapper().build_html_with_secondary_structure(
+ width, height, esm_protein.to_pdb_string(), residue_id_to_sse
+ )
- # Save HTMLs to volume for debugging issues.
- HTMLS_PATH.mkdir(parents=True, exist_ok=True)
- with open(str(HTMLS_PATH / "db.html"), "w") as f:
- f.write(postprocess_html(database_sse_html))
+ L.info("Constructing HTML for ESM3 prediction with confidence.")
+ esm_pLDDT_html = py3DMolViewWrapper().build_html_with_pLDDTs(
+ width, height, esm_protein.to_pdb_string(), residue_pLDDTs
+ )
- return [postprocess_html(h)
- for h in (esm_sse_html, esm_pLDDT_html, database_sse_html)]
+ pdb_sequence = get_sequence(maybe_stale_pdb_id)
+ if pdb_sequence == sequence:
+ pdb_id = maybe_stale_pdb_id
+ # if pdb_id.find("_") != -1: # "1CRN_A" -> "1CRN"
+ # pdb_id = pdb_id[:pdb_id.index("_")] # WHy was this here.
+ L.info(f"Constructing HTML for RCSB entry for {pdb_id}")
+ pdb_string, residue_id_to_sse = extract_pdb_and_residues(pdb_id)
+ rcsb_sse_html = (
+ py3DMolViewWrapper().build_html_with_secondary_structure(
+ width, height, pdb_string, residue_id_to_sse
+ )
+ )
+ else:
+ L.info("Sequence structure is unknown, generating HTML as such.")
+ rcsb_sse_html = "Folding Structure of Sequence not found.
"
+
+ return [
+ postprocess_html(h)
+ for h in (esm_sse_html, esm_pLDDT_html, rcsb_sse_html)
+ ]
web_app = FastAPI()
@@ -365,23 +296,25 @@ async def background():
"Actin [1ATN]",
"Ribonuclease A [5RSA]",
]
+
+ # Number the examples.
example_pdbs = [f"{i+1}) {x}" for i, x in enumerate(example_pdbs)]
with gr.Blocks(
theme=theme, css=css, title="ESM3 Protein Folding"
) as interface:
- # Title
gr.Markdown("# Fold Proteins using ESM3 Fold")
with gr.Row():
- # PDB ID text box + Get sequence from it button
with gr.Column():
gr.Markdown("## Custom PDB ID")
pdb_id_box = gr.Textbox(
label="Enter PDB ID or select one on the right",
- placeholder="e.g. '5JQ4', '1MBO', '1TPO', etc.")
- get_sequence_button = (
- gr.Button("Retrieve Sequence from PDB ID", variant="primary"))
+ placeholder="e.g. '5JQ4', '1MBO', '1TPO', etc.",
+ )
+ get_sequence_button = gr.Button(
+ "Retrieve Sequence from PDB ID", variant="primary"
+ )
pdb_link_button = gr.Button(value="Open PDB page for ID")
rcsb_link = "https://www.rcsb.org/structure/"
@@ -389,50 +322,53 @@ async def background():
fn=None,
inputs=pdb_id_box,
js=f"""(pdb_id) => {{ window.open("{rcsb_link}" + pdb_id) }}""",
- )
+ )
with gr.Column():
- def extract_pdb_id(idx):
- pdb = example_pdbs[idx]
- s = pdb.index("[") + 1
- e = pdb.index("]")
- return pdb[s:e]
- h = int(len(example_pdbs) / 2)
+ def extract_pdb_id(example_idx):
+ pdb = example_pdbs[example_idx]
+ return pdb[pdb.index("[") + 1 : pdb.index("]")]
+
+ half_len = int(len(example_pdbs) / 2)
gr.Markdown("## Example PDB IDs")
with gr.Row():
with gr.Column():
- # add in a few examples to inspire users
- for i, pdb in enumerate(example_pdbs[:h]):
+ for i, pdb in enumerate(example_pdbs[:half_len]):
btn = gr.Button(pdb, variant="secondary")
btn.click(
- fn=lambda j=i: extract_pdb_id(j), outputs=pdb_id_box)
+ fn=lambda j=i: extract_pdb_id(j),
+ outputs=pdb_id_box,
+ )
with gr.Column():
- # more examples
- for i, pdb in enumerate(example_pdbs[h:]):
+ for i, pdb in enumerate(example_pdbs[half_len:]):
btn = gr.Button(pdb, variant="secondary")
btn.click(
- fn=lambda j=i+4: extract_pdb_id(j), outputs=pdb_id_box)
+ fn=lambda j=i + half_len: extract_pdb_id(j),
+ outputs=pdb_id_box,
+ )
- # Sequence text box + Run ESM from it button
gr.Markdown("## Sequence")
sequence_box = gr.Textbox(
label="Enter a sequence or retrieve it from a PDB ID",
- placeholder="e.g. 'MVTRLE...', 'GKQEG...', etc.")
- run_esm_and_graph_button = (
- gr.Button("Run ESM3 Fold on Sequence", variant="primary"))
+ placeholder="e.g. 'MVTRLE...', 'GKQEG...', etc.",
+ )
+ run_esm_button = gr.Button(
+ "Run ESM3 Fold on Sequence", variant="primary"
+ )
- # 3 Columns of Protein Folding
htmls = []
+ legend_height = 100
with gr.Row():
- # Column 1: ESM Prediction with SSE coloring.
with gr.Column():
gr.Markdown("## ESM3 Prediction - Secondary Structs")
gr.Image( # output image component
- height=100, width=400,
+ height=legend_height,
+ width=width,
value="/assets/secondaryStructureLegend.png",
- show_download_button=False, show_label=False,
+ show_download_button=False,
+ show_label=False,
show_fullscreen_button=False,
)
htmls.append(gr.HTML())
@@ -441,27 +377,34 @@ def extract_pdb_id(idx):
with gr.Column():
gr.Markdown("## ESM3 Prediction - PLTT Confidence")
gr.Image( # output image component
- height=100, width=400, value="/assets/plddtLegend2.png",
- show_download_button=False, show_label=False,
+ height=legend_height,
+ width=width,
+ value="/assets/plddtLegend2.png",
+ show_download_button=False,
+ show_label=False,
show_fullscreen_button=False,
)
htmls.append(gr.HTML())
# Column 3: Database showing Crystalized form of Protein if avail.
with gr.Column():
- gr.Markdown("## Crystalized Structure from PDB")
+ gr.Markdown("## Crystalized Structure from RCSB's PDB")
gr.Image( # output image component
- height=100, width=400,
+ height=legend_height,
+ width=width,
value="/assets/secondaryStructureLegend.png",
- show_download_button=False, show_label=False,
+ show_download_button=False,
+ show_label=False,
show_fullscreen_button=False,
)
htmls.append(gr.HTML())
get_sequence_button.click(
- fn=get_sequence, inputs=[pdb_id_box], outputs=[sequence_box])
- run_esm_and_graph_button.click(fn=run_esm_and_graph,
- inputs=[sequence_box, pdb_id_box], outputs=htmls)
+ fn=get_sequence, inputs=[pdb_id_box], outputs=[sequence_box]
+ )
+ run_esm_button.click(
+ fn=run_esm, inputs=[sequence_box, pdb_id_box], outputs=htmls
+ )
# mount for execution on Modal
return mount_gradio_app(
@@ -470,18 +413,106 @@ def extract_pdb_id(idx):
path="/",
)
+
@app.local_entrypoint()
def main():
- pdb_id = "1ZNI"
- # pdb_id = "1MBO"
- test = retrieve_pdb.remote(pdb_id)
-
- # # YFP?
- # sequence = (
- # "AMFSKVNNQKMLEDCFYIRKKVFVEEQGIPEESEIDEYESESIHLIGYDNGQPVATARIRPINETTVKIERVAVMKSHRGQGMGRMLMQAVESLAKDEGFYVATMNAQCHAIPFYESLNFKMRGNIFLEEGIEHIEMTKKLT")
- # for i in range(10):
- # ModelInference().generate.remote(sequence)
- # x = 1
-
-if __name__ == "__main__":
- main()
+ sequence = (
+ "AMFSKVNNQKMLEDCFYIRKKVFVEEQGIPEESEIDEYESESIHLIGYDNGQPVATARIRPINETTVKI"
+ "ERVAVMKSHRGQGMGRMLMQAVESLAKDEGFYVATMNAQCHAIPFYESLNFKMRGNIFLEEGIEHIEMT"
+ "KKLT"
+ )
+ esm_protein = Model().inference.remote(sequence)
+ L.info(esm_protein)
+
+
+# ## Addenda
+
+# The remainder of this code is boilerplate.
+
+# ### Extracting Sequences and Residues from PDBs
+
+# A fair amount of code is required to extract sequences and residue
+# information from pdb strings, we build several helper functions for it below.
+
+
+def fetch_pdb_if_necessary(pdb_id, pdb_type):
+ """Fetch PDB from RCSB if not already cached."""
+
+ assert pdb_type in ("pdb", "pdbx")
+
+ file_path = PDBS_PATH / f"{pdb_id}.{pdb_type}"
+ if not os.path.exists(file_path):
+ L.info(f"Loading PDB {file_path} from server...")
+ rcsb.fetch(pdb_id, pdb_type, str(PDBS_PATH))
+ return file_path
+
+
+def get_sequence(pdb_id):
+ try:
+ pdb_id = pdb_id.strip()
+ pdbx_file_path = fetch_pdb_if_necessary(pdb_id, "pdbx")
+
+ structure = pdbx.get_structure(CIFFile.read(pdbx_file_path), model=1)
+
+ amino_sequences, _ = b_structure.to_sequence(
+ structure[b_structure.filter_amino_acids(structure)]
+ )
+
+ sequence = "".join([str(s) for s in amino_sequences])
+ return sequence
+
+ except Exception as e:
+ return f"Error: {e}"
+
+
+def extract_residues(atoms):
+ residue_secondary_structures = b_structure.annotate_sse(atoms)
+ residue_ids = b_structure.get_residues(atoms)[0]
+ residue_id_to_sse = {}
+ for residue_id, sse in zip(residue_ids, residue_secondary_structures):
+ residue_id_to_sse[int(residue_id)] = sse
+ return residue_id_to_sse
+
+
+def extract_pdb_and_residues(pdb_id):
+ pdb_file_path = fetch_pdb_if_necessary(pdb_id, "pdb")
+ pdb_string = Path(pdb_file_path).read_text()
+
+ pdbx_file_path = fetch_pdb_if_necessary(pdb_id, "pdbx")
+ structure = pdbx.get_structure(CIFFile.read(pdbx_file_path), model=1)
+ atoms = structure[b_structure.filter_amino_acids(structure)]
+ residue_id_to_sse = extract_residues(atoms)
+
+ return pdb_string, residue_id_to_sse
+
+# ### Miscellaneous
+# The remaining code includes small helper functions for cleaning up HTML
+# generated by py3DMol for viewing on a webpage.
+
+def remove_nested_quotes(html):
+ """Remove triple nested quotes in HEADER"""
+ if html.find("HEADER") == -1:
+ return html
+
+ i = html.index("HEADER")
+ j = html[i:].index('"pdb");') + i - len('",')
+ header_html = html[i:j]
+ for delete_me in ("\\'", '\\"', "'", '"'):
+ header_html = header_html.replace(delete_me, "")
+ html = html[:i] + header_html + html[j:]
+ return html
+
+
+def postprocess_html(html):
+ html = remove_nested_quotes(html)
+ html = html.replace("'", '"')
+
+ L.info("Wrapping py3DMol HTML in iframe so we can view multiple HTMLs.")
+ html_wrapped = f"""{html}"""
+ iframe_html = (
+ f""""""
+ )
+ return iframe_html # FIXME?
+ # return html_wrapped
diff --git a/06_gpu_and_ml/protein_fold/py3DmolWrapper.py b/06_gpu_and_ml/protein_fold/py3DmolWrapper.py
index 91c59c67e..30512e542 100644
--- a/06_gpu_and_ml/protein_fold/py3DmolWrapper.py
+++ b/06_gpu_and_ml/protein_fold/py3DmolWrapper.py
@@ -3,10 +3,12 @@
# For PDB file formatting and conventions refer to:
# https://www.biostat.jhsph.edu/~iruczins/teaching/260.655/links/pdbformat.pdf
-import py3Dmol
+
+from dataclasses import dataclass
import logging as L
import numpy as np
-from dataclasses import dataclass
+import py3Dmol
+
@dataclass
class pLDDTBands:
@@ -15,7 +17,7 @@ class pLDDTBands:
name: str
color: str
-class py3DMolViewWrapper(object):
+class py3DMolViewWrapper:
def __init__(self):
# Color & Ranges AlphaFold colab
self.pLDDT_bands = [
@@ -34,7 +36,7 @@ def __init__(self):
'a' : magenta_hex, # Alpha Helix
'b': gold_hex, # Beta Sheet
'c': white_hex, # Coil
- '': black_hex} # Loop?
+ '': black_hex} # Loop
#################################
### Secondary Structures Plot ###
@@ -64,13 +66,8 @@ def build_html_with_secondary_structure(self,
view = py3Dmol.view(width=width, height=height)
view.addModel(pdb_string, "pdb")
- lowest_residue_id = self.setup_secondary_structures_plot(
- pdb_string, residue_id_to_sse)
-
color_map = {rid : self.secondary_structure_to_color[sse]
for rid, sse in residue_id_to_sse.items()}
- print (lowest_residue_id)
- print (color_map)
view.setStyle(
{"cartoon":
@@ -78,7 +75,6 @@ def build_html_with_secondary_structure(self,
{"prop": "resi", # refers to residual index in pdb_string
"map": color_map}}})
view.zoomTo()
- # TODO: pdb_string sometimes creates invalid HTML (missing ')')
return view._make_html()
###############################
@@ -134,7 +130,6 @@ def setup_pLDDTs_plot(
assert num_residues == len(residue_pLDDTs), (
f"{num_residues} != {len(residue_pLDDTs)}")
-
new_pdb_string = '\n'.join(new_lines)
return new_pdb_string
@@ -155,5 +150,4 @@ def build_html_with_pLDDTs(self, width, height, pdb_string, residue_pLDDTs):
'map': color_map}}})
view.zoomTo()
- # TODO: pdb_string sometimes creates invalid HTML (missing ')')
return view._make_html()