diff --git a/process_report/process_report.py b/process_report/process_report.py index 7553e4d..2624276 100644 --- a/process_report/process_report.py +++ b/process_report/process_report.py @@ -60,12 +60,30 @@ def load_old_pis(old_pi_file): return old_pi_dict +def dump_old_pis(old_pi_file, old_pi_dict: dict): + with open(old_pi_file, "w") as f: + for pi, first_month in old_pi_dict.items(): + f.write(f"{pi},{first_month}\n") + + def is_old_pi(old_pi_dict, pi, invoice_month): - if pi in old_pi_dict and old_pi_dict[pi] != invoice_month: + first_invoice_month = old_pi_dict.get(pi, invoice_month) + if compare_invoice_month(first_invoice_month, invoice_month): + sys.exit( + f"PI {pi} from {first_invoice_month} found in {invoice_month} invoice!" + ) + if compare_invoice_month(invoice_month, first_invoice_month): return True return False +def compare_invoice_month(month_1, month_2): + """Returns True if 1st date is later than 2nd date""" + dt1 = datetime.datetime.strptime(month_1, "%Y-%m") + dt2 = datetime.datetime.strptime(month_2, "%Y-%m") + return dt1 > dt2 + + def get_invoice_bucket(): try: s3_resource = boto3.resource( @@ -103,7 +121,7 @@ def main(): parser.add_argument( "--upload-to-s3", action="store_true", - help="If set, uploads all processed invoices to S3", + help="If set, uploads all processed invoices and old PI file to S3", ) parser.add_argument( "--invoice-month", @@ -159,16 +177,20 @@ def main(): parser.add_argument( "--old-pi-file", required=False, - help="Name of csv file listing previously billed PIs", + help="Name of csv file listing previously billed PIs. If not provided, defaults to fetching from S3", ) args = parser.parse_args() invoice_month = args.invoice_month if args.fetch_from_s3: - csv_files = fetch_S3_invoices(invoice_month) + csv_files = fetch_s3_invoices(invoice_month) else: csv_files = args.csv_files + if args.old_pi_file: + old_pi_file = args.old_pi_file + else: + old_pi_file = fetch_s3_old_pi_file() merged_dataframe = merge_csv(csv_files) @@ -192,7 +214,7 @@ def main(): billable_projects = remove_non_billables(merged_dataframe, pi, projects) billable_projects = validate_pi_names(billable_projects) - credited_projects = apply_credits_new_pi(billable_projects, args.old_pi_file) + credited_projects = apply_credits_new_pi(billable_projects, old_pi_file) export_billables(credited_projects, args.output_file) export_pi_billables(credited_projects, args.output_folder, invoice_month) @@ -211,9 +233,10 @@ def main(): upload_to_s3(invoice_list, invoice_month) upload_to_s3_HU_BU(args.HU_BU_invoice_file, invoice_month) + upload_to_s3_old_pi_file(old_pi_file) -def fetch_S3_invoices(invoice_month): +def fetch_s3_invoices(invoice_month): """Fetches usage invoices from S3 given invoice month""" s3_invoice_list = list() invoice_bucket = get_invoice_bucket() @@ -290,7 +313,7 @@ def remove_billables(dataframe, pi, projects, output_file): def validate_pi_names(dataframe): invalid_pi_projects = dataframe[pandas.isna(dataframe[PI_FIELD])] for i, row in invalid_pi_projects.iterrows(): - print(f"Warning: Project {row[PROJECT_FIELD]} has empty PI field") + print(f"Warning: Billable project {row[PROJECT_FIELD]} has empty PI field") dataframe = dataframe[~pandas.isna(dataframe[PI_FIELD])] return dataframe @@ -316,6 +339,8 @@ def apply_credits_new_pi(dataframe, old_pi_file): for i, row in pi_projects.iterrows(): dataframe.at[i, BALANCE_FIELD] = row[COST_FIELD] else: + old_pi_dict[pi] = invoice_month + print(f"Found new PI {pi}") remaining_credit = new_pi_credit_amount for i, row in pi_projects.iterrows(): project_cost = row[COST_FIELD] @@ -329,9 +354,23 @@ def apply_credits_new_pi(dataframe, old_pi_file): if remaining_credit == 0: break + dump_old_pis(old_pi_file, old_pi_dict) + return dataframe +def fetch_s3_old_pi_file(): + local_name = "PI.csv" + invoice_bucket = get_invoice_bucket() + invoice_bucket.download_file("PIs/PI.csv", local_name) + return local_name + + +def upload_to_s3_old_pi_file(old_pi_file): + invoice_bucket = get_invoice_bucket() + invoice_bucket.upload_file(old_pi_file, "PIs/PI.csv") + + def add_institution(dataframe: pandas.DataFrame): """Determine every PI's institution name, logging any PI whose institution cannot be determined This is performed by `get_institution_from_pi()`, which tries to match the PI's username to diff --git a/process_report/tests/unit_tests.py b/process_report/tests/unit_tests.py index 1aa1428..baa58f3 100644 --- a/process_report/tests/unit_tests.py +++ b/process_report/tests/unit_tests.py @@ -328,6 +328,16 @@ def test_apply_credit_0002(self): self.assertEqual(0, credited_projects.loc[4, "Balance"]) self.assertEqual(800, credited_projects.loc[5, "Balance"]) + updated_old_pi_answer = "PI2,2023-09\nPI3,2024-02\nPI4,2024-03\nPI1,2024-03\n" + with open(self.old_pi_file, "r") as f: + self.assertEqual(updated_old_pi_answer, f.read()) + + def test_apply_credit_error(self): + old_pi_dict = {"PI1": "2024-12"} + invoice_month = "2024-03" + with self.assertRaises(SystemExit): + process_report.is_old_pi(old_pi_dict, "PI1", invoice_month) + class TestValidateBillables(TestCase): def setUp(self):