Skip to content

Commit

Permalink
Update README.md
Browse files Browse the repository at this point in the history
  • Loading branch information
etetteh committed Sep 14, 2021
1 parent 7e48807 commit b0dac81
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,28 +14,28 @@ The dataset names are condensed as short strings: `"nih"`= NIH ChestX-ray8 datas
For each setting, we compute the ROC-AUC for the following chest x-ray pathologies (labels): Cardiomegaly, Pneumonia, Effusion, Edema, Atelectasis, Consolidation, and Pneumothorax.

For each split, you train on two (2) datasets, validate on one (1) and test on the remaining one (1). \
The [chestREx.py](https://github.com/etetteh/OoD_Gen-Chest_Xray-REx/blob/main/chestREx.py) file contains code to run both our baseline and REx models.
The [chest.py](https://github.com/etetteh/OoD_Gen-Chest_Xray-REx/blob/main/chest.py) file contains code to run both our baseline and REx models.

To **finetune** or perform **feature extraction** with ImageNet weights pass the `--pretrained` and `--feat_extract` arguments **respectively**

### Train Using Baseline Model (Merged Datasets)
To train a DenseNet-121 **Baseline** model by fine-tuning on the first split, and validate on the MIMIC-CXR dataset, with seed=0 run the following code:
```
python chestREx.py --baseline --arch densenet121 --pretrained --split 0 --valid_data mc --seed 0
python chest.py --baseline --arch densenet121 --pretrained --split 0 --valid_data mc --seed 0
```
Note that for the first split, PadChest is automatically selected as the `test_data`, when you pass MIMIC-CXR as the validation data, and vice versa.

### Train Using Baseline Model (Balanced Mini-Batch Sampling)
To train a DenseNet-121 **Baseline REx-Off** model by fine-tuning on the first split, and validate on the MIMIC-CXR dataset, with seed=0 run the following code:
```
python chestREx.py --arch densenet121 --pretrained --weight_decay=0.0 --split 0 --valid_data mc --seed 0
python chest.py --arch densenet121 --pretrained --weight_decay=0.0 --split 0 --valid_data mc --seed 0
```
and always pass `--weight_decay=0.0`

### Train Using REx Model
To train the **REx** model, we run the same code above with some addtional arguments. We first switch the argument from `--baseline` to `--rex`, and also specify the amount of penalty weight (float in multiples of 10) to use by `--penalty_weight=<penalty weight amount>`, and always pass `--weight_decay=0.0` Example:
```
python chestREx.py --arch densenet121 --pretrained --weight_decay=0.0 --penalty_weight=100.0 --split 0 --valid_data mc --seed 0
python chest.py --arch densenet121 --pretrained --weight_decay=0.0 --penalty_weight=100.0 --split 0 --valid_data mc --seed 0
```
If no model architecture is specified, the code trains all the following architectures: `resnet50`, `shufflenet_v2_x0_5`, `shufflenet_v2_x1_0`, and `densenet121`.

Expand Down

0 comments on commit b0dac81

Please sign in to comment.