Skip to content

Out-of-Distribution Generalization of Chest X-ray Using Risk Extrapolation

License

Notifications You must be signed in to change notification settings

12341123/OoD_Gen-Chest_Xray

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

40 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

OoD_Gen-Chest_Xray-REx

Out-of-Distribution Generalization of Chest X-ray Using Risk Extrapolation

Requirements (Installations)

Install the following libraries/packages with pip

torch 
torchvision
torchxrayvsion

Seven (7) Pathologies, Four (4) Datasets, & Twelve Training Settings

There are 12 different training, validation and test settings generated by combining 4 different Chest X-ray datasets (NIH ChestX-ray8 dataset, PadChest dataset, CheXpert, and MIMIC-CXR). These 12 settings are broken down into 6 splits (ranging from 0 to 5) that can be called by passing the argument --split=<split>. For each split, you have the option to choose between 2 validation datasets by passing the argument --valid_data=<name of valid dataset>. The dataset names are condensed as short strings: "nih"= NIH ChestX-ray8 dataset, "pc" = PadChest dataset, "cx" = CheXpert, and "mc" = MIMIC-CXR.
For each setting, we compute the ROC-AUC for the following chest x-ray pathologies (labels): Cardiomegaly, Pneumonia, Effusion, Edema, Atelectasis, Consolidation, and Pneumothorax.

For each split, you train on two (2) datasets, validate on one (1) and test on the remaining one (1).
The chest.py file contains code to run both our baseline and REx models.

To finetune or perform feature extraction with ImageNet weights pass the '--pretrained' and '--feat_extract' arguments respectively

Train Using Baseline Model

To train a resnet-50 baseline model from scratch on the first split, and validate on the MIMIC-CXR dataset, run the following code:

python chest.py --baseline --arch resnet50 --split 0 --valid_data mc

Note that for the first split, PadChest is automatically selected as the test_data, when you pass MIMIC-CXR as the validation data, and vice versa.

Train Using REx Model

To train the REx model, we run the same code above with some addtional arguments. We first turn off the '--baseline' argument, and also specify the amount of penalty weight (float in multiples of 10) to use by --penalty_weight=<penalty weight amount>, and always pass '--weight_decay=0.0' Example:

python chest.py --arch resnet50 --weight_decay=0.0 --penalty_weight=100.0 --split 0 --valid_data mc

If no model architecture is specified, the code trains all the following architectures: resnet50, shufflenet_v2_x0_5, shufflenet_v2_x1_0, and densenet121.

About

Out-of-Distribution Generalization of Chest X-ray Using Risk Extrapolation

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%