Skip to content

Commit

Permalink
Allows old PI file to be fetched from S3, updated during processing, …
Browse files Browse the repository at this point in the history
…and uploaded back to S3

Also provides logic for cases where future PIs are encountered in past invoices, in which case they will be
considered new PIs for the purpose of the New PI credit, and their first-month date will be changed to the
past invoice's date
  • Loading branch information
QuanMPhm committed Apr 25, 2024
1 parent d61d2ba commit 62d985f
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 15 deletions.
41 changes: 34 additions & 7 deletions process_report/process_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,14 @@ def load_old_pis(old_pi_file):
return old_pi_dict


def dump_old_pis(old_pi_file, old_pi_dict: dict):
with open(old_pi_file, "w") as f:
for pi, first_month in old_pi_dict.items():
f.write(f"{pi},{first_month}\n")


def is_old_pi(old_pi_dict, pi, invoice_month):
if pi in old_pi_dict and old_pi_dict[pi] != invoice_month:
if old_pi_dict.get(pi, invoice_month) < invoice_month:
return True
return False

Expand Down Expand Up @@ -101,7 +107,7 @@ def main():
parser.add_argument(
"--upload-to-s3",
action="store_true",
help="If set, uploads all processed invoices to S3",
help="If set, uploads all processed invoices and old PI file to S3",
)
parser.add_argument(
"--invoice-month",
Expand Down Expand Up @@ -163,16 +169,20 @@ def main():
parser.add_argument(
"--old-pi-file",
required=False,
help="Name of csv file listing previously billed PIs",
help="Name of csv file listing previously billed PIs. If not provided, defaults to fetching from S3",
)
args = parser.parse_args()

invoice_month = args.invoice_month

if args.fetch_from_s3:
csv_files = fetch_S3_invoices(invoice_month)
csv_files = fetch_s3_invoices(invoice_month)
else:
csv_files = args.csv_files
if args.old_pi_file:
old_pi_file = args.old_pi_file
else:
old_pi_file = fetch_s3_old_pi_file()

merged_dataframe = merge_csv(csv_files)

Expand All @@ -196,7 +206,7 @@ def main():

billable_projects = remove_non_billables(merged_dataframe, pi, projects)
billable_projects = validate_pi_names(billable_projects)
credited_projects = apply_credits_new_pi(billable_projects, args.old_pi_file)
credited_projects = apply_credits_new_pi(billable_projects, old_pi_file)

export_billables(credited_projects, args.output_file)
export_pi_billables(credited_projects, args.output_folder, invoice_month)
Expand All @@ -217,9 +227,10 @@ def main():
invoice_list.append(os.path.join(args.output_folder, pi_invoice))

upload_to_s3(invoice_list, invoice_month)
upload_to_s3_old_pi_file(old_pi_file)


def fetch_S3_invoices(invoice_month):
def fetch_s3_invoices(invoice_month):
"""Fetches usage invoices from S3 given invoice month"""
s3_invoice_list = list()
invoice_bucket = get_invoice_bucket()
Expand Down Expand Up @@ -294,7 +305,7 @@ def remove_billables(dataframe, pi, projects, output_file):
def validate_pi_names(dataframe):
invalid_pi_projects = dataframe[pandas.isna(dataframe[PI_FIELD])]
for i, row in invalid_pi_projects.iterrows():
print(f"Warning: Project {row[PROJECT_FIELD]} has empty PI field")
print(f"Warning: Billable project {row[PROJECT_FIELD]} has empty PI field")
dataframe = dataframe[~pandas.isna(dataframe[PI_FIELD])]

return dataframe
Expand All @@ -320,6 +331,8 @@ def apply_credits_new_pi(dataframe, old_pi_file):
for i, row in pi_projects.iterrows():
dataframe.at[i, BALANCE_FIELD] = row[COST_FIELD]
else:
old_pi_dict[pi] = invoice_month
print(f"Found new PI {pi}")
remaining_credit = new_pi_credit_amount
for i, row in pi_projects.iterrows():
project_cost = row[COST_FIELD]
Expand All @@ -333,9 +346,23 @@ def apply_credits_new_pi(dataframe, old_pi_file):
if remaining_credit == 0:
break

dump_old_pis(old_pi_file, old_pi_dict)

return dataframe


def fetch_s3_old_pi_file():
local_name = "PI.csv"
invoice_bucket = get_invoice_bucket()
invoice_bucket.download_file("PIs/PI.csv", local_name)
return local_name


def upload_to_s3_old_pi_file(old_pi_file):
invoice_bucket = get_invoice_bucket()
invoice_bucket.upload_file(old_pi_file, "PIs/PI.csv")


def add_institution(dataframe: pandas.DataFrame):
"""Determine every PI's institution name, logging any PI whose institution cannot be determined
This is performed by `get_institution_from_pi()`, which tries to match the PI's username to
Expand Down
18 changes: 10 additions & 8 deletions process_report/tests/unit_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,9 +283,9 @@ def setUp(self):
self.dataframe = pandas.DataFrame(data)
old_pi = [
"PI2,2023-09",
"PI3,2024-02",
"PI3,2024-06",
"PI4,2024-03",
] # Case with old and new pi in pi file
] # Case with old, new, and future pi in pi file
old_pi_file = tempfile.NamedTemporaryFile(delete=False, mode="w", suffix=".csv")
for pi in old_pi:
old_pi_file.write(pi + "\n")
Expand All @@ -306,28 +306,30 @@ def test_apply_credit_0002(self):
non_credited_project = dataframe[pandas.isna(dataframe["Credit Code"])]
credited_projects = dataframe[dataframe["Credit Code"] == "0002"]

self.assertEqual(2, len(non_credited_project))
self.assertEqual(1, len(non_credited_project))
self.assertEqual(
non_credited_project.loc[2, "Cost"], non_credited_project.loc[2, "Balance"]
)
self.assertEqual(
non_credited_project.loc[3, "Cost"], non_credited_project.loc[3, "Balance"]
)

self.assertEqual(4, len(credited_projects.index))
self.assertEqual(5, len(credited_projects.index))
self.assertTrue("PI2" not in credited_projects["Manager (PI)"].unique())
self.assertTrue("PI3" not in credited_projects["Manager (PI)"].unique())

self.assertEqual(10, credited_projects.loc[0, "Credit"])
self.assertEqual(100, credited_projects.loc[1, "Credit"])
self.assertEqual(1000, credited_projects.loc[3, "Credit"])
self.assertEqual(800, credited_projects.loc[4, "Credit"])
self.assertEqual(200, credited_projects.loc[5, "Credit"])

self.assertEqual(0, credited_projects.loc[0, "Balance"])
self.assertEqual(0, credited_projects.loc[1, "Balance"])
self.assertEqual(4000, credited_projects.loc[3, "Balance"])
self.assertEqual(0, credited_projects.loc[4, "Balance"])
self.assertEqual(800, credited_projects.loc[5, "Balance"])

updated_old_pi_answer = "PI2,2023-09\nPI3,2024-03\nPI4,2024-03\nPI1,2024-03\n"
with open(self.old_pi_file, "r") as f:
self.assertEqual(updated_old_pi_answer, f.read())


class TestValidateBillables(TestCase):
def setUp(self):
Expand Down

0 comments on commit 62d985f

Please sign in to comment.