From 35382519d636bbc67d083397ad17a38af424b187 Mon Sep 17 00:00:00 2001 From: Matthias Hadlich Date: Mon, 7 Aug 2023 16:21:32 +0200 Subject: [PATCH 1/8] Add warning if the imported network does not have the right keys Signed-off-by: Matthias Hadlich --- monailabel/tasks/infer/basic_infer.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/monailabel/tasks/infer/basic_infer.py b/monailabel/tasks/infer/basic_infer.py index 7a9c37932..d01389cf0 100644 --- a/monailabel/tasks/infer/basic_infer.py +++ b/monailabel/tasks/infer/basic_infer.py @@ -57,7 +57,7 @@ def __init__( output_label_key: str = "pred", output_json_key: str = "result", config: Union[None, Dict[str, Any]] = None, - load_strict: bool = False, + load_strict: bool = True, roi_size=None, preload=False, train_mode=False, @@ -452,6 +452,13 @@ def _get_network(self, device, data): if path: checkpoint = torch.load(path, map_location=torch.device(device)) + for key in self.model_state_dict: + if key not in checkpoint: + logger.warning( + f"Expected key {key} has not been found in the checkpoint keys: {checkpoint.keys()}" + ) + logger.warning("The run will now continue unless load_strict is set to True") + break model_state_dict = checkpoint.get(self.model_state_dict, checkpoint) network.load_state_dict(model_state_dict, strict=self.load_strict) else: From 27d5c23021e0c37a498b1074ed0dcfc2c8552087 Mon Sep 17 00:00:00 2001 From: Matthias Hadlich Date: Tue, 8 Aug 2023 07:56:25 +0000 Subject: [PATCH 2/8] Update warning Signed-off-by: Matthias Hadlich --- monailabel/tasks/infer/basic_infer.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/monailabel/tasks/infer/basic_infer.py b/monailabel/tasks/infer/basic_infer.py index d01389cf0..e69e5ac96 100644 --- a/monailabel/tasks/infer/basic_infer.py +++ b/monailabel/tasks/infer/basic_infer.py @@ -452,14 +452,17 @@ def _get_network(self, device, data): if path: checkpoint = torch.load(path, map_location=torch.device(device)) - for key in self.model_state_dict: - if key not in checkpoint: - logger.warning( - f"Expected key {key} has not been found in the checkpoint keys: {checkpoint.keys()}" - ) - logger.warning("The run will now continue unless load_strict is set to True") - break model_state_dict = checkpoint.get(self.model_state_dict, checkpoint) + + if set(self.network.state_dict().keys()) != set(checkpoint.keys()): + logger.warning( + f"Checkpoint keys don't match network.state_dict()! Items that exist in only one dict" + f" but not in the other: {set(self.network.state_dict().keys()) ^ set(checkpoint.keys())}" + ) + logger.warning( + "The run will now continue unless load_strict is set to True. " + "If loading fails or the network behaves abnormally, please check the loaded weights" + ) network.load_state_dict(model_state_dict, strict=self.load_strict) else: network = torch.jit.load(path, map_location=torch.device(device)) From c3c9d83cf446f43905a3d8d9c21bb27bbdadba7b Mon Sep 17 00:00:00 2001 From: Matthias Hadlich Date: Tue, 8 Aug 2023 07:56:54 +0000 Subject: [PATCH 3/8] Set load_strict to false for deepedit Signed-off-by: Matthias Hadlich --- sample-apps/radiology/lib/infers/deepedit.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sample-apps/radiology/lib/infers/deepedit.py b/sample-apps/radiology/lib/infers/deepedit.py index a1e153b8a..a86fb46ce 100644 --- a/sample-apps/radiology/lib/infers/deepedit.py +++ b/sample-apps/radiology/lib/infers/deepedit.py @@ -71,6 +71,7 @@ def __init__( self.spatial_size = spatial_size self.target_spacing = target_spacing self.number_intensity_ch = number_intensity_ch + self.load_strict = False def pre_transforms(self, data=None): t = [ From 11b600a0cdcd786a2269641e0d08315289260e0b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 8 Aug 2023 08:05:02 +0000 Subject: [PATCH 4/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monailabel/tasks/infer/basic_infer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monailabel/tasks/infer/basic_infer.py b/monailabel/tasks/infer/basic_infer.py index e69e5ac96..d5b5ad1a4 100644 --- a/monailabel/tasks/infer/basic_infer.py +++ b/monailabel/tasks/infer/basic_infer.py @@ -460,8 +460,8 @@ def _get_network(self, device, data): f" but not in the other: {set(self.network.state_dict().keys()) ^ set(checkpoint.keys())}" ) logger.warning( - "The run will now continue unless load_strict is set to True. " - "If loading fails or the network behaves abnormally, please check the loaded weights" + "The run will now continue unless load_strict is set to True. " + "If loading fails or the network behaves abnormally, please check the loaded weights" ) network.load_state_dict(model_state_dict, strict=self.load_strict) else: From f1bb6b591d87451e28afc3e76fbe92afe38d2223 Mon Sep 17 00:00:00 2001 From: Matthias Hadlich Date: Sun, 27 Aug 2023 08:58:42 +0200 Subject: [PATCH 5/8] Set load_strict=False for all existing models Signed-off-by: Matthias Hadlich --- monailabel/tasks/infer/bundle.py | 1 + sample-apps/endoscopy/lib/infers/deepedit.py | 1 + sample-apps/endoscopy/lib/infers/inbody.py | 2 +- sample-apps/endoscopy/lib/infers/tooltracking.py | 2 +- sample-apps/pathology/lib/infers/classification_nuclei.py | 2 +- sample-apps/pathology/lib/infers/nuclick.py | 1 + sample-apps/pathology/lib/infers/segmentation_nuclei.py | 1 + sample-apps/radiology/lib/infers/deepedit.py | 1 + sample-apps/radiology/lib/infers/deepgrow.py | 1 + sample-apps/radiology/lib/infers/deepgrow_pipeline.py | 3 +++ sample-apps/radiology/lib/infers/localization_spine.py | 1 + sample-apps/radiology/lib/infers/localization_vertebra.py | 1 + sample-apps/radiology/lib/infers/segmentation.py | 1 + sample-apps/radiology/lib/infers/segmentation_spleen.py | 1 + sample-apps/radiology/lib/infers/segmentation_vertebra.py | 1 + sample-apps/radiology/lib/infers/vertebra_pipeline.py | 1 + 16 files changed, 18 insertions(+), 3 deletions(-) diff --git a/monailabel/tasks/infer/bundle.py b/monailabel/tasks/infer/bundle.py index 628408f16..a364ca02f 100644 --- a/monailabel/tasks/infer/bundle.py +++ b/monailabel/tasks/infer/bundle.py @@ -149,6 +149,7 @@ def __init__( dimension=dimension, description=description, preload=strtobool(conf.get("preload", "false")), + load_strict=False, **kwargs, ) diff --git a/sample-apps/endoscopy/lib/infers/deepedit.py b/sample-apps/endoscopy/lib/infers/deepedit.py index 265e121d0..129654aae 100644 --- a/sample-apps/endoscopy/lib/infers/deepedit.py +++ b/sample-apps/endoscopy/lib/infers/deepedit.py @@ -55,6 +55,7 @@ def __init__( labels=labels, dimension=dimension, description=description, + load_strict=False, **kwargs, ) diff --git a/sample-apps/endoscopy/lib/infers/inbody.py b/sample-apps/endoscopy/lib/infers/inbody.py index defeacca2..5a5c31ac1 100644 --- a/sample-apps/endoscopy/lib/infers/inbody.py +++ b/sample-apps/endoscopy/lib/infers/inbody.py @@ -24,7 +24,7 @@ class InBody(BundleInferTask): """ def __init__(self, path: str, conf: Dict[str, str], **kwargs): - super().__init__(path, conf, type=InferType.CLASSIFICATION, add_post_restore=False, **kwargs) + super().__init__(path, conf, type=InferType.CLASSIFICATION, add_post_restore=False, load_strict=False, **kwargs) # Override Labels self.labels = {"InBody": 0, "OutBody": 1} diff --git a/sample-apps/endoscopy/lib/infers/tooltracking.py b/sample-apps/endoscopy/lib/infers/tooltracking.py index 1e60512be..584866ddc 100644 --- a/sample-apps/endoscopy/lib/infers/tooltracking.py +++ b/sample-apps/endoscopy/lib/infers/tooltracking.py @@ -28,7 +28,7 @@ class ToolTracking(BundleInferTask): """ def __init__(self, path: str, conf: Dict[str, str], **kwargs): - super().__init__(path, conf, type=InferType.SEGMENTATION, **kwargs) + super().__init__(path, conf, type=InferType.SEGMENTATION, load_strict=False, **kwargs) # Override Labels self.labels = {"Tool": 1} diff --git a/sample-apps/pathology/lib/infers/classification_nuclei.py b/sample-apps/pathology/lib/infers/classification_nuclei.py index 49ccac35f..a73c94168 100644 --- a/sample-apps/pathology/lib/infers/classification_nuclei.py +++ b/sample-apps/pathology/lib/infers/classification_nuclei.py @@ -24,7 +24,7 @@ class ClassificationNuclei(BundleInferTask): """ def __init__(self, path: str, conf: Dict[str, str], **kwargs): - super().__init__(path, conf, type=InferType.CLASSIFICATION, add_post_restore=False, **kwargs) + super().__init__(path, conf, type=InferType.CLASSIFICATION, add_post_restore=False, load_strict=False, **kwargs) # Override Labels self.labels = { diff --git a/sample-apps/pathology/lib/infers/nuclick.py b/sample-apps/pathology/lib/infers/nuclick.py index 73930a5a3..c2431d3ec 100644 --- a/sample-apps/pathology/lib/infers/nuclick.py +++ b/sample-apps/pathology/lib/infers/nuclick.py @@ -41,6 +41,7 @@ def __init__(self, path: str, conf: Dict[str, str], **kwargs): **kwargs, pre_filter=[LoadImaged, SqueezeDimd], post_filter=[KeepLargestConnectedComponentd, SaveImaged], + load_strict=False, ) # Override Labels diff --git a/sample-apps/pathology/lib/infers/segmentation_nuclei.py b/sample-apps/pathology/lib/infers/segmentation_nuclei.py index 652202b25..e8e9e1691 100644 --- a/sample-apps/pathology/lib/infers/segmentation_nuclei.py +++ b/sample-apps/pathology/lib/infers/segmentation_nuclei.py @@ -48,6 +48,7 @@ def __init__( labels=labels, dimension=dimension, description=description, + load_strict=False, **kwargs, ) diff --git a/sample-apps/radiology/lib/infers/deepedit.py b/sample-apps/radiology/lib/infers/deepedit.py index 152d1c45c..e380d6fa9 100644 --- a/sample-apps/radiology/lib/infers/deepedit.py +++ b/sample-apps/radiology/lib/infers/deepedit.py @@ -65,6 +65,7 @@ def __init__( input_key="image", output_label_key="pred", output_json_key="result", + load_strict=False, **kwargs, ) diff --git a/sample-apps/radiology/lib/infers/deepgrow.py b/sample-apps/radiology/lib/infers/deepgrow.py index 107395834..59e891868 100644 --- a/sample-apps/radiology/lib/infers/deepgrow.py +++ b/sample-apps/radiology/lib/infers/deepgrow.py @@ -63,6 +63,7 @@ def __init__( labels=labels, dimension=dimension, description=description, + load_strict=False, **kwargs, ) diff --git a/sample-apps/radiology/lib/infers/deepgrow_pipeline.py b/sample-apps/radiology/lib/infers/deepgrow_pipeline.py index 97895b611..ea9e1c959 100644 --- a/sample-apps/radiology/lib/infers/deepgrow_pipeline.py +++ b/sample-apps/radiology/lib/infers/deepgrow_pipeline.py @@ -68,6 +68,7 @@ def __init__( dimension=dimension, description=description, config={"cache_transforms": True, "cache_transforms_in_memory": True, "cache_transforms_ttl": 300}, + load_strict=False, ) self.model_3d = model_3d self.spatial_size = spatial_size @@ -79,6 +80,8 @@ def __init__( self.random_point_density = random_point_density self.output_largest_cc = output_largest_cc + self.load_strict = False + def pre_transforms(self, data=None) -> Sequence[Callable]: t = [ LoadImaged(keys="image"), diff --git a/sample-apps/radiology/lib/infers/localization_spine.py b/sample-apps/radiology/lib/infers/localization_spine.py index 2c56fff16..347d1536e 100644 --- a/sample-apps/radiology/lib/infers/localization_spine.py +++ b/sample-apps/radiology/lib/infers/localization_spine.py @@ -54,6 +54,7 @@ def __init__( labels=labels, dimension=dimension, description=description, + load_strict=False, **kwargs, ) self.target_spacing = target_spacing diff --git a/sample-apps/radiology/lib/infers/localization_vertebra.py b/sample-apps/radiology/lib/infers/localization_vertebra.py index 8e5e39604..fec4cc5a9 100644 --- a/sample-apps/radiology/lib/infers/localization_vertebra.py +++ b/sample-apps/radiology/lib/infers/localization_vertebra.py @@ -56,6 +56,7 @@ def __init__( labels=labels, dimension=dimension, description=description, + load_strict=False, **kwargs, ) self.target_spacing = target_spacing diff --git a/sample-apps/radiology/lib/infers/segmentation.py b/sample-apps/radiology/lib/infers/segmentation.py index 8ce848f34..1eadab540 100644 --- a/sample-apps/radiology/lib/infers/segmentation.py +++ b/sample-apps/radiology/lib/infers/segmentation.py @@ -54,6 +54,7 @@ def __init__( labels=labels, dimension=dimension, description=description, + load_strict=False, **kwargs, ) self.target_spacing = target_spacing diff --git a/sample-apps/radiology/lib/infers/segmentation_spleen.py b/sample-apps/radiology/lib/infers/segmentation_spleen.py index 4bfa7ceeb..1e4c4102a 100644 --- a/sample-apps/radiology/lib/infers/segmentation_spleen.py +++ b/sample-apps/radiology/lib/infers/segmentation_spleen.py @@ -53,6 +53,7 @@ def __init__( labels=labels, dimension=dimension, description=description, + load_strict=False, **kwargs, ) self.target_spacing = target_spacing diff --git a/sample-apps/radiology/lib/infers/segmentation_vertebra.py b/sample-apps/radiology/lib/infers/segmentation_vertebra.py index cfaabced3..a85ed77ce 100644 --- a/sample-apps/radiology/lib/infers/segmentation_vertebra.py +++ b/sample-apps/radiology/lib/infers/segmentation_vertebra.py @@ -63,6 +63,7 @@ def __init__( labels=labels, dimension=dimension, description=description, + load_strict=False, **kwargs, ) self.target_spacing = target_spacing diff --git a/sample-apps/radiology/lib/infers/vertebra_pipeline.py b/sample-apps/radiology/lib/infers/vertebra_pipeline.py index b015993cf..39fa433de 100644 --- a/sample-apps/radiology/lib/infers/vertebra_pipeline.py +++ b/sample-apps/radiology/lib/infers/vertebra_pipeline.py @@ -47,6 +47,7 @@ def __init__( labels=task_seg_vertebra.labels, dimension=task_seg_vertebra.dimension, description=description, + load_strict=False, **kwargs, ) From bcf00f06d8bb49239597c5979adb8eed3d691510 Mon Sep 17 00:00:00 2001 From: Matthias Hadlich Date: Sun, 27 Aug 2023 09:03:18 +0200 Subject: [PATCH 6/8] Fix double load_strict=False Signed-off-by: Matthias Hadlich --- sample-apps/radiology/lib/infers/deepgrow_pipeline.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/sample-apps/radiology/lib/infers/deepgrow_pipeline.py b/sample-apps/radiology/lib/infers/deepgrow_pipeline.py index ea9e1c959..7c7145493 100644 --- a/sample-apps/radiology/lib/infers/deepgrow_pipeline.py +++ b/sample-apps/radiology/lib/infers/deepgrow_pipeline.py @@ -80,8 +80,6 @@ def __init__( self.random_point_density = random_point_density self.output_largest_cc = output_largest_cc - self.load_strict = False - def pre_transforms(self, data=None) -> Sequence[Callable]: t = [ LoadImaged(keys="image"), From 520705360a2009dda1464948b7dcad7351d69f50 Mon Sep 17 00:00:00 2001 From: SACHIDANAND ALLE Date: Thu, 31 Aug 2023 09:36:29 -0700 Subject: [PATCH 7/8] Update nuclick.py Fix the order for load_strict Signed-off-by: SACHIDANAND ALLE --- sample-apps/pathology/lib/infers/nuclick.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sample-apps/pathology/lib/infers/nuclick.py b/sample-apps/pathology/lib/infers/nuclick.py index c2431d3ec..a780b3597 100644 --- a/sample-apps/pathology/lib/infers/nuclick.py +++ b/sample-apps/pathology/lib/infers/nuclick.py @@ -38,10 +38,10 @@ def __init__(self, path: str, conf: Dict[str, str], **kwargs): conf, type=InferType.ANNOTATION, add_post_restore=False, + load_strict=False, **kwargs, pre_filter=[LoadImaged, SqueezeDimd], post_filter=[KeepLargestConnectedComponentd, SaveImaged], - load_strict=False, ) # Override Labels From 67dfc498afe548bf8e9a5e0cb115f80c98297815 Mon Sep 17 00:00:00 2001 From: Sachidanand Alle Date: Thu, 31 Aug 2023 10:06:23 -0700 Subject: [PATCH 8/8] Fix load_strict for bundle infer Signed-off-by: Sachidanand Alle --- monailabel/tasks/infer/bundle.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/monailabel/tasks/infer/bundle.py b/monailabel/tasks/infer/bundle.py index a364ca02f..18bd83e89 100644 --- a/monailabel/tasks/infer/bundle.py +++ b/monailabel/tasks/infer/bundle.py @@ -87,6 +87,7 @@ def __init__( extend_load_image: bool = True, add_post_restore: bool = True, dropout: float = 0.0, + load_strict=False, **kwargs, ): self.valid: bool = False @@ -149,7 +150,7 @@ def __init__( dimension=dimension, description=description, preload=strtobool(conf.get("preload", "false")), - load_strict=False, + load_strict=load_strict, **kwargs, )