Out-of-Distribution Generalization of Chest X-ray Using Risk Extrapolation
Install the following libraries/packages with pip
torch
torchvision
torchxrayvsion
There are 12 different training, validation and test settings generated by combining 4 different Chest X-ray datasets (NIH ChestX-ray8 dataset, PadChest dataset, CheXpert, and MIMIC-CXR). These 12 settings are broken down into 6 splits (ranging from 0 to 5) that can be called by passing the argument --split=<split>
. For each split, you have the option to choose between 2 validation datasets by passing the argument --valid_data=<name of valid dataset>
.
The dataset names are condensed as short strings: "nih"
= NIH ChestX-ray8 dataset, "pc"
= PadChest dataset, "cx"
= CheXpert, and "mc"
= MIMIC-CXR.
For each setting, we compute the ROC-AUC for the following chest x-ray pathologies (labels): Cardiomegaly, Pneumonia, Effusion, Edema, Atelectasis, Consolidation, and Pneumothorax.
For each split, you train on two (2) datasets, validate on one (1) and test on the remaining one (1).
The chest.py file contains code to run both our baseline and REx models.
To finetune or perform feature extraction with ImageNet weights pass the --pretrained
and --feat_extract
arguments respectively
To train a resnet-50 Baseline model from scratch on the first split, and validate on the MIMIC-CXR dataset, run the following code:
python chest.py --baseline --arch resnet50 --split 0 --valid_data mc
Note that for the first split, PadChest is automatically selected as the test_data
, when you pass MIMIC-CXR as the validation data, and vice versa.
To train the REx model, we run the same code above with some addtional arguments. We first turn off the --baseline
argument, and also specify the amount of penalty weight (float in multiples of 10) to use by --penalty_weight=<penalty weight amount>
, and always pass --weight_decay=0.0
Example:
python chest.py --arch resnet50 --weight_decay=0.0 --penalty_weight=100.0 --split 0 --valid_data mc
If no model architecture is specified, the code trains all the following architectures: resnet50
, shufflenet_v2_x0_5
, shufflenet_v2_x1_0
, and densenet121
.