Skip to content

Commit

Permalink
Update western_usa_live_fuel_moisture.py with plot class
Browse files Browse the repository at this point in the history
  • Loading branch information
preethatr07 authored Dec 17, 2024
1 parent bc1f441 commit 56e4264
Showing 1 changed file with 77 additions and 0 deletions.
77 changes: 77 additions & 0 deletions torchgeo/datasets/western_usa_live_fuel_moisture.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
import pandas as pd
import torch

import matplotlib.pyplot as plt
import seaborn as sns

from .errors import DatasetNotFoundError
from .geo import NonGeoDataset
from .utils import Path, which
Expand Down Expand Up @@ -297,3 +300,77 @@ def _download(self) -> None:
os.makedirs(self.root, exist_ok=True)
azcopy = which('azcopy')
azcopy('sync', self.url, self.root, '--recursive=true')


def plot(
self,
x_feature: str = None,
y_feature: str = None,
kind: str = "scatter",
title: str = None,
save_path: str = None,
) -> None:
"""Plot features or relationships within the dataset.
Args:
x_feature: Name of the feature to plot on the x-axis.
y_feature: Name of the feature to plot on the y-axis.
Defaults to the label if not specified.
kind: Type of plot ('scatter', 'hist', 'box', or 'geo').
title: Title of the plot.
save_path: If provided, save the plot to the given path.
"""
if x_feature not in self.input_features:
raise ValueError(f"'{x_feature}' is not a valid input feature.")
if y_feature is None:
y_feature = self.label_name
if y_feature not in self.input_features and y_feature != self.label_name:
raise ValueError(f"'{y_feature}' is not a valid feature or label.")

plt.figure(figsize=(10, 6))

if kind == "scatter":
# Scatter plot for feature relationships
sns.scatterplot(
x=self.dataframe[x_feature],
y=self.dataframe[y_feature],
alpha=0.7,
)
plt.xlabel(x_feature)
plt.ylabel(y_feature)
plt.title(title or f"Scatter plot: {x_feature} vs {y_feature}")

elif kind == "hist":
# Histogram for a single feature
sns.histplot(self.dataframe[x_feature], kde=True, bins=30, color="blue")
plt.xlabel(x_feature)
plt.title(title or f"Distribution of {x_feature}")

elif kind == "box":
# Boxplot for feature distributions
sns.boxplot(y=self.dataframe[x_feature])
plt.title(title or f"Boxplot of {x_feature}")

elif kind == "geo":
# Spatial scatter plot using latitude and longitude
if "lat" not in self.input_features or "lon" not in self.input_features:
raise ValueError("Latitude ('lat') and longitude ('lon') must be input features for geo plots.")
sns.scatterplot(
x=self.dataframe["lon"],
y=self.dataframe["lat"],
hue=self.dataframe[self.label_name],
palette="viridis",
alpha=0.7,
)
plt.xlabel("Longitude")
plt.ylabel("Latitude")
plt.title(title or "Geographic Distribution of Fuel Moisture")

else:
raise ValueError(f"Plot kind '{kind}' is not supported.")

plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300)
else:
plt.show()

0 comments on commit 56e4264

Please sign in to comment.