From 780d57a9d8887164582daebc279d9807672c4515 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Wed, 14 Aug 2024 13:23:12 +0200 Subject: [PATCH] Adjust import of util from tianshou to sensai.util --- docs/02_notebooks/L3_slicing.ipynb | 9 ++++++--- notebooks/experiments_plots.ipynb | 16 +++++++++------- scripts/armscan_array_obs.py | 2 +- scripts/armscan_debugging.py | 2 +- scripts/armscan_dqn_sac_hl.py | 2 +- scripts/armscan_ppo_hl.py | 2 +- scripts/armscan_sac_hl.py | 2 +- 7 files changed, 20 insertions(+), 15 deletions(-) diff --git a/docs/02_notebooks/L3_slicing.ipynb b/docs/02_notebooks/L3_slicing.ipynb index bfb7856..0b6bfcc 100644 --- a/docs/02_notebooks/L3_slicing.ipynb +++ b/docs/02_notebooks/L3_slicing.ipynb @@ -203,7 +203,8 @@ "plt.imshow(slice_img, aspect=6)\n", "plt.axis(\"off\")\n", "plt.show()" - ] + ], + "execution_count": null }, { "cell_type": "markdown", @@ -223,7 +224,8 @@ "print(f\"{reward=}\")\n", "plt.axis(\"off\")\n", "plt.show()" - ] + ], + "execution_count": null }, { "cell_type": "markdown", @@ -280,7 +282,8 @@ "\n", "animation = camera.animate()\n", "HTML(animation.to_jshtml())" - ] + ], + "execution_count": null }, { "cell_type": "markdown", diff --git a/notebooks/experiments_plots.ipynb b/notebooks/experiments_plots.ipynb index c514600..42ffa89 100644 --- a/notebooks/experiments_plots.ipynb +++ b/notebooks/experiments_plots.ipynb @@ -19,8 +19,8 @@ "source": [ "import os.path\n", "\n", - "import pandas as pd\n", "import matplotlib.pyplot as plt\n", + "import pandas as pd\n", "import seaborn as sns" ], "outputs": [], @@ -59,19 +59,21 @@ }, "cell_type": "code", "source": [ - "experiment_name = \"run-sac-characteristic-array-rew-details-y_42_20240630-191219-tag-train_lens_stat_mean.csv\"\n", + "experiment_name = (\n", + " \"run-sac-characteristic-array-rew-details-y_42_20240630-191219-tag-train_lens_stat_mean.csv\"\n", + ")\n", "experiment_path = os.path.join(experiments_dir, experiment_name)\n", "df = pd.read_csv(experiment_path)\n", "\n", "# Simple Matplotlib plot\n", - "plt.plot(df['Step'], df['Value'])\n", - "plt.xlabel('Step')\n", - "plt.ylabel('Value')\n", - "plt.title('Plot Title')\n", + "plt.plot(df[\"Step\"], df[\"Value\"])\n", + "plt.xlabel(\"Step\")\n", + "plt.ylabel(\"Value\")\n", + "plt.title(\"Plot Title\")\n", "plt.show()\n", "\n", "# Seaborn plot for enhanced aesthetics\n", - "sns.lineplot(data=df, x='Step', y='Value')\n", + "sns.lineplot(data=df, x=\"Step\", y=\"Value\")\n", "plt.show()" ], "id": "initial_id", diff --git a/scripts/armscan_array_obs.py b/scripts/armscan_array_obs.py index f8204d6..0414854 100644 --- a/scripts/armscan_array_obs.py +++ b/scripts/armscan_array_obs.py @@ -9,6 +9,7 @@ from armscan_env.envs.rewards import LabelmapClusteringBasedReward from armscan_env.volumes.loading import load_sitk_volumes from armscan_env.wrapper import ArmscanEnvFactory +from sensai.util.logging import datetime_tag from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.env import VectorEnvType @@ -18,7 +19,6 @@ ) from tianshou.highlevel.params.alpha import AutoAlphaFactoryDefault from tianshou.highlevel.params.policy_params import SACParams -from tianshou.utils.logging import datetime_tag if __name__ == "__main__": config = get_config() diff --git a/scripts/armscan_debugging.py b/scripts/armscan_debugging.py index 7ad9e49..fbc8e4e 100644 --- a/scripts/armscan_debugging.py +++ b/scripts/armscan_debugging.py @@ -10,6 +10,7 @@ from armscan_env.envs.rewards import LabelmapClusteringBasedReward from armscan_env.volumes.loading import RegisteredLabelmap from armscan_env.wrapper import ArmscanEnvFactory +from sensai.util.logging import datetime_tag from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.env import VectorEnvType @@ -19,7 +20,6 @@ ) from tianshou.highlevel.params.alpha import AutoAlphaFactoryDefault from tianshou.highlevel.params.policy_params import SACParams -from tianshou.utils.logging import datetime_tag if __name__ == "__main__": config = get_config() diff --git a/scripts/armscan_dqn_sac_hl.py b/scripts/armscan_dqn_sac_hl.py index f1a79e3..1be132d 100644 --- a/scripts/armscan_dqn_sac_hl.py +++ b/scripts/armscan_dqn_sac_hl.py @@ -10,6 +10,7 @@ from armscan_env.network import ActorFactoryArmscanNet from armscan_env.volumes.loading import RegisteredLabelmap from armscan_env.wrapper import ArmscanEnvFactory +from sensai.util.logging import datetime_tag from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.env import VectorEnvType @@ -19,7 +20,6 @@ ) from tianshou.highlevel.params.alpha import AutoAlphaFactoryDefault from tianshou.highlevel.params.policy_params import SACParams -from tianshou.utils.logging import datetime_tag if __name__ == "__main__": config = get_config() diff --git a/scripts/armscan_ppo_hl.py b/scripts/armscan_ppo_hl.py index d0ee2af..bbafce3 100644 --- a/scripts/armscan_ppo_hl.py +++ b/scripts/armscan_ppo_hl.py @@ -10,6 +10,7 @@ from armscan_env.network import ActorFactoryArmscanNet from armscan_env.volumes.loading import RegisteredLabelmap from armscan_env.wrapper import ArmscanEnvFactory +from sensai.util.logging import datetime_tag from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.env import VectorEnvType @@ -18,7 +19,6 @@ DistributionFunctionFactoryIndependentGaussians, ) from tianshou.highlevel.params.policy_params import PPOParams -from tianshou.utils.logging import datetime_tag if __name__ == "__main__": config = get_config() diff --git a/scripts/armscan_sac_hl.py b/scripts/armscan_sac_hl.py index c4825fd..6a6e242 100644 --- a/scripts/armscan_sac_hl.py +++ b/scripts/armscan_sac_hl.py @@ -9,6 +9,7 @@ from armscan_env.envs.rewards import LabelmapClusteringBasedReward from armscan_env.volumes.loading import load_sitk_volumes from armscan_env.wrapper import ArmscanEnvFactory +from sensai.util.logging import datetime_tag from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.env import VectorEnvType @@ -18,7 +19,6 @@ ) from tianshou.highlevel.params.alpha import AutoAlphaFactoryDefault from tianshou.highlevel.params.policy_params import SACParams -from tianshou.utils.logging import datetime_tag if __name__ == "__main__": config = get_config()