Skip to content

Commit

Permalink
refactor: replace every if ... else to if ... is None else
Browse files Browse the repository at this point in the history
  • Loading branch information
Francisco Muñoz committed Jan 20, 2025
1 parent 25c45ba commit 3c4a2f2
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 33 deletions.
24 changes: 15 additions & 9 deletions geetools/ee_feature_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def toDictionary(
print(json.dumps(countries.getInfo(), indent=2))
"""
uniqueIds = self._obj.aggregate_array(keyColumn)
selectors = ee.List(selectors) if selectors else self._obj.first().propertyNames()
selectors = ee.List(selectors) if selectors is None else self._obj.first().propertyNames()
keyColumn = ee.String(keyColumn)

features = self._obj.toList(self._obj.size())
Expand Down Expand Up @@ -321,14 +321,16 @@ def byProperties(
features = features.map(lambda i: ee.Algorithms.If(isString(i), i, ee.Number(i).format()))

# retrieve properties for each feature
properties = ee.List(properties) if properties else self._obj.first().propertyNames()
properties = (
ee.List(properties) if properties is None else self._obj.first().propertyNames()
)
properties = properties.remove(featureId)
values = properties.map(
lambda p: ee.Dictionary.fromLists(features, self._obj.aggregate_array(p))
)

# get the label to use in the dictionary if requested
labels = ee.List(labels) if labels else properties
labels = ee.List(labels) if labels is None else properties

return ee.Dictionary.fromLists(labels, values)

Expand Down Expand Up @@ -380,9 +382,9 @@ def byFeatures(
"""
# compute the properties and their labels
props = ee.List(properties) if properties else self._obj.first().propertyNames()
props = ee.List(properties) if properties is None else self._obj.first().propertyNames()
props = props.remove(featureId)
labels = ee.List(labels) if labels else props
labels = ee.List(labels) if labels is None else props

# create a function to get the properties of a feature
# we need to map the featureCollection into a list as it's not possible to return something else than a
Expand Down Expand Up @@ -455,14 +457,18 @@ def plot_by_features(
label.set_rotation(45)
"""
# Get the features and properties
props = ee.List(properties) if properties else self._obj.first().propertyNames().getInfo()
props = (
ee.List(properties)
if properties is None
else self._obj.first().propertyNames().getInfo()
)
props = props.remove(featureId)

# get the data from server
data = self.byProperties(featureId, props, labels).getInfo()

# reorder the data according to the labels or properties set by the user
labels = labels if labels else props.getInfo()
labels = labels if labels is None else props.getInfo()
data = {k: data[k] for k in labels}

return plot_data(type=type, data=data, label_name=featureId, colors=colors, ax=ax, **kwargs)
Expand Down Expand Up @@ -518,14 +524,14 @@ def plot_by_properties(
"""
# Get the features and properties
fc = self._obj
props = ee.List(properties) if properties else fc.first().propertyNames()
props = ee.List(properties) if properties is None else fc.first().propertyNames()
props = props.remove(featureId)

# get the data from server
data = self.byFeatures(featureId, props, labels).getInfo()

# reorder the data according to the lapbes or properties set by the user
labels = labels if labels else props.getInfo()
labels = labels if labels is None else props.getInfo()
data = {f: {k: data[f][k] for k in labels} for f in data.keys()}

return plot_data(type=type, data=data, label_name=featureId, colors=colors, ax=ax, **kwargs)
Expand Down
38 changes: 19 additions & 19 deletions geetools/ee_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def addDate(self, format: str | ee.String = "") -> ee.Image:
"""
# parse the inputs
isMillis = ee.String(format).equals(ee.String(""))
format = ee.String(format) if format else ee.String("YYYYMMdd")
format = ee.String(format) if format is None else ee.String("YYYYMMdd")

# extract the date from the object and create a image band from it
date = self._obj.date()
Expand Down Expand Up @@ -222,7 +222,7 @@ def doyToDate(
print(image.reduceRegion(ee.Reducer.min(), vatican, 1).getInfo())
"""
year = ee.Number(year)
band = ee.String(band) if band else ee.String(self._obj.bandNames().get(0))
band = ee.String(band) if band is None else ee.String(self._obj.bandNames().get(0))
dateFormat = ee.String(dateFormat)

doyList = ee.List.sequence(0, 365)
Expand Down Expand Up @@ -345,7 +345,7 @@ def toGrid(
grid = image.geetools.toGrid(1, 'B2', buffer)
print(grid.getInfo())
"""
band = ee.String(band) if band else self._obj.bandNames().get(0)
band = ee.String(band) if band is None else self._obj.bandNames().get(0)
projection = self._obj.select(band).projection()
size = projection.nominalScale().multiply(ee.Number(size).toInt())

Expand Down Expand Up @@ -468,8 +468,8 @@ def full(
image = ee.Image.geetools.full([1, 2, 3], ['a', 'b', 'c'])
print(image.bandNames().getInfo())
"""
values = ee.List(values) if values else ee.List([0])
names = ee.List(names) if names else ee.List(["constant"])
values = ee.List(values) if values is None else ee.List([0])
names = ee.List(names) if names is None else ee.List(["constant"])

# resize value to the same length as names
values = ee.List(
Expand Down Expand Up @@ -573,7 +573,7 @@ def reduceBands(
if not isinstance(reducer, str):
raise TypeError("reducer must be a Python string")

bands = ee.List(bands) if bands else ee.List([])
bands = ee.List(bands) if bands is None else ee.List([])
name = ee.String(name)
bands = ee.Algorithms.If(bands.size().eq(0), self._obj.bandNames(), bands)
name = ee.Algorithms.If(name.equals(ee.String("")), reducer, name)
Expand Down Expand Up @@ -678,7 +678,7 @@ def gauss(self, band: str | ee.String = "") -> ee.Image:
image = image.geetools.gauss()
print(image.bandNames().getInfo())
"""
band = ee.String(band) if band else ee.String(self._obj.bandNames().get(0))
band = ee.String(band) if band is None else ee.String(self._obj.bandNames().get(0))
image = self._obj.select(band)

kwargs = {"geometry": image.geometry(), "bestEffort": True}
Expand Down Expand Up @@ -1739,10 +1739,10 @@ def byBands(
features = features.map(lambda i: ee.Algorithms.If(isString(i), i, ee.Number(i).format()))

# get the bands to be used in the reducer
eeBands = ee.List(bands) if bands else self._obj.bandNames()
eeBands = ee.List(bands) if bands is None else self._obj.bandNames()

# retrieve the label to use for each bands if provided
eeLabels = ee.List(labels) if labels else eeBands
eeLabels = ee.List(labels) if labels is None else eeBands

# by default for 1 band image, the reducers are renaming the output band. To ensure it keeps
# the original band name we add setOutputs that is ignored for multi band images.
Expand Down Expand Up @@ -1830,10 +1830,10 @@ def byRegions(
features = features.map(lambda i: ee.Algorithms.If(isString(i), i, ee.Number(i).format()))

# get the bands to be used in the reducer
bands = ee.List(bands) if bands else self._obj.bandNames()
bands = ee.List(bands) if bands is None else self._obj.bandNames()

# retrieve the label to use for each bands if provided
labels = ee.List(labels) if labels else bands
labels = ee.List(labels) if labels is None else bands

# by default for 1 band image, the reducers are renaming the output band. To ensure it keeps
# the original band name we add setOutputs that is ignored for multi band images.
Expand Down Expand Up @@ -1942,8 +1942,8 @@ def plot_by_regions(
features = features.getInfo()

# extract the labels from the parameters
eeBands = ee.List(bands) if bands else self._obj.bandNames()
labels = labels if labels else eeBands.getInfo()
eeBands = ee.List(bands) if bands is None else self._obj.bandNames()
labels = labels if labels is None else eeBands.getInfo()

# reorder the data according to the labels id set by the user
data = {b: {f: data[b][f] for f in features} for b in labels}
Expand Down Expand Up @@ -2033,8 +2033,8 @@ def plot_by_bands(
features = features.getInfo()

# extract the labels from the parameters
eeBands = ee.List(bands) if bands else self._obj.bandNames()
labels = labels if labels else eeBands.getInfo()
eeBands = ee.List(bands) if bands is None else self._obj.bandNames()
labels = labels if labels is None else eeBands.getInfo()

# reorder the data according to the labels id set by the user
data = {f: {b: data[f][b] for b in labels} for f in features}
Expand Down Expand Up @@ -2102,14 +2102,14 @@ def plot_hist(
# TODO: In this case, the default is "to all bands",
# but the original implementation the default was an empty list.
# Is that correct?
eeBands = ee.List(bands) if bands else self._obj.bandNames()
eeBands = ee.List(bands) if bands is None else self._obj.bandNames()
# TODO: Same here
eeLabels = ee.List(labels).flatten() if labels else eeBands
eeLabels = ee.List(labels).flatten() if labels is None else eeBands
new_labels: list[str] = eeLabels.getInfo()
new_colors: list[str] = colors if colors else plt.get_cmap("tab10").colors
new_colors: list[str] = colors if colors is None else plt.get_cmap("tab10").colors

# retrieve the region from the parameters
region = region if region else self._obj.geometry()
region = region if region is None else self._obj.geometry()

# extract the data from the server
image = self._obj.select(eeBands).rename(eeLabels).clip(region)
Expand Down
10 changes: 5 additions & 5 deletions geetools/ee_image_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,7 +669,7 @@ def outliers(
"""
# cast parameters and compute the outlier band names
initBands = self._obj.first().bandNames()
statBands = ee.List(bands) if bands else initBands
statBands = ee.List(bands) if bands is None else initBands
outBands = statBands.map(lambda b: ee.String(b).cat("_outlier"))

# compute the mean and std dev for each band
Expand Down Expand Up @@ -1214,8 +1214,8 @@ def datesByBands(
print(reduced.getInfo())
"""
# cast parameters
eeBands = ee.List(bands) if bands else self._obj.first().bandNames()
eeLabels = ee.List(labels) if labels else eeBands
eeBands = ee.List(bands) if bands is None else self._obj.first().bandNames()
eeLabels = ee.List(labels) if labels is None else eeBands

# recast band names as labels in the source collection
ic = self._obj.select(eeBands).map(lambda i: i.rename(eeLabels))
Expand Down Expand Up @@ -1384,8 +1384,8 @@ def doyByBands(
- :docstring:`ee.ImageCollection.geetools.plot_doy_by_years`
"""
# cast parameters
bands = ee.List(bands) if bands else self._obj.first().bandNames()
labels = ee.List(labels) if labels else bands
bands = ee.List(bands) if bands is None else self._obj.first().bandNames()
labels = ee.List(labels) if labels is None else bands

# recast band names as labels in the source collection
ic = self._obj.select(bands).map(lambda i: i.rename(labels))
Expand Down

0 comments on commit 3c4a2f2

Please sign in to comment.