Skip to content

Commit

Permalink
Updated to remove writing of download report, returns byte data
Browse files Browse the repository at this point in the history
  • Loading branch information
jcadam14 committed Nov 7, 2024
1 parent 4398f7e commit 62b2726
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 43 deletions.
2 changes: 1 addition & 1 deletion src/regtech_data_validator/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def validate(
case OutputFormat.TABLE:
print(df_to_table(final_df))
case OutputFormat.DOWNLOAD:
df_to_download(final_df)
print(df_to_download(final_df))
case _:
raise ValueError(f'output format "{output}" not supported')

Expand Down
33 changes: 7 additions & 26 deletions src/regtech_data_validator/data_formatters.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import boto3
import ujson
import polars as pl
import fsspec

from tabulate import tabulate

Expand Down Expand Up @@ -116,13 +114,13 @@ def format_findings(df: pl.DataFrame, phase, checks):

def df_to_download(
df: pl.DataFrame,
path: str = "download_report.csv",
warning_count: int = 0,
error_count: int = 0,
max_errors: int = 1000000,
):
if df.is_empty():
# return headers of csv for 'emtpy' report
buffer = BytesIO()
empty_df = pl.DataFrame(
{
"validation_type": [],
Expand All @@ -134,9 +132,9 @@ def df_to_download(
"validation_description": [],
}
)
with fsspec.open(path, mode='wb') as f:
empty_df.write_csv(f, quote_style='non_numeric')
return
empty_df.write_csv(buffer, quote_style='non_numeric', include_header=True)
buffer.seek(0)
return buffer.getvalue()

# get the check for the phase the results were in, so we can pull out static data from each
# found check
Expand Down Expand Up @@ -196,26 +194,9 @@ def df_to_download(
f'"Your register contains {total_errors} {error_type}, however, only {max_errors} records are displayed in this report. To see additional {error_type}, correct the listed records, and upload a new file."\n'.encode()
)

if path.startswith("s3"):
sorted_df.write_csv(buffer, quote_style='non_numeric', include_header=False)
buffer.seek(0)
upload(path, buffer.getvalue())
else:
with fsspec.open(path, mode='wb') as f:
sorted_df.write_csv(buffer, quote_style='non_numeric', include_header=False)
buffer.seek(0)
f.write(buffer.getvalue())


def upload(path: str, content: bytes) -> None:
bucket = path.split("s3://")[1].split("/")[0]
opath = path.split("s3://")[1].replace(bucket + "/", "")
s3 = boto3.client("s3")
s3.put_object(
Bucket=bucket,
Key=opath,
Body=content,
)
sorted_df.write_csv(buffer, quote_style='non_numeric', include_header=False)
buffer.seek(0)
return buffer.getvalue()


def df_to_csv(df: pl.DataFrame) -> str:
Expand Down
18 changes: 2 additions & 16 deletions tests/test_output_formats.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
import polars as pl
import ujson

import tempfile
from pathlib import Path

from regtech_data_validator import global_data
from regtech_data_validator.data_formatters import df_to_csv, df_to_str, df_to_json, df_to_table, df_to_download
from regtech_data_validator.validation_results import ValidationPhase
Expand Down Expand Up @@ -254,14 +251,8 @@ def test_download_csv(self):
"""
).strip('\n')

gf = tempfile.NamedTemporaryFile(delete=False, mode='w', suffix='.csv')
temp_path = Path(gf.name)
df_to_download(self.findings_df, str(temp_path.resolve()))
with open(temp_path, 'r') as output:
actual_output = output.read()
print(f"{actual_output}")
actual_output = df_to_download(self.findings_df).decode('utf-8')
assert actual_output.strip() == expected_output
temp_path.unlink()

def test_empty_download_csv(self):
expected_output = dedent(
Expand All @@ -270,10 +261,5 @@ def test_empty_download_csv(self):
"""
).strip('\n')

gf = tempfile.NamedTemporaryFile(delete=False, mode='w', suffix='.csv')
temp_path = Path(gf.name)
df_to_download(pl.DataFrame(), str(temp_path.resolve()))
with open(temp_path, 'r') as output:
actual_output = output.read()
actual_output = df_to_download(pl.DataFrame()).decode('utf-8')
assert actual_output.strip() == expected_output
temp_path.unlink()

0 comments on commit 62b2726

Please sign in to comment.