Skip to content

Commit

Permalink
Merge pull request #30 from linyan439/elixr-demo-colab
Browse files Browse the repository at this point in the history
Elixr demo colab
  • Loading branch information
asellergren authored Oct 27, 2023
2 parents 6a6d069 + f004b2f commit 11d4af8
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 29 deletions.
97 changes: 72 additions & 25 deletions cxr-foundation/CXR_Foundation_Demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,11 @@
},
"outputs": [],
"source": [
"!git clone https://github.com/Google-Health/imaging-research.git\n",
"!pip install imaging-research/cxr-foundation/\n",
"\n",
"# Notebook specific dependencies\n",
"!pip install matplotlib sklearn tf-models-official>=2.13.0 google-cloud-storage"
"!pip install matplotlib tf-models-official>=2.13.0 google-cloud-storage\n",
"\n",
"!git clone https://github.com/Google-Health/imaging-research.git\n",
"!pip install imaging-research/cxr-foundation/"
]
},
{
Expand Down Expand Up @@ -156,7 +156,8 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "SiFTgj8Icyyl"
"id": "SiFTgj8Icyyl",
"cellView": "form"
},
"outputs": [],
"source": [
Expand Down Expand Up @@ -227,10 +228,44 @@
"id": "1r1YZDcEcyym"
},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"source": [
"# Selection between embedding versions"
],
"metadata": {
"id": "ru_dJOFliER2"
}
},
{
"cell_type": "code",
"source": [
"from cxr_foundation.inference import ModelVersion\n",
"import shutil\n",
"\n",
"EMBEDDING_VERSION = 'elixr' #@param ['elixr', 'cxr_foundation']\n",
"if EMBEDDING_VERSION == 'cxr_foundation':\n",
" MODEL_VERSION = ModelVersion.V1\n",
" TOKEN_NUM = 1\n",
" EMBEDDINGS_SIZE = 1376\n",
"elif EMBEDDING_VERSION == 'elixr':\n",
" MODEL_VERSION = ModelVersion.V2\n",
" TOKEN_NUM = 32\n",
" EMBEDDINGS_SIZE = 768\n",
"if not os.path.exists(EMBEDDINGS_DIR):\n",
" os.makedirs(EMBEDDINGS_DIR)\n",
"else:\n",
" # Empty embedding dir to avoid caching when switching embedding versions\n",
" shutil.rmtree(EMBEDDINGS_DIR)\n",
" os.makedirs(EMBEDDINGS_DIR)"
]
],
"metadata": {
"id": "53heuYYCh1x0"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
Expand All @@ -247,8 +282,9 @@
"source": [
"import random\n",
"max_index = len(df_labels[\"dicom_file\"].values)\n",
"print('There are total of %s images. We will sample 5 to display.' % max_index)\n",
"list_of_images = random.sample(range(0, max_index-1), 5)\n",
"file_to_visualize = 2\n",
"print('There are total of %s images. We will sample %d to display.' % (max_index, file_to_visualize))\n",
"list_of_images = random.sample(range(0, max_index-1), file_to_visualize)\n",
"\n",
"import io\n",
"import pydicom\n",
Expand All @@ -258,7 +294,7 @@
"from cxr_foundation import constants\n",
"from cxr_foundation import example_generator_lib\n",
"\n",
"file_to_visualize = 2\n",
"\n",
"\n",
"def show_dicom(adicom):\n",
" \"\"\"Shows the DICOM in a format as passed to CXR Foundation.\"\"\"\n",
Expand Down Expand Up @@ -293,15 +329,15 @@
"source": [
"import logging\n",
"\n",
"from cxr_foundation.inference import generate_embeddings, InputFileType, OutputFileType\n",
"from cxr_foundation.inference import generate_embeddings, InputFileType, OutputFileType, ModelVersion\n",
"\n",
"\n",
"logger = logging.getLogger()\n",
"logger.setLevel(logging.INFO)\n",
"\n",
"# Generate and store a few embeddings in .npz format\n",
"generate_embeddings(input_files=df_labels[\"dicom_file\"].values[:5], output_dir=EMBEDDINGS_DIR,\n",
" input_type=InputFileType.DICOM, output_type=OutputFileType.NPZ)"
" input_type=InputFileType.DICOM, output_type=OutputFileType.NPZ, model_version=MODEL_VERSION)"
]
},
{
Expand All @@ -318,7 +354,7 @@
"filename = df_labels[\"embedding_file\"][0].replace(\"tfrecord\", \"npz\")\n",
"values = embeddings_data.read_npz_values(filename)\n",
"\n",
"print(values)\n",
"print(values.shape)\n",
"\n",
"# NOTE: The rest of the notebook will use the .tfrecord data"
]
Expand All @@ -333,7 +369,7 @@
"source": [
"# Generate all the embedings in .tfrecord format\n",
"generate_embeddings(input_files=df_labels[\"dicom_file\"].values, output_dir=EMBEDDINGS_DIR,\n",
" input_type=InputFileType.DICOM, output_type=OutputFileType.TFRECORD)"
" input_type=InputFileType.DICOM, output_type=OutputFileType.TFRECORD, model_version=MODEL_VERSION)"
]
},
{
Expand All @@ -355,7 +391,7 @@
"# you don't use Tensorflow, you can use the following function to read\n",
"# the values directly into a numpy array.\n",
"values = embeddings_data.read_tfrecord_values(filename)\n",
"print(values)"
"print(values.shape)"
]
},
{
Expand Down Expand Up @@ -407,6 +443,9 @@
"\n",
"\n",
"def create_model(heads,\n",
" # token_num=1 for original CXR foundation embedding\n",
" # token_num=32 for ELIXR embedding\n",
" token_num=1,\n",
" embeddings_size=1376,\n",
" learning_rate=0.1,\n",
" end_lr_factor=1.0,\n",
Expand All @@ -421,8 +460,10 @@
"\n",
"\n",
" \"\"\"\n",
" inputs = tf.keras.Input(shape=(embeddings_size,))\n",
" hidden = inputs\n",
" inputs = tf.keras.Input(shape=(token_num * embeddings_size,))\n",
" inputs_reshape = tf.keras.layers.Reshape((token_num, embeddings_size))(inputs)\n",
" inputs_pooled = tf.keras.layers.GlobalAveragePooling1D(data_format='channels_last')(inputs_reshape)\n",
" hidden = inputs_pooled\n",
" # If no hidden_layer_sizes are provided, model will be a linear probe.\n",
" for size in hidden_layer_sizes:\n",
" hidden = tf.keras.layers.Dense(\n",
Expand Down Expand Up @@ -475,15 +516,23 @@
"outputs": [],
"source": [
"# Create training and validation Datasets\n",
"training_data = embeddings_data.get_dataset(filenames=df_train[\"embedding_file\"].values,\n",
" labels=df_train[DIAGNOSIS].values)\n",
"training_data = embeddings_data.get_dataset(\n",
" filenames=df_train[\"embedding_file\"].values,\n",
" labels=df_train[DIAGNOSIS].values,\n",
" embeddings_size=TOKEN_NUM * EMBEDDINGS_SIZE)\n",
"\n",
"\n",
"validation_data = embeddings_data.get_dataset(filenames=df_validate[\"embedding_file\"].values,\n",
" labels=df_validate[DIAGNOSIS].values)\n",
"validation_data = embeddings_data.get_dataset(\n",
" filenames=df_validate[\"embedding_file\"].values,\n",
" labels=df_validate[DIAGNOSIS].values,\n",
" embeddings_size=TOKEN_NUM * EMBEDDINGS_SIZE)\n",
"\n",
"# Create and train the model\n",
"model = create_model([DIAGNOSIS])\n",
"model = create_model(\n",
" [DIAGNOSIS],\n",
" token_num=TOKEN_NUM,\n",
" embeddings_size = EMBEDDINGS_SIZE,\n",
")\n",
"\n",
"model.fit(\n",
" x=training_data.batch(512).prefetch(tf.data.AUTOTUNE).cache(),\n",
Expand Down Expand Up @@ -576,6 +625,7 @@
},
"outputs": [],
"source": [
"%matplotlib inline\n",
"labels = eval_df[f'{DIAGNOSIS}_value'].values\n",
"predictions = eval_df[f'{DIAGNOSIS}_prediction'].values\n",
"false_positive_rate, true_positive_rate, thresholds = sklearn.metrics.roc_curve(\n",
Expand All @@ -589,9 +639,6 @@
],
"metadata": {
"colab": {
"collapsed_sections": [
"8LpEO7UrU9eS"
],
"provenance": []
},
"kernelspec": {
Expand Down Expand Up @@ -619,4 +666,4 @@
},
"nbformat": 4,
"nbformat_minor": 0
}
}
11 changes: 7 additions & 4 deletions cxr-foundation/cxr_foundation/embeddings_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def read_npz_values(filename: str) -> np.ndarray:

def parse_serialized_example_values(
serialized_example: bytes,
embeddings_size: int = constants.DEFAULT_EMBEDDINGS_SIZE
) -> tf.Tensor:
"""Parses and extracts the embeddings values from a serialized tf.Example generated by the CXR foundation service.
Expand All @@ -140,10 +141,10 @@ def parse_serialized_example_values(
"""
features = {
constants.EMBEDDING_KEY: tf.io.FixedLenFeature(
[constants.DEFAULT_EMBEDDINGS_SIZE],
[embeddings_size],
tf.float32,
default_value=tf.constant(
0.0, shape=[constants.DEFAULT_EMBEDDINGS_SIZE]
0.0, shape=[embeddings_size]
),
)
}
Expand All @@ -152,7 +153,9 @@ def parse_serialized_example_values(


def get_dataset(
filenames: Iterable[str], labels: Iterable[int]
filenames: Iterable[str],
labels: Iterable[int],
embeddings_size: int = constants.DEFAULT_EMBEDDINGS_SIZE
) -> tf.data.Dataset:
"""Create a tf.data.Dataset from the specified tfrecord files and labels.
Expand All @@ -165,7 +168,7 @@ def get_dataset(
"""
ds_embeddings = tf.data.TFRecordDataset(
filenames, num_parallel_reads=tf.data.AUTOTUNE
).map(parse_serialized_example_values)
).map(lambda x: parse_serialized_example_values(x, embeddings_size))
ds_labels = tf.data.Dataset.from_tensor_slices(labels)

return tf.data.Dataset.zip((ds_embeddings, ds_labels))

0 comments on commit 11d4af8

Please sign in to comment.