diff --git a/examples/ML+DL-Examples/Spark-DL/dl_inference/README.md b/examples/ML+DL-Examples/Spark-DL/dl_inference/README.md index 782db87c..d704f299 100644 --- a/examples/ML+DL-Examples/Spark-DL/dl_inference/README.md +++ b/examples/ML+DL-Examples/Spark-DL/dl_inference/README.md @@ -37,10 +37,10 @@ Here is a full list of the notebooks with their published example links: | ------------- | ------------- | ------------- | ------------- | ------------- | 1 | PyTorch | Image Classification | Training a model to predict clothing categories in FashionMNIST, including accelerated inference with Torch-TensorRT. | [Link](https://pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html) | 2 | PyTorch | Regression | Training a model to predict housing prices in the California Housing Dataset, including accelerated inference with Torch-TensorRT. | [Link](https://github.com/christianversloot/machine-learning-articles/blob/main/how-to-create-a-neural-network-for-regression-with-pytorch.md) -| 3 | Tensorflow | Image Classification | Training a model to predict hand-written digits in MNIST. | [Link](https://www.tensorflow.org/tutorials/keras/save_and_load) -| 4 | Tensorflow | Feature Columns | Training a model with preprocessing layers to predict likelihood of pet adoption in the PetFinder mini dataset. | [Link](https://www.tensorflow.org/tutorials/structured_data/preprocessing_layers) +| 3 | Tensorflow | Image Classification | Training a model to predict hand-written digits in MNIST. | [Link](https://github.com/tensorflow/docs/blob/master/site/en/tutorials/keras/save_and_load.ipynb) +| 4 | Tensorflow | Feature Columns | Training a model with preprocessing layers to predict likelihood of pet adoption in the PetFinder mini dataset. | [Link](https://github.com/tensorflow/docs/blob/master/site/en/tutorials/structured_data/preprocessing_layers.ipynb) | 5 | Tensorflow | Keras Metadata | Training ResNet-50 to perform flower recognition on Databricks. | [Link](https://docs.databricks.com/en/_extras/notebooks/source/deep-learning/keras-metadata.html) -| 6 | Tensorflow | Text Classification | Training a model to perform sentiment analysis on the IMDB dataset. | [Link](https://www.tensorflow.org/tutorials/keras/text_classification) +| 6 | Tensorflow | Text Classification | Training a model to perform sentiment analysis on the IMDB dataset. | [Link](https://github.com/tensorflow/docs/blob/master/site/en/tutorials/keras/text_classification.ipynb) | 7+8 | HuggingFace | Conditional Generation | Sentence translation using the T5 text-to-text transformer, with notebooks demoing both Torch and Tensorflow. | [Link](https://huggingface.co/docs/transformers/model_doc/t5#t5) | 9+10 | HuggingFace | Pipelines | Sentiment analysis using Huggingface pipelines, with notebooks demoing both Torch and Tensorflow. | [Link](https://huggingface.co/docs/transformers/quicktour#pipeline-usage) | 11 | HuggingFace | Sentence Transformers | Sentence embeddings using the SentenceTransformers framework in Torch. | [Link](https://huggingface.co/sentence-transformers) diff --git a/examples/ML+DL-Examples/Spark-DL/dl_inference/pytorch/image_classification_torch.ipynb b/examples/ML+DL-Examples/Spark-DL/dl_inference/pytorch/image_classification_torch.ipynb index 47489dcc..ea07ce28 100644 --- a/examples/ML+DL-Examples/Spark-DL/dl_inference/pytorch/image_classification_torch.ipynb +++ b/examples/ML+DL-Examples/Spark-DL/dl_inference/pytorch/image_classification_torch.ipynb @@ -263,78 +263,78 @@ "text": [ "Epoch 1\n", "-------------------------------\n", - "loss: 2.299719 [ 64/60000]\n", - "loss: 2.293332 [ 6464/60000]\n", - "loss: 2.269917 [12864/60000]\n", - "loss: 2.260744 [19264/60000]\n", - "loss: 2.247810 [25664/60000]\n", - "loss: 2.222256 [32064/60000]\n", - "loss: 2.225422 [38464/60000]\n", - "loss: 2.195026 [44864/60000]\n", - "loss: 2.194622 [51264/60000]\n", - "loss: 2.158175 [57664/60000]\n", + "loss: 2.301038 [ 64/60000]\n", + "loss: 2.289769 [ 6464/60000]\n", + "loss: 2.268618 [12864/60000]\n", + "loss: 2.264085 [19264/60000]\n", + "loss: 2.244277 [25664/60000]\n", + "loss: 2.209504 [32064/60000]\n", + "loss: 2.220515 [38464/60000]\n", + "loss: 2.185288 [44864/60000]\n", + "loss: 2.186121 [51264/60000]\n", + "loss: 2.149065 [57664/60000]\n", "Test Error: \n", - " Accuracy: 47.1%, Avg loss: 2.153042 \n", + " Accuracy: 37.8%, Avg loss: 2.151644 \n", "\n", "Epoch 2\n", "-------------------------------\n", - "loss: 2.162534 [ 64/60000]\n", - "loss: 2.154336 [ 6464/60000]\n", - "loss: 2.091042 [12864/60000]\n", - "loss: 2.104471 [19264/60000]\n", - "loss: 2.054451 [25664/60000]\n", - "loss: 2.001035 [32064/60000]\n", - "loss: 2.025180 [38464/60000]\n", - "loss: 1.949615 [44864/60000]\n", - "loss: 1.957106 [51264/60000]\n", - "loss: 1.876436 [57664/60000]\n", + "loss: 2.164946 [ 64/60000]\n", + "loss: 2.157853 [ 6464/60000]\n", + "loss: 2.100765 [12864/60000]\n", + "loss: 2.117897 [19264/60000]\n", + "loss: 2.058581 [25664/60000]\n", + "loss: 1.995217 [32064/60000]\n", + "loss: 2.026708 [38464/60000]\n", + "loss: 1.948186 [44864/60000]\n", + "loss: 1.959582 [51264/60000]\n", + "loss: 1.881658 [57664/60000]\n", "Test Error: \n", - " Accuracy: 54.6%, Avg loss: 1.876885 \n", + " Accuracy: 52.5%, Avg loss: 1.886264 \n", "\n", "Epoch 3\n", "-------------------------------\n", - "loss: 1.906243 [ 64/60000]\n", - "loss: 1.879715 [ 6464/60000]\n", - "loss: 1.758657 [12864/60000]\n", - "loss: 1.795318 [19264/60000]\n", - "loss: 1.692177 [25664/60000]\n", - "loss: 1.652430 [32064/60000]\n", - "loss: 1.669603 [38464/60000]\n", - "loss: 1.583420 [44864/60000]\n", - "loss: 1.603508 [51264/60000]\n", - "loss: 1.493881 [57664/60000]\n", + "loss: 1.922469 [ 64/60000]\n", + "loss: 1.893279 [ 6464/60000]\n", + "loss: 1.780482 [12864/60000]\n", + "loss: 1.822908 [19264/60000]\n", + "loss: 1.696129 [25664/60000]\n", + "loss: 1.653140 [32064/60000]\n", + "loss: 1.675662 [38464/60000]\n", + "loss: 1.584822 [44864/60000]\n", + "loss: 1.609127 [51264/60000]\n", + "loss: 1.500899 [57664/60000]\n", "Test Error: \n", - " Accuracy: 61.9%, Avg loss: 1.514976 \n", + " Accuracy: 60.3%, Avg loss: 1.521902 \n", "\n", "Epoch 4\n", "-------------------------------\n", - "loss: 1.573342 [ 64/60000]\n", - "loss: 1.548722 [ 6464/60000]\n", - "loss: 1.402007 [12864/60000]\n", - "loss: 1.461628 [19264/60000]\n", - "loss: 1.353920 [25664/60000]\n", - "loss: 1.358175 [32064/60000]\n", - "loss: 1.361608 [38464/60000]\n", - "loss: 1.302804 [44864/60000]\n", - "loss: 1.330850 [51264/60000]\n", - "loss: 1.224925 [57664/60000]\n", + "loss: 1.593910 [ 64/60000]\n", + "loss: 1.555975 [ 6464/60000]\n", + "loss: 1.412051 [12864/60000]\n", + "loss: 1.480928 [19264/60000]\n", + "loss: 1.348195 [25664/60000]\n", + "loss: 1.352939 [32064/60000]\n", + "loss: 1.361179 [38464/60000]\n", + "loss: 1.298819 [44864/60000]\n", + "loss: 1.325064 [51264/60000]\n", + "loss: 1.226879 [57664/60000]\n", "Test Error: \n", - " Accuracy: 63.8%, Avg loss: 1.254037 \n", + " Accuracy: 63.2%, Avg loss: 1.254962 \n", "\n", "Epoch 5\n", "-------------------------------\n", - "loss: 1.321162 [ 64/60000]\n", - "loss: 1.315946 [ 6464/60000]\n", - "loss: 1.152864 [12864/60000]\n", - "loss: 1.244943 [19264/60000]\n", - "loss: 1.130193 [25664/60000]\n", - "loss: 1.160290 [32064/60000]\n", - "loss: 1.168214 [38464/60000]\n", - "loss: 1.123758 [44864/60000]\n", - "loss: 1.158085 [51264/60000]\n", - "loss: 1.063427 [57664/60000]\n", + "loss: 1.337471 [ 64/60000]\n", + "loss: 1.314826 [ 6464/60000]\n", + "loss: 1.155245 [12864/60000]\n", + "loss: 1.257553 [19264/60000]\n", + "loss: 1.123370 [25664/60000]\n", + "loss: 1.155071 [32064/60000]\n", + "loss: 1.168100 [38464/60000]\n", + "loss: 1.119365 [44864/60000]\n", + "loss: 1.149572 [51264/60000]\n", + "loss: 1.067573 [57664/60000]\n", "Test Error: \n", - " Accuracy: 65.1%, Avg loss: 1.089330 \n", + " Accuracy: 64.4%, Avg loss: 1.090368 \n", "\n", "Done!\n" ] @@ -531,7 +531,32 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 24, + "id": "362b266b", + "metadata": {}, + "outputs": [], + "source": [ + "import torch_tensorrt as trt\n", + "import time" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "f0ac1362", + "metadata": {}, + "outputs": [], + "source": [ + "# Optional: set the filename for the TensorRT timing cache\n", + "timestamp = time.time()\n", + "timing_cache = f\"/tmp/timing_cache-{timestamp}.bin\"\n", + "with open(timing_cache, \"wb\") as f:\n", + " pass" + ] + }, + { + "cell_type": "code", + "execution_count": 20, "id": "f3e3bdc4", "metadata": {}, "outputs": [ @@ -539,10 +564,9 @@ "name": "stderr", "output_type": "stream", "text": [ - "WARNING:torch_tensorrt.dynamo.conversion.aten_ops_converters:Unable to import quantization op. Please install modelopt library (https://github.com/NVIDIA/TensorRT-Model-Optimizer?tab=readme-ov-file#installation) to add support for compiling quantized models\n", "INFO:torch_tensorrt.dynamo.utils:Using Default Torch-TRT Runtime (as requested by user)\n", "INFO:torch_tensorrt.dynamo.utils:Device not specified, using Torch default current device - cuda:0. If this is incorrect, please specify an input device, via the device keyword.\n", - "INFO:torch_tensorrt.dynamo.utils:Compilation Settings: CompilationSettings(enabled_precisions={}, debug=False, workspace_size=0, min_block_size=5, torch_executed_ops=set(), pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_double=False, use_fast_partitioner=True, enable_experimental_decompositions=False, device=Device(type=DeviceType.GPU, gpu_id=0), require_full_compilation=False, disable_tf32=False, assume_dynamic_shape_support=False, sparse_weights=False, refit=False, engine_capability=, num_avg_timing_iters=1, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=536870912, dryrun=False, hardware_compatible=False, timing_cache_path='/tmp/timing_cache.bin')\n", + "INFO:torch_tensorrt.dynamo.utils:Compilation Settings: CompilationSettings(enabled_precisions={}, debug=False, workspace_size=0, min_block_size=5, torch_executed_ops=set(), pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_double=False, use_fast_partitioner=True, enable_experimental_decompositions=False, device=Device(type=DeviceType.GPU, gpu_id=0), require_full_compilation=False, disable_tf32=False, assume_dynamic_shape_support=False, sparse_weights=False, refit=False, engine_capability=, num_avg_timing_iters=1, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=536870912, dryrun=False, hardware_compatible=False, timing_cache_path='/tmp/timing_cache-1729187850.4862776.bin')\n", "\n", "WARNING:torch_tensorrt.dynamo._compiler:Node _param_constant1 of op type get_attr does not have metadata. This could sometimes lead to undefined behavior.\n", "WARNING:torch_tensorrt.dynamo._compiler:Some nodes do not have metadata (shape and dtype information). This could lead to problems sometimes if the graph has PyTorch and TensorRT segments.\n", @@ -552,23 +576,23 @@ "WARNING:py.warnings:/home/rishic/anaconda3/envs/spark-dl-torch/lib/python3.11/site-packages/torch_tensorrt/dynamo/conversion/impl/activation/base.py:40: DeprecationWarning: Use Deprecated in TensorRT 10.1. Superseded by explicit quantization. instead.\n", " if input_val.dynamic_range is not None and dyn_range_fn is not None:\n", "\n", - "INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT INetwork construction elapsed time: 0:00:00.005708\n", + "INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT INetwork construction elapsed time: 0:00:00.005662\n", "INFO:torch_tensorrt [TensorRT Conversion Context]:Global timing cache in use. Profiling results in this builder pass will be stored.\n", "INFO:torch_tensorrt [TensorRT Conversion Context]:Detected 1 inputs and 1 output network tensors.\n", "INFO:torch_tensorrt [TensorRT Conversion Context]:Total Host Persistent Memory: 21984\n", "INFO:torch_tensorrt [TensorRT Conversion Context]:Total Device Persistent Memory: 0\n", "INFO:torch_tensorrt [TensorRT Conversion Context]:Total Scratch Memory: 0\n", "INFO:torch_tensorrt [TensorRT Conversion Context]:[BlockAssignment] Started assigning block shifts. This will take 4 steps to complete.\n", - "INFO:torch_tensorrt [TensorRT Conversion Context]:[BlockAssignment] Algorithm ShiftNTopDown took 0.257559ms to assign 2 blocks to 4 nodes requiring 4096 bytes.\n", + "INFO:torch_tensorrt [TensorRT Conversion Context]:[BlockAssignment] Algorithm ShiftNTopDown took 0.115746ms to assign 2 blocks to 4 nodes requiring 4096 bytes.\n", "INFO:torch_tensorrt [TensorRT Conversion Context]:Total Activation Memory: 4096\n", "INFO:torch_tensorrt [TensorRT Conversion Context]:Total Weights Memory: 2678824\n", - "INFO:torch_tensorrt [TensorRT Conversion Context]:Engine generation completed in 0.023755 seconds.\n", + "INFO:torch_tensorrt [TensorRT Conversion Context]:Engine generation completed in 1.58824 seconds.\n", "INFO:torch_tensorrt [TensorRT Conversion Context]:[MemUsageStats] Peak memory usage of TRT CPU/GPU memory allocators: CPU 1 MiB, GPU 5 MiB\n", - "INFO:torch_tensorrt [TensorRT Conversion Context]:[MemUsageStats] Peak memory usage during Engine building and serialization: CPU: 3800 MiB\n", - "INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Build TRT engine elapsed time: 0:00:00.027501\n", + "INFO:torch_tensorrt [TensorRT Conversion Context]:[MemUsageStats] Peak memory usage during Engine building and serialization: CPU: 3950 MiB\n", + "INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Build TRT engine elapsed time: 0:00:01.591865\n", "INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT Engine uses: 2832188 bytes of Memory\n", "INFO:torch_tensorrt [TensorRT Conversion Context]:Serialized 26 bytes of code generator cache.\n", - "INFO:torch_tensorrt [TensorRT Conversion Context]:Serialized 176 timing cache entries\n" + "INFO:torch_tensorrt [TensorRT Conversion Context]:Serialized 43 timing cache entries\n" ] }, { @@ -580,8 +604,6 @@ } ], "source": [ - "import torch_tensorrt as trt\n", - "\n", "inputs_bs1 = torch.randn((1, 784), dtype=torch.float).to(\"cuda\")\n", "# This indicates dimension 0 of inputs_bs1 is dynamic whose range of values is [1, 50]. No recompilation will happen when the batch size changes.\n", "torch._dynamo.mark_dynamic(inputs_bs1, 0, min=1, max=64)\n", @@ -590,6 +612,7 @@ " ir=\"torch_compile\",\n", " inputs=inputs_bs1,\n", " enabled_precisions={torch.float},\n", + " timing_cache_path=timing_cache,\n", ")\n", "\n", "stream = torch.cuda.Stream()\n", @@ -612,7 +635,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 21, "id": "6b8f1b45", "metadata": {}, "outputs": [ @@ -620,31 +643,31 @@ "name": "stderr", "output_type": "stream", "text": [ - "INFO:torch_tensorrt.dynamo._compiler:Compilation Settings: CompilationSettings(enabled_precisions={}, debug=False, workspace_size=0, min_block_size=5, torch_executed_ops=set(), pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_double=False, use_fast_partitioner=True, enable_experimental_decompositions=False, device=Device(type=DeviceType.GPU, gpu_id=0), require_full_compilation=True, disable_tf32=False, assume_dynamic_shape_support=False, sparse_weights=False, refit=False, engine_capability=, num_avg_timing_iters=1, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=536870912, dryrun=False, hardware_compatible=False, timing_cache_path='/tmp/timing_cache.bin')\n", + "INFO:torch_tensorrt.dynamo._compiler:Compilation Settings: CompilationSettings(enabled_precisions={}, debug=False, workspace_size=0, min_block_size=5, torch_executed_ops=set(), pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_double=False, use_fast_partitioner=True, enable_experimental_decompositions=False, device=Device(type=DeviceType.GPU, gpu_id=0), require_full_compilation=True, disable_tf32=False, assume_dynamic_shape_support=False, sparse_weights=False, refit=False, engine_capability=, num_avg_timing_iters=1, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=536870912, dryrun=False, hardware_compatible=False, timing_cache_path='/tmp/timing_cache-1729187850.4862776.bin')\n", "\n", "INFO:torch_tensorrt.dynamo._compiler:Partitioning the graph via the fast partitioner\n", - "INFO:torch_tensorrt [TensorRT Conversion Context]:[MemUsageChange] Init CUDA: CPU +0, GPU +0, now: CPU 624, GPU 715 (MiB)\n", - "INFO:torch_tensorrt [TensorRT Conversion Context]:[MemUsageChange] Init builder kernel library: CPU +1632, GPU +286, now: CPU 2256, GPU 1001 (MiB)\n", + "INFO:torch_tensorrt [TensorRT Conversion Context]:[MemUsageChange] Init CUDA: CPU +0, GPU +0, now: CPU 758, GPU 715 (MiB)\n", + "INFO:torch_tensorrt [TensorRT Conversion Context]:[MemUsageChange] Init builder kernel library: CPU +1633, GPU +286, now: CPU 2391, GPU 1001 (MiB)\n", "WARNING:py.warnings:/home/rishic/anaconda3/envs/spark-dl-torch/lib/python3.11/site-packages/torch_tensorrt/dynamo/conversion/impl/activation/base.py:40: DeprecationWarning: Use Deprecated in TensorRT 10.1. Superseded by explicit quantization. instead.\n", " if input_val.dynamic_range is not None and dyn_range_fn is not None:\n", "\n", - "INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT INetwork construction elapsed time: 0:00:00.004551\n", + "INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT INetwork construction elapsed time: 0:00:00.004664\n", "INFO:torch_tensorrt [TensorRT Conversion Context]:Global timing cache in use. Profiling results in this builder pass will be stored.\n", "INFO:torch_tensorrt [TensorRT Conversion Context]:Detected 1 inputs and 1 output network tensors.\n", "INFO:torch_tensorrt [TensorRT Conversion Context]:Total Host Persistent Memory: 21984\n", "INFO:torch_tensorrt [TensorRT Conversion Context]:Total Device Persistent Memory: 0\n", "INFO:torch_tensorrt [TensorRT Conversion Context]:Total Scratch Memory: 0\n", "INFO:torch_tensorrt [TensorRT Conversion Context]:[BlockAssignment] Started assigning block shifts. This will take 4 steps to complete.\n", - "INFO:torch_tensorrt [TensorRT Conversion Context]:[BlockAssignment] Algorithm ShiftNTopDown took 0.133258ms to assign 2 blocks to 4 nodes requiring 4096 bytes.\n", + "INFO:torch_tensorrt [TensorRT Conversion Context]:[BlockAssignment] Algorithm ShiftNTopDown took 0.113766ms to assign 2 blocks to 4 nodes requiring 4096 bytes.\n", "INFO:torch_tensorrt [TensorRT Conversion Context]:Total Activation Memory: 4096\n", "INFO:torch_tensorrt [TensorRT Conversion Context]:Total Weights Memory: 2678824\n", - "INFO:torch_tensorrt [TensorRT Conversion Context]:Engine generation completed in 0.0190609 seconds.\n", + "INFO:torch_tensorrt [TensorRT Conversion Context]:Engine generation completed in 0.022595 seconds.\n", "INFO:torch_tensorrt [TensorRT Conversion Context]:[MemUsageStats] Peak memory usage of TRT CPU/GPU memory allocators: CPU 1 MiB, GPU 5 MiB\n", - "INFO:torch_tensorrt [TensorRT Conversion Context]:[MemUsageStats] Peak memory usage during Engine building and serialization: CPU: 3818 MiB\n", - "INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Build TRT engine elapsed time: 0:00:00.021306\n", + "INFO:torch_tensorrt [TensorRT Conversion Context]:[MemUsageStats] Peak memory usage during Engine building and serialization: CPU: 3968 MiB\n", + "INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Build TRT engine elapsed time: 0:00:00.025016\n", "INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT Engine uses: 2833124 bytes of Memory\n", "INFO:torch_tensorrt [TensorRT Conversion Context]:Serialized 26 bytes of code generator cache.\n", - "INFO:torch_tensorrt [TensorRT Conversion Context]:Serialized 176 timing cache entries\n" + "INFO:torch_tensorrt [TensorRT Conversion Context]:Serialized 43 timing cache entries\n" ] }, { @@ -662,7 +685,9 @@ "# Produce traced graph in the ExportedProgram format\n", "exp_program = trt.dynamo.trace(model_from_state, inputs)\n", "# Compile the traced graph to produce an optimized module\n", - "trt_gm = trt.dynamo.compile(exp_program, inputs=inputs, require_full_compilation=True)\n", + "trt_gm = trt.dynamo.compile(exp_program, \n", + " inputs=inputs, \n", + " timing_cache_path=timing_cache)\n", "\n", "stream = torch.cuda.Stream()\n", "with torch.no_grad(), torch.cuda.stream(stream):\n", @@ -681,22 +706,10 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 23, "id": "d87e4b20", "metadata": {}, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING:py.warnings:/home/rishic/anaconda3/envs/spark-dl-torch/lib/python3.11/site-packages/torch_tensorrt/dynamo/_exporter.py:364: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer\n", - " engine_node = gm.graph.get_attr(engine_name)\n", - "\n", - "WARNING:py.warnings:/home/rishic/anaconda3/envs/spark-dl-torch/lib/python3.11/site-packages/torch/fx/graph.py:1545: UserWarning: Node _run_on_acc_0_engine target _run_on_acc_0_engine _run_on_acc_0_engine of does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target\n", - " warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does '\n", - "\n" - ] - }, { "name": "stdout", "output_type": "stream", diff --git a/examples/ML+DL-Examples/Spark-DL/dl_inference/pytorch/regression_torch.ipynb b/examples/ML+DL-Examples/Spark-DL/dl_inference/pytorch/regression_torch.ipynb index d2fd9157..5ccc22ec 100644 --- a/examples/ML+DL-Examples/Spark-DL/dl_inference/pytorch/regression_torch.ipynb +++ b/examples/ML+DL-Examples/Spark-DL/dl_inference/pytorch/regression_torch.ipynb @@ -539,6 +539,31 @@ "(You may see a warning about modelopt quantization. This is safe to ignore, as [implicit quantization](https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#intro-quantization) is deprecated in the latest TensorRT. See [this link](https://pytorch.org/TensorRT/tutorials/_rendered_examples/dynamo/vgg16_fp8_ptq.html) for a guide to explicit quantization.)" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "9ffb27fc", + "metadata": {}, + "outputs": [], + "source": [ + "import torch_tensorrt as trt\n", + "import time" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e0c10f90", + "metadata": {}, + "outputs": [], + "source": [ + "# Optional: set the filename for the TensorRT timing cache\n", + "timestamp = time.time()\n", + "timing_cache = f\"/tmp/timing_cache-{timestamp}.bin\"\n", + "with open(timing_cache, \"wb\") as f:\n", + " pass" + ] + }, { "cell_type": "code", "execution_count": 20, @@ -599,8 +624,6 @@ } ], "source": [ - "import torch_tensorrt as trt\n", - "\n", "inputs_bs1 = torch.randn((10, 8), dtype=torch.float).to(\"cuda\")\n", "# This indicates dimension 0 of inputs_bs1 is dynamic whose range of values is [1, 50]. No recompilation will happen when the batch size changes.\n", "torch._dynamo.mark_dynamic(inputs_bs1, 0, min=1, max=50)\n", @@ -609,6 +632,7 @@ " ir=\"torch_compile\",\n", " inputs=inputs_bs1,\n", " enabled_precisions={torch.float},\n", + " timing_cache_path=timing_cache,\n", ")\n", "\n", "stream = torch.cuda.Stream()\n", @@ -719,7 +743,9 @@ "# Produce traced graph in the ExportedProgram format\n", "exp_program = trt.dynamo.trace(loaded_mlp, inputs)\n", "# Compile the traced graph to produce an optimized module\n", - "trt_gm = trt.dynamo.compile(exp_program, inputs=inputs, device='cuda:0')\n", + "trt_gm = trt.dynamo.compile(exp_program,\n", + " inputs=inputs,\n", + " timing_cache_path=timing_cache)\n", "\n", "stream = torch.cuda.Stream()\n", "with torch.no_grad(), torch.cuda.stream(stream):\n", diff --git a/examples/ML+DL-Examples/Spark-DL/dl_inference/tensorflow/image_classification.ipynb b/examples/ML+DL-Examples/Spark-DL/dl_inference/tensorflow/image_classification.ipynb deleted file mode 100644 index 5add2686..00000000 --- a/examples/ML+DL-Examples/Spark-DL/dl_inference/tensorflow/image_classification.ipynb +++ /dev/null @@ -1,2460 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "52d55e3f", - "metadata": {}, - "source": [ - "# Pyspark TensorFlow Inference\n", - "\n", - "## Image classification\n", - "Based on: https://www.tensorflow.org/tutorials/keras/save_and_load" - ] - }, - { - "cell_type": "markdown", - "id": "5233632d", - "metadata": {}, - "source": [ - "### Using TensorFlow\n", - "Note that cuFFT/cuDNN/cuBLAS registration errors are expected with `tf=2.17.0` and will not affect behavior, as noted in [this issue.](https://github.com/tensorflow/tensorflow/issues/62075) \n", - "This notebook does not demonstrate inference with TensorRT, as [TF-TRT](https://docs.nvidia.com/deeplearning/tensorrt/release-notes/index.html#tensorrt-10) does not yet support `tf=2.17.0`. See the `pytorch` notebooks for TensorRT demos." - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "c8b28f02", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2024-10-03 17:40:20.324462: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n", - "2024-10-03 17:40:20.331437: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", - "2024-10-03 17:40:20.339109: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", - "2024-10-03 17:40:20.341362: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n", - "2024-10-03 17:40:20.347337: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", - "To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", - "2024-10-03 17:40:20.672391: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2.17.0\n" - ] - } - ], - "source": [ - "import matplotlib.pyplot as plt\n", - "import numpy as np\n", - "import subprocess\n", - "import tensorflow as tf\n", - "import os\n", - "\n", - "from tensorflow import keras\n", - "\n", - "print(tf.version.VERSION)" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "e2e67086", - "metadata": {}, - "outputs": [], - "source": [ - "# Enable GPU memory growth\n", - "gpus = tf.config.experimental.list_physical_devices('GPU')\n", - "if gpus:\n", - " try:\n", - " for gpu in gpus:\n", - " tf.config.experimental.set_memory_growth(gpu, True)\n", - " except RuntimeError as e:\n", - " print(e)" - ] - }, - { - "cell_type": "markdown", - "id": "7e0c7ad6", - "metadata": {}, - "source": [ - "### Load and preprocess dataset" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "5b007f7c", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\u001b[1m11490434/11490434\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 0us/step\n" - ] - }, - { - "data": { - "text/plain": [ - "((60000, 28, 28), (10000, 28, 28))" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# load dataset as numpy arrays\n", - "(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()\n", - "train_images.shape, test_images.shape" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "7b7cedd1", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "((1000, 784), (1000, 784))" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "train_labels = train_labels[:1000]\n", - "test_labels = test_labels[:1000]\n", - "\n", - "train_images = train_images[:1000].reshape(-1, 28 * 28) / 255.0\n", - "test_images = test_images[:1000].reshape(-1, 28 * 28) / 255.0\n", - "\n", - "train_images.shape, test_images.shape" - ] - }, - { - "cell_type": "markdown", - "id": "867a4403", - "metadata": {}, - "source": [ - "### Define a model" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "746d94db", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/rishic/anaconda3/envs/spark-dl-tf/lib/python3.11/site-packages/keras/src/layers/core/dense.py:87: UserWarning: Do not pass an `input_shape`/`input_dim` argument to a layer. When using Sequential models, prefer using an `Input(shape)` object as the first layer in the model instead.\n", - " super().__init__(activity_regularizer=activity_regularizer, **kwargs)\n", - "2024-10-03 17:40:21.624052: I tensorflow/core/common_runtime/gpu/gpu_device.cc:2021] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 45743 MB memory: -> device: 0, name: NVIDIA RTX A6000, pci bus id: 0000:01:00.0, compute capability: 8.6\n" - ] - }, - { - "data": { - "text/html": [ - "
Model: \"sequential\"\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1mModel: \"sequential\"\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
-       "┃ Layer (type)                     Output Shape                  Param # ┃\n",
-       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
-       "│ dense (Dense)                   │ (None, 512)            │       401,920 │\n",
-       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
-       "│ dropout (Dropout)               │ (None, 512)            │             0 │\n",
-       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
-       "│ dense_1 (Dense)                 │ (None, 10)             │         5,130 │\n",
-       "└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
-       "
\n" - ], - "text/plain": [ - "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n", - "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n", - "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n", - "│ dense (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m401,920\u001b[0m │\n", - "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", - "│ dropout (\u001b[38;5;33mDropout\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n", - "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", - "│ dense_1 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m10\u001b[0m) │ \u001b[38;5;34m5,130\u001b[0m │\n", - "└─────────────────────────────────┴────────────────────────┴───────────────┘\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
 Total params: 407,050 (1.55 MB)\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m407,050\u001b[0m (1.55 MB)\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
 Trainable params: 407,050 (1.55 MB)\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m407,050\u001b[0m (1.55 MB)\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
 Non-trainable params: 0 (0.00 B)\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "# Define a simple sequential model\n", - "def create_model():\n", - " model = tf.keras.Sequential([\n", - " keras.layers.Dense(512, activation='relu', input_shape=(784,)),\n", - " keras.layers.Dropout(0.2),\n", - " keras.layers.Dense(10)\n", - " ])\n", - "\n", - " model.compile(optimizer='adam',\n", - " loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", - " metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])\n", - "\n", - " return model\n", - "\n", - "# Create a basic model instance\n", - "model = create_model()\n", - "\n", - "# Display the model's architecture\n", - "model.summary()" - ] - }, - { - "cell_type": "markdown", - "id": "605d082a", - "metadata": {}, - "source": [ - "### Save checkpoints during training" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "244746be", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 1/10\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n", - "I0000 00:00:1727977222.161202 1835280 service.cc:146] XLA service 0x7ec778008e00 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:\n", - "I0000 00:00:1727977222.161216 1835280 service.cc:154] StreamExecutor device (0): NVIDIA RTX A6000, Compute Capability 8.6\n", - "2024-10-03 17:40:22.168848: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.\n", - "2024-10-03 17:40:22.206298: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:531] Loaded cuDNN version 8907\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\u001b[1m 1/32\u001b[0m \u001b[37m━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[1m24s\u001b[0m 778ms/step - loss: 2.3278 - sparse_categorical_accuracy: 0.1250" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "I0000 00:00:1727977222.715572 1835280 device_compiler.h:188] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 17ms/step - loss: 1.5867 - sparse_categorical_accuracy: 0.5096 " - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2024-10-03 17:40:23.780912: I external/local_xla/xla/stream_executor/cuda/cuda_asm_compiler.cc:393] ptxas warning : Registers are spilled to local memory in function 'gemm_fusion_dot_33', 4 bytes spill stores, 4 bytes spill loads\n", - "\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "Epoch 1: val_sparse_categorical_accuracy improved from -inf to 0.78700, saving model to training_1/checkpoint.model.keras\n", - "\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 44ms/step - loss: 1.5733 - sparse_categorical_accuracy: 0.5144 - val_loss: 0.7061 - val_sparse_categorical_accuracy: 0.7870\n", - "Epoch 2/10\n", - "\u001b[1m 1/32\u001b[0m \u001b[37m━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - loss: 0.5514 - sparse_categorical_accuracy: 0.8438\n", - "Epoch 2: val_sparse_categorical_accuracy improved from 0.78700 to 0.83700, saving model to training_1/checkpoint.model.keras\n", - "\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 2ms/step - loss: 0.4276 - sparse_categorical_accuracy: 0.8935 - val_loss: 0.5268 - val_sparse_categorical_accuracy: 0.8370\n", - "Epoch 3/10\n", - "\u001b[1m 1/32\u001b[0m \u001b[37m━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[1m0s\u001b[0m 20ms/step - loss: 0.1458 - sparse_categorical_accuracy: 0.9688\n", - "Epoch 3: val_sparse_categorical_accuracy improved from 0.83700 to 0.85600, saving model to training_1/checkpoint.model.keras\n", - "\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 1ms/step - loss: 0.2721 - sparse_categorical_accuracy: 0.9236 - val_loss: 0.4716 - val_sparse_categorical_accuracy: 0.8560\n", - "Epoch 4/10\n", - "\u001b[1m 1/32\u001b[0m \u001b[37m━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - loss: 0.2223 - sparse_categorical_accuracy: 0.9375\n", - "Epoch 4: val_sparse_categorical_accuracy did not improve from 0.85600\n", - "\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 1ms/step - loss: 0.2159 - sparse_categorical_accuracy: 0.9547 - val_loss: 0.4682 - val_sparse_categorical_accuracy: 0.8540\n", - "Epoch 5/10\n", - "\u001b[1m 1/32\u001b[0m \u001b[37m━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[1m0s\u001b[0m 16ms/step - loss: 0.1483 - sparse_categorical_accuracy: 0.9688\n", - "Epoch 5: val_sparse_categorical_accuracy improved from 0.85600 to 0.86900, saving model to training_1/checkpoint.model.keras\n", - "\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 2ms/step - loss: 0.1457 - sparse_categorical_accuracy: 0.9716 - val_loss: 0.4285 - val_sparse_categorical_accuracy: 0.8690\n", - "Epoch 6/10\n", - "\u001b[1m 1/32\u001b[0m \u001b[37m━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - loss: 0.0836 - sparse_categorical_accuracy: 0.9688\n", - "Epoch 6: val_sparse_categorical_accuracy did not improve from 0.86900\n", - "\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 1ms/step - loss: 0.1292 - sparse_categorical_accuracy: 0.9712 - val_loss: 0.4551 - val_sparse_categorical_accuracy: 0.8580\n", - "Epoch 7/10\n", - "\u001b[1m 1/32\u001b[0m \u001b[37m━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[1m0s\u001b[0m 12ms/step - loss: 0.0920 - sparse_categorical_accuracy: 0.9688\n", - "Epoch 7: val_sparse_categorical_accuracy did not improve from 0.86900\n", - "\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 1ms/step - loss: 0.0974 - sparse_categorical_accuracy: 0.9822 - val_loss: 0.4016 - val_sparse_categorical_accuracy: 0.8670\n", - "Epoch 8/10\n", - "\u001b[1m 1/32\u001b[0m \u001b[37m━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[1m0s\u001b[0m 12ms/step - loss: 0.0993 - sparse_categorical_accuracy: 0.9688\n", - "Epoch 8: val_sparse_categorical_accuracy did not improve from 0.86900\n", - "\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 1ms/step - loss: 0.0702 - sparse_categorical_accuracy: 0.9920 - val_loss: 0.3999 - val_sparse_categorical_accuracy: 0.8650\n", - "Epoch 9/10\n", - "\u001b[1m 1/32\u001b[0m \u001b[37m━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[1m0s\u001b[0m 12ms/step - loss: 0.0599 - sparse_categorical_accuracy: 1.0000\n", - "Epoch 9: val_sparse_categorical_accuracy improved from 0.86900 to 0.87800, saving model to training_1/checkpoint.model.keras\n", - "\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 1ms/step - loss: 0.0457 - sparse_categorical_accuracy: 0.9974 - val_loss: 0.4145 - val_sparse_categorical_accuracy: 0.8780\n", - "Epoch 10/10\n", - "\u001b[1m 1/32\u001b[0m \u001b[37m━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - loss: 0.0286 - sparse_categorical_accuracy: 1.0000\n", - "Epoch 10: val_sparse_categorical_accuracy did not improve from 0.87800\n", - "\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 1ms/step - loss: 0.0351 - sparse_categorical_accuracy: 0.9987 - val_loss: 0.4200 - val_sparse_categorical_accuracy: 0.8720\n" - ] - }, - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "checkpoint_path = \"training_1/checkpoint.model.keras\"\n", - "checkpoint_dir = os.path.dirname(checkpoint_path)\n", - "\n", - "# Create a callback that saves the model's weights\n", - "cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,\n", - " monitor='val_sparse_categorical_accuracy',\n", - " mode='max',\n", - " save_best_only=True,\n", - " verbose=1)\n", - "\n", - "# Train the model with the new callback\n", - "model.fit(train_images, \n", - " train_labels, \n", - " epochs=10,\n", - " validation_data=(test_images, test_labels),\n", - " callbacks=[cp_callback]) # Pass callback to training\n", - "\n", - "# This may generate warnings related to saving the state of the optimizer.\n", - "# These warnings (and similar warnings throughout this notebook)\n", - "# are in place to discourage outdated usage, and can be ignored." - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "310eae08", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "['checkpoint.model.keras']" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "os.listdir(checkpoint_dir)" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "50eeb6e5", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "INFO:tensorflow:Assets written to: mnist_model/assets\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:tensorflow:Assets written to: mnist_model/assets\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Saved artifact at 'mnist_model'. The following endpoints are available:\n", - "\n", - "* Endpoint 'serve'\n", - " args_0 (POSITIONAL_ONLY): TensorSpec(shape=(None, 784), dtype=tf.float32, name='keras_tensor')\n", - "Output Type:\n", - " TensorSpec(shape=(None, 10), dtype=tf.float32, name=None)\n", - "Captures:\n", - " 139403584120848: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", - " 139403240100240: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", - " 139403240100048: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", - " 139403240099856: TensorSpec(shape=(), dtype=tf.resource, name=None)\n" - ] - } - ], - "source": [ - "# Export model in saved_model format\n", - "model.export(\"mnist_model\")" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "6d3bba9e", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/rishic/anaconda3/envs/spark-dl-tf/lib/python3.11/site-packages/keras/src/layers/core/dense.py:87: UserWarning: Do not pass an `input_shape`/`input_dim` argument to a layer. When using Sequential models, prefer using an `Input(shape)` object as the first layer in the model instead.\n", - " super().__init__(activity_regularizer=activity_regularizer, **kwargs)\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "32/32 - 0s - 10ms/step - loss: 2.4196 - sparse_categorical_accuracy: 0.0590\n", - "Untrained model, accuracy: 5.90%\n" - ] - } - ], - "source": [ - "# Create a basic model instance\n", - "model = create_model()\n", - "\n", - "# Evaluate the model\n", - "loss, acc = model.evaluate(test_images, test_labels, verbose=2)\n", - "print(\"Untrained model, accuracy: {:5.2f}%\".format(100 * acc))" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "22ad1708", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "32/32 - 0s - 713us/step - loss: 0.4145 - sparse_categorical_accuracy: 0.8780\n", - "Restored model, accuracy: 87.80%\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/rishic/anaconda3/envs/spark-dl-tf/lib/python3.11/site-packages/keras/src/saving/saving_lib.py:713: UserWarning: Skipping variable loading for optimizer 'adam', because it has 2 variables whereas the saved optimizer has 10 variables. \n", - " saveable.load_own_variables(weights_store.get(inner_path))\n" - ] - } - ], - "source": [ - "# Load the weights from the checkpoint\n", - "model.load_weights(checkpoint_path)\n", - "\n", - "# Re-evaluate the model\n", - "loss, acc = model.evaluate(test_images, test_labels, verbose=2)\n", - "print(\"Restored model, accuracy: {:5.2f}%\".format(100 * acc))" - ] - }, - { - "cell_type": "markdown", - "id": "1c097d63", - "metadata": {}, - "source": [ - "### Checkpoint callback options" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "id": "cb336e89", - "metadata": {}, - "outputs": [], - "source": [ - "!rm -rf training_2\n", - "!mkdir training_2" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "750b6deb", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "Epoch 5: saving model to training_2/cp-0005.weights.h5\n", - "\n", - "Epoch 10: saving model to training_2/cp-0010.weights.h5\n", - "\n", - "Epoch 15: saving model to training_2/cp-0015.weights.h5\n", - "\n", - "Epoch 20: saving model to training_2/cp-0020.weights.h5\n", - "\n", - "Epoch 25: saving model to training_2/cp-0025.weights.h5\n", - "\n", - "Epoch 30: saving model to training_2/cp-0030.weights.h5\n", - "\n", - "Epoch 35: saving model to training_2/cp-0035.weights.h5\n", - "\n", - "Epoch 40: saving model to training_2/cp-0040.weights.h5\n", - "\n", - "Epoch 45: saving model to training_2/cp-0045.weights.h5\n", - "\n", - "Epoch 50: saving model to training_2/cp-0050.weights.h5\n" - ] - }, - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 13, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# Include the epoch in the file name (uses `str.format`)\n", - "checkpoint_path = \"training_2/cp-{epoch:04d}.weights.h5\"\n", - "checkpoint_dir = os.path.dirname(checkpoint_path)\n", - "\n", - "batch_size = 32\n", - "\n", - "# Calculate the number of batches per epoch\n", - "import math\n", - "n_batches = len(train_images) / batch_size\n", - "n_batches = math.ceil(n_batches) # round up the number of batches to the nearest whole integer\n", - "\n", - "# Create a callback that saves the model's weights every 5 epochs\n", - "cp_callback = tf.keras.callbacks.ModelCheckpoint(\n", - " filepath=checkpoint_path, \n", - " verbose=1, \n", - " save_weights_only=True,\n", - " save_freq=5*n_batches)\n", - "\n", - "# Create a new model instance\n", - "model = create_model()\n", - "\n", - "# Save the weights using the `checkpoint_path` format\n", - "model.save_weights(checkpoint_path.format(epoch=0))\n", - "\n", - "# Train the model with the new callback\n", - "model.fit(train_images, \n", - " train_labels,\n", - " epochs=50, \n", - " batch_size=batch_size, \n", - " callbacks=[cp_callback],\n", - " validation_data=(test_images, test_labels),\n", - " verbose=0)" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "1c43fd3d", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "['cp-0000.weights.h5',\n", - " 'cp-0015.weights.h5',\n", - " 'cp-0010.weights.h5',\n", - " 'cp-0035.weights.h5',\n", - " 'cp-0020.weights.h5',\n", - " 'cp-0040.weights.h5',\n", - " 'cp-0050.weights.h5',\n", - " 'cp-0005.weights.h5',\n", - " 'cp-0045.weights.h5',\n", - " 'cp-0025.weights.h5',\n", - " 'cp-0030.weights.h5']" - ] - }, - "execution_count": 14, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "os.listdir(checkpoint_dir)" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "id": "0d7ae715", - "metadata": {}, - "outputs": [], - "source": [ - "latest = \"training_2/cp-0030.weights.h5\"" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "id": "d345c6f7", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "32/32 - 0s - 9ms/step - loss: 0.4501 - sparse_categorical_accuracy: 0.8720\n", - "Restored model, accuracy: 87.20%\n" - ] - } - ], - "source": [ - "# Create a new model instance\n", - "model = create_model()\n", - "\n", - "# Load the previously saved weights\n", - "model.load_weights(latest)\n", - "\n", - "# Re-evaluate the model from the latest checkpoint\n", - "loss, acc = model.evaluate(test_images, test_labels, verbose=2)\n", - "print(\"Restored model, accuracy: {:5.2f}%\".format(100 * acc))" - ] - }, - { - "cell_type": "markdown", - "id": "a86f4700", - "metadata": {}, - "source": [ - "## PySpark" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "id": "7fcf07bb", - "metadata": {}, - "outputs": [], - "source": [ - "import pandas as pd\n", - "from pyspark.sql import SparkSession\n", - "from pyspark import SparkConf" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "2c022c24", - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "conda_env = os.environ.get(\"CONDA_PREFIX\")\n", - "\n", - "conf = SparkConf()\n", - "if 'spark' not in globals():\n", - " # If Spark is not already started with Jupyter, attach to Spark Standalone\n", - " import socket\n", - " hostname = socket.gethostname()\n", - " conf.setMaster(f\"spark://{hostname}:7077\") # assuming Master is on default port 7077\n", - "conf.set(\"spark.task.maxFailures\", \"1\")\n", - "conf.set(\"spark.driver.memory\", \"8g\")\n", - "conf.set(\"spark.executor.memory\", \"8g\")\n", - "conf.set(\"spark.pyspark.python\", f\"{conda_env}/bin/python\")\n", - "conf.set(\"spark.pyspark.driver.python\", f\"{conda_env}/bin/python\")\n", - "conf.set(\"spark.sql.execution.pyspark.udf.simplifiedTraceback.enabled\", \"false\")\n", - "conf.set(\"spark.sql.pyspark.jvmStacktrace.enabled\", \"true\")\n", - "conf.set(\"spark.sql.execution.arrow.pyspark.enabled\", \"true\")\n", - "conf.set(\"spark.python.worker.reuse\", \"true\")\n", - "# Create Spark Session\n", - "spark = SparkSession.builder.appName(\"spark-dl-examples\").config(conf=conf).getOrCreate()\n", - "sc = spark.sparkContext" - ] - }, - { - "cell_type": "markdown", - "id": "c81d0b1b", - "metadata": {}, - "source": [ - "### Convert numpy array to Spark DataFrame (via Pandas DataFrame)" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "id": "49ff5203", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(1000, 784)" - ] - }, - "execution_count": 19, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# numpy array to pandas DataFrame\n", - "test_pdf = pd.DataFrame(test_images)\n", - "test_pdf.shape" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "id": "182ee0c7", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "CPU times: user 134 ms, sys: 15.5 ms, total: 149 ms\n", - "Wall time: 1.36 s\n" - ] - } - ], - "source": [ - "%%time\n", - "df = spark.createDataFrame(test_pdf).repartition(8)" - ] - }, - { - "cell_type": "markdown", - "id": "d4e1c7ec-64fa-43c4-9bcf-0868a401d1f2", - "metadata": {}, - "source": [ - "### Save as Parquet (784 columns of float)" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "id": "0061c39a-0871-429e-a4ff-751d26bf4b04", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "24/10/03 17:40:32 WARN SparkStringUtils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.\n", - "[Stage 0:> (0 + 8) / 8]\r" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "CPU times: user 2.49 ms, sys: 1.65 ms, total: 4.13 ms\n", - "Wall time: 1.66 s\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - " \r" - ] - } - ], - "source": [ - "%%time\n", - "df.write.mode(\"overwrite\").parquet(\"mnist_784\")" - ] - }, - { - "cell_type": "markdown", - "id": "18315afb-3fa2-4953-9297-52c04dd70c32", - "metadata": {}, - "source": [ - "### Save as Parquet (1 column of 784 float)" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "id": "302c73ec", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "CPU times: user 6.71 ms, sys: 4.92 ms, total: 11.6 ms\n", - "Wall time: 11.4 ms\n" - ] - }, - { - "data": { - "text/plain": [ - "(1000, 1)" - ] - }, - "execution_count": 22, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "%%time\n", - "test_pdf['data'] = test_pdf.values.tolist()\n", - "pdf = test_pdf[['data']]\n", - "pdf.shape" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "id": "5495901b", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "CPU times: user 46.6 ms, sys: 4.71 ms, total: 51.3 ms\n", - "Wall time: 91.7 ms\n" - ] - } - ], - "source": [ - "%%time\n", - "df = spark.createDataFrame(pdf)" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "id": "5fa7faa8-c6bd-41b0-b5f7-fb121f0332e6", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "CPU times: user 807 μs, sys: 724 μs, total: 1.53 ms\n", - "Wall time: 211 ms\n" - ] - } - ], - "source": [ - "%%time\n", - "df.write.mode(\"overwrite\").parquet(\"mnist_1\")" - ] - }, - { - "cell_type": "markdown", - "id": "c87b444e", - "metadata": {}, - "source": [ - "### Check arrow memory configuration" - ] - }, - { - "cell_type": "code", - "execution_count": 25, - "id": "3d4ca414", - "metadata": {}, - "outputs": [], - "source": [ - "spark.conf.set(\"spark.sql.execution.arrow.maxRecordsPerBatch\", \"128\")\n", - "# This line will fail if the vectorized reader runs out of memory\n", - "assert len(df.head()) > 0, \"`df` should not be empty\" " - ] - }, - { - "cell_type": "markdown", - "id": "9b6dde30-98a9-45db-ab3a-d4546f9bed99", - "metadata": {}, - "source": [ - "## Inference using Spark DL API" - ] - }, - { - "cell_type": "markdown", - "id": "4238fb28-d002-4b4d-9aa1-8af1fbd5d569", - "metadata": {}, - "source": [ - "### 1 column of 784 float" - ] - }, - { - "cell_type": "code", - "execution_count": 26, - "id": "db30fba6-24d0-4c00-8502-04f9b10e7e16", - "metadata": {}, - "outputs": [], - "source": [ - "import numpy as np\n", - "import os\n", - "import pandas as pd\n", - "\n", - "from pyspark.ml.functions import predict_batch_udf\n", - "from pyspark.sql.functions import array, col, struct\n", - "from pyspark.sql.types import ArrayType, FloatType, Union, Dict" - ] - }, - { - "cell_type": "code", - "execution_count": 27, - "id": "b9cf62f8-96b2-4716-80bd-bb93d5f939bd", - "metadata": {}, - "outputs": [], - "source": [ - "# get absolute path to model\n", - "model_dir = \"{}/training_1/checkpoint.model.keras\".format(os.getcwd())" - ] - }, - { - "cell_type": "code", - "execution_count": 28, - "id": "b81fa297-d9d0-4600-880d-dbdcdf8bccc6", - "metadata": {}, - "outputs": [], - "source": [ - "def predict_batch_fn():\n", - " import tensorflow as tf\n", - "\n", - " # Enable GPU memory growth to avoid CUDA OOM\n", - " gpus = tf.config.experimental.list_physical_devices('GPU')\n", - " if gpus:\n", - " try:\n", - " for gpu in gpus:\n", - " tf.config.experimental.set_memory_growth(gpu, True)\n", - " except RuntimeError as e:\n", - " print(e)\n", - "\n", - " model = tf.keras.models.load_model(model_dir)\n", - " def predict(inputs: np.ndarray) -> np.ndarray:\n", - " return model.predict(inputs)\n", - " \n", - " return predict" - ] - }, - { - "cell_type": "code", - "execution_count": 29, - "id": "72a689bd-dd82-492e-8740-1738a215325f", - "metadata": {}, - "outputs": [], - "source": [ - "mnist = predict_batch_udf(predict_batch_fn,\n", - " return_type=ArrayType(FloatType()),\n", - " batch_size=1024,\n", - " input_tensor_shapes=[[784]])" - ] - }, - { - "cell_type": "code", - "execution_count": 30, - "id": "60a70150-26b1-4145-9e7d-6e17389216b7", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "1" - ] - }, - "execution_count": 30, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "df = spark.read.parquet(\"mnist_1\")\n", - "len(df.columns)" - ] - }, - { - "cell_type": "code", - "execution_count": 31, - "id": "e027f0d2-0f65-47b7-a562-2f0965faceec", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "+--------------------+\n", - "| data|\n", - "+--------------------+\n", - "|[0.0, 0.0, 0.0, 0...|\n", - "|[0.0, 0.0, 0.0, 0...|\n", - "|[0.0, 0.0, 0.0, 0...|\n", - "|[0.0, 0.0, 0.0, 0...|\n", - "|[0.0, 0.0, 0.0, 0...|\n", - "+--------------------+\n", - "only showing top 5 rows\n", - "\n" - ] - } - ], - "source": [ - "df.show(5)" - ] - }, - { - "cell_type": "code", - "execution_count": 32, - "id": "f0c3fb2e-469e-47bc-b948-8f6b0d7f6513", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "[Stage 4:===================================================> (7 + 1) / 8]\r" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "CPU times: user 18.5 ms, sys: 13.3 ms, total: 31.8 ms\n", - "Wall time: 5.03 s\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - " \r" - ] - } - ], - "source": [ - "%%time\n", - "# first pass caches model/fn\n", - "preds = df.withColumn(\"preds\", mnist(struct(df.columns))).collect()" - ] - }, - { - "cell_type": "code", - "execution_count": 33, - "id": "cdfa229a-f4a9-4c11-a410-de4a21c02c82", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "CPU times: user 37.3 ms, sys: 12.4 ms, total: 49.8 ms\n", - "Wall time: 259 ms\n" - ] - } - ], - "source": [ - "%%time\n", - "preds = df.withColumn(\"preds\", mnist(*df.columns)).collect()" - ] - }, - { - "cell_type": "code", - "execution_count": 34, - "id": "5586ce49-6f93-4343-9b66-0dbb64972179", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "CPU times: user 22.9 ms, sys: 5.96 ms, total: 28.8 ms\n", - "Wall time: 237 ms\n" - ] - } - ], - "source": [ - "%%time\n", - "preds = df.withColumn(\"preds\", mnist(*[col(c) for c in df.columns])).collect()" - ] - }, - { - "cell_type": "markdown", - "id": "004f1599-3c62-499e-9fd8-ed5cb0c90de4", - "metadata": { - "tags": [] - }, - "source": [ - "#### Check predictions" - ] - }, - { - "cell_type": "code", - "execution_count": 35, - "id": "4f947dc0-6b18-4605-810b-e83250a161db", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - " \r" - ] - }, - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
datapreds
0[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...[-5.88436, -3.1058547, 0.10873719, 12.67319, -...
1[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...[-3.273286, -8.362554, 1.8936121, -3.8881433, ...
2[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...[-3.3856308, 0.6785604, 1.3146863, 0.9275978, ...
3[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...[-2.7754688, -7.3659225, 11.768427, 1.3434286,...
4[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...[-4.9426627, 4.0774136, -0.4529277, -0.9312789...
5[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...[-5.226616, -3.1389174, 2.6100307, 3.695045, -...
6[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...[-4.3006196, 5.1169925, 0.5850615, -0.76248693...
7[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...[-2.3985956, -1.4814724, -4.884057, -0.2391600...
8[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...[0.82160115, -2.8640625, -1.6951559, -4.489290...
9[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...[-1.2338604, -2.151981, -4.171742, 1.6106845, ...
\n", - "
" - ], - "text/plain": [ - " data \\\n", - "0 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... \n", - "1 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... \n", - "2 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... \n", - "3 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... \n", - "4 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... \n", - "5 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... \n", - "6 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... \n", - "7 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... \n", - "8 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... \n", - "9 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... \n", - "\n", - " preds \n", - "0 [-5.88436, -3.1058547, 0.10873719, 12.67319, -... \n", - "1 [-3.273286, -8.362554, 1.8936121, -3.8881433, ... \n", - "2 [-3.3856308, 0.6785604, 1.3146863, 0.9275978, ... \n", - "3 [-2.7754688, -7.3659225, 11.768427, 1.3434286,... \n", - "4 [-4.9426627, 4.0774136, -0.4529277, -0.9312789... \n", - "5 [-5.226616, -3.1389174, 2.6100307, 3.695045, -... \n", - "6 [-4.3006196, 5.1169925, 0.5850615, -0.76248693... \n", - "7 [-2.3985956, -1.4814724, -4.884057, -0.2391600... \n", - "8 [0.82160115, -2.8640625, -1.6951559, -4.489290... \n", - "9 [-1.2338604, -2.151981, -4.171742, 1.6106845, ... " - ] - }, - "execution_count": 35, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "preds = df.withColumn(\"preds\", mnist(*df.columns)).limit(10).toPandas()\n", - "preds" - ] - }, - { - "cell_type": "code", - "execution_count": 36, - "id": "de4964e0-d1f8-4753-afa1-a8f95ca3f151", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "array([ -5.88436 , -3.1058547 , 0.10873719, 12.67319 ,\n", - " -5.143787 , 4.0859914 , -10.203137 , -1.4333997 ,\n", - " -3.3865087 , -3.8473575 ], dtype=float32)" - ] - }, - "execution_count": 36, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "sample = preds.iloc[0]\n", - "sample.preds" - ] - }, - { - "cell_type": "code", - "execution_count": 37, - "id": "44e9a874-e301-4b72-8df7-bf1c5133c287", - "metadata": {}, - "outputs": [], - "source": [ - "import matplotlib.pyplot as plt\n", - "import numpy as np" - ] - }, - { - "cell_type": "code", - "execution_count": 38, - "id": "c60e5af4-fc1e-4575-a717-f304664235be", - "metadata": {}, - "outputs": [], - "source": [ - "prediction = np.argmax(sample.preds)\n", - "img = np.array(sample.data).reshape(28,28)" - ] - }, - { - "cell_type": "code", - "execution_count": 39, - "id": "eb45ecc9-d376-40c4-ad7b-2bd08ca5aaf6", - "metadata": {}, - "outputs": [ - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGzCAYAAABpdMNsAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAkfElEQVR4nO3dfXQUdZ7v8U/nqSEkaR7yLAFCFHRAcAYly/AgSiQEZUCYGUG9F7gziJiggI6KR0Udzsksrg7qIHjcHVhHEGWOyMoiDg9JGBRwwTCIM2QhJ0g4kIBcSYcAIaR/9w+uvbQkQDUdfkl4v86pc+iq37fqm6Lgk+qqrnYZY4wAALjKwmw3AAC4NhFAAAArCCAAgBUEEADACgIIAGAFAQQAsIIAAgBYQQABAKwggAAAVhBAgAPdunXTpEmT/K8LCwvlcrlUWFgYsm24XC698MILIVsf0FwRQGgxlixZIpfL5Z/atGmjHj16KC8vT5WVlbbbc2TNmjUtJmTefvtt3X777UpKSpLb7VZ6eromT56s/fv3224NLVyE7QYAp1566SWlp6fr9OnT2rx5sxYuXKg1a9Zo9+7dio6Ovqq9DBkyRKdOnVJUVJSjujVr1mjBggUNhtCpU6cUEdF8/mkWFxcrPT1dP/vZz9ShQweVlZXp7bff1urVq/W3v/1NqamptltEC9V8jnLgMuXk5OjWW2+VJP36179Wp06d9Oqrr2rVqlWaMGFCgzU1NTVq165dyHsJCwtTmzZtQrrOUK/vSr355psXzBszZoxuvfVWvfPOO3r66actdIXWgLfg0OLdeeedkqSysjJJ0qRJkxQTE6PS0lKNHDlSsbGxeuCBByRJPp9P8+fPV69evdSmTRslJSVp6tSp+u677wLWaYzR3Llz1blzZ0VHR+uOO+7Q119/fcG2G7sGtG3bNo0cOVIdOnRQu3bt1KdPH7322mv+/hYsWCBJAW8pfq+ha0DFxcXKyclRXFycYmJiNGzYMG3dujVgzPdvUX722WeaNWuWEhIS1K5dO9177706evRowNiqqirt2bNHVVVVl7OLL9CtWzdJ0vHjx4OqByTOgNAKlJaWSpI6derkn3f27FllZ2dr0KBB+pd/+Rf/W3NTp07VkiVLNHnyZD366KMqKyvTH/7wBxUXF+uzzz5TZGSkJOn555/X3LlzNXLkSI0cOVJffvmlhg8frjNnzlyyn3Xr1umee+5RSkqKHnvsMSUnJ+sf//iHVq9erccee0xTp07VoUOHtG7dOv3pT3+65Pq+/vprDR48WHFxcXryyScVGRmpt956S0OHDlVRUZEyMzMDxk+fPl0dOnTQnDlztH//fs2fP195eXl6//33/WNWrlypyZMna/HixQE3VVzMsWPHVF9frwMHDuill16SJA0bNuyyaoEGGaCFWLx4sZFk1q9fb44ePWrKy8vN8uXLTadOnUzbtm3NwYMHjTHGTJw40UgyTz/9dED9X//6VyPJLF26NGD+2rVrA+YfOXLEREVFmbvvvtv4fD7/uGeeecZIMhMnTvTPKygoMJJMQUGBMcaYs2fPmvT0dNO1a1fz3XffBWzn/HXl5uaaxv75STJz5szxvx4zZoyJiooypaWl/nmHDh0ysbGxZsiQIRfsn6ysrIBtzZw504SHh5vjx49fMHbx4sUN9tAQt9ttJBlJplOnTub111+/7FqgIbwFhxYnKytLCQkJSktL0/jx4xUTE6OVK1fquuuuCxg3bdq0gNcrVqyQx+PRXXfdpW+//dY/9evXTzExMSooKJAkrV+/XmfOnNH06dMD3hqbMWPGJXsrLi5WWVmZZsyYofbt2wcsO39dl6u+vl5/+ctfNGbMGHXv3t0/PyUlRffff782b94sr9cbUPPQQw8FbGvw4MGqr6/XN9984583adIkGWMu++xHkj755BOtWbNGr7zyirp06aKamhrHPw9wPt6CQ4uzYMEC9ejRQxEREUpKSlLPnj0VFhb4u1RERIQ6d+4cMG/v3r2qqqpSYmJig+s9cuSIJPn/o77hhhsClickJKhDhw4X7e37twN79+59+T/QRRw9elQnT55Uz549L1h20003yefzqby8XL169fLP79KlS8C473v+4XUup+644w5J524CGT16tHr37q2YmBjl5eVd0Xpx7SKA0OL079/ffxdcY9xu9wWh5PP5lJiYqKVLlzZYk5CQELIebQoPD29wvjEmZNvIyMjQj3/8Yy1dupQAQtAIIFwzMjIytH79eg0cOFBt27ZtdFzXrl0lnTtjOv9tr6NHj17yLCIjI0OStHv3bmVlZTU67nLfjktISFB0dLRKSkouWLZnzx6FhYUpLS3tstYVaqdOnVJtba2VbaN14BoQrhm//OUvVV9fr9/+9rcXLDt79qz/luKsrCxFRkbqjTfeCDhrmD9//iW38ZOf/ETp6emaP3/+Bbcon7+u7z+TdKnbmMPDwzV8+HCtWrUq4MkDlZWVWrZsmQYNGqS4uLhL9vVDl3sb9tmzZxsM3S+++EJfffXVJc9EgYvhDAjXjNtvv11Tp05Vfn6+du7cqeHDhysyMlJ79+7VihUr9Nprr+nnP/+5EhIS9MQTTyg/P1/33HOPRo4cqeLiYn3yySeKj4+/6DbCwsK0cOFCjRo1SrfccosmT56slJQU7dmzR19//bU+/fRTSVK/fv0kSY8++qiys7MVHh6u8ePHN7jOuXPnat26dRo0aJAeeeQRRURE6K233lJtba3mzZsX1L643NuwT5w4obS0NN13333q1auX2rVrp6+++kqLFy+Wx+PRc889F9T2AYkAwjVm0aJF6tevn9566y0988wzioiIULdu3fTggw9q4MCB/nFz585VmzZttGjRIhUUFCgzM1N/+ctfdPfdd19yG9nZ2SooKNCLL76oV155RT6fTxkZGZoyZYp/zNixYzV9+nQtX75c7777rowxjQZQr1699Ne//lWzZ89Wfn6+fD6fMjMz9e67717wGaBQi46O1q9//WsVFBToz3/+s06dOqXU1FRNmDBBzz77rP8DqUAwXCaUVyYBALhMXAMCAFhBAAEArCCAAABWEEAAACsIIACAFQQQAMCKZvc5IJ/Pp0OHDik2NjaopwcDAOwyxqi6ulqpqakXPJPxfM0ugA4dOmTt2VYAgNApLy+/4Kn052t2ARQbGytJGqSRilCk5W4AAE6dVZ02a43///PGNFkALViwQC+//LIqKirUt29fvfHGG+rfv/8l675/2y1CkYpwEUAA0OL8/+frXOoySpPchPD+++9r1qxZmjNnjr788kv17dtX2dnZ/i/8AgCgSQLo1Vdf1ZQpUzR58mT96Ec/0qJFixQdHa0//vGPTbE5AEALFPIAOnPmjHbs2BHwZVxhYWHKysrSli1bLhhfW1srr9cbMAEAWr+QB9C3336r+vp6JSUlBcxPSkpSRUXFBePz8/Pl8Xj8E3fAAcC1wfoHUWfPnq2qqir/VF5ebrslAMBVEPK74OLj4xUeHq7KysqA+ZWVlUpOTr5gvNvtltvtDnUbAIBmLuRnQFFRUerXr582bNjgn+fz+bRhwwYNGDAg1JsDALRQTfI5oFmzZmnixIm69dZb1b9/f82fP181NTWaPHlyU2wOANACNUkA3XfffTp69Kief/55VVRU6JZbbtHatWsvuDEBAHDtchljjO0mzuf1euXxeDRUo3kSAgC0QGdNnQq1SlVVVYqLi2t0nPW74AAA1yYCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWBFhuwHYZ37aN7i6cOe/v0RWeh3XlP7vRMc1vu6nHNdI0p7b/+i4JtzlfD9MPTjAcU3hp7c4run2nzWOayRJW3cFVwc4wBkQAMAKAggAYEXIA+iFF16Qy+UKmG688cZQbwYA0MI1yTWgXr16af369f+zkQguNQEAAjVJMkRERCg5ObkpVg0AaCWa5BrQ3r17lZqaqu7du+uBBx7QgQMHGh1bW1srr9cbMAEAWr+QB1BmZqaWLFmitWvXauHChSorK9PgwYNVXV3d4Pj8/Hx5PB7/lJaWFuqWAADNUMgDKCcnR7/4xS/Up08fZWdna82aNTp+/Lg++OCDBsfPnj1bVVVV/qm8vDzULQEAmqEmvzugffv26tGjh/bt29fgcrfbLbfb3dRtAACamSb/HNCJEydUWlqqlJSUpt4UAKAFCXkAPfHEEyoqKtL+/fv1+eef695771V4eLgmTJgQ6k0BAFqwkL8Fd/DgQU2YMEHHjh1TQkKCBg0apK1btyohISHUmwIAtGAuY4yx3cT5vF6vPB6Phmq0IlyRttuxqubnmY5rKm91flK7dsLLjmskqUtEW8c1/2v/XY5r/tRtneManFN8xhdU3eOP5zmuif5wW1DbQutz1tSpUKtUVVWluLi4RsfxLDgAgBUEEADACgIIAGAFAQQAsIIAAgBYQQABAKwggAAAVhBAAAArCCAAgBUEEADACgIIAGAFAQQAsIKHkV4lR/J+6rim8OlXHNdEu6Ic1zR339afclzTxhXc71Z1cv7PYcaBexzX/DLxvxzX3B1d5bgmWPvqah3XPDH4l45rzpYfdFyD5o+HkQIAmjUCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsiLDdwLXCF+68pjU+2frlYz9yXLNhxiDHNfVtg/vd6rvrnT+B/br/POy45s2EcY5r7v7zHx3XBGvsf011XNPt+P7QN4JWjTMgAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCh5FeJan/+jfHNR88kui4Jjv6gOOanDlPOK6RpLoYl+Oa6/7joOOaiP07nNc4rjgnOYia+iBqKu/5aRBVV8+uny5xXDMmiAes+qqrHdeg9eAMCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCs4GGkV4mvpsZxzTs90xzXvJ0z1nFNfEGx4xpJ8p0+7bjmbFBbunrCExIc13x3V4bjmscf/sBxDdDacAYEALCCAAIAWOE4gDZt2qRRo0YpNTVVLpdLH330UcByY4yef/55paSkqG3btsrKytLevXtD1S8AoJVwHEA1NTXq27evFixY0ODyefPm6fXXX9eiRYu0bds2tWvXTtnZ2TodxPUCAEDr5fgmhJycHOXk5DS4zBij+fPn69lnn9Xo0aMlSe+8846SkpL00Ucfafz48VfWLQCg1QjpNaCysjJVVFQoKyvLP8/j8SgzM1NbtmxpsKa2tlZerzdgAgC0fiENoIqKCklSUlJSwPykpCT/sh/Kz8+Xx+PxT2lpzm89BgC0PNbvgps9e7aqqqr8U3l5ue2WAABXQUgDKDk5WZJUWVkZML+ystK/7Ifcbrfi4uICJgBA6xfSAEpPT1dycrI2bNjgn+f1erVt2zYNGDAglJsCALRwju+CO3HihPbt2+d/XVZWpp07d6pjx47q0qWLZsyYoblz5+qGG25Qenq6nnvuOaWmpmrMmDGh7BsA0MI5DqDt27frjjvu8L+eNWuWJGnixIlasmSJnnzySdXU1Oihhx7S8ePHNWjQIK1du1Zt2rQJXdcAgBbPZYwxtps4n9frlcfj0VCNVoQr0nY7aKHC23uCqnt8x2bHNUPanAlqW1eDT76g6n73bV/HNduGpTiuqf/2mOMaNH9nTZ0KtUpVVVUXva5v/S44AMC1iQACAFhBAAEArCCAAABWEEAAACsIIACAFQQQAMAKAggAYAUBBACwggACAFhBAAEArCCAAABWEEAAACscfx0D0BKUPdorqLohbTaGuBO7VtXEB1X3ed+oIKp4sjWc4QwIAGAFAQQAsIIAAgBYQQABAKwggAAAVhBAAAArCCAAgBUEEADACgIIAGAFAQQAsIIAAgBYQQABAKzgYaRAK3ZX28NB1c2d9YDjmrqYoDblWGLxWcc1bT7+ogk6wZXiDAgAYAUBBACwggACAFhBAAEArCCAAABWEEAAACsIIACAFQQQAMAKAggAYAUBBACwggACAFhBAAEArOBhpGiV0jacDKpuxyTnNf3cQW3qqogJC665HY+/EeJOQmfOkR87rtnxMb9rN0f8rQAArCCAAABWEEAAACsIIACAFQQQAMAKAggAYAUBBACwggACAFhBAAEArCCAAABWEEAAACsIIACAFTyMFK2S67OdQdW9NOp+xzW1STGOa6of9zqu+eyW5Y5rWqNnE7Y7rrnzwUeD2pbn3a1B1eHycAYEALCCAAIAWOE4gDZt2qRRo0YpNTVVLpdLH330UcDySZMmyeVyBUwjRowIVb8AgFbCcQDV1NSob9++WrBgQaNjRowYocOHD/un995774qaBAC0Po5vQsjJyVFOTs5Fx7jdbiUnJwfdFACg9WuSa0CFhYVKTExUz549NW3aNB07dqzRsbW1tfJ6vQETAKD1C3kAjRgxQu+88442bNigf/7nf1ZRUZFycnJUX1/f4Pj8/Hx5PB7/lJaWFuqWAADNUMg/BzR+/Hj/n2+++Wb16dNHGRkZKiws1LBhwy4YP3v2bM2aNcv/2uv1EkIAcA1o8tuwu3fvrvj4eO3bt6/B5W63W3FxcQETAKD1a/IAOnjwoI4dO6aUlJSm3hQAoAVx/BbciRMnAs5mysrKtHPnTnXs2FEdO3bUiy++qHHjxik5OVmlpaV68skndf311ys7OzukjQMAWjbHAbR9+3bdcccd/tffX7+ZOHGiFi5cqF27dunf//3fdfz4caWmpmr48OH67W9/K7fbHbquAQAtnuMAGjp0qIwxjS7/9NNPr6ghwKb6r0sc10R87Xw7HQpcjmtGRf3Ucc3+P/VwXCNJn2QudFzTOaJtUNtyKtIV7rjmdMfgrjZ4gqrC5eJZcAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALAi5F/JDeAyXOSJ8o2W1NY6run6y68c10jSnW/NdFzz3/csCmpbuHZxBgQAsIIAAgBYQQABAKwggAAAVhBAAAArCCAAgBUEEADACgIIAGAFAQQAsIIAAgBYQQABAKwggAAAVvAwUqAVc0VGBVfXtj7EnYTOrjPOe0vcXtMEneBKcQYEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFbwMFKgFSt545ag6v572MLQNhJCMx6f7rgm+vNtTdAJrhRnQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQ8jRasUFhsbXF17T4g7adiRu9Ic19w1/TPHNf+RuMBxzTlX53fTD04kOq6J2/KN45qzjitwNXAGBACwggACAFjhKIDy8/N12223KTY2VomJiRozZoxKSkoCxpw+fVq5ubnq1KmTYmJiNG7cOFVWVoa0aQBAy+cogIqKipSbm6utW7dq3bp1qqur0/Dhw1VTU+MfM3PmTH388cdasWKFioqKdOjQIY0dOzbkjQMAWjZHNyGsXbs24PWSJUuUmJioHTt2aMiQIaqqqtK//du/admyZbrzzjslSYsXL9ZNN92krVu36p/+6Z9C1zkAoEW7omtAVVVVkqSOHTtKknbs2KG6ujplZWX5x9x4443q0qWLtmzZ0uA6amtr5fV6AyYAQOsXdAD5fD7NmDFDAwcOVO/evSVJFRUVioqKUvv27QPGJiUlqaKiosH15Ofny+Px+Ke0NOe3pwIAWp6gAyg3N1e7d+/W8uXLr6iB2bNnq6qqyj+Vl5df0foAAC1DUB9EzcvL0+rVq7Vp0yZ17tzZPz85OVlnzpzR8ePHA86CKisrlZyc3OC63G633G53MG0AAFowR2dAxhjl5eVp5cqV2rhxo9LT0wOW9+vXT5GRkdqwYYN/XklJiQ4cOKABAwaEpmMAQKvg6AwoNzdXy5Yt06pVqxQbG+u/ruPxeNS2bVt5PB796le/0qxZs9SxY0fFxcVp+vTpGjBgAHfAAQACOAqghQsXSpKGDh0aMH/x4sWaNGmSJOn3v/+9wsLCNG7cONXW1io7O1tvvvlmSJoFALQeLmOMsd3E+bxerzwej4ZqtCJckbbbuSaE9b0pqLo9uTGOa5LT/q/jmiMlCY5rJt9Z6LhGkp7q9HVQdQhOn88nOa7p8ouvQt8IQuqsqVOhVqmqqkpxcXGNjuNZcAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALAiqG9ERfPl6tfLcU3b3x8Jalv/nfFuUHWO9bk6m2nuak2d45pIV3hQ26qsr3VcM+dQjuOazq8F1x9aB86AAABWEEAAACsIIACAFQQQAMAKAggAYAUBBACwggACAFhBAAEArCCAAABWEEAAACsIIACAFQQQAMAKHkbaytR0jXFc8373fw1ya1FB1jU9n3xB1c08NNhxzW8S1zuuyf4813FNbGG045rqbo5LJEnps7cEUVXtuCJMO4PYDloLzoAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoeRtrKRH+4zXHNLw4/HNS2jv64neMaXxDPL61z/nxVvf1//uC8SFLpbacd10z78VTHNek7dzmukTGOS+KdbwW4ajgDAgBYQQABAKwggAAAVhBAAAArCCAAgBUEEADACgIIAGAFAQQAsIIAAgBYQQABAKwggAAAVhBAAAArXMYE8YTDJuT1euXxeDRUoxXhirTdDgDAobOmToVapaqqKsXFxTU6jjMgAIAVBBAAwApHAZSfn6/bbrtNsbGxSkxM1JgxY1RSUhIwZujQoXK5XAHTww8H930zAIDWy1EAFRUVKTc3V1u3btW6detUV1en4cOHq6amJmDclClTdPjwYf80b968kDYNAGj5HH0j6tq1awNeL1myRImJidqxY4eGDBninx8dHa3k5OTQdAgAaJWu6BpQVVWVJKljx44B85cuXar4+Hj17t1bs2fP1smTJxtdR21trbxeb8AEAGj9HJ0Bnc/n82nGjBkaOHCgevfu7Z9///33q2vXrkpNTdWuXbv01FNPqaSkRB9++GGD68nPz9eLL74YbBsAgBYq6M8BTZs2TZ988ok2b96szp07Nzpu48aNGjZsmPbt26eMjIwLltfW1qq2ttb/2uv1Ki0tjc8BAUALdbmfAwrqDCgvL0+rV6/Wpk2bLho+kpSZmSlJjQaQ2+2W2+0Opg0AQAvmKICMMZo+fbpWrlypwsJCpaenX7Jm586dkqSUlJSgGgQAtE6OAig3N1fLli3TqlWrFBsbq4qKCkmSx+NR27ZtVVpaqmXLlmnkyJHq1KmTdu3apZkzZ2rIkCHq06dPk/wAAICWydE1IJfL1eD8xYsXa9KkSSovL9eDDz6o3bt3q6amRmlpabr33nv17LPPXvR9wPPxLDgAaNma5BrQpbIqLS1NRUVFTlYJALhG8Sw4AIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVEbYb+CFjjCTprOokY7kZAIBjZ1Un6X/+P29Mswug6upqSdJmrbHcCQDgSlRXV8vj8TS63GUuFVFXmc/n06FDhxQbGyuXyxWwzOv1Ki0tTeXl5YqLi7PUoX3sh3PYD+ewH85hP5zTHPaDMUbV1dVKTU1VWFjjV3qa3RlQWFiYOnfufNExcXFx1/QB9j32wznsh3PYD+ewH86xvR8udubzPW5CAABYQQABAKxoUQHkdrs1Z84cud1u261YxX44h/1wDvvhHPbDOS1pPzS7mxAAANeGFnUGBABoPQggAIAVBBAAwAoCCABgBQEEALCixQTQggUL1K1bN7Vp00aZmZn64osvbLd01b3wwgtyuVwB04033mi7rSa3adMmjRo1SqmpqXK5XProo48Clhtj9PzzzyslJUVt27ZVVlaW9u7da6fZJnSp/TBp0qQLjo8RI0bYabaJ5Ofn67bbblNsbKwSExM1ZswYlZSUBIw5ffq0cnNz1alTJ8XExGjcuHGqrKy01HHTuJz9MHTo0AuOh4cffthSxw1rEQH0/vvva9asWZozZ46+/PJL9e3bV9nZ2Tpy5Ijt1q66Xr166fDhw/5p8+bNtltqcjU1Nerbt68WLFjQ4PJ58+bp9ddf16JFi7Rt2za1a9dO2dnZOn369FXutGldaj9I0ogRIwKOj/fee+8qdtj0ioqKlJubq61bt2rdunWqq6vT8OHDVVNT4x8zc+ZMffzxx1qxYoWKiop06NAhjR071mLXoXc5+0GSpkyZEnA8zJs3z1LHjTAtQP/+/U1ubq7/dX19vUlNTTX5+fkWu7r65syZY/r27Wu7DaskmZUrV/pf+3w+k5ycbF5++WX/vOPHjxu3223ee+89Cx1eHT/cD8YYM3HiRDN69Ggr/dhy5MgRI8kUFRUZY8793UdGRpoVK1b4x/zjH/8wksyWLVtstdnkfrgfjDHm9ttvN4899pi9pi5Dsz8DOnPmjHbs2KGsrCz/vLCwMGVlZWnLli0WO7Nj7969Sk1NVffu3fXAAw/owIEDtluyqqysTBUVFQHHh8fjUWZm5jV5fBQWFioxMVE9e/bUtGnTdOzYMdstNamqqipJUseOHSVJO3bsUF1dXcDxcOONN6pLly6t+nj44X743tKlSxUfH6/evXtr9uzZOnnypI32GtXsnob9Q99++63q6+uVlJQUMD8pKUl79uyx1JUdmZmZWrJkiXr27KnDhw/rxRdf1ODBg7V7927Fxsbabs+KiooKSWrw+Ph+2bVixIgRGjt2rNLT01VaWqpnnnlGOTk52rJli8LDw223F3I+n08zZszQwIED1bt3b0nnjoeoqCi1b98+YGxrPh4a2g+SdP/996tr165KTU3Vrl279NRTT6mkpEQffvihxW4DNfsAwv/Iycnx/7lPnz7KzMxU165d9cEHH+hXv/qVxc7QHIwfP97/55tvvll9+vRRRkaGCgsLNWzYMIudNY3c3Fzt3r37mrgOejGN7YeHHnrI/+ebb75ZKSkpGjZsmEpLS5WRkXG122xQs38LLj4+XuHh4RfcxVJZWank5GRLXTUP7du3V48ePbRv3z7brVjz/THA8XGh7t27Kz4+vlUeH3l5eVq9erUKCgoCvj8sOTlZZ86c0fHjxwPGt9bjobH90JDMzExJalbHQ7MPoKioKPXr108bNmzwz/P5fNqwYYMGDBhgsTP7Tpw4odLSUqWkpNhuxZr09HQlJycHHB9er1fbtm275o+PgwcP6tixY63q+DDGKC8vTytXrtTGjRuVnp4esLxfv36KjIwMOB5KSkp04MCBVnU8XGo/NGTnzp2S1LyOB9t3QVyO5cuXG7fbbZYsWWL+/ve/m4ceesi0b9/eVFRU2G7tqnr88cdNYWGhKSsrM5999pnJysoy8fHx5siRI7Zba1LV1dWmuLjYFBcXG0nm1VdfNcXFxeabb74xxhjzu9/9zrRv396sWrXK7Nq1y4wePdqkp6ebU6dOWe48tC62H6qrq80TTzxhtmzZYsrKysz69evNT37yE3PDDTeY06dP2249ZKZNm2Y8Ho8pLCw0hw8f9k8nT570j3n44YdNly5dzMaNG8327dvNgAEDzIABAyx2HXqX2g/79u0zL730ktm+fbspKyszq1atMt27dzdDhgyx3HmgFhFAxhjzxhtvmC5dupioqCjTv39/s3XrVtstXXX33XefSUlJMVFRUea6664z9913n9m3b5/ttppcQUGBkXTBNHHiRGPMuVuxn3vuOZOUlGTcbrcZNmyYKSkpsdt0E7jYfjh58qQZPny4SUhIMJGRkaZr165mypQpre6XtIZ+fklm8eLF/jGnTp0yjzzyiOnQoYOJjo429957rzl8+LC9ppvApfbDgQMHzJAhQ0zHjh2N2+02119/vfnNb35jqqqq7Db+A3wfEADAimZ/DQgA0DoRQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAV/w/hgVLrpVGHsAAAAABJRU5ErkJggg==", - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "plt.figure()\n", - "plt.title(\"Prediction: {}\".format(prediction))\n", - "plt.imshow(img)\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "id": "39167347-0b99-4972-998c-e1230bf1d4d5", - "metadata": {}, - "source": [ - "### 784 columns of float" - ] - }, - { - "cell_type": "code", - "execution_count": 40, - "id": "f1285e8b-1b96-437b-973a-eb868e33afb7", - "metadata": {}, - "outputs": [], - "source": [ - "import numpy as np\n", - "\n", - "from pyspark.ml.functions import predict_batch_udf\n", - "from pyspark.sql.functions import array, col, struct\n", - "from pyspark.sql.types import ArrayType, FloatType, Union, Dict" - ] - }, - { - "cell_type": "code", - "execution_count": 41, - "id": "6bea332e-f6de-494f-a0db-795d9fe3e134", - "metadata": {}, - "outputs": [], - "source": [ - "def predict_batch_fn():\n", - " import tensorflow as tf\n", - " # Enable GPU memory growth\n", - " gpus = tf.config.experimental.list_physical_devices('GPU')\n", - " if gpus:\n", - " try:\n", - " for gpu in gpus:\n", - " tf.config.experimental.set_memory_growth(gpu, True)\n", - " except RuntimeError as e:\n", - " print(e)\n", - " \n", - " model = tf.keras.models.load_model(model_dir)\n", - " def predict(inputs: np.ndarray) -> np.ndarray:\n", - " return model.predict(inputs)\n", - " \n", - " return predict" - ] - }, - { - "cell_type": "code", - "execution_count": 42, - "id": "731d234c-549f-4df3-8a2b-312e63195396", - "metadata": {}, - "outputs": [], - "source": [ - "mnist = predict_batch_udf(predict_batch_fn,\n", - " return_type=ArrayType(FloatType()),\n", - " batch_size=1024,\n", - " input_tensor_shapes=[[784]])" - ] - }, - { - "cell_type": "code", - "execution_count": 43, - "id": "a40fe207-6246-4b0e-abde-823979878d97", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "784" - ] - }, - "execution_count": 43, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "df = spark.read.parquet(\"mnist_784\")\n", - "len(df.columns)" - ] - }, - { - "cell_type": "code", - "execution_count": 44, - "id": "10904f12-03e7-4518-8f12-2aa11989ddf5", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "[Stage 10:=======> (1 + 7) / 8]\r" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "CPU times: user 45.6 ms, sys: 26 ms, total: 71.6 ms\n", - "Wall time: 5.51 s\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - " \r" - ] - } - ], - "source": [ - "%%time\n", - "preds = df.withColumn(\"preds\", mnist(struct(*df.columns))).collect()" - ] - }, - { - "cell_type": "code", - "execution_count": 45, - "id": "671128df-f0f4-4f54-b35c-d63a78c7f89a", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "CPU times: user 46.5 ms, sys: 34 ms, total: 80.5 ms\n", - "Wall time: 884 ms\n" - ] - } - ], - "source": [ - "%%time\n", - "preds = df.withColumn(\"preds\", mnist(array(*df.columns))).collect()" - ] - }, - { - "cell_type": "code", - "execution_count": 46, - "id": "ce35deaf-7d49-4f34-9bf9-b4e6fc5761f4", - "metadata": {}, - "outputs": [], - "source": [ - "# should raise ValueError\n", - "# preds = df.withColumn(\"preds\", mnist(*df.columns)).collect()" - ] - }, - { - "cell_type": "markdown", - "id": "01709833-484b-451f-9aa8-37be5b7baf14", - "metadata": {}, - "source": [ - "### Check prediction" - ] - }, - { - "cell_type": "code", - "execution_count": 47, - "id": "f9119632-b284-45d7-a262-c262e034c15c", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - " \r" - ] - }, - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
0123456789...775776777778779780781782783preds
00.00.00.00.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.00.0[-5.88436, -3.1058552, 0.108737305, 12.67319, ...
10.00.00.00.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.00.0[-3.2732859, -8.362555, 1.893612, -3.888143, 0...
20.00.00.00.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.00.0[-3.3856308, 0.6785604, 1.3146865, 0.9275978, ...
30.00.00.00.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.00.0[-2.775469, -7.3659234, 11.768431, 1.3434289, ...
40.00.00.00.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.00.0[-4.942663, 4.0774136, -0.45292768, -0.9312788...
50.00.00.00.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.00.0[-5.226616, -3.1389174, 2.6100307, 3.695045, -...
60.00.00.00.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.00.0[-4.3006196, 5.116993, 0.5850617, -0.7624871, ...
70.00.00.00.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.00.0[-2.398596, -1.4814726, -4.8840575, -0.2391601...
80.00.00.00.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.00.0[0.82160157, -2.8640628, -1.6951559, -4.489291...
90.00.00.00.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.00.0[-1.2338604, -2.151981, -4.1717424, 1.6106843,...
\n", - "

10 rows × 785 columns

\n", - "
" - ], - "text/plain": [ - " 0 1 2 3 4 5 6 7 8 9 ... 775 776 777 778 \\\n", - "0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 \n", - "1 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 \n", - "2 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 \n", - "3 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 \n", - "4 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 \n", - "5 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 \n", - "6 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 \n", - "7 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 \n", - "8 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 \n", - "9 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 \n", - "\n", - " 779 780 781 782 783 preds \n", - "0 0.0 0.0 0.0 0.0 0.0 [-5.88436, -3.1058552, 0.108737305, 12.67319, ... \n", - "1 0.0 0.0 0.0 0.0 0.0 [-3.2732859, -8.362555, 1.893612, -3.888143, 0... \n", - "2 0.0 0.0 0.0 0.0 0.0 [-3.3856308, 0.6785604, 1.3146865, 0.9275978, ... \n", - "3 0.0 0.0 0.0 0.0 0.0 [-2.775469, -7.3659234, 11.768431, 1.3434289, ... \n", - "4 0.0 0.0 0.0 0.0 0.0 [-4.942663, 4.0774136, -0.45292768, -0.9312788... \n", - "5 0.0 0.0 0.0 0.0 0.0 [-5.226616, -3.1389174, 2.6100307, 3.695045, -... \n", - "6 0.0 0.0 0.0 0.0 0.0 [-4.3006196, 5.116993, 0.5850617, -0.7624871, ... \n", - "7 0.0 0.0 0.0 0.0 0.0 [-2.398596, -1.4814726, -4.8840575, -0.2391601... \n", - "8 0.0 0.0 0.0 0.0 0.0 [0.82160157, -2.8640628, -1.6951559, -4.489291... \n", - "9 0.0 0.0 0.0 0.0 0.0 [-1.2338604, -2.151981, -4.1717424, 1.6106843,... \n", - "\n", - "[10 rows x 785 columns]" - ] - }, - "execution_count": 47, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "preds = df.withColumn(\"preds\", mnist(struct(df.columns))).limit(10).toPandas()\n", - "preds" - ] - }, - { - "cell_type": "code", - "execution_count": 48, - "id": "7c067c62-03a6-461e-a1ff-4653276fbea1", - "metadata": {}, - "outputs": [], - "source": [ - "import matplotlib.pyplot as plt\n", - "import numpy as np" - ] - }, - { - "cell_type": "code", - "execution_count": 49, - "id": "a7084ad0-c021-4296-bad0-7a238971f53b", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "array([ -5.88436 , -3.1058552, 0.1087373, 12.67319 , -5.1437874,\n", - " 4.085992 , -10.203137 , -1.4333997, -3.3865087, -3.8473575],\n", - " dtype=float32)" - ] - }, - "execution_count": 49, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "sample = preds.iloc[0]\n", - "sample.preds" - ] - }, - { - "cell_type": "code", - "execution_count": 50, - "id": "8167c832-93ef-4f50-873b-07b67c19ef53", - "metadata": {}, - "outputs": [], - "source": [ - "prediction = np.argmax(sample.preds)\n", - "img = sample.drop('preds').to_numpy(dtype=float)\n", - "img = np.array(img).reshape(28,28)" - ] - }, - { - "cell_type": "code", - "execution_count": 51, - "id": "297811e1-aecb-4afd-9a6a-30c49e8881cc", - "metadata": {}, - "outputs": [ - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGzCAYAAABpdMNsAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAkfElEQVR4nO3dfXQUdZ7v8U/nqSEkaR7yLAFCFHRAcAYly/AgSiQEZUCYGUG9F7gziJiggI6KR0Udzsksrg7qIHjcHVhHEGWOyMoiDg9JGBRwwTCIM2QhJ0g4kIBcSYcAIaR/9w+uvbQkQDUdfkl4v86pc+iq37fqm6Lgk+qqrnYZY4wAALjKwmw3AAC4NhFAAAArCCAAgBUEEADACgIIAGAFAQQAsIIAAgBYQQABAKwggAAAVhBAgAPdunXTpEmT/K8LCwvlcrlUWFgYsm24XC698MILIVsf0FwRQGgxlixZIpfL5Z/atGmjHj16KC8vT5WVlbbbc2TNmjUtJmTefvtt3X777UpKSpLb7VZ6eromT56s/fv3224NLVyE7QYAp1566SWlp6fr9OnT2rx5sxYuXKg1a9Zo9+7dio6Ovqq9DBkyRKdOnVJUVJSjujVr1mjBggUNhtCpU6cUEdF8/mkWFxcrPT1dP/vZz9ShQweVlZXp7bff1urVq/W3v/1NqamptltEC9V8jnLgMuXk5OjWW2+VJP36179Wp06d9Oqrr2rVqlWaMGFCgzU1NTVq165dyHsJCwtTmzZtQrrOUK/vSr355psXzBszZoxuvfVWvfPOO3r66actdIXWgLfg0OLdeeedkqSysjJJ0qRJkxQTE6PS0lKNHDlSsbGxeuCBByRJPp9P8+fPV69evdSmTRslJSVp6tSp+u677wLWaYzR3Llz1blzZ0VHR+uOO+7Q119/fcG2G7sGtG3bNo0cOVIdOnRQu3bt1KdPH7322mv+/hYsWCBJAW8pfq+ha0DFxcXKyclRXFycYmJiNGzYMG3dujVgzPdvUX722WeaNWuWEhIS1K5dO9177706evRowNiqqirt2bNHVVVVl7OLL9CtWzdJ0vHjx4OqByTOgNAKlJaWSpI6derkn3f27FllZ2dr0KBB+pd/+Rf/W3NTp07VkiVLNHnyZD366KMqKyvTH/7wBxUXF+uzzz5TZGSkJOn555/X3LlzNXLkSI0cOVJffvmlhg8frjNnzlyyn3Xr1umee+5RSkqKHnvsMSUnJ+sf//iHVq9erccee0xTp07VoUOHtG7dOv3pT3+65Pq+/vprDR48WHFxcXryyScVGRmpt956S0OHDlVRUZEyMzMDxk+fPl0dOnTQnDlztH//fs2fP195eXl6//33/WNWrlypyZMna/HixQE3VVzMsWPHVF9frwMHDuill16SJA0bNuyyaoEGGaCFWLx4sZFk1q9fb44ePWrKy8vN8uXLTadOnUzbtm3NwYMHjTHGTJw40UgyTz/9dED9X//6VyPJLF26NGD+2rVrA+YfOXLEREVFmbvvvtv4fD7/uGeeecZIMhMnTvTPKygoMJJMQUGBMcaYs2fPmvT0dNO1a1fz3XffBWzn/HXl5uaaxv75STJz5szxvx4zZoyJiooypaWl/nmHDh0ysbGxZsiQIRfsn6ysrIBtzZw504SHh5vjx49fMHbx4sUN9tAQt9ttJBlJplOnTub111+/7FqgIbwFhxYnKytLCQkJSktL0/jx4xUTE6OVK1fquuuuCxg3bdq0gNcrVqyQx+PRXXfdpW+//dY/9evXTzExMSooKJAkrV+/XmfOnNH06dMD3hqbMWPGJXsrLi5WWVmZZsyYofbt2wcsO39dl6u+vl5/+ctfNGbMGHXv3t0/PyUlRffff782b94sr9cbUPPQQw8FbGvw4MGqr6/XN9984583adIkGWMu++xHkj755BOtWbNGr7zyirp06aKamhrHPw9wPt6CQ4uzYMEC9ejRQxEREUpKSlLPnj0VFhb4u1RERIQ6d+4cMG/v3r2qqqpSYmJig+s9cuSIJPn/o77hhhsClickJKhDhw4X7e37twN79+59+T/QRRw9elQnT55Uz549L1h20003yefzqby8XL169fLP79KlS8C473v+4XUup+644w5J524CGT16tHr37q2YmBjl5eVd0Xpx7SKA0OL079/ffxdcY9xu9wWh5PP5lJiYqKVLlzZYk5CQELIebQoPD29wvjEmZNvIyMjQj3/8Yy1dupQAQtAIIFwzMjIytH79eg0cOFBt27ZtdFzXrl0lnTtjOv9tr6NHj17yLCIjI0OStHv3bmVlZTU67nLfjktISFB0dLRKSkouWLZnzx6FhYUpLS3tstYVaqdOnVJtba2VbaN14BoQrhm//OUvVV9fr9/+9rcXLDt79qz/luKsrCxFRkbqjTfeCDhrmD9//iW38ZOf/ETp6emaP3/+Bbcon7+u7z+TdKnbmMPDwzV8+HCtWrUq4MkDlZWVWrZsmQYNGqS4uLhL9vVDl3sb9tmzZxsM3S+++EJfffXVJc9EgYvhDAjXjNtvv11Tp05Vfn6+du7cqeHDhysyMlJ79+7VihUr9Nprr+nnP/+5EhIS9MQTTyg/P1/33HOPRo4cqeLiYn3yySeKj4+/6DbCwsK0cOFCjRo1SrfccosmT56slJQU7dmzR19//bU+/fRTSVK/fv0kSY8++qiys7MVHh6u8ePHN7jOuXPnat26dRo0aJAeeeQRRURE6K233lJtba3mzZsX1L643NuwT5w4obS0NN13333q1auX2rVrp6+++kqLFy+Wx+PRc889F9T2AYkAwjVm0aJF6tevn9566y0988wzioiIULdu3fTggw9q4MCB/nFz585VmzZttGjRIhUUFCgzM1N/+ctfdPfdd19yG9nZ2SooKNCLL76oV155RT6fTxkZGZoyZYp/zNixYzV9+nQtX75c7777rowxjQZQr1699Ne//lWzZ89Wfn6+fD6fMjMz9e67717wGaBQi46O1q9//WsVFBToz3/+s06dOqXU1FRNmDBBzz77rP8DqUAwXCaUVyYBALhMXAMCAFhBAAEArCCAAABWEEAAACsIIACAFQQQAMCKZvc5IJ/Pp0OHDik2NjaopwcDAOwyxqi6ulqpqakXPJPxfM0ugA4dOmTt2VYAgNApLy+/4Kn052t2ARQbGytJGqSRilCk5W4AAE6dVZ02a43///PGNFkALViwQC+//LIqKirUt29fvfHGG+rfv/8l675/2y1CkYpwEUAA0OL8/+frXOoySpPchPD+++9r1qxZmjNnjr788kv17dtX2dnZ/i/8AgCgSQLo1Vdf1ZQpUzR58mT96Ec/0qJFixQdHa0//vGPTbE5AEALFPIAOnPmjHbs2BHwZVxhYWHKysrSli1bLhhfW1srr9cbMAEAWr+QB9C3336r+vp6JSUlBcxPSkpSRUXFBePz8/Pl8Xj8E3fAAcC1wfoHUWfPnq2qqir/VF5ebrslAMBVEPK74OLj4xUeHq7KysqA+ZWVlUpOTr5gvNvtltvtDnUbAIBmLuRnQFFRUerXr582bNjgn+fz+bRhwwYNGDAg1JsDALRQTfI5oFmzZmnixIm69dZb1b9/f82fP181NTWaPHlyU2wOANACNUkA3XfffTp69Kief/55VVRU6JZbbtHatWsvuDEBAHDtchljjO0mzuf1euXxeDRUo3kSAgC0QGdNnQq1SlVVVYqLi2t0nPW74AAA1yYCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWBFhuwHYZ37aN7i6cOe/v0RWeh3XlP7vRMc1vu6nHNdI0p7b/+i4JtzlfD9MPTjAcU3hp7c4run2nzWOayRJW3cFVwc4wBkQAMAKAggAYEXIA+iFF16Qy+UKmG688cZQbwYA0MI1yTWgXr16af369f+zkQguNQEAAjVJMkRERCg5ObkpVg0AaCWa5BrQ3r17lZqaqu7du+uBBx7QgQMHGh1bW1srr9cbMAEAWr+QB1BmZqaWLFmitWvXauHChSorK9PgwYNVXV3d4Pj8/Hx5PB7/lJaWFuqWAADNUMgDKCcnR7/4xS/Up08fZWdna82aNTp+/Lg++OCDBsfPnj1bVVVV/qm8vDzULQEAmqEmvzugffv26tGjh/bt29fgcrfbLbfb3dRtAACamSb/HNCJEydUWlqqlJSUpt4UAKAFCXkAPfHEEyoqKtL+/fv1+eef695771V4eLgmTJgQ6k0BAFqwkL8Fd/DgQU2YMEHHjh1TQkKCBg0apK1btyohISHUmwIAtGAuY4yx3cT5vF6vPB6Phmq0IlyRttuxqubnmY5rKm91flK7dsLLjmskqUtEW8c1/2v/XY5r/tRtneManFN8xhdU3eOP5zmuif5wW1DbQutz1tSpUKtUVVWluLi4RsfxLDgAgBUEEADACgIIAGAFAQQAsIIAAgBYQQABAKwggAAAVhBAAAArCCAAgBUEEADACgIIAGAFAQQAsIKHkV4lR/J+6rim8OlXHNdEu6Ic1zR339afclzTxhXc71Z1cv7PYcaBexzX/DLxvxzX3B1d5bgmWPvqah3XPDH4l45rzpYfdFyD5o+HkQIAmjUCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsiLDdwLXCF+68pjU+2frlYz9yXLNhxiDHNfVtg/vd6rvrnT+B/br/POy45s2EcY5r7v7zHx3XBGvsf011XNPt+P7QN4JWjTMgAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCh5FeJan/+jfHNR88kui4Jjv6gOOanDlPOK6RpLoYl+Oa6/7joOOaiP07nNc4rjgnOYia+iBqKu/5aRBVV8+uny5xXDMmiAes+qqrHdeg9eAMCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCs4GGkV4mvpsZxzTs90xzXvJ0z1nFNfEGx4xpJ8p0+7bjmbFBbunrCExIc13x3V4bjmscf/sBxDdDacAYEALCCAAIAWOE4gDZt2qRRo0YpNTVVLpdLH330UcByY4yef/55paSkqG3btsrKytLevXtD1S8AoJVwHEA1NTXq27evFixY0ODyefPm6fXXX9eiRYu0bds2tWvXTtnZ2TodxPUCAEDr5fgmhJycHOXk5DS4zBij+fPn69lnn9Xo0aMlSe+8846SkpL00Ucfafz48VfWLQCg1QjpNaCysjJVVFQoKyvLP8/j8SgzM1NbtmxpsKa2tlZerzdgAgC0fiENoIqKCklSUlJSwPykpCT/sh/Kz8+Xx+PxT2lpzm89BgC0PNbvgps9e7aqqqr8U3l5ue2WAABXQUgDKDk5WZJUWVkZML+ystK/7Ifcbrfi4uICJgBA6xfSAEpPT1dycrI2bNjgn+f1erVt2zYNGDAglJsCALRwju+CO3HihPbt2+d/XVZWpp07d6pjx47q0qWLZsyYoblz5+qGG25Qenq6nnvuOaWmpmrMmDGh7BsA0MI5DqDt27frjjvu8L+eNWuWJGnixIlasmSJnnzySdXU1Oihhx7S8ePHNWjQIK1du1Zt2rQJXdcAgBbPZYwxtps4n9frlcfj0VCNVoQr0nY7aKHC23uCqnt8x2bHNUPanAlqW1eDT76g6n73bV/HNduGpTiuqf/2mOMaNH9nTZ0KtUpVVVUXva5v/S44AMC1iQACAFhBAAEArCCAAABWEEAAACsIIACAFQQQAMAKAggAYAUBBACwggACAFhBAAEArCCAAABWEEAAACscfx0D0BKUPdorqLohbTaGuBO7VtXEB1X3ed+oIKp4sjWc4QwIAGAFAQQAsIIAAgBYQQABAKwggAAAVhBAAAArCCAAgBUEEADACgIIAGAFAQQAsIIAAgBYQQABAKzgYaRAK3ZX28NB1c2d9YDjmrqYoDblWGLxWcc1bT7+ogk6wZXiDAgAYAUBBACwggACAFhBAAEArCCAAABWEEAAACsIIACAFQQQAMAKAggAYAUBBACwggACAFhBAAEArOBhpGiV0jacDKpuxyTnNf3cQW3qqogJC665HY+/EeJOQmfOkR87rtnxMb9rN0f8rQAArCCAAABWEEAAACsIIACAFQQQAMAKAggAYAUBBACwggACAFhBAAEArCCAAABWEEAAACsIIACAFTyMFK2S67OdQdW9NOp+xzW1STGOa6of9zqu+eyW5Y5rWqNnE7Y7rrnzwUeD2pbn3a1B1eHycAYEALCCAAIAWOE4gDZt2qRRo0YpNTVVLpdLH330UcDySZMmyeVyBUwjRowIVb8AgFbCcQDV1NSob9++WrBgQaNjRowYocOHD/un995774qaBAC0Po5vQsjJyVFOTs5Fx7jdbiUnJwfdFACg9WuSa0CFhYVKTExUz549NW3aNB07dqzRsbW1tfJ6vQETAKD1C3kAjRgxQu+88442bNigf/7nf1ZRUZFycnJUX1/f4Pj8/Hx5PB7/lJaWFuqWAADNUMg/BzR+/Hj/n2+++Wb16dNHGRkZKiws1LBhwy4YP3v2bM2aNcv/2uv1EkIAcA1o8tuwu3fvrvj4eO3bt6/B5W63W3FxcQETAKD1a/IAOnjwoI4dO6aUlJSm3hQAoAVx/BbciRMnAs5mysrKtHPnTnXs2FEdO3bUiy++qHHjxik5OVmlpaV68skndf311ys7OzukjQMAWjbHAbR9+3bdcccd/tffX7+ZOHGiFi5cqF27dunf//3fdfz4caWmpmr48OH67W9/K7fbHbquAQAtnuMAGjp0qIwxjS7/9NNPr6ghwKb6r0sc10R87Xw7HQpcjmtGRf3Ucc3+P/VwXCNJn2QudFzTOaJtUNtyKtIV7rjmdMfgrjZ4gqrC5eJZcAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALAi5F/JDeAyXOSJ8o2W1NY6run6y68c10jSnW/NdFzz3/csCmpbuHZxBgQAsIIAAgBYQQABAKwggAAAVhBAAAArCCAAgBUEEADACgIIAGAFAQQAsIIAAgBYQQABAKwggAAAVvAwUqAVc0VGBVfXtj7EnYTOrjPOe0vcXtMEneBKcQYEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFbwMFKgFSt545ag6v572MLQNhJCMx6f7rgm+vNtTdAJrhRnQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQ8jRasUFhsbXF17T4g7adiRu9Ic19w1/TPHNf+RuMBxzTlX53fTD04kOq6J2/KN45qzjitwNXAGBACwggACAFjhKIDy8/N12223KTY2VomJiRozZoxKSkoCxpw+fVq5ubnq1KmTYmJiNG7cOFVWVoa0aQBAy+cogIqKipSbm6utW7dq3bp1qqur0/Dhw1VTU+MfM3PmTH388cdasWKFioqKdOjQIY0dOzbkjQMAWjZHNyGsXbs24PWSJUuUmJioHTt2aMiQIaqqqtK//du/admyZbrzzjslSYsXL9ZNN92krVu36p/+6Z9C1zkAoEW7omtAVVVVkqSOHTtKknbs2KG6ujplZWX5x9x4443q0qWLtmzZ0uA6amtr5fV6AyYAQOsXdAD5fD7NmDFDAwcOVO/evSVJFRUVioqKUvv27QPGJiUlqaKiosH15Ofny+Px+Ke0NOe3pwIAWp6gAyg3N1e7d+/W8uXLr6iB2bNnq6qqyj+Vl5df0foAAC1DUB9EzcvL0+rVq7Vp0yZ17tzZPz85OVlnzpzR8ePHA86CKisrlZyc3OC63G633G53MG0AAFowR2dAxhjl5eVp5cqV2rhxo9LT0wOW9+vXT5GRkdqwYYN/XklJiQ4cOKABAwaEpmMAQKvg6AwoNzdXy5Yt06pVqxQbG+u/ruPxeNS2bVt5PB796le/0qxZs9SxY0fFxcVp+vTpGjBgAHfAAQACOAqghQsXSpKGDh0aMH/x4sWaNGmSJOn3v/+9wsLCNG7cONXW1io7O1tvvvlmSJoFALQeLmOMsd3E+bxerzwej4ZqtCJckbbbuSaE9b0pqLo9uTGOa5LT/q/jmiMlCY5rJt9Z6LhGkp7q9HVQdQhOn88nOa7p8ouvQt8IQuqsqVOhVqmqqkpxcXGNjuNZcAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALAiqG9ERfPl6tfLcU3b3x8Jalv/nfFuUHWO9bk6m2nuak2d45pIV3hQ26qsr3VcM+dQjuOazq8F1x9aB86AAABWEEAAACsIIACAFQQQAMAKAggAYAUBBACwggACAFhBAAEArCCAAABWEEAAACsIIACAFQQQAMAKHkbaytR0jXFc8373fw1ya1FB1jU9n3xB1c08NNhxzW8S1zuuyf4813FNbGG045rqbo5LJEnps7cEUVXtuCJMO4PYDloLzoAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoeRtrKRH+4zXHNLw4/HNS2jv64neMaXxDPL61z/nxVvf1//uC8SFLpbacd10z78VTHNek7dzmukTGOS+KdbwW4ajgDAgBYQQABAKwggAAAVhBAAAArCCAAgBUEEADACgIIAGAFAQQAsIIAAgBYQQABAKwggAAAVhBAAAArXMYE8YTDJuT1euXxeDRUoxXhirTdDgDAobOmToVapaqqKsXFxTU6jjMgAIAVBBAAwApHAZSfn6/bbrtNsbGxSkxM1JgxY1RSUhIwZujQoXK5XAHTww8H930zAIDWy1EAFRUVKTc3V1u3btW6detUV1en4cOHq6amJmDclClTdPjwYf80b968kDYNAGj5HH0j6tq1awNeL1myRImJidqxY4eGDBninx8dHa3k5OTQdAgAaJWu6BpQVVWVJKljx44B85cuXar4+Hj17t1bs2fP1smTJxtdR21trbxeb8AEAGj9HJ0Bnc/n82nGjBkaOHCgevfu7Z9///33q2vXrkpNTdWuXbv01FNPqaSkRB9++GGD68nPz9eLL74YbBsAgBYq6M8BTZs2TZ988ok2b96szp07Nzpu48aNGjZsmPbt26eMjIwLltfW1qq2ttb/2uv1Ki0tjc8BAUALdbmfAwrqDCgvL0+rV6/Wpk2bLho+kpSZmSlJjQaQ2+2W2+0Opg0AQAvmKICMMZo+fbpWrlypwsJCpaenX7Jm586dkqSUlJSgGgQAtE6OAig3N1fLli3TqlWrFBsbq4qKCkmSx+NR27ZtVVpaqmXLlmnkyJHq1KmTdu3apZkzZ2rIkCHq06dPk/wAAICWydE1IJfL1eD8xYsXa9KkSSovL9eDDz6o3bt3q6amRmlpabr33nv17LPPXvR9wPPxLDgAaNma5BrQpbIqLS1NRUVFTlYJALhG8Sw4AIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVEbYb+CFjjCTprOokY7kZAIBjZ1Un6X/+P29Mswug6upqSdJmrbHcCQDgSlRXV8vj8TS63GUuFVFXmc/n06FDhxQbGyuXyxWwzOv1Ki0tTeXl5YqLi7PUoX3sh3PYD+ewH85hP5zTHPaDMUbV1dVKTU1VWFjjV3qa3RlQWFiYOnfufNExcXFx1/QB9j32wznsh3PYD+ewH86xvR8udubzPW5CAABYQQABAKxoUQHkdrs1Z84cud1u261YxX44h/1wDvvhHPbDOS1pPzS7mxAAANeGFnUGBABoPQggAIAVBBAAwAoCCABgBQEEALCixQTQggUL1K1bN7Vp00aZmZn64osvbLd01b3wwgtyuVwB04033mi7rSa3adMmjRo1SqmpqXK5XProo48Clhtj9PzzzyslJUVt27ZVVlaW9u7da6fZJnSp/TBp0qQLjo8RI0bYabaJ5Ofn67bbblNsbKwSExM1ZswYlZSUBIw5ffq0cnNz1alTJ8XExGjcuHGqrKy01HHTuJz9MHTo0AuOh4cffthSxw1rEQH0/vvva9asWZozZ46+/PJL9e3bV9nZ2Tpy5Ijt1q66Xr166fDhw/5p8+bNtltqcjU1Nerbt68WLFjQ4PJ58+bp9ddf16JFi7Rt2za1a9dO2dnZOn369FXutGldaj9I0ogRIwKOj/fee+8qdtj0ioqKlJubq61bt2rdunWqq6vT8OHDVVNT4x8zc+ZMffzxx1qxYoWKiop06NAhjR071mLXoXc5+0GSpkyZEnA8zJs3z1LHjTAtQP/+/U1ubq7/dX19vUlNTTX5+fkWu7r65syZY/r27Wu7DaskmZUrV/pf+3w+k5ycbF5++WX/vOPHjxu3223ee+89Cx1eHT/cD8YYM3HiRDN69Ggr/dhy5MgRI8kUFRUZY8793UdGRpoVK1b4x/zjH/8wksyWLVtstdnkfrgfjDHm9ttvN4899pi9pi5Dsz8DOnPmjHbs2KGsrCz/vLCwMGVlZWnLli0WO7Nj7969Sk1NVffu3fXAAw/owIEDtluyqqysTBUVFQHHh8fjUWZm5jV5fBQWFioxMVE9e/bUtGnTdOzYMdstNamqqipJUseOHSVJO3bsUF1dXcDxcOONN6pLly6t+nj44X743tKlSxUfH6/evXtr9uzZOnnypI32GtXsnob9Q99++63q6+uVlJQUMD8pKUl79uyx1JUdmZmZWrJkiXr27KnDhw/rxRdf1ODBg7V7927Fxsbabs+KiooKSWrw+Ph+2bVixIgRGjt2rNLT01VaWqpnnnlGOTk52rJli8LDw223F3I+n08zZszQwIED1bt3b0nnjoeoqCi1b98+YGxrPh4a2g+SdP/996tr165KTU3Vrl279NRTT6mkpEQffvihxW4DNfsAwv/Iycnx/7lPnz7KzMxU165d9cEHH+hXv/qVxc7QHIwfP97/55tvvll9+vRRRkaGCgsLNWzYMIudNY3c3Fzt3r37mrgOejGN7YeHHnrI/+ebb75ZKSkpGjZsmEpLS5WRkXG122xQs38LLj4+XuHh4RfcxVJZWank5GRLXTUP7du3V48ePbRv3z7brVjz/THA8XGh7t27Kz4+vlUeH3l5eVq9erUKCgoCvj8sOTlZZ86c0fHjxwPGt9bjobH90JDMzExJalbHQ7MPoKioKPXr108bNmzwz/P5fNqwYYMGDBhgsTP7Tpw4odLSUqWkpNhuxZr09HQlJycHHB9er1fbtm275o+PgwcP6tixY63q+DDGKC8vTytXrtTGjRuVnp4esLxfv36KjIwMOB5KSkp04MCBVnU8XGo/NGTnzp2S1LyOB9t3QVyO5cuXG7fbbZYsWWL+/ve/m4ceesi0b9/eVFRU2G7tqnr88cdNYWGhKSsrM5999pnJysoy8fHx5siRI7Zba1LV1dWmuLjYFBcXG0nm1VdfNcXFxeabb74xxhjzu9/9zrRv396sWrXK7Nq1y4wePdqkp6ebU6dOWe48tC62H6qrq80TTzxhtmzZYsrKysz69evNT37yE3PDDTeY06dP2249ZKZNm2Y8Ho8pLCw0hw8f9k8nT570j3n44YdNly5dzMaNG8327dvNgAEDzIABAyx2HXqX2g/79u0zL730ktm+fbspKyszq1atMt27dzdDhgyx3HmgFhFAxhjzxhtvmC5dupioqCjTv39/s3XrVtstXXX33XefSUlJMVFRUea6664z9913n9m3b5/ttppcQUGBkXTBNHHiRGPMuVuxn3vuOZOUlGTcbrcZNmyYKSkpsdt0E7jYfjh58qQZPny4SUhIMJGRkaZr165mypQpre6XtIZ+fklm8eLF/jGnTp0yjzzyiOnQoYOJjo429957rzl8+LC9ppvApfbDgQMHzJAhQ0zHjh2N2+02119/vfnNb35jqqqq7Db+A3wfEADAimZ/DQgA0DoRQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAV/w/hgVLrpVGHsAAAAABJRU5ErkJggg==", - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "plt.figure()\n", - "plt.title(\"Prediction: {}\".format(prediction))\n", - "plt.imshow(img)\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "id": "5961593d-182e-4620-9a5e-f98ba3d2534d", - "metadata": {}, - "source": [ - "### Using Triton Inference Server\n", - "\n", - "Note: you can restart the kernel and run from this point to simulate running in a different node or environment." - ] - }, - { - "cell_type": "code", - "execution_count": 52, - "id": "a64d19b1-ba4a-4dc7-b3a9-368dc47d0fd8", - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "from pyspark.ml.functions import predict_batch_udf\n", - "from pyspark.sql.functions import col, struct\n", - "from pyspark.sql.types import ArrayType, FloatType" - ] - }, - { - "cell_type": "code", - "execution_count": 53, - "id": "8fa92fe4-2e04-4d82-a357-bfdfca38bd8c", - "metadata": {}, - "outputs": [], - "source": [ - "%%bash\n", - "# copy model to expected layout for Triton\n", - "rm -rf models\n", - "mkdir -p models/mnist_model/1\n", - "cp -r mnist_model models/mnist_model/1/model.savedmodel\n", - "\n", - "# add config.pbtxt\n", - "cp models_config/mnist_model/config.pbtxt models/mnist_model/config.pbtxt" - ] - }, - { - "cell_type": "markdown", - "id": "f1673e0e-5c75-44e1-88c6-5f5cf1275e4b", - "metadata": {}, - "source": [ - "#### Start Triton Server on each executor" - ] - }, - { - "cell_type": "code", - "execution_count": 54, - "id": "0f7ecb25-be16-40c4-bdbb-441e2f537000", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - " \r" - ] - }, - { - "data": { - "text/plain": [ - "[True]" - ] - }, - "execution_count": 54, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "num_executors = 1\n", - "triton_models_dir = \"{}/models\".format(os.getcwd())\n", - "nodeRDD = sc.parallelize(list(range(num_executors)), num_executors)\n", - "\n", - "def start_triton(it):\n", - " import docker\n", - " import time\n", - " import tritonclient.grpc as grpcclient\n", - " \n", - " client=docker.from_env()\n", - " containers=client.containers.list(filters={\"name\": \"spark-triton\"})\n", - " if containers:\n", - " \n", - " print(\">>>> containers: {}\".format([c.short_id for c in containers]))\n", - " else:\n", - " container=client.containers.run(\n", - " \"nvcr.io/nvidia/tritonserver:24.08-py3\", \"tritonserver --model-repository=/models\",\n", - " detach=True,\n", - " device_requests=[docker.types.DeviceRequest(device_ids=[\"0\"], capabilities=[['gpu']])],\n", - " name=\"spark-triton\",\n", - " network_mode=\"host\",\n", - " remove=True,\n", - " shm_size=\"64M\",\n", - " volumes={triton_models_dir: {\"bind\": \"/models\", \"mode\": \"ro\"}}\n", - " )\n", - " print(\">>>> starting triton: {}\".format(container.short_id))\n", - "\n", - " # wait for triton to be running\n", - " time.sleep(15)\n", - " client = grpcclient.InferenceServerClient(\"localhost:8001\")\n", - " ready = False\n", - " while not ready:\n", - " try:\n", - " ready = client.is_server_ready()\n", - " except Exception as e:\n", - " time.sleep(5)\n", - " \n", - " return [True]\n", - "\n", - "nodeRDD.barrier().mapPartitions(start_triton).collect()" - ] - }, - { - "cell_type": "code", - "execution_count": 55, - "id": "43b93753-1d52-4060-9986-f24c30a67528", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "StructType([StructField('data', ArrayType(DoubleType(), True), True)])" - ] - }, - "execution_count": 55, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "df = spark.read.parquet(\"mnist_1\")\n", - "df.schema" - ] - }, - { - "cell_type": "markdown", - "id": "036680eb-babd-4b07-8b2c-ce6e724f4e85", - "metadata": {}, - "source": [ - "#### Run inference" - ] - }, - { - "cell_type": "code", - "execution_count": 56, - "id": "3af08bd0-3838-4769-a8de-2643db4101c6", - "metadata": {}, - "outputs": [], - "source": [ - "def triton_fn(triton_uri, model_name):\n", - " import numpy as np\n", - " import tritonclient.grpc as grpcclient\n", - "\n", - " np_types = {\n", - " \"BOOL\": np.dtype(np.bool_),\n", - " \"INT8\": np.dtype(np.int8),\n", - " \"INT16\": np.dtype(np.int16),\n", - " \"INT32\": np.dtype(np.int32),\n", - " \"INT64\": np.dtype(np.int64),\n", - " \"FP16\": np.dtype(np.float16),\n", - " \"FP32\": np.dtype(np.float32),\n", - " \"FP64\": np.dtype(np.float64),\n", - " \"FP64\": np.dtype(np.double),\n", - " \"BYTES\": np.dtype(object)\n", - " }\n", - "\n", - " client = grpcclient.InferenceServerClient(triton_uri)\n", - " model_meta = client.get_model_metadata(model_name)\n", - "\n", - " def predict(inputs):\n", - " if isinstance(inputs, np.ndarray):\n", - " # single ndarray input\n", - " request = [grpcclient.InferInput(model_meta.inputs[0].name, inputs.shape, model_meta.inputs[0].datatype)]\n", - " request[0].set_data_from_numpy(inputs.astype(np_types[model_meta.inputs[0].datatype]))\n", - " else:\n", - " # dict of multiple ndarray inputs\n", - " request = [grpcclient.InferInput(i.name, inputs[i.name].shape, i.datatype) for i in model_meta.inputs]\n", - " for i in request:\n", - " i.set_data_from_numpy(inputs[i.name()].astype(np_types[i.datatype()]))\n", - "\n", - " response = client.infer(model_name, inputs=request)\n", - "\n", - " if len(model_meta.outputs) > 1:\n", - " # return dictionary of numpy arrays\n", - " return {o.name: response.as_numpy(o.name) for o in model_meta.outputs}\n", - " else:\n", - " # return single numpy array\n", - " return response.as_numpy(model_meta.outputs[0].name)\n", - "\n", - " return predict" - ] - }, - { - "cell_type": "code", - "execution_count": 57, - "id": "6658d2a1-ef7b-4ca1-9fb6-f2ac9050f3e5", - "metadata": {}, - "outputs": [], - "source": [ - "from functools import partial\n", - "\n", - "predict = predict_batch_udf(partial(triton_fn, \"localhost:8001\", \"mnist_model\"),\n", - " return_type=ArrayType(FloatType()),\n", - " input_tensor_shapes=[[784]],\n", - " batch_size=8192)" - ] - }, - { - "cell_type": "code", - "execution_count": 58, - "id": "8397aa14-82fd-4351-a477-dc8e8b321fa2", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "CPU times: user 20.1 ms, sys: 3.41 ms, total: 23.5 ms\n", - "Wall time: 625 ms\n" - ] - } - ], - "source": [ - "%%time\n", - "preds = df.withColumn(\"preds\", predict(struct(\"data\"))).collect()" - ] - }, - { - "cell_type": "code", - "execution_count": 59, - "id": "82698bd9-377a-4415-8971-835487f876cc", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "CPU times: user 30.3 ms, sys: 8.81 ms, total: 39.2 ms\n", - "Wall time: 154 ms\n" - ] - } - ], - "source": [ - "%%time\n", - "preds = df.withColumn(\"preds\", predict(\"data\")).collect()" - ] - }, - { - "cell_type": "code", - "execution_count": 60, - "id": "419ad7bd-fa28-49d3-b98d-db9fba5aeaef", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "CPU times: user 2.67 ms, sys: 4.2 ms, total: 6.87 ms\n", - "Wall time: 131 ms\n" - ] - }, - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
datapreds
0[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...[-5.7614846, -3.52228, -1.1202906, 13.053683, ...
1[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...[-3.1390061, -8.71185, 0.82955813, -4.034869, ...
2[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...[-3.046528, 0.3521706, 0.6788677, 0.72303534, ...
3[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...[-2.401024, -7.6780066, 11.145876, 1.2857256, ...
4[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...[-5.0012593, 3.806796, -0.8154834, -0.9550028,...
5[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...[-5.0425925, -3.4815094, 1.641246, 3.608149, -...
6[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...[-4.288771, 5.0072904, 0.27649477, -0.797148, ...
7[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...[-2.2032878, -1.6879876, -5.874276, -0.5945335...
8[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...[1.1337761, -3.1751056, -2.5246286, -5.028277,...
9[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...[-0.92484117, -2.4703276, -5.023897, 1.46669, ...
\n", - "
" - ], - "text/plain": [ - " data \\\n", - "0 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... \n", - "1 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... \n", - "2 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... \n", - "3 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... \n", - "4 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... \n", - "5 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... \n", - "6 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... \n", - "7 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... \n", - "8 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... \n", - "9 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... \n", - "\n", - " preds \n", - "0 [-5.7614846, -3.52228, -1.1202906, 13.053683, ... \n", - "1 [-3.1390061, -8.71185, 0.82955813, -4.034869, ... \n", - "2 [-3.046528, 0.3521706, 0.6788677, 0.72303534, ... \n", - "3 [-2.401024, -7.6780066, 11.145876, 1.2857256, ... \n", - "4 [-5.0012593, 3.806796, -0.8154834, -0.9550028,... \n", - "5 [-5.0425925, -3.4815094, 1.641246, 3.608149, -... \n", - "6 [-4.288771, 5.0072904, 0.27649477, -0.797148, ... \n", - "7 [-2.2032878, -1.6879876, -5.874276, -0.5945335... \n", - "8 [1.1337761, -3.1751056, -2.5246286, -5.028277,... \n", - "9 [-0.92484117, -2.4703276, -5.023897, 1.46669, ... " - ] - }, - "execution_count": 60, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "%%time\n", - "preds = df.withColumn(\"preds\", predict(col(\"data\"))).limit(10).toPandas()\n", - "preds" - ] - }, - { - "cell_type": "code", - "execution_count": 61, - "id": "79d90a26", - "metadata": {}, - "outputs": [], - "source": [ - "import matplotlib.pyplot as plt\n", - "import numpy as np" - ] - }, - { - "cell_type": "code", - "execution_count": 62, - "id": "4ca495f5", - "metadata": {}, - "outputs": [], - "source": [ - "sample = preds.iloc[0]\n", - "sample.preds\n", - "\n", - "prediction = np.argmax(sample.preds)\n", - "img = np.array(sample.data).reshape(28,28)" - ] - }, - { - "cell_type": "code", - "execution_count": 63, - "id": "a5d10903", - "metadata": {}, - "outputs": [ - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGzCAYAAABpdMNsAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAkfElEQVR4nO3dfXQUdZ7v8U/nqSEkaR7yLAFCFHRAcAYly/AgSiQEZUCYGUG9F7gziJiggI6KR0Udzsksrg7qIHjcHVhHEGWOyMoiDg9JGBRwwTCIM2QhJ0g4kIBcSYcAIaR/9w+uvbQkQDUdfkl4v86pc+iq37fqm6Lgk+qqrnYZY4wAALjKwmw3AAC4NhFAAAArCCAAgBUEEADACgIIAGAFAQQAsIIAAgBYQQABAKwggAAAVhBAgAPdunXTpEmT/K8LCwvlcrlUWFgYsm24XC698MILIVsf0FwRQGgxlixZIpfL5Z/atGmjHj16KC8vT5WVlbbbc2TNmjUtJmTefvtt3X777UpKSpLb7VZ6eromT56s/fv3224NLVyE7QYAp1566SWlp6fr9OnT2rx5sxYuXKg1a9Zo9+7dio6Ovqq9DBkyRKdOnVJUVJSjujVr1mjBggUNhtCpU6cUEdF8/mkWFxcrPT1dP/vZz9ShQweVlZXp7bff1urVq/W3v/1NqamptltEC9V8jnLgMuXk5OjWW2+VJP36179Wp06d9Oqrr2rVqlWaMGFCgzU1NTVq165dyHsJCwtTmzZtQrrOUK/vSr355psXzBszZoxuvfVWvfPOO3r66actdIXWgLfg0OLdeeedkqSysjJJ0qRJkxQTE6PS0lKNHDlSsbGxeuCBByRJPp9P8+fPV69evdSmTRslJSVp6tSp+u677wLWaYzR3Llz1blzZ0VHR+uOO+7Q119/fcG2G7sGtG3bNo0cOVIdOnRQu3bt1KdPH7322mv+/hYsWCBJAW8pfq+ha0DFxcXKyclRXFycYmJiNGzYMG3dujVgzPdvUX722WeaNWuWEhIS1K5dO9177706evRowNiqqirt2bNHVVVVl7OLL9CtWzdJ0vHjx4OqByTOgNAKlJaWSpI6derkn3f27FllZ2dr0KBB+pd/+Rf/W3NTp07VkiVLNHnyZD366KMqKyvTH/7wBxUXF+uzzz5TZGSkJOn555/X3LlzNXLkSI0cOVJffvmlhg8frjNnzlyyn3Xr1umee+5RSkqKHnvsMSUnJ+sf//iHVq9erccee0xTp07VoUOHtG7dOv3pT3+65Pq+/vprDR48WHFxcXryyScVGRmpt956S0OHDlVRUZEyMzMDxk+fPl0dOnTQnDlztH//fs2fP195eXl6//33/WNWrlypyZMna/HixQE3VVzMsWPHVF9frwMHDuill16SJA0bNuyyaoEGGaCFWLx4sZFk1q9fb44ePWrKy8vN8uXLTadOnUzbtm3NwYMHjTHGTJw40UgyTz/9dED9X//6VyPJLF26NGD+2rVrA+YfOXLEREVFmbvvvtv4fD7/uGeeecZIMhMnTvTPKygoMJJMQUGBMcaYs2fPmvT0dNO1a1fz3XffBWzn/HXl5uaaxv75STJz5szxvx4zZoyJiooypaWl/nmHDh0ysbGxZsiQIRfsn6ysrIBtzZw504SHh5vjx49fMHbx4sUN9tAQt9ttJBlJplOnTub111+/7FqgIbwFhxYnKytLCQkJSktL0/jx4xUTE6OVK1fquuuuCxg3bdq0gNcrVqyQx+PRXXfdpW+//dY/9evXTzExMSooKJAkrV+/XmfOnNH06dMD3hqbMWPGJXsrLi5WWVmZZsyYofbt2wcsO39dl6u+vl5/+ctfNGbMGHXv3t0/PyUlRffff782b94sr9cbUPPQQw8FbGvw4MGqr6/XN9984583adIkGWMu++xHkj755BOtWbNGr7zyirp06aKamhrHPw9wPt6CQ4uzYMEC9ejRQxEREUpKSlLPnj0VFhb4u1RERIQ6d+4cMG/v3r2qqqpSYmJig+s9cuSIJPn/o77hhhsClickJKhDhw4X7e37twN79+59+T/QRRw9elQnT55Uz549L1h20003yefzqby8XL169fLP79KlS8C473v+4XUup+644w5J524CGT16tHr37q2YmBjl5eVd0Xpx7SKA0OL079/ffxdcY9xu9wWh5PP5lJiYqKVLlzZYk5CQELIebQoPD29wvjEmZNvIyMjQj3/8Yy1dupQAQtAIIFwzMjIytH79eg0cOFBt27ZtdFzXrl0lnTtjOv9tr6NHj17yLCIjI0OStHv3bmVlZTU67nLfjktISFB0dLRKSkouWLZnzx6FhYUpLS3tstYVaqdOnVJtba2VbaN14BoQrhm//OUvVV9fr9/+9rcXLDt79qz/luKsrCxFRkbqjTfeCDhrmD9//iW38ZOf/ETp6emaP3/+Bbcon7+u7z+TdKnbmMPDwzV8+HCtWrUq4MkDlZWVWrZsmQYNGqS4uLhL9vVDl3sb9tmzZxsM3S+++EJfffXVJc9EgYvhDAjXjNtvv11Tp05Vfn6+du7cqeHDhysyMlJ79+7VihUr9Nprr+nnP/+5EhIS9MQTTyg/P1/33HOPRo4cqeLiYn3yySeKj4+/6DbCwsK0cOFCjRo1SrfccosmT56slJQU7dmzR19//bU+/fRTSVK/fv0kSY8++qiys7MVHh6u8ePHN7jOuXPnat26dRo0aJAeeeQRRURE6K233lJtba3mzZsX1L643NuwT5w4obS0NN13333q1auX2rVrp6+++kqLFy+Wx+PRc889F9T2AYkAwjVm0aJF6tevn9566y0988wzioiIULdu3fTggw9q4MCB/nFz585VmzZttGjRIhUUFCgzM1N/+ctfdPfdd19yG9nZ2SooKNCLL76oV155RT6fTxkZGZoyZYp/zNixYzV9+nQtX75c7777rowxjQZQr1699Ne//lWzZ89Wfn6+fD6fMjMz9e67717wGaBQi46O1q9//WsVFBToz3/+s06dOqXU1FRNmDBBzz77rP8DqUAwXCaUVyYBALhMXAMCAFhBAAEArCCAAABWEEAAACsIIACAFQQQAMCKZvc5IJ/Pp0OHDik2NjaopwcDAOwyxqi6ulqpqakXPJPxfM0ugA4dOmTt2VYAgNApLy+/4Kn052t2ARQbGytJGqSRilCk5W4AAE6dVZ02a43///PGNFkALViwQC+//LIqKirUt29fvfHGG+rfv/8l675/2y1CkYpwEUAA0OL8/+frXOoySpPchPD+++9r1qxZmjNnjr788kv17dtX2dnZ/i/8AgCgSQLo1Vdf1ZQpUzR58mT96Ec/0qJFixQdHa0//vGPTbE5AEALFPIAOnPmjHbs2BHwZVxhYWHKysrSli1bLhhfW1srr9cbMAEAWr+QB9C3336r+vp6JSUlBcxPSkpSRUXFBePz8/Pl8Xj8E3fAAcC1wfoHUWfPnq2qqir/VF5ebrslAMBVEPK74OLj4xUeHq7KysqA+ZWVlUpOTr5gvNvtltvtDnUbAIBmLuRnQFFRUerXr582bNjgn+fz+bRhwwYNGDAg1JsDALRQTfI5oFmzZmnixIm69dZb1b9/f82fP181NTWaPHlyU2wOANACNUkA3XfffTp69Kief/55VVRU6JZbbtHatWsvuDEBAHDtchljjO0mzuf1euXxeDRUo3kSAgC0QGdNnQq1SlVVVYqLi2t0nPW74AAA1yYCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWBFhuwHYZ37aN7i6cOe/v0RWeh3XlP7vRMc1vu6nHNdI0p7b/+i4JtzlfD9MPTjAcU3hp7c4run2nzWOayRJW3cFVwc4wBkQAMAKAggAYEXIA+iFF16Qy+UKmG688cZQbwYA0MI1yTWgXr16af369f+zkQguNQEAAjVJMkRERCg5ObkpVg0AaCWa5BrQ3r17lZqaqu7du+uBBx7QgQMHGh1bW1srr9cbMAEAWr+QB1BmZqaWLFmitWvXauHChSorK9PgwYNVXV3d4Pj8/Hx5PB7/lJaWFuqWAADNUMgDKCcnR7/4xS/Up08fZWdna82aNTp+/Lg++OCDBsfPnj1bVVVV/qm8vDzULQEAmqEmvzugffv26tGjh/bt29fgcrfbLbfb3dRtAACamSb/HNCJEydUWlqqlJSUpt4UAKAFCXkAPfHEEyoqKtL+/fv1+eef695771V4eLgmTJgQ6k0BAFqwkL8Fd/DgQU2YMEHHjh1TQkKCBg0apK1btyohISHUmwIAtGAuY4yx3cT5vF6vPB6Phmq0IlyRttuxqubnmY5rKm91flK7dsLLjmskqUtEW8c1/2v/XY5r/tRtneManFN8xhdU3eOP5zmuif5wW1DbQutz1tSpUKtUVVWluLi4RsfxLDgAgBUEEADACgIIAGAFAQQAsIIAAgBYQQABAKwggAAAVhBAAAArCCAAgBUEEADACgIIAGAFAQQAsIKHkV4lR/J+6rim8OlXHNdEu6Ic1zR339afclzTxhXc71Z1cv7PYcaBexzX/DLxvxzX3B1d5bgmWPvqah3XPDH4l45rzpYfdFyD5o+HkQIAmjUCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsiLDdwLXCF+68pjU+2frlYz9yXLNhxiDHNfVtg/vd6rvrnT+B/br/POy45s2EcY5r7v7zHx3XBGvsf011XNPt+P7QN4JWjTMgAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCh5FeJan/+jfHNR88kui4Jjv6gOOanDlPOK6RpLoYl+Oa6/7joOOaiP07nNc4rjgnOYia+iBqKu/5aRBVV8+uny5xXDMmiAes+qqrHdeg9eAMCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCs4GGkV4mvpsZxzTs90xzXvJ0z1nFNfEGx4xpJ8p0+7bjmbFBbunrCExIc13x3V4bjmscf/sBxDdDacAYEALCCAAIAWOE4gDZt2qRRo0YpNTVVLpdLH330UcByY4yef/55paSkqG3btsrKytLevXtD1S8AoJVwHEA1NTXq27evFixY0ODyefPm6fXXX9eiRYu0bds2tWvXTtnZ2TodxPUCAEDr5fgmhJycHOXk5DS4zBij+fPn69lnn9Xo0aMlSe+8846SkpL00Ucfafz48VfWLQCg1QjpNaCysjJVVFQoKyvLP8/j8SgzM1NbtmxpsKa2tlZerzdgAgC0fiENoIqKCklSUlJSwPykpCT/sh/Kz8+Xx+PxT2lpzm89BgC0PNbvgps9e7aqqqr8U3l5ue2WAABXQUgDKDk5WZJUWVkZML+ystK/7Ifcbrfi4uICJgBA6xfSAEpPT1dycrI2bNjgn+f1erVt2zYNGDAglJsCALRwju+CO3HihPbt2+d/XVZWpp07d6pjx47q0qWLZsyYoblz5+qGG25Qenq6nnvuOaWmpmrMmDGh7BsA0MI5DqDt27frjjvu8L+eNWuWJGnixIlasmSJnnzySdXU1Oihhx7S8ePHNWjQIK1du1Zt2rQJXdcAgBbPZYwxtps4n9frlcfj0VCNVoQr0nY7aKHC23uCqnt8x2bHNUPanAlqW1eDT76g6n73bV/HNduGpTiuqf/2mOMaNH9nTZ0KtUpVVVUXva5v/S44AMC1iQACAFhBAAEArCCAAABWEEAAACsIIACAFQQQAMAKAggAYAUBBACwggACAFhBAAEArCCAAABWEEAAACscfx0D0BKUPdorqLohbTaGuBO7VtXEB1X3ed+oIKp4sjWc4QwIAGAFAQQAsIIAAgBYQQABAKwggAAAVhBAAAArCCAAgBUEEADACgIIAGAFAQQAsIIAAgBYQQABAKzgYaRAK3ZX28NB1c2d9YDjmrqYoDblWGLxWcc1bT7+ogk6wZXiDAgAYAUBBACwggACAFhBAAEArCCAAABWEEAAACsIIACAFQQQAMAKAggAYAUBBACwggACAFhBAAEArOBhpGiV0jacDKpuxyTnNf3cQW3qqogJC665HY+/EeJOQmfOkR87rtnxMb9rN0f8rQAArCCAAABWEEAAACsIIACAFQQQAMAKAggAYAUBBACwggACAFhBAAEArCCAAABWEEAAACsIIACAFTyMFK2S67OdQdW9NOp+xzW1STGOa6of9zqu+eyW5Y5rWqNnE7Y7rrnzwUeD2pbn3a1B1eHycAYEALCCAAIAWOE4gDZt2qRRo0YpNTVVLpdLH330UcDySZMmyeVyBUwjRowIVb8AgFbCcQDV1NSob9++WrBgQaNjRowYocOHD/un995774qaBAC0Po5vQsjJyVFOTs5Fx7jdbiUnJwfdFACg9WuSa0CFhYVKTExUz549NW3aNB07dqzRsbW1tfJ6vQETAKD1C3kAjRgxQu+88442bNigf/7nf1ZRUZFycnJUX1/f4Pj8/Hx5PB7/lJaWFuqWAADNUMg/BzR+/Hj/n2+++Wb16dNHGRkZKiws1LBhwy4YP3v2bM2aNcv/2uv1EkIAcA1o8tuwu3fvrvj4eO3bt6/B5W63W3FxcQETAKD1a/IAOnjwoI4dO6aUlJSm3hQAoAVx/BbciRMnAs5mysrKtHPnTnXs2FEdO3bUiy++qHHjxik5OVmlpaV68skndf311ys7OzukjQMAWjbHAbR9+3bdcccd/tffX7+ZOHGiFi5cqF27dunf//3fdfz4caWmpmr48OH67W9/K7fbHbquAQAtnuMAGjp0qIwxjS7/9NNPr6ghwKb6r0sc10R87Xw7HQpcjmtGRf3Ucc3+P/VwXCNJn2QudFzTOaJtUNtyKtIV7rjmdMfgrjZ4gqrC5eJZcAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALAi5F/JDeAyXOSJ8o2W1NY6run6y68c10jSnW/NdFzz3/csCmpbuHZxBgQAsIIAAgBYQQABAKwggAAAVhBAAAArCCAAgBUEEADACgIIAGAFAQQAsIIAAgBYQQABAKwggAAAVvAwUqAVc0VGBVfXtj7EnYTOrjPOe0vcXtMEneBKcQYEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFbwMFKgFSt545ag6v572MLQNhJCMx6f7rgm+vNtTdAJrhRnQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQ8jRasUFhsbXF17T4g7adiRu9Ic19w1/TPHNf+RuMBxzTlX53fTD04kOq6J2/KN45qzjitwNXAGBACwggACAFjhKIDy8/N12223KTY2VomJiRozZoxKSkoCxpw+fVq5ubnq1KmTYmJiNG7cOFVWVoa0aQBAy+cogIqKipSbm6utW7dq3bp1qqur0/Dhw1VTU+MfM3PmTH388cdasWKFioqKdOjQIY0dOzbkjQMAWjZHNyGsXbs24PWSJUuUmJioHTt2aMiQIaqqqtK//du/admyZbrzzjslSYsXL9ZNN92krVu36p/+6Z9C1zkAoEW7omtAVVVVkqSOHTtKknbs2KG6ujplZWX5x9x4443q0qWLtmzZ0uA6amtr5fV6AyYAQOsXdAD5fD7NmDFDAwcOVO/evSVJFRUVioqKUvv27QPGJiUlqaKiosH15Ofny+Px+Ke0NOe3pwIAWp6gAyg3N1e7d+/W8uXLr6iB2bNnq6qqyj+Vl5df0foAAC1DUB9EzcvL0+rVq7Vp0yZ17tzZPz85OVlnzpzR8ePHA86CKisrlZyc3OC63G633G53MG0AAFowR2dAxhjl5eVp5cqV2rhxo9LT0wOW9+vXT5GRkdqwYYN/XklJiQ4cOKABAwaEpmMAQKvg6AwoNzdXy5Yt06pVqxQbG+u/ruPxeNS2bVt5PB796le/0qxZs9SxY0fFxcVp+vTpGjBgAHfAAQACOAqghQsXSpKGDh0aMH/x4sWaNGmSJOn3v/+9wsLCNG7cONXW1io7O1tvvvlmSJoFALQeLmOMsd3E+bxerzwej4ZqtCJckbbbuSaE9b0pqLo9uTGOa5LT/q/jmiMlCY5rJt9Z6LhGkp7q9HVQdQhOn88nOa7p8ouvQt8IQuqsqVOhVqmqqkpxcXGNjuNZcAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALAiqG9ERfPl6tfLcU3b3x8Jalv/nfFuUHWO9bk6m2nuak2d45pIV3hQ26qsr3VcM+dQjuOazq8F1x9aB86AAABWEEAAACsIIACAFQQQAMAKAggAYAUBBACwggACAFhBAAEArCCAAABWEEAAACsIIACAFQQQAMAKHkbaytR0jXFc8373fw1ya1FB1jU9n3xB1c08NNhxzW8S1zuuyf4813FNbGG045rqbo5LJEnps7cEUVXtuCJMO4PYDloLzoAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoeRtrKRH+4zXHNLw4/HNS2jv64neMaXxDPL61z/nxVvf1//uC8SFLpbacd10z78VTHNek7dzmukTGOS+KdbwW4ajgDAgBYQQABAKwggAAAVhBAAAArCCAAgBUEEADACgIIAGAFAQQAsIIAAgBYQQABAKwggAAAVhBAAAArXMYE8YTDJuT1euXxeDRUoxXhirTdDgDAobOmToVapaqqKsXFxTU6jjMgAIAVBBAAwApHAZSfn6/bbrtNsbGxSkxM1JgxY1RSUhIwZujQoXK5XAHTww8H930zAIDWy1EAFRUVKTc3V1u3btW6detUV1en4cOHq6amJmDclClTdPjwYf80b968kDYNAGj5HH0j6tq1awNeL1myRImJidqxY4eGDBninx8dHa3k5OTQdAgAaJWu6BpQVVWVJKljx44B85cuXar4+Hj17t1bs2fP1smTJxtdR21trbxeb8AEAGj9HJ0Bnc/n82nGjBkaOHCgevfu7Z9///33q2vXrkpNTdWuXbv01FNPqaSkRB9++GGD68nPz9eLL74YbBsAgBYq6M8BTZs2TZ988ok2b96szp07Nzpu48aNGjZsmPbt26eMjIwLltfW1qq2ttb/2uv1Ki0tjc8BAUALdbmfAwrqDCgvL0+rV6/Wpk2bLho+kpSZmSlJjQaQ2+2W2+0Opg0AQAvmKICMMZo+fbpWrlypwsJCpaenX7Jm586dkqSUlJSgGgQAtE6OAig3N1fLli3TqlWrFBsbq4qKCkmSx+NR27ZtVVpaqmXLlmnkyJHq1KmTdu3apZkzZ2rIkCHq06dPk/wAAICWydE1IJfL1eD8xYsXa9KkSSovL9eDDz6o3bt3q6amRmlpabr33nv17LPPXvR9wPPxLDgAaNma5BrQpbIqLS1NRUVFTlYJALhG8Sw4AIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVEbYb+CFjjCTprOokY7kZAIBjZ1Un6X/+P29Mswug6upqSdJmrbHcCQDgSlRXV8vj8TS63GUuFVFXmc/n06FDhxQbGyuXyxWwzOv1Ki0tTeXl5YqLi7PUoX3sh3PYD+ewH85hP5zTHPaDMUbV1dVKTU1VWFjjV3qa3RlQWFiYOnfufNExcXFx1/QB9j32wznsh3PYD+ewH86xvR8udubzPW5CAABYQQABAKxoUQHkdrs1Z84cud1u261YxX44h/1wDvvhHPbDOS1pPzS7mxAAANeGFnUGBABoPQggAIAVBBAAwAoCCABgBQEEALCixQTQggUL1K1bN7Vp00aZmZn64osvbLd01b3wwgtyuVwB04033mi7rSa3adMmjRo1SqmpqXK5XProo48Clhtj9PzzzyslJUVt27ZVVlaW9u7da6fZJnSp/TBp0qQLjo8RI0bYabaJ5Ofn67bbblNsbKwSExM1ZswYlZSUBIw5ffq0cnNz1alTJ8XExGjcuHGqrKy01HHTuJz9MHTo0AuOh4cffthSxw1rEQH0/vvva9asWZozZ46+/PJL9e3bV9nZ2Tpy5Ijt1q66Xr166fDhw/5p8+bNtltqcjU1Nerbt68WLFjQ4PJ58+bp9ddf16JFi7Rt2za1a9dO2dnZOn369FXutGldaj9I0ogRIwKOj/fee+8qdtj0ioqKlJubq61bt2rdunWqq6vT8OHDVVNT4x8zc+ZMffzxx1qxYoWKiop06NAhjR071mLXoXc5+0GSpkyZEnA8zJs3z1LHjTAtQP/+/U1ubq7/dX19vUlNTTX5+fkWu7r65syZY/r27Wu7DaskmZUrV/pf+3w+k5ycbF5++WX/vOPHjxu3223ee+89Cx1eHT/cD8YYM3HiRDN69Ggr/dhy5MgRI8kUFRUZY8793UdGRpoVK1b4x/zjH/8wksyWLVtstdnkfrgfjDHm9ttvN4899pi9pi5Dsz8DOnPmjHbs2KGsrCz/vLCwMGVlZWnLli0WO7Nj7969Sk1NVffu3fXAAw/owIEDtluyqqysTBUVFQHHh8fjUWZm5jV5fBQWFioxMVE9e/bUtGnTdOzYMdstNamqqipJUseOHSVJO3bsUF1dXcDxcOONN6pLly6t+nj44X743tKlSxUfH6/evXtr9uzZOnnypI32GtXsnob9Q99++63q6+uVlJQUMD8pKUl79uyx1JUdmZmZWrJkiXr27KnDhw/rxRdf1ODBg7V7927Fxsbabs+KiooKSWrw+Ph+2bVixIgRGjt2rNLT01VaWqpnnnlGOTk52rJli8LDw223F3I+n08zZszQwIED1bt3b0nnjoeoqCi1b98+YGxrPh4a2g+SdP/996tr165KTU3Vrl279NRTT6mkpEQffvihxW4DNfsAwv/Iycnx/7lPnz7KzMxU165d9cEHH+hXv/qVxc7QHIwfP97/55tvvll9+vRRRkaGCgsLNWzYMIudNY3c3Fzt3r37mrgOejGN7YeHHnrI/+ebb75ZKSkpGjZsmEpLS5WRkXG122xQs38LLj4+XuHh4RfcxVJZWank5GRLXTUP7du3V48ePbRv3z7brVjz/THA8XGh7t27Kz4+vlUeH3l5eVq9erUKCgoCvj8sOTlZZ86c0fHjxwPGt9bjobH90JDMzExJalbHQ7MPoKioKPXr108bNmzwz/P5fNqwYYMGDBhgsTP7Tpw4odLSUqWkpNhuxZr09HQlJycHHB9er1fbtm275o+PgwcP6tixY63q+DDGKC8vTytXrtTGjRuVnp4esLxfv36KjIwMOB5KSkp04MCBVnU8XGo/NGTnzp2S1LyOB9t3QVyO5cuXG7fbbZYsWWL+/ve/m4ceesi0b9/eVFRU2G7tqnr88cdNYWGhKSsrM5999pnJysoy8fHx5siRI7Zba1LV1dWmuLjYFBcXG0nm1VdfNcXFxeabb74xxhjzu9/9zrRv396sWrXK7Nq1y4wePdqkp6ebU6dOWe48tC62H6qrq80TTzxhtmzZYsrKysz69evNT37yE3PDDTeY06dP2249ZKZNm2Y8Ho8pLCw0hw8f9k8nT570j3n44YdNly5dzMaNG8327dvNgAEDzIABAyx2HXqX2g/79u0zL730ktm+fbspKyszq1atMt27dzdDhgyx3HmgFhFAxhjzxhtvmC5dupioqCjTv39/s3XrVtstXXX33XefSUlJMVFRUea6664z9913n9m3b5/ttppcQUGBkXTBNHHiRGPMuVuxn3vuOZOUlGTcbrcZNmyYKSkpsdt0E7jYfjh58qQZPny4SUhIMJGRkaZr165mypQpre6XtIZ+fklm8eLF/jGnTp0yjzzyiOnQoYOJjo429957rzl8+LC9ppvApfbDgQMHzJAhQ0zHjh2N2+02119/vfnNb35jqqqq7Db+A3wfEADAimZ/DQgA0DoRQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAV/w/hgVLrpVGHsAAAAABJRU5ErkJggg==", - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "plt.figure()\n", - "plt.title(\"Prediction: {}\".format(prediction))\n", - "plt.imshow(img)\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "id": "6377f41a-5654-410b-8bad-d392e9dce7b8", - "metadata": { - "tags": [] - }, - "source": [ - "#### Stop Triton Server on each executor" - ] - }, - { - "cell_type": "code", - "execution_count": 64, - "id": "9c9fd967-5cd9-4265-add9-db5c1ccf9893", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - " \r" - ] - }, - { - "data": { - "text/plain": [ - "[True]" - ] - }, - "execution_count": 64, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "def stop_triton(it):\n", - " import docker\n", - " import time\n", - " \n", - " client=docker.from_env()\n", - " containers=client.containers.list(filters={\"name\": \"spark-triton\"})\n", - " print(\">>>> stopping containers: {}\".format([c.short_id for c in containers]))\n", - " if containers:\n", - " container=containers[0]\n", - " container.stop(timeout=120)\n", - "\n", - " return [True]\n", - "\n", - "nodeRDD.barrier().mapPartitions(stop_triton).collect()" - ] - }, - { - "cell_type": "code", - "execution_count": 65, - "id": "f612dc0b-538f-4ecf-81f7-ef6b58c493ab", - "metadata": {}, - "outputs": [], - "source": [ - "spark.stop()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "490fc849-e47a-48d7-accc-429ff1cced6b", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "spark-dl-tf", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.9" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/examples/ML+DL-Examples/Spark-Rapids-ML/pca/README.md b/examples/ML+DL-Examples/Spark-Rapids-ML/pca/README.md index 2879b94a..4b5fd567 100644 --- a/examples/ML+DL-Examples/Spark-Rapids-ML/pca/README.md +++ b/examples/ML+DL-Examples/Spark-Rapids-ML/pca/README.md @@ -5,12 +5,12 @@ The notebook uses PCA to reduce a random dataset with 2048 feature dimensions to ## Build -Please refer to the Spark-Rapids-ML [README](https://github.com/NVIDIA/spark-rapids-ml/blob/HEAD/python) for environment setup instructions and API usage. +Please refer to the Spark-Rapids-ML [README](https://github.com/NVIDIA/spark-rapids-ml/blob/HEAD/python) to setup the RAPIDS conda environment and install Spark-Rapids-ML dependencies. ## Download RAPIDS Jar from Maven Central -Download the RAPIDS jar from Maven Central: [rapids-4-spark_2.12-24.08.1.jar](https://repo1.maven.org/maven2/com/nvidia/rapids-4-spark_2.12/24.08.1/rapids-4-spark_2.12-24.08.1.jar) -Alternatively, see the Spark-Rapids [download page](https://nvidia.github.io/spark-rapids/docs/download.html#download-rapids-accelerator-for-apache-spark-v24081) for version selection. +Download the [Spark-Rapids plugin](https://nvidia.github.io/spark-rapids/docs/download.html#download-rapids-accelerator-for-apache-spark-v24081). +For Spark-RAPIDS-ML version 24.08, download the RAPIDS jar from Maven Central: [rapids-4-spark_2.12-24.08.1.jar](https://repo1.maven.org/maven2/com/nvidia/rapids-4-spark_2.12/24.08.1/rapids-4-spark_2.12-24.08.1.jar). ## Running the Notebooks diff --git a/examples/ML+DL-Examples/Spark-Rapids-ML/pca/notebooks/pca.ipynb b/examples/ML+DL-Examples/Spark-Rapids-ML/pca/notebooks/pca.ipynb index 3bf570fe..37a11238 100644 --- a/examples/ML+DL-Examples/Spark-Rapids-ML/pca/notebooks/pca.ipynb +++ b/examples/ML+DL-Examples/Spark-Rapids-ML/pca/notebooks/pca.ipynb @@ -17,7 +17,8 @@ "source": [ "import numpy as np\n", "import pandas as pd\n", - "import time" + "import time\n", + "import os" ] }, { @@ -58,7 +59,6 @@ "\n", " SPARK_RAPIDS_VERSION = \"24.08.1\"\n", " rapids_jar = f\"rapids-4-spark_2.12-{SPARK_RAPIDS_VERSION}.jar\"\n", - "\n", " if not os.path.exists(rapids_jar):\n", " print(\"Downloading spark rapids jar\")\n", " url = f\"https://repo1.maven.org/maven2/com/nvidia/rapids-4-spark_2.12/{SPARK_RAPIDS_VERSION}/{rapids_jar}\"\n", @@ -71,7 +71,6 @@ " print(f\"Failed to download the file. Status code: {response.status_code}\")\n", " else:\n", " print(\"File already exists. Skipping download.\")\n", - " \n", " return rapids_jar\n", "\n", "def initialize_spark(rapids_jar: str):\n", @@ -119,7 +118,9 @@ "# Check if Spark session is already active, if not, initialize it\n", "if 'spark' not in globals():\n", " print(\"No active Spark session found, initializing manually.\")\n", - " rapids_jar = get_rapids_jar()\n", + " rapids_jar = os.environ.get('RAPIDS_JAR')\n", + " if rapids_jar is None:\n", + " rapids_jar = get_rapids_jar()\n", " spark = initialize_spark(rapids_jar)\n", "else:\n", " print(\"Using existing Spark session.\")"