From 27f28f101f07e2c6aa19135719cc5fc7f90b06c5 Mon Sep 17 00:00:00 2001 From: Mu-Magdy Date: Sat, 9 Nov 2024 19:38:08 +0200 Subject: [PATCH] pointer example + pointer dataset class --- deepforest/convert centroids.py | 11 +++++ deepforest/data/OSBS_029_centroids.csv | 62 ++++++++++++++++++++++++++ deepforest/dataset.py | 46 +++++++++++++++++++ 3 files changed, 119 insertions(+) create mode 100644 deepforest/convert centroids.py create mode 100644 deepforest/data/OSBS_029_centroids.csv diff --git a/deepforest/convert centroids.py b/deepforest/convert centroids.py new file mode 100644 index 00000000..99622344 --- /dev/null +++ b/deepforest/convert centroids.py @@ -0,0 +1,11 @@ +import pandas as pd + +df = pd.read_csv('deepforest/data/OSBS_029.csv') + +# Calculate centroids +df['x_center'] = round((df['xmin'] + df['xmax']) / 2) +df['y_center'] = round((df['ymin'] + df['ymax']) / 2) + +# Drop original bounding box columns +df_centroids = df.drop(columns=["xmin", "ymin", "xmax", "ymax"]) +df_centroids.to_csv("deepforest/data/OSBS_029_centroids.csv") diff --git a/deepforest/data/OSBS_029_centroids.csv b/deepforest/data/OSBS_029_centroids.csv new file mode 100644 index 00000000..050aad04 --- /dev/null +++ b/deepforest/data/OSBS_029_centroids.csv @@ -0,0 +1,62 @@ +,image_path,label,x_center,y_center +0,OSBS_029.tif,Tree,215.0,78.0 +1,OSBS_029.tif,Tree,272.0,120.0 +2,OSBS_029.tif,Tree,196.0,278.0 +3,OSBS_029.tif,Tree,382.0,14.0 +4,OSBS_029.tif,Tree,330.0,30.0 +5,OSBS_029.tif,Tree,382.0,46.0 +6,OSBS_029.tif,Tree,295.0,19.0 +7,OSBS_029.tif,Tree,382.0,225.0 +8,OSBS_029.tif,Tree,106.0,131.0 +9,OSBS_029.tif,Tree,132.0,130.0 +10,OSBS_029.tif,Tree,180.0,173.0 +11,OSBS_029.tif,Tree,140.0,172.0 +12,OSBS_029.tif,Tree,362.0,305.0 +13,OSBS_029.tif,Tree,27.0,185.0 +14,OSBS_029.tif,Tree,21.0,236.0 +15,OSBS_029.tif,Tree,91.0,166.0 +16,OSBS_029.tif,Tree,384.0,94.0 +17,OSBS_029.tif,Tree,176.0,126.0 +18,OSBS_029.tif,Tree,172.0,212.0 +19,OSBS_029.tif,Tree,128.0,220.0 +20,OSBS_029.tif,Tree,62.0,388.0 +21,OSBS_029.tif,Tree,134.0,384.0 +22,OSBS_029.tif,Tree,244.0,255.0 +23,OSBS_029.tif,Tree,314.0,384.0 +24,OSBS_029.tif,Tree,354.0,360.0 +25,OSBS_029.tif,Tree,20.0,38.0 +26,OSBS_029.tif,Tree,20.0,86.0 +27,OSBS_029.tif,Tree,76.0,30.0 +28,OSBS_029.tif,Tree,116.0,52.0 +29,OSBS_029.tif,Tree,168.0,22.0 +30,OSBS_029.tif,Tree,208.0,20.0 +31,OSBS_029.tif,Tree,391.0,282.0 +32,OSBS_029.tif,Tree,347.0,144.0 +33,OSBS_029.tif,Tree,346.0,101.0 +34,OSBS_029.tif,Tree,378.0,130.0 +35,OSBS_029.tif,Tree,133.0,285.0 +36,OSBS_029.tif,Tree,128.0,340.0 +37,OSBS_029.tif,Tree,196.0,378.0 +38,OSBS_029.tif,Tree,215.0,357.0 +39,OSBS_029.tif,Tree,196.0,326.0 +40,OSBS_029.tif,Tree,259.0,324.0 +41,OSBS_029.tif,Tree,292.0,353.0 +42,OSBS_029.tif,Tree,72.0,215.0 +43,OSBS_029.tif,Tree,133.0,84.0 +44,OSBS_029.tif,Tree,74.0,93.0 +45,OSBS_029.tif,Tree,276.0,262.0 +46,OSBS_029.tif,Tree,350.0,64.0 +47,OSBS_029.tif,Tree,268.0,67.0 +48,OSBS_029.tif,Tree,312.0,114.0 +49,OSBS_029.tif,Tree,302.0,152.0 +50,OSBS_029.tif,Tree,225.0,114.0 +51,OSBS_029.tif,Tree,273.0,215.0 +52,OSBS_029.tif,Tree,44.0,356.0 +53,OSBS_029.tif,Tree,36.0,384.0 +54,OSBS_029.tif,Tree,16.0,278.0 +55,OSBS_029.tif,Tree,93.0,264.0 +56,OSBS_029.tif,Tree,78.0,312.0 +57,OSBS_029.tif,Tree,102.0,376.0 +58,OSBS_029.tif,Tree,244.0,142.0 +59,OSBS_029.tif,Tree,331.0,194.0 +60,OSBS_029.tif,Tree,236.0,226.0 diff --git a/deepforest/dataset.py b/deepforest/dataset.py index d058fbd7..fafc3c2b 100644 --- a/deepforest/dataset.py +++ b/deepforest/dataset.py @@ -248,3 +248,49 @@ def __getitem__(self, idx): image = box return image + + +class PointDataset(Dataset): + """An in-memory dataset for point predictions (centroids). + + Args: + df: a pandas dataframe with image_path, x_center, y_center, and label columns. + transform: a function to apply to the image. + root_dir: the directory where the image is stored. + + Returns: + rgb: a tensor of shape (3, height, width) + """ + + def __init__(self, df, root_dir, transform=None, augment=False): + self.df = df + self.root_dir = root_dir + self.transform = transform + self.augment = augment + + unique_image = self.df['image_path'].unique() + assert len(unique_image + ) == 1, "There should be only one unique image for this class object" + + # Open the image using rasterio + self.src = rio.open(os.path.join(root_dir, unique_image[0])) + + def __len__(self): + return len(self.df) + + def __getitem__(self, idx): + row = self.df.iloc[idx] + x_center = row['x_center'] + y_center = row['y_center'] + + # Read a small window around the centroid to create a context image for the point (optional) + box_size = 20 # e.g., small crop around the centroid + window = Window(x_center - box_size // 2, y_center - box_size // 2, box_size, + box_size) + image = self.src.read(window=window) + image = np.rollaxis(image, 0, 3) + + if self.transform: + image = self.transform(image) + + return image, (x_center, y_center)