diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index c1df722..f9b300c 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -31,9 +31,9 @@ jobs: strategy: max-parallel: 2 matrix: - python-version: ['3.8'] - torch-version: [1.10.0, 2.0.0] - os: [ubuntu-latest] # only run ubuntu for now because the other ones fail for no reason, macos-latest, windows-latest] + python-version: ['3.9'] + torch-version: [2.1.1] + os: [ubuntu-latest, macos-latest, windows-latest] # only run ubuntu for now because the other ones fail for no reason, macos-latest, windows-latest] # Steps represent a sequence of tasks that will be executed as part of the job steps: diff --git a/requirements-dev.txt b/requirements-dev.txt index 8a9abe5..bcf1023 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,3 +1,4 @@ -r requirements.txt +wheel pytest -pydicom>=2.3.1 \ No newline at end of file +pydicom>=2.3.1 diff --git a/torchxrayvision/autoencoders.py b/torchxrayvision/autoencoders.py index ee3a1ac..a03e1ee 100644 --- a/torchxrayvision/autoencoders.py +++ b/torchxrayvision/autoencoders.py @@ -222,7 +222,7 @@ def ResNetAE(weights=None): """A ResNet based autoencoder. Possible weights for this class include: - + - "101-elastic" trained on PadChest, NIH, CheXpert, and MIMIC. From the paper https://arxiv.org/abs/2102.09475 .. code-block:: python diff --git a/torchxrayvision/baseline_models/chexpert/model.py b/torchxrayvision/baseline_models/chexpert/model.py index 23d1788..c32a0a9 100644 --- a/torchxrayvision/baseline_models/chexpert/model.py +++ b/torchxrayvision/baseline_models/chexpert/model.py @@ -79,7 +79,7 @@ def infer(self, x, tasks): for task in tasks: idx = self.task_sequence[task] - #task_prob = probs.detach().cpu().numpy()[idx] + # task_prob = probs.detach().cpu().numpy()[idx] task_prob = probs[idx] task2results[task] = task_prob @@ -226,7 +226,7 @@ def infer(self, img, tasks): else: task2ensemble_results[task].append(individual_task2results[task]) - assert all([task in task2ensemble_results for task in tasks]),\ + assert all([task in task2ensemble_results for task in tasks]), \ "Not all tasks in task2ensemble_results" task2results = {} diff --git a/torchxrayvision/baseline_models/riken/__init__.py b/torchxrayvision/baseline_models/riken/__init__.py index 18dcd13..409930f 100644 --- a/torchxrayvision/baseline_models/riken/__init__.py +++ b/torchxrayvision/baseline_models/riken/__init__.py @@ -43,18 +43,18 @@ class AgeModel(nn.Module): url = {https://www.nature.com/articles/s43856-022-00220-6}, year = {2022} } - + """ targets: List[str] = ["Age"] """""" - + def __init__(self): - + super(AgeModel, self).__init__() - + url = "https://github.com/mlmed/torchxrayvision/releases/download/v1/baseline_models_riken_xray_age_every_model_age_senet154_v2_tl_26_ft_7_fp32.pt" - + weights_filename = os.path.basename(url) weights_storage_folder = os.path.expanduser(os.path.join("~", ".torchxrayvision", "models_data")) self.weights_filename_local = os.path.expanduser(os.path.join(weights_storage_folder, weights_filename)) @@ -81,17 +81,17 @@ def __init__(self): [0.485, 0.456, 0.406], [0.229, 0.224, 0.225], ) - + def forward(self, x): x = x.repeat(1, 3, 1, 1) x = self.upsample(x) - + # expecting values between [-1024,1024] x = (x + 1024) / 2048 # now between [0,1] - + x = self.norm(x) return self.model(x) - + def __repr__(self): return "riken-age-prediction" diff --git a/torchxrayvision/baseline_models/xinario/__init__.py b/torchxrayvision/baseline_models/xinario/__init__.py index ac97e21..e8e90e0 100644 --- a/torchxrayvision/baseline_models/xinario/__init__.py +++ b/torchxrayvision/baseline_models/xinario/__init__.py @@ -10,7 +10,7 @@ class ViewModel(nn.Module): """ - + The native resolution of the model is 320x320. Images are scaled automatically. @@ -26,7 +26,7 @@ class ViewModel(nn.Module): pred = model(image) # tensor([[17.3186, 26.7156]]), grad_fn=) - + model.targets[pred.argmax()] # Lateral @@ -37,13 +37,13 @@ class ViewModel(nn.Module): targets: List[str] = ['Frontal', 'Lateral'] """""" - + def __init__(self): - + super(ViewModel, self).__init__() - + url = "https://github.com/mlmed/torchxrayvision/releases/download/v1/xinario_chestViewSplit_resnet-50.pt" - + weights_filename = os.path.basename(url) weights_storage_folder = os.path.expanduser(os.path.join("~", ".torchxrayvision", "models_data")) self.weights_filename_local = os.path.expanduser(os.path.join(weights_storage_folder, weights_filename)) @@ -54,7 +54,6 @@ def __init__(self): pathlib.Path(weights_storage_folder).mkdir(parents=True, exist_ok=True) xrv.utils.download(url, self.weights_filename_local) - self.model = torchvision.models.resnet.resnet50() try: weights = torch.load(self.weights_filename_local) @@ -74,17 +73,17 @@ def __init__(self): [0.485, 0.456, 0.406], [0.229, 0.224, 0.225], ) - + def forward(self, x): x = x.repeat(1, 3, 1, 1) x = self.upsample(x) - + # expecting values between [-1024,1024] x = (x + 1024) / 2048 # now between [0,1] - + x = self.norm(x) - return self.model(x)[:,:2] # cut off the rest of the outputs - + return self.model(x)[:, :2] # cut off the rest of the outputs + def __repr__(self): return "xinario-view-prediction" diff --git a/torchxrayvision/datasets.py b/torchxrayvision/datasets.py index 31b3863..6eba9fc 100644 --- a/torchxrayvision/datasets.py +++ b/torchxrayvision/datasets.py @@ -128,6 +128,7 @@ class Dataset: metadata file and for some the metadata files are packaged in the library so only the imgpath needs to be specified. """ + def __init__(self): pass @@ -262,7 +263,7 @@ def __init__(self, datasets, seed=0, label_concat=False): print("Could not merge dataframes (.csv not available):", sys.exc_info()[0]) self.csv = self.csv.reset_index(drop=True) - + def __setattr__(self, name, value): if hasattr(self, 'labels'): # check only if have finished init, otherwise __init__ breaks @@ -346,6 +347,7 @@ class SubsetDataset(Dataset): - of PC_Dataset num_samples=94825 views=['PA', 'AP'] data_aug=None """ + def __init__(self, dataset, idxs=None): super(SubsetDataset, self).__init__() self.dataset = dataset @@ -365,7 +367,7 @@ def __setattr__(self, name, value): # check only if have finished init, otherwise __init__ breaks if name in ['transform', 'data_aug', 'labels', 'pathologies', 'targets']: raise NotImplementedError(f'Cannot set {name} on a subset dataset. Set the transforms directly on the dataset object. If it was to be set via this subset dataset it would have to modify the internal dataset which could have unexpected side effects') - + object.__setattr__(self, name, value) def string(self): @@ -895,17 +897,17 @@ def __init__(self, "216840111366964012373310883942009170084120009_00-097-074.png", "216840111366964012819207061112010315104455352_04-024-184.png", "216840111366964012819207061112010306085429121_04-020-102.png", - "216840111366964012989926673512011083134050913_00-168-009.png", # broken PNG file (chunk b'\x00\x00\x00\x00') - "216840111366964012373310883942009152114636712_00-102-045.png", # "OSError: image file is truncated" - "216840111366964012819207061112010281134410801_00-129-131.png", # "OSError: image file is truncated" - "216840111366964012487858717522009280135853083_00-075-001.png", # "OSError: image file is truncated" - "216840111366964012989926673512011151082430686_00-157-045.png", # broken PNG file (chunk b'\x00\x00\x00\x00') - "216840111366964013686042548532013208193054515_02-026-007.png", # "OSError: image file is truncated" - "216840111366964013590140476722013058110301622_02-056-111.png", # "OSError: image file is truncated" - "216840111366964013590140476722013043111952381_02-065-198.png", # "OSError: image file is truncated" - "216840111366964013829543166512013353113303615_02-092-190.png", # "OSError: image file is truncated" - "216840111366964013962490064942014134093945580_01-178-104.png", # "OSError: image file is truncated" - ] + "216840111366964012989926673512011083134050913_00-168-009.png", # broken PNG file (chunk b'\x00\x00\x00\x00') + "216840111366964012373310883942009152114636712_00-102-045.png", # "OSError: image file is truncated" + "216840111366964012819207061112010281134410801_00-129-131.png", # "OSError: image file is truncated" + "216840111366964012487858717522009280135853083_00-075-001.png", # "OSError: image file is truncated" + "216840111366964012989926673512011151082430686_00-157-045.png", # broken PNG file (chunk b'\x00\x00\x00\x00') + "216840111366964013686042548532013208193054515_02-026-007.png", # "OSError: image file is truncated" + "216840111366964013590140476722013058110301622_02-056-111.png", # "OSError: image file is truncated" + "216840111366964013590140476722013043111952381_02-065-198.png", # "OSError: image file is truncated" + "216840111366964013829543166512013353113303615_02-092-190.png", # "OSError: image file is truncated" + "216840111366964013962490064942014134093945580_01-178-104.png", # "OSError: image file is truncated" + ] self.csv = self.csv[~self.csv["ImageID"].isin(missing)] if unique_patients: @@ -920,7 +922,7 @@ def __init__(self, mask = self.csv["Labels"].str.contains(pathology.lower()) if pathology in mapping: for syn in mapping[pathology]: - #print("mapping", syn) + # print("mapping", syn) mask |= self.csv["Labels"].str.contains(syn.lower()) labels.append(mask.values) self.labels = np.asarray(labels).T @@ -1094,7 +1096,7 @@ def __getitem__(self, idx): sample["lab"] = self.labels[idx] imgid = self.csv['Path'].iloc[idx] - #clean up path in csv so the user can specify the path + # clean up path in csv so the user can specify the path imgid = imgid.replace("CheXpert-v1.0-small/", "").replace("CheXpert-v1.0/", "") img_path = os.path.join(self.imgpath, imgid) img = imread(img_path) @@ -1344,7 +1346,7 @@ def __init__(self, imgpath, mask = self.csv["labels_automatic"].str.contains(pathology.lower()) if pathology in mapping: for syn in mapping[pathology]: - #print("mapping", syn) + # print("mapping", syn) mask |= self.csv["labels_automatic"].str.contains(syn.lower()) labels.append(mask.values) @@ -1994,7 +1996,7 @@ def __init__(self, transform=None, data_aug=None, seed=0 - ): + ): super(ObjectCXR_Dataset, self).__init__() np.random.seed(seed) # Reset the seed so all runs are the same. @@ -2053,6 +2055,7 @@ def __call__(self, x): class XRayResizer(object): """Resize an image to a specific size""" + def __init__(self, size: int, engine="skimage"): self.size = size self.engine = engine @@ -2076,6 +2079,7 @@ def __call__(self, img: np.ndarray) -> np.ndarray: class XRayCenterCrop(object): """Perform a center crop on the long dimension of the input image""" + def crop_center(self, img: np.ndarray) -> np.ndarray: _, y, x = img.shape crop_size = np.min([y, x]) diff --git a/torchxrayvision/models.py b/torchxrayvision/models.py index 4d0663e..dc72bbe 100644 --- a/torchxrayvision/models.py +++ b/torchxrayvision/models.py @@ -80,6 +80,8 @@ } # Just created for documentation + + class Model: """The library is composed of core and baseline classifiers. Core classifiers are trained specifically for this library and baseline @@ -132,6 +134,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: """ pass + class _DenseLayer(nn.Sequential): def __init__(self, num_input_features, growth_rate, bn_size, drop_rate): super(_DenseLayer, self).__init__() @@ -190,7 +193,7 @@ class DenseNet(nn.Module): :param weights: Specify a weight name to load pre-trained weights :param op_threshs: Specify a weight name to load pre-trained weights :param apply_sigmoid: Apply a sigmoid - + """ targets: List[str] = [ @@ -379,7 +382,6 @@ class ResNet(nn.Module): ] """""" - def __init__(self, weights: str = None, apply_sigmoid: bool = False): super(ResNet, self).__init__() diff --git a/torchxrayvision/utils.py b/torchxrayvision/utils.py index b07a1aa..e829d22 100644 --- a/torchxrayvision/utils.py +++ b/torchxrayvision/utils.py @@ -80,7 +80,8 @@ def load_image(fname: str): return img -def read_xray_dcm(path:PathLike, voi_lut:bool=False, fix_monochrome:bool=True)->ndarray: + +def read_xray_dcm(path: PathLike, voi_lut: bool = False, fix_monochrome: bool = True) -> ndarray: """read a dicom-like file and convert to numpy array Args: @@ -98,26 +99,26 @@ def read_xray_dcm(path:PathLike, voi_lut:bool=False, fix_monochrome:bool=True)-> # get the pixel array ds = pydicom.dcmread(path, force=True) - data = ds.pixel_array # we have not tested RGB, YBR_FULL, or YBR_FULL_422 yet. - if ds.PhotometricInterpretation not in ['MONOCHROME1', 'MONOCHROME2']: + if ds.PhotometricInterpretation not in ['MONOCHROME1', 'MONOCHROME2']: raise NotImplementedError(f'PhotometricInterpretation `{ds.PhotometricInterpretation}` is not yet supported.') # get the max possible pixel value from DCM header max_possible_pixel_val = (2**ds.BitsStored - 1) + data = ds.pixel_array + # LUT for human friendly view if voi_lut: data = pydicom.pixel_data_handlers.util.apply_voi_lut(data, ds, index=0) - # `MONOCHROME1` have an inverted view; Bones are black; background is white # https://web.archive.org/web/20150920230923/http://www.mccauslandcenter.sc.edu/mricro/dicom/index.html if fix_monochrome and ds.PhotometricInterpretation == "MONOCHROME1": warnings.warn(f"Coverting MONOCHROME1 to MONOCHROME2 interpretation for file: {path}. Can be avoided by setting `fix_monochrome=False`") data = max_possible_pixel_val - data - # normalize data to [-1024, 1024] + # normalize data to [-1024, 1024] data = normalize(data, max_possible_pixel_val) return data @@ -129,13 +130,13 @@ def infer(model: torch.nn.Module, dataset: torch.utils.data.Dataset, threads=4, batch_size=threads, num_workers=threads, ) - + preds = [] with torch.inference_mode(): for i, batch in enumerate(tqdm(dl)): output = model(batch["img"].to(device)) - + output = output.detach().cpu().numpy() preds.append(output) - + return np.concatenate(preds)