Skip to content

Commit

Permalink
Style
Browse files Browse the repository at this point in the history
  • Loading branch information
bw4sz committed Jun 13, 2024
1 parent 45ac4bf commit 177c324
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 18 deletions.
20 changes: 10 additions & 10 deletions deepforest/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,17 +206,17 @@ def create_trainer(self, logger=None, callbacks=[], **kwargs):
enable_checkpointing = True
else:
enable_checkpointing = False

trainer_args = {
"logger":logger,
"max_epochs":self.config["train"]["epochs"],
"enable_checkpointing":enable_checkpointing,
"devices":self.config["devices"],
"accelerator":self.config["accelerator"],
"fast_dev_run":self.config["train"]["fast_dev_run"],
"callbacks":callbacks,
"limit_val_batches":limit_val_batches,
"num_sanity_val_steps":num_sanity_val_steps
"logger": logger,
"max_epochs": self.config["train"]["epochs"],
"enable_checkpointing": enable_checkpointing,
"devices": self.config["devices"],
"accelerator": self.config["accelerator"],
"fast_dev_run": self.config["train"]["fast_dev_run"],
"callbacks": callbacks,
"limit_val_batches": limit_val_batches,
"num_sanity_val_steps": num_sanity_val_steps
}
# Update with kwargs to allow them to override config
trainer_args.update(kwargs)
Expand Down
24 changes: 16 additions & 8 deletions deepforest/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,16 @@ def project_boxes(df, root_dir, transform=True):

return df

def download_ArcGIS_REST(url, xmin, ymin, xmax, ymax, savedir, additional_params=None, image_name="image.tiff", download_service="exportImage"):

def download_ArcGIS_REST(url,
xmin,
ymin,
xmax,
ymax,
savedir,
additional_params=None,
image_name="image.tiff",
download_service="exportImage"):
"""
Fetch data from a web server using geographic boundaries. The data is saved as a GeoTIFF file. The bbox is in the format of xmin, ymin, xmax, ymax for lat long coordinates..
This function is used to download data from an ArcGIS REST service, not WMTS or WMS services.
Expand All @@ -582,9 +591,7 @@ def download_ArcGIS_REST(url, xmin, ymin, xmax, ymax, savedir, additional_params
The response from the web server.
"""
# Construct the query parameters with the geographic boundaries
params = {
"f": "json"
}
params = {"f": "json"}
# add any additional parameters
if additional_params:
params.update(additional_params)
Expand All @@ -593,7 +600,7 @@ def download_ArcGIS_REST(url, xmin, ymin, xmax, ymax, savedir, additional_params
response = requests.get(url, params=params)

# turn into dict
response_dict = json.loads(response.content)
response_dict = json.loads(response.content)
spatialReference = response_dict["spatialReference"]
if "latestWkid" in spatialReference:
wkid = spatialReference["latestWkid"]
Expand All @@ -603,13 +610,14 @@ def download_ArcGIS_REST(url, xmin, ymin, xmax, ymax, savedir, additional_params

# Convert bbox into image coordinates
bbox = f"{xmin},{ymin},{xmax},{ymax}"
bounds = gpd.GeoDataFrame(geometry=[shapely.geometry.box(ymin, xmin, ymax, xmax)], crs='EPSG:4326').to_crs(crs).bounds
bounds = gpd.GeoDataFrame(geometry=[shapely.geometry.box(ymin, xmin, ymax, xmax)],
crs='EPSG:4326').to_crs(crs).bounds

# update the params
params.update({
"bbox": f"{bounds.minx[0]},{bounds.miny[0]},{bounds.maxx[0]},{bounds.maxy[0]}",
"f": "image",
'format':'tiff',
'format': 'tiff',
"noData": "0"
})

Expand All @@ -622,4 +630,4 @@ def download_ArcGIS_REST(url, xmin, ymin, xmax, ymax, savedir, additional_params
f.write(response.content)
return filename
else:
raise Exception(f"Failed to fetch data: {response.code}")
raise Exception(f"Failed to fetch data: {response.code}")

0 comments on commit 177c324

Please sign in to comment.