Skip to content

Commit

Permalink
enhance output and results saving
Browse files Browse the repository at this point in the history
  • Loading branch information
etetteh committed Sep 23, 2021
1 parent a462b52 commit 5e90788
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 22 deletions.
45 changes: 23 additions & 22 deletions chest.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from sklearn.metrics import roc_auc_score, accuracy_score

import torchxrayvision as xrv

from utils import write_results

parser = argparse.ArgumentParser(description='X-RAY Pathology Detection')
parser.add_argument('--seed', type=int, default=0, help='')
Expand Down Expand Up @@ -540,7 +540,7 @@ def get_model_inputimg(model_name, num_classes):
model = model.to(device)

print("\n Training Model \n")
output_dir = "merge_train-" + str(cfg.merge_train) + "_split-" + str(cfg.split) + "_" + model_name + "_valid-" + cfg.valid_data + "/"
output_dir = model_name + "_merge_train-" + str(cfg.merge_train) + "_split-" + str(cfg.split) + "_valid-" + cfg.valid_data + "_seed-" + str(cfg.seed) + "/"

metrics, best_metric, = main(model, model_name, output_dir, num_epochs=cfg.num_epochs)
print(f"Best validation AUC: {best_metric:4.4f}")
Expand All @@ -565,24 +565,25 @@ def get_model_inputimg(model_name, num_classes):
print(f"Average AUC for all pathologies {test_auc:4.4f}")
print(f"Test loss: {test_loss:4.4f}")
print(f"AUC for each task {[round(x, 4) for x in task_aucs]}")

test_filename = "test_results_.csv"
field_names = ['Model', 'Test_AVG_AUC', 'Test_loss', 'Cardiomegaly', 'Effusion', 'Edema', 'Consolidation', ]
model_name = model_name + "_merge_train-" + str(cfg.merge_train) + "_split-" + str(cfg.split) + "_valid-" + cfg.valid_data + "_seed-" + str(cfg.seed)
results = {
'Model': model_name,
'Test_AVG_AUC': round(test_auc, 2),
'Test_loss': round(test_loss, 2),
'Cardiomegaly': round(task_aucs[0], 2),
'Effusion': round(task_aucs[1], 2),
'Edema': round(task_aucs[2], 2),
'Consolidation': round(task_aucs[3], 2),
}

if os.path.exists(test_filename):
write_results(filename=test_filename, field_names=field_names, results=results)
else:
with open(test_filename, 'w') as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=field_names)

test_filename = output_dir + output_dir.strip("/") + "_test_results.csv"

with open(test_filename, 'w') as csvfile:
field_names = ['Test_loss', 'Test_AVG_AUC',
'Cardiomegaly',
'Effusion',
'Edema',
'Consolidation',
]
writer = csv.DictWriter(csvfile, fieldnames=field_names)

writer.writeheader()
writer.writerow({
'Test_loss': round(test_loss, 4),
'Test_AVG_AUC': round(test_auc, 4),
'Cardiomegaly': round(task_aucs[0], 4),
'Effusion': round(task_aucs[1], 4),
'Edema': round(task_aucs[2], 4),
'Consolidation': round(task_aucs[3], 4),
})
writer.writeheader()
writer.writerow(results)
37 changes: 37 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import os
import csv

def write_results(filename: str, field_names: list, results: dict):
read_file = open(filename, "r")
results_file = csv.DictReader(read_file)
update = []
new = []
row = {}
for r in results_file:
if r['Model'] == results['Model']:
for key, value in results.items():
row[key] = value
update = row
else:
for key, value in results.items():
row[key] = value
new = row

read_file.close()

if update:
print("Results exists. Updating results in file...")
print(update)
update_file = open(filename, "w", newline='')
data = csv.DictWriter(update_file, delimiter=',', fieldnames=field_names)
data.writeheader()
data.writerows([update])
else:
print("Results does not exist. Writing results to file...")
print(new)
update_file = open(filename, "a+", newline='')
data = csv.DictWriter(update_file, delimiter=',', fieldnames=field_names)
data.writerows([new])

update_file.close()

0 comments on commit 5e90788

Please sign in to comment.