This repository has been archived by the owner on Oct 9, 2023. It is now read-only.
Replies: 1 comment
-
Hi @andife thanks for the question! Here's an example showing how you can use custom transforms from albumentations with the semantic segmentation task: from functools import partial
from typing import Tuple
import torch
from dataclasses import dataclass
import albumentations as A
import flash
from flash import InputTransform
from flash.core.data.utils import download_data
from flash.image import SemanticSegmentation, SemanticSegmentationData
# 1. Create the DataModule
# The data was generated with the CARLA self-driving simulator as part of the Kaggle Lyft Udacity Challenge.
# More info here: https://www.kaggle.com/kumaresanmanickavelu/lyft-udacity-challenge
download_data(
"https://github.com/ongchinkiat/LyftPerceptionChallenge/releases/download/v0.1/carla-capture-20180513A.zip",
"./data",
)
@dataclass
class CustomTransform(InputTransform):
image_size: Tuple[int, int] = (300, 300)
crop_size: Tuple[int, int] = (256, 256)
def __post_init__(self):
self.train_transform = A.Compose([
A.Resize(width=self.image_size[0], height=self.image_size[1]),
A.RandomCrop(width=self.crop_size[0], height=self.crop_size[1]),
A.HorizontalFlip(p=0.5),
A.RandomBrightnessContrast(p=0.2),
])
self.transform = A.Compose([
A.Resize(width=self.image_size[0], height=self.image_size[1]),
A.CenterCrop(width=self.crop_size[0], height=self.crop_size[1]),
])
super().__post_init__()
@staticmethod
def _apply_transform(transform, sample):
if "target" in sample:
kwargs = {
"mask": sample["target"].numpy()
}
else:
kwargs = {}
transformed = transform(
image=sample["input"].permute(1, 2, 0).numpy(),
**kwargs
)
sample["input"] = torch.from_numpy(transformed["image"]).permute(2, 0, 1)
if "mask" in transformed:
sample["target"] = torch.from_numpy(transformed["mask"])
return sample
def per_sample_transform(self):
return partial(self._apply_transform, self.transform)
def train_per_sample_transform(self):
return partial(self._apply_transform, self.train_transform)
@staticmethod
def _prepare_target(target) -> torch.Tensor:
"""Convert the target mask to long and remove the channel dimension."""
return target.long().squeeze(1)
def target_per_batch_transform(self):
return self._prepare_target
datamodule = SemanticSegmentationData.from_folders(
train_folder="data/CameraRGB",
train_target_folder="data/CameraSeg",
val_split=0.1,
train_transform=CustomTransform,
val_transform=CustomTransform,
transform_kwargs=dict(image_size=(300, 300)),
num_classes=21,
batch_size=4,
)
# 2. Build the task
model = SemanticSegmentation(
backbone="mobilenetv3_large_100",
head="fpn",
num_classes=datamodule.num_classes,
)
# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count(), fast_dev_run=True)
trainer.finetune(model, datamodule=datamodule, strategy="freeze")
# 4. Segment a few images!
datamodule = SemanticSegmentationData.from_files(
predict_transform=CustomTransform,
predict_files=[
"data/CameraRGB/F61-1.png",
"data/CameraRGB/F62-1.png",
"data/CameraRGB/F63-1.png",
],
batch_size=3,
)
predictions = trainer.predict(model, datamodule=datamodule)
print(predictions)
# 5. Save the model!
trainer.save_checkpoint("semantic_segmentation_model.pt") Hope that helps 😃 |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hello,
I'm trying to use a more advanced/specific image augmentation setup for a pytorch-flash task.
As starting point I look at the following examples and combined them:
#1107
https://lightning-flash.readthedocs.io/en/stable/reference/semantic_segmentation.html
The error I get is
as I understand icevision is not implemented for semantic segmentation task. As it seems to be implemented in instance_segmentation, the code should be transferable to the semantic segmentation functionality?
Another option I tried was https://github.com/PyTorchLightning/lightning-flash#flash-transforms but for I realized that
segmentation/input_transform.py does not have the same functionality as flash.image.classification.input_transform so far.
What would be the next recommended steps?
Beta Was this translation helpful? Give feedback.
All reactions