Skip to content

Commit

Permalink
Style edits
Browse files Browse the repository at this point in the history
  • Loading branch information
bw4sz committed Jun 6, 2024
1 parent 942f383 commit 4fbbf6b
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 92 deletions.
28 changes: 18 additions & 10 deletions deepforest/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from rasterio.windows import Window
from torchvision import transforms


def get_transform(augment):
"""Albumentations transformation of bounding boxs"""
if augment:
Expand Down Expand Up @@ -137,6 +138,7 @@ def __getitem__(self, idx):


class TileDataset(Dataset):

def __init__(self,
tile: typing.Optional[np.ndarray],
preload_images: bool = False,
Expand Down Expand Up @@ -179,16 +181,20 @@ def __getitem__(self, idx):

return crop


def bounding_box_transform(augment=False):
data_transforms = []
data_transforms.append(transforms.ToTensor())
data_transforms.append(resnet_normalize)
data_transforms.append(transforms.Resize([224,224]))
data_transforms.append(transforms.Resize([224, 224]))
if augment:
data_transforms.append(transforms.RandomHorizontalFlip(0.5))
return transforms.Compose(data_transforms)

resnet_normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

resnet_normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])


class BoundingBoxDataset(Dataset):
"""An in memory dataset for bounding box predictions
Expand All @@ -200,37 +206,39 @@ class BoundingBoxDataset(Dataset):
Returns:
rgb: a tensor of shape (3, height, width)
"""

def __init__(self, df, root_dir, transform=None, augment=False):
self.df = df

if transform is None:
self.transform = bounding_box_transform(augment=augment)
else:
self.transform = transform

unique_image = self.df['image_path'].unique()
assert len(unique_image) == 1, "There should be only one unique image for this class object"
assert len(unique_image
) == 1, "There should be only one unique image for this class object"

# Open the image using rasterio
self.src = rio.open(os.path.join(root_dir,unique_image[0]))
self.src = rio.open(os.path.join(root_dir, unique_image[0]))

def __len__(self):
return len(self.df)

def __getitem__(self, idx):
row = self.df.iloc[idx]
xmin = row['xmin']
xmax = row['xmax']
ymin = row['ymin']
ymax = row['ymax']

# Read the RGB data
box = self.src.read(window=Window(xmin, ymin, xmax-xmin, ymax-ymin))
box = self.src.read(window=Window(xmin, ymin, xmax - xmin, ymax - ymin))
box = np.rollaxis(box, 0, 3)

if self.transform:
image = self.transform(box)
else:
image = box
return image

return image
13 changes: 6 additions & 7 deletions deepforest/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,13 +525,12 @@ def predict_tile(self,

if crop_model:
# If a crop model is provided, predict on each crop
results = predict._predict_crop_model_(
crop_model=crop_model,
results=results,
raster_path=raster_path,
trainer=self.trainer,
transform=crop_transform,
augment=crop_augment)
results = predict._predict_crop_model_(crop_model=crop_model,
results=results,
raster_path=raster_path,
trainer=self.trainer,
transform=crop_transform,
augment=crop_augment)

return results

Expand Down
149 changes: 78 additions & 71 deletions deepforest/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import torch.nn.functional as F
import cv2


class Model():
"""A architecture agnostic class that controls the basic train, eval and predict functions.
A model should optionally allow a backbone for pretraining. To add new architectures, simply create a new module in models/ and write a create_model.
Expand Down Expand Up @@ -62,42 +63,55 @@ def check_model(self):
model_keys.sort()
assert model_keys == ['boxes', 'labels', 'scores']


def simple_resnet_50(num_classes=2):
m = models.resnet50(pretrained=True)
num_ftrs = m.fc.in_features
m.fc = torch.nn.Linear(num_ftrs, num_classes)
m.fc = torch.nn.Linear(num_ftrs, num_classes)

return m


return m

class CropModel(LightningModule):

def __init__(self, num_classes=2, batch_size=4, num_workers=0, lr=0.0001, model=None):
super().__init__()

# Model
self.num_classes = num_classes
if model == None:
self.model = simple_resnet_50(num_classes=num_classes)
else:
self.model = model
self.model = model

# Metrics
self.accuracy = torchmetrics.Accuracy(average='none', num_classes=num_classes, task="multiclass")
self.total_accuracy = torchmetrics.Accuracy(num_classes=num_classes, task="multiclass")
self.precision_metric = torchmetrics.Precision(num_classes=num_classes, task="multiclass")
self.metrics = torchmetrics.MetricCollection({"Class Accuracy":self.accuracy, "Accuracy":self.total_accuracy, "Precision":self.precision_metric})
self.accuracy = torchmetrics.Accuracy(average='none',
num_classes=num_classes,
task="multiclass")
self.total_accuracy = torchmetrics.Accuracy(num_classes=num_classes,
task="multiclass")
self.precision_metric = torchmetrics.Precision(num_classes=num_classes,
task="multiclass")
self.metrics = torchmetrics.MetricCollection({
"Class Accuracy": self.accuracy,
"Accuracy": self.total_accuracy,
"Precision": self.precision_metric
})

# Training Hyperparameters
self.batch_size = batch_size
self.num_workers = num_workers
self.lr = lr

def create_trainer(self, **kwargs):
"""Create a pytorch lightning trainer object"""
self.trainer = Trainer(**kwargs)

def load_from_disk(self, train_dir, val_dir):
self.train_ds = ImageFolder(root=train_dir, transform=self.get_transform(augment=True))
self.val_ds = ImageFolder(root=val_dir, transform=self.get_transform(augment=False))
self.train_ds = ImageFolder(root=train_dir,
transform=self.get_transform(augment=True))
self.val_ds = ImageFolder(root=val_dir,
transform=self.get_transform(augment=False))

def get_transform(self, augment):
"""
Expand All @@ -112,11 +126,11 @@ def get_transform(self, augment):
data_transforms = []
data_transforms.append(transforms.ToTensor())
data_transforms.append(self.normalize())
data_transforms.append(transforms.Resize([224,224]))
data_transforms.append(transforms.Resize([224, 224]))
if augment:
data_transforms.append(transforms.RandomHorizontalFlip(0.5))
return transforms.Compose(data_transforms)

def write_crops(self, root_dir, images, boxes, labels, savedir):
"""
Write crops to disk.
Expand All @@ -134,7 +148,7 @@ def write_crops(self, root_dir, images, boxes, labels, savedir):

# Create a directory for each label
for label in labels:
os.makedirs(os.path.join(savedir,label), exist_ok=True)
os.makedirs(os.path.join(savedir, label), exist_ok=True)

# Use rasterio to read the image
for index, box in enumerate(boxes):
Expand All @@ -150,63 +164,57 @@ def write_crops(self, root_dir, images, boxes, labels, savedir):
cv2.imwrite(img_path, img)

def normalize(self):
return transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
return transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

def forward(self, x):
output = self.model(x)
output = F.sigmoid(output)

return output

def train_dataloader(self):
train_loader = torch.utils.data.DataLoader(
self.train_ds,
batch_size=self.batch_size,
shuffle=True,
num_workers=self.num_workers
)

train_loader = torch.utils.data.DataLoader(self.train_ds,
batch_size=self.batch_size,
shuffle=True,
num_workers=self.num_workers)

return train_loader

def predict_dataloader(self, ds):
loader = torch.utils.data.DataLoader(
ds,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers
)

loader = torch.utils.data.DataLoader(ds,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers)

return loader

def val_dataloader(self):
val_loader = torch.utils.data.DataLoader(
self.val_ds,
batch_size=self.batch_size,
shuffle=True,
num_workers=self.num_workers
)

val_loader = torch.utils.data.DataLoader(self.val_ds,
batch_size=self.batch_size,
shuffle=True,
num_workers=self.num_workers)

return val_loader

def training_step(self, batch, batch_idx):
x,y = batch
x, y = batch
outputs = self.forward(x)
loss = F.cross_entropy(outputs,y)
self.log("train_loss",loss)
loss = F.cross_entropy(outputs, y)
self.log("train_loss", loss)

return loss

def predict_step(self, batch, batch_idx):
outputs = self.forward(batch)
yhat = F.softmax(outputs, 1)

return yhat

def validation_step(self, batch, batch_idx):
x,y = batch
x, y = batch
outputs = self(x)
loss = F.cross_entropy(outputs,y)
self.log("val_loss",loss)
loss = F.cross_entropy(outputs, y)
self.log("val_loss", loss)
metric_dict = self.metrics(outputs, y)
for key, value in metric_dict.items():
for key, value in metric_dict.items():
Expand All @@ -216,35 +224,34 @@ def validation_step(self, batch, batch_idx):
else:
self.log(key, value, on_step=False, on_epoch=True)
return loss



def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
mode='min',
factor=0.5,
patience=10,
verbose=True,
threshold=0.0001,
threshold_mode='rel',
cooldown=0,
min_lr=0,
eps=1e-08)
mode='min',
factor=0.5,
patience=10,
verbose=True,
threshold=0.0001,
threshold_mode='rel',
cooldown=0,
min_lr=0,
eps=1e-08)

#Monitor rate is val data is used
return {'optimizer':optimizer, 'lr_scheduler': scheduler,"monitor":'val_loss'}
return {'optimizer': optimizer, 'lr_scheduler': scheduler, "monitor": 'val_loss'}

def dataset_confusion(self, loader):
"""Create a confusion matrix from a data loader"""
true_class = []
predicted_class = []
self.eval()
for batch in loader:
x,y = batch
true_class.append(F.one_hot(y,num_classes=self.num_classes).detach().numpy())
x, y = batch
true_class.append(F.one_hot(y, num_classes=self.num_classes).detach().numpy())
prediction = self(x)
predicted_class.append(prediction.detach().numpy())

true_class = np.concatenate(true_class)
predicted_class = np.concatenate(predicted_class)

Expand Down
18 changes: 14 additions & 4 deletions deepforest/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,13 @@ def _dataloader_wrapper_(model,

return results

def _predict_crop_model_(crop_model, trainer, results, raster_path, transform=None, augment=False):

def _predict_crop_model_(crop_model,
trainer,
results,
raster_path,
transform=None,
augment=False):
"""
Predicts crop model on a raster file.
Expand All @@ -203,14 +209,18 @@ def _predict_crop_model_(crop_model, trainer, results, raster_path, transform=No
Returns:
The updated results dataframe with predicted labels and scores.
"""
bounding_box_dataset = dataset.BoundingBoxDataset(results, root_dir=os.path.dirname(raster_path), transform=transform, augment=augment)
bounding_box_dataset = dataset.BoundingBoxDataset(
results,
root_dir=os.path.dirname(raster_path),
transform=transform,
augment=augment)
crop_dataloader = crop_model.predict_dataloader(bounding_box_dataset)
crop_results = trainer.predict(crop_model, crop_dataloader)
stacked_outputs = np.vstack(np.concatenate(crop_results))
label = np.argmax(stacked_outputs, 1)
score = np.max(stacked_outputs, 1)
score = np.max(stacked_outputs, 1)

results["cropmodel_label"] = label
results["cropmodel_score"] = score

return results
return results

0 comments on commit 4fbbf6b

Please sign in to comment.