Skip to content

Commit

Permalink
channel_mask and lin/finetuning
Browse files Browse the repository at this point in the history
  • Loading branch information
Modexus committed Nov 14, 2022
1 parent fecfe90 commit d1455c5
Show file tree
Hide file tree
Showing 11 changed files with 164 additions and 80 deletions.
10 changes: 5 additions & 5 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -280,19 +280,19 @@
"justMyCode": false
},
{
"name": "Python: mae_bigearthnet_evaluate",
"name": "Python: mae_bigearthnet_linearprobing",
"type": "python",
"request": "launch",
"program": "train.py",
"console": "integratedTerminal",
"args": [
"config_file=conf/mae_bigearthnet_evaluate.yaml",
"config_file=conf/mae_bigearthnet_linearprobing.yaml",
"program.seed=666",
"program.output_dir=/data/users/mike/experiments",
"program.log_dir=/data/users/mike/logs",
"program.output_dir=/scratch/users/mike/experiments",
"program.log_dir=/scratch/users/mike/logs",
"program.overwrite=False",
"logger.offline=True",
"trainer.gpus=[0]"
"trainer.gpus=[0,1,2,3]"
],
"justMyCode": true
},
Expand Down
21 changes: 16 additions & 5 deletions conf/mae_bigearthnet_linearprobing.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
trainer:
gpus: [0,1,2,3]
max_epochs: 100
accelerator: "gpu"
devices: [0,1,2,3]
precision: "bf16"
accumulate_grad_batches: 2
max_epochs: 100
max_time: "00:10:00:00"
limit_val_batches: 0
benchmark: True
experiment:
task: "mae_bigearthnet_linearprobing"
Expand All @@ -23,25 +27,32 @@ experiment:
mean_patches: True
channel_wise: True
multi_label: True
in_channels: 14
in_channels: 12
out_channels: 1
num_classes: 19
load_checkpoint: "last.ckpt"
imagenet_pretrained: False
lr: 1e-1
optimizer: "SGD"
optimizer_kwargs:
weight_decay: 0.0
momentum: 0.9
lr_min: 0.0
warmup_lr_init: 1.5e-7
num_warmup: 10
mask_fns:
- "random_channel_masking"
mask_kwargs:
random_channel_masking:
num_keep: 720
probability: 1.0
datamodule:
#root_dir: "/data/users/mike/data/BigEarthNetFixed"
root_dir: "/scratch/users/mike/data/FFCV"
bands: "all"
num_classes: 19
batch_size: 32
num_workers: 8
batch_size: 512
num_workers: 7
pin_memory: True
prefetch_factor: 5
persistent_workers: True
Expand Down
7 changes: 4 additions & 3 deletions conf/mae_bigearthnet_train.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ trainer:
precision: "bf16"
accumulate_grad_batches: 4
max_epochs: 800
# limit_train_batches: 1.0
limit_train_batches: 1.0
# log_every_n_steps: 1
limit_val_batches: 0.125
limit_val_batches: 1
check_val_every_n_epoch: 10
benchmark: True
enable_progress_bar: True
Expand All @@ -23,11 +23,12 @@ experiment:
sensor: "naip"
imagenet_pretrained: False
pretrained: False
resume_checkpoint: "epoch=639-step=41600.ckpt"
# resume_checkpoint: "epoch=799-step=52000.ckpt"
image_size: 120
crop_size: 96
patch_size: 8
batch_size: ${experiment.datamodule.batch_size}
create_sharded: True
channel_wise: True
channel_shuffle: True
mask_tokens_encoder: False
Expand Down
2 changes: 1 addition & 1 deletion create_bigearthnet_ffcv.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import numpy as np
import rasterio
import torch
from ffcv.fields import NDArrayField, FloatField
from ffcv.fields import NDArrayField
from ffcv.writer import DatasetWriter
from rasterio.enums import Resampling
from torch import Tensor
Expand Down
2 changes: 1 addition & 1 deletion download_bigearthnet_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@
# %%
ssl._create_default_https_context = ssl._create_unverified_context
# %%
ds = BigEarthNet(root="/scratch/users/mike/data/BigEarthNet", download=True, bands="s2")
ds = BigEarthNet(root="/scratch/users/mike/data/BigEarthNet", download=True, bands="s2")
9 changes: 5 additions & 4 deletions torchgeo/datamodules/bigearthnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,19 @@

"""BigEarthNet datamodule."""

import os
from typing import Any, Dict, Optional

import matplotlib.pyplot as plt
import pytorch_lightning as pl
import torch
from ffcv.fields.decoders import NDArrayDecoder
from ffcv.loader import Loader, OrderOption
from ffcv.transforms import ToTensor
from torch.utils.data import DataLoader
from torchvision.transforms import Compose
import os

from ..datasets import BigEarthNet
from ffcv.loader import Loader, OrderOption
from ffcv.fields.decoders import NDArrayDecoder
from ffcv.transforms import ToTensor

# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
Expand Down
6 changes: 3 additions & 3 deletions torchgeo/models/maevit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

from typing import cast

import deepspeed
import torch
from timm.models.layers import Mlp
from timm.models.vision_transformer import Block
from torch import Tensor
from torch.nn import Conv2d, LayerNorm, Linear, Module, Sequential
import deepspeed

from .utils import (
get_channel_encodings,
Expand Down Expand Up @@ -343,8 +343,8 @@ def __init__(
image_size: int,
patch_size: int = 16,
channel_wise: bool = False,
num_checkpoints_encoder: bool = False,
num_checkpoints_decoder: bool = False,
num_checkpoints_encoder: int = 0,
num_checkpoints_decoder: int = 0,
embed_dim: int = 1024,
depth: int = 24,
num_heads: int = 16,
Expand Down
2 changes: 1 addition & 1 deletion torchgeo/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def init_weights(m: Module) -> None:
"""Initialize the weights."""
if isinstance(m, Linear):
init.xavier_uniform_(m.weight)
if isinstance(m, Linear) and m.bias is not None:
if m.bias is not None:
init.constant_(m.bias, 0)
elif isinstance(m, LayerNorm):
init.constant_(m.bias, 0)
Expand Down
Loading

0 comments on commit d1455c5

Please sign in to comment.