Skip to content

Commit

Permalink
Adjust import of util from tianshou to sensai.util
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Panchenko committed Aug 14, 2024
1 parent d3b7d37 commit 780d57a
Show file tree
Hide file tree
Showing 7 changed files with 20 additions and 15 deletions.
9 changes: 6 additions & 3 deletions docs/02_notebooks/L3_slicing.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,8 @@
"plt.imshow(slice_img, aspect=6)\n",
"plt.axis(\"off\")\n",
"plt.show()"
]
],
"execution_count": null
},
{
"cell_type": "markdown",
Expand All @@ -223,7 +224,8 @@
"print(f\"{reward=}\")\n",
"plt.axis(\"off\")\n",
"plt.show()"
]
],
"execution_count": null
},
{
"cell_type": "markdown",
Expand Down Expand Up @@ -280,7 +282,8 @@
"\n",
"animation = camera.animate()\n",
"HTML(animation.to_jshtml())"
]
],
"execution_count": null
},
{
"cell_type": "markdown",
Expand Down
16 changes: 9 additions & 7 deletions notebooks/experiments_plots.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
"source": [
"import os.path\n",
"\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"import pandas as pd\n",
"import seaborn as sns"
],
"outputs": [],
Expand Down Expand Up @@ -59,19 +59,21 @@
},
"cell_type": "code",
"source": [
"experiment_name = \"run-sac-characteristic-array-rew-details-y_42_20240630-191219-tag-train_lens_stat_mean.csv\"\n",
"experiment_name = (\n",
" \"run-sac-characteristic-array-rew-details-y_42_20240630-191219-tag-train_lens_stat_mean.csv\"\n",
")\n",
"experiment_path = os.path.join(experiments_dir, experiment_name)\n",
"df = pd.read_csv(experiment_path)\n",
"\n",
"# Simple Matplotlib plot\n",
"plt.plot(df['Step'], df['Value'])\n",
"plt.xlabel('Step')\n",
"plt.ylabel('Value')\n",
"plt.title('Plot Title')\n",
"plt.plot(df[\"Step\"], df[\"Value\"])\n",
"plt.xlabel(\"Step\")\n",
"plt.ylabel(\"Value\")\n",
"plt.title(\"Plot Title\")\n",
"plt.show()\n",
"\n",
"# Seaborn plot for enhanced aesthetics\n",
"sns.lineplot(data=df, x='Step', y='Value')\n",
"sns.lineplot(data=df, x=\"Step\", y=\"Value\")\n",
"plt.show()"
],
"id": "initial_id",
Expand Down
2 changes: 1 addition & 1 deletion scripts/armscan_array_obs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from armscan_env.envs.rewards import LabelmapClusteringBasedReward
from armscan_env.volumes.loading import load_sitk_volumes
from armscan_env.wrapper import ArmscanEnvFactory
from sensai.util.logging import datetime_tag

from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.env import VectorEnvType
Expand All @@ -18,7 +19,6 @@
)
from tianshou.highlevel.params.alpha import AutoAlphaFactoryDefault
from tianshou.highlevel.params.policy_params import SACParams
from tianshou.utils.logging import datetime_tag

if __name__ == "__main__":
config = get_config()
Expand Down
2 changes: 1 addition & 1 deletion scripts/armscan_debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from armscan_env.envs.rewards import LabelmapClusteringBasedReward
from armscan_env.volumes.loading import RegisteredLabelmap
from armscan_env.wrapper import ArmscanEnvFactory
from sensai.util.logging import datetime_tag

from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.env import VectorEnvType
Expand All @@ -19,7 +20,6 @@
)
from tianshou.highlevel.params.alpha import AutoAlphaFactoryDefault
from tianshou.highlevel.params.policy_params import SACParams
from tianshou.utils.logging import datetime_tag

if __name__ == "__main__":
config = get_config()
Expand Down
2 changes: 1 addition & 1 deletion scripts/armscan_dqn_sac_hl.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from armscan_env.network import ActorFactoryArmscanNet
from armscan_env.volumes.loading import RegisteredLabelmap
from armscan_env.wrapper import ArmscanEnvFactory
from sensai.util.logging import datetime_tag

from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.env import VectorEnvType
Expand All @@ -19,7 +20,6 @@
)
from tianshou.highlevel.params.alpha import AutoAlphaFactoryDefault
from tianshou.highlevel.params.policy_params import SACParams
from tianshou.utils.logging import datetime_tag

if __name__ == "__main__":
config = get_config()
Expand Down
2 changes: 1 addition & 1 deletion scripts/armscan_ppo_hl.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from armscan_env.network import ActorFactoryArmscanNet
from armscan_env.volumes.loading import RegisteredLabelmap
from armscan_env.wrapper import ArmscanEnvFactory
from sensai.util.logging import datetime_tag

from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.env import VectorEnvType
Expand All @@ -18,7 +19,6 @@
DistributionFunctionFactoryIndependentGaussians,
)
from tianshou.highlevel.params.policy_params import PPOParams
from tianshou.utils.logging import datetime_tag

if __name__ == "__main__":
config = get_config()
Expand Down
2 changes: 1 addition & 1 deletion scripts/armscan_sac_hl.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from armscan_env.envs.rewards import LabelmapClusteringBasedReward
from armscan_env.volumes.loading import load_sitk_volumes
from armscan_env.wrapper import ArmscanEnvFactory
from sensai.util.logging import datetime_tag

from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.env import VectorEnvType
Expand All @@ -18,7 +19,6 @@
)
from tianshou.highlevel.params.alpha import AutoAlphaFactoryDefault
from tianshou.highlevel.params.policy_params import SACParams
from tianshou.utils.logging import datetime_tag

if __name__ == "__main__":
config = get_config()
Expand Down

0 comments on commit 780d57a

Please sign in to comment.