Skip to content

ImageFlowNet: Forecasting Multiscale Image-Level Trajectories of Disease Progression with Irregularly-Sampled Longitudinal Medical Images

License

Notifications You must be signed in to change notification settings

KrishnaswamyLab/ImageFlowNet

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

[Temporary] NOTE

If you are looking for the code implementation of DiffKillR: Killing and Recreating Diffeomorphisms for Cell Annotation in Dense Microscopy Images, please redirect to https://github.com/KrishnaswamyLab/DiffKillR.

ImageFlowNet

Krishnaswamy Lab, Yale University

Twitter Github Stars ArXiv

This is the official implementation of ImageFlowNet.

Please raise issues here.

A Glimpse into the Methods

Citation

@article{liu2024imageflownet,
  title={ImageFlowNet: Forecasting Multiscale Trajectories of Disease Progression with Irregularly-Sampled Longitudinal Medical Images},
  author={Liu, Chen and Xu, Ke and Shen, Liangbo L and Huguet, Guillaume and Wang, Zilong and Tong, Alexander and Bzdok, Danilo and Stewart, Jay and Wang, Jay C and Del Priore, Lucian V and Krishnaswamy, Smita}
  journal={arXiv preprint arXiv:2406.14794},
  year={2024}
}

Abstract

The forecasting of disease progression from images is a holy grail for clinical decision making. However, this task is complicated by the inherent high dimensionality, temporal sparsity and sampling irregularity in longitudinal image acquisitions. Existing methods often rely on extracting hand-crafted features and performing time-series analysis in this vector space, leading to a loss of rich spatial information within the images. To overcome these challenges, we introduce ImageFlowNet, a novel framework that learns latent-space flow fields that evolve multiscale representations in joint embedding spaces using neural ODEs and SDEs to model disease progression in the image domain. Notably, ImageFlowNet learns multiscale joint representation spaces by combining cohorts of patients together so that information can be transferred between the patient samples. The dynamics then provide plausible trajectories of progress, with the SDE providing alternative trajectories from the same starting point. We provide theoretical insights that support our formulation of ODEs, and motivate our regularizations involving high-level visual features, latent space organization, and trajectory smoothness. We then demonstrate ImageFlowNet's effectiveness through empirical evaluations on three longitudinal medical image datasets depicting progression in retinal geographic atrophy, multiple sclerosis, and glioblastoma.

Citation

@article{liu2024imageflownet,
  title={ImageFlowNet: Forecasting Multiscale Trajectories of Disease Progression with Irregularly-Sampled Longitudinal Medical Images},
  author={Liu, Chen and Xu, Ke and Shen, Liangbo L and Huguet, Guillaume and Wang, Zilong and Tong, Alexander and Bzdok, Danilo and Stewart, Jay and Wang, Jay C and Del Priore, Lucian V and Krishnaswamy, Smita}
  journal={arXiv preprint arXiv:2406.14794},
  year={2024}
}

Repository Hierarchy

ImageFlowNet
    ├── comparison: some comparisons are in the `src` folder instead.
    |   └── interpolation
    |
    ├── checkpoints: only for segmentor model weights. Other model weights in `results`.
    |
    ├── data: folders containing data files.
    |   ├── brain_LUMIERE: Brain Glioblastoma
    |   ├── brain_MS: Brain Multiple Sclerosis
    |   └── retina_ucsf: Retinal Geographic Atrophy
    |
    ├── external_src: other repositories or code.
    |
    ├── results: generated results, including training log, model weights, and evaluation results.
    |
    └── src
        ├── data_utils
        ├── datasets
        ├── nn
        ├── preprocessing
        ├── utils
        └── *.py: some main scripts

Pre-trained weights

We have uploaded the weights for the retinal images.

  1. The weights for the segmentor can be found in checkpoints/segment_retinaUCSF_seed1.pty
  2. The weights for the ImageFlowNetODE models can be found in Google Drive. You can put them under results/retina_ucsf_ImageFlowNetODE_smoothness-0.100_latent-0.001_contrastive-0.010_invariance-0.000_seed_1/run_1/retina_ucsf_ImageFlowNetODE_smoothness-0.100_latent-0.001_contrastive-0.010_invariance-0.000_seed_1_best_pred_psnr.pty and results/retina_ucsf_ImageFlowNetODE_smoothness-0.100_latent-0.001_contrastive-0.010_invariance-0.000_seed_1/run_1/retina_ucsf_ImageFlowNetODE_smoothness-0.100_latent-0.001_contrastive-0.010_invariance-0.000_seed_1_best_seg_dice.pty.

Reproduce the results

Image registration

cd src/preprocessing
python test_registration.py

Training a segmentation network (only for quantitative evaluation purposes)

cd src/
python train_segmentor.py

Training the main network.

cd src/
# ImageFlowNet_{ODE}
python train_2pt_all.py --model ImageFlowNetODE --random-seed 1
python train_2pt_all.py --model ImageFlowNetODE --random-seed 1 --mode test --run-count 1

# ImageFlowNet_{SDE}
python train_2pt_all.py --model ImageFlowNetSDE --random-seed 1
python train_2pt_all.py --model ImageFlowNetSDE --random-seed 1 --mode test --run-count 1

Some common arguments.

--dataset-name: name of the dataset (`retina_ucsf`, `brain_ms`, `brain_gbm`)
--segmentor-ckpt: the location of the segmentor model. Both for training and using the segmentor.

Ablations.

  1. Flow field formulation.
python train_2pt_all.py --model ODEUNet
python train_2pt_all.py --model ImageFlowNetODE
  1. Single-scale vs multiscale ODEs.
python train_2pt_all.py --model ImageFlowNetODE --ode-location 'bottleneck'
python train_2pt_all.py --model ImageFlowNetODE --ode-location 'all_resolutions'
python train_2pt_all.py --model ImageFlowNetODE --ode-location 'all_connections' # default
  1. Visual feature regularization.
python train_2pt_all.py --model ImageFlowNetODE --coeff-latent 0.1
  1. Contrastive learning regularization.
python train_2pt_all.py --model ImageFlowNetODE --coeff-contrastive 0.1
  1. Trajectory smoothness regularization.
python train_2pt_all.py --model ImageFlowNetODE --coeff-smoothness 0.1

Comparisons

Image interpolation/extrapolation methods.

cd comparison/interpolation
python run_baseline_interp.py --method linear
python run_baseline_interp.py --method cubic_spline

Time-conditional UNet.

cd src
python train_2pt_all.py --model T_UNet --random-seed 1 --mode train
python train_2pt_all.py --model T_UNet --random-seed 1 --mode test --run-count 1

Time-aware diffusion model (Image-to-Image Schrodinger Bridge)

cd src
python train_2pt_all.py --model I2SBUNet --random-seed 1
python train_2pt_all.py --model I2SBUNet --random-seed 1 --mode test --run-count 1

Style-based Manifold Extrapolation (Nat. Mach. Int. 2022).

conda deactivate
conda activate stylegan

cd src/preprocessing
python 04_unpack_retina_UCSF.py

cd ../../comparison/style_manifold_extrapolation/stylegan2-ada-pytorch
python train.py --outdir=../training-runs --data='../../../data/retina_ucsf/UCSF_images_final_unpacked_256x256/' --gpus=1

Datasets

  1. Retinal Geographic Atrophy dataset from METforMIN study (UCSF).
  2. Brain Multiple Sclerosis dataset.
  3. Brain Glioblastoma dataset.

Data preparation and preprocessing

  1. Retinal Geographic Atrophy dataset.
  • Put data under: data/retina_ucsf/Images/
cd src/preprocessing
python 01_preprocess_retina_UCSF.py
python 02_register_retina_UCSF.py
python 03_crop_retina_UCSF.py
  1. Brain Multiple Sclerosis dataset.
  • Put data under: data/brain_MS/brain_MS_images/trainX/ after unzipping.
cd src/preprocessing
python 01_preprocess_brain_MS.py
  1. Brain Glioblastoma dataset.
  • Put data under: data/brain_LUMIERE/ after unzipping.
cd src/preprocessing
python 01_preprocess_brain_GBM.py

Segment Anything Model (SAM)

This is only used for test_registration.py to facilitate visualization. Not used anywhere else.

cd `external_src/`
mkdir SAM && cd SAM
wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

Dependencies

We developed the codebase in a miniconda environment. How we created the conda environment:

# Optional: Update to libmamba solver.
conda update -n base conda
conda install -n base conda-libmamba-solver
conda config --set solver libmamba

conda create --name imageflownet pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch -c nvidia -c anaconda -c conda-forge
conda activate imageflownet
conda install scikit-learn scikit-image pillow matplotlib seaborn tqdm -c pytorch -c anaconda -c conda-forge
conda install read-roi -c conda-forge
python -m pip install -U albumentations
python -m pip install timm
python -m pip install opencv-python
python -m pip install git+https://github.com/facebookresearch/segment-anything.git
python -m pip install monai
python -m pip install torchdiffeq
python -m pip install torch-ema
python -m pip install torchcde
python -m pip install torchsde
python -m pip install phate
python -m pip install psutil
python -m pip install ninja

# For 3D registration
python -m pip install antspyx

Acknowledgements

We adapted some of the code from

  1. I^2SB: Image-to-Image Schrodinger Bridge

About

ImageFlowNet: Forecasting Multiscale Image-Level Trajectories of Disease Progression with Irregularly-Sampled Longitudinal Medical Images

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages