Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add Polygon model support #837

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 55 additions & 1 deletion deepforest/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@
from albumentations.pytorch import ToTensorV2
import torch
import typing
from PIL import Image
from PIL import Image, ImageDraw
import rasterio as rio
from deepforest import preprocess
from rasterio.windows import Window
from torchvision import transforms
import torchvision
from torchvision.tv_tensors import BoundingBoxes, Mask
import slidingwindow
import warnings

Expand Down Expand Up @@ -315,3 +317,55 @@ def __getitem__(self, idx):
image = box

return image


class PolygonDataset(Dataset):

def __init__(self, img_keys, annotation_df, class_to_idx, transforms=None):
super(Dataset, self).__init__()
self._img_keys = img_keys
self._annotation_df = annotation_df
self._class_to_idx = class_to_idx
self._transforms = transforms

def __len__(self):
return len(self._img_keys)

def __getitem__(self, index):
img_key = self._img_keys[index]
annotation = self._annotation_df.loc[img_key]
image, target = self._load_image_and_target(annotation)

if self._transforms:
image, target = self._transforms(image, target)

return image, target

def create_polygon_mask(self, image_size, vertices):
mask_img = Image.new('L', image_size, 0)
ImageDraw.Draw(mask_img, 'L').polygon(vertices, fill=(255))

return mask_img

def _load_image_and_target(self, annotation):
filepath = annotation['image_path']
image = Image.open(filepath).convert('RGB')

labels = [shape['label'] for shape in annotation['shapes']]
labels = torch.Tensor([self._class_to_idx[label] for label in labels])
labels = labels.to(dtype=torch.int64)

shape_points = [shape['points'] for shape in annotation['shapes']]
xy_coords = [[tuple(p) for p in points] for points in shape_points]
mask_imgs = [self.create_polygon_mask(image.size, xy) for xy in xy_coords]
masks = Mask(
torch.concat([
Mask(transforms.PILToTensor()(mask_img), dtype=torch.bool)
for mask_img in mask_imgs
]))

bboxes = BoundingBoxes(data=torchvision.ops.masks_to_boxes(masks),
format='xyxy',
canvas_size=image.size[::-1])

return image, {'masks': masks, 'boxes': bboxes, 'labels': labels}
51 changes: 51 additions & 0 deletions deepforest/models/MaskRCNN.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import torch
from torchvision.models.detection import maskrcnn_resnet50_fpn_v2
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
from deepforest.model import Model


class Model(Model):

def __init__(self, config, **kwargs):
super().__init__(config)

def load_backbone(self):
backbone = maskrcnn_resnet50_fpn_v2(weights='DEFAULT')

return backbone

def create_model(self, backbone=None):
model = maskrcnn_resnet50_fpn_v2(weights='DEFAULT')

# Modify the box predictor for the desired number of classes
in_features_box = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(
in_features_box, num_classes=self.config["num_classes"])

# Modify the mask predictor for the desired number of classes
in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
dim_reduced = model.roi_heads.mask_predictor.conv5_mask.out_channels
model.roi_heads.mask_predictor = MaskRCNNPredictor(
in_features_mask, dim_reduced, num_classes=self.config["num_classes"])

return model

def check_model(self):
"""Ensure that model follows deepforest guidelines, see ##### If fails,
raise ValueError."""
# This assumes model creation is not expensive
test_model = self.create_model()
test_model.eval()

# Create a dummy batch of 3 band data.
x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]

predictions = test_model(x)
# Model takes in a batch of images
assert len(predictions) == 2

# Returns a list equal to number of images with proper keys per image
model_keys = list(predictions[1].keys())
model_keys.sort()
assert model_keys == ['boxes', 'masks', 'labels', 'scores']