Skip to content

Commit

Permalink
added training scripts and readme (#12)
Browse files Browse the repository at this point in the history
* added training scripts and readme

* polish
  • Loading branch information
FrankLeeeee authored Feb 26, 2024
1 parent 9644774 commit a33c656
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 2 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -160,5 +160,7 @@ cython_debug/
#.idea/
.vscode/

# files needed to train
# misc files
dataset/
runs/
checkpoints/
8 changes: 7 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,13 @@ pip install -r requirements.txt

### Training

To be added.
You can invoke the training via the command below.

```bash
bash ./scripts/train.sh
```

You can also modify the arguments in `train.sh` for your own need.

### Inference

Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ torchvision
datasets
transformers
av
tensorboard
41 changes: 41 additions & 0 deletions scripts/train.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#!/usr/bin/env bash

# get args
GPUS=${1:8}

# get root dir
FOLDER_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
ROOT_DIR=$FOLDER_DIR/..

# go to root dir
cd $ROOT_DIR

# define dataset shards
COLLATED_VIDEO_DIR=./dataset/MSRVTT-collated/val/videos
PROCESSED_DATASET=(
./dataset/MSRVTT-processed/val/part-00000
./dataset/MSRVTT-processed/val/part-00001
./dataset/MSRVTT-processed/val/part-00002
./dataset/MSRVTT-processed/val/part-00003
./dataset/MSRVTT-processed/val/part-00004
./dataset/MSRVTT-processed/val/part-00005
./dataset/MSRVTT-processed/val/part-00006
./dataset/MSRVTT-processed/val/part-00007
./dataset/MSRVTT-processed/val/part-00008
./dataset/MSRVTT-processed/val/part-00009
)

# run single node training
torchrun --standalone \
--nproc_per_node $GPUS \
train.py \
--epochs 1 \
--batch_size 1 \
--lr 1e-4 \
--accumulation_steps 32 \
--grad_checkpoint \
--dataset $PROCESSED_DATASET \
--video_dir $COLLATED_VIDEO_DIR \
--save_interval 224 \
--checkpoint_dir ./checkpoints \
--tensorboard_dir ./runs

0 comments on commit a33c656

Please sign in to comment.