Skip to content

Commit

Permalink
✨ feat: EVA02 & DINOv2 Exps
Browse files Browse the repository at this point in the history
  • Loading branch information
Yingyue-L committed Aug 31, 2023
1 parent 535ba45 commit 87a3470
Show file tree
Hide file tree
Showing 14 changed files with 1,195 additions and 42 deletions.
21 changes: 7 additions & 14 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -167,18 +167,11 @@ run_affinitynet.sh
CenterPoints/exp_run.sh
colorful.py
*.json
/data/coco/
/data/voc12/
/OnlineRetraining/run.sh
/OnlineRetraining/run_coco.sh
/OnlineRetraining/run_coco_vit.sh
/OnlineRetraining/run_vit.sh
/OnlineRetraining/val_tmp.py
mlruns/*
MCTformerV2_results/*
WeakTr_results/*
MCTformerV2_coco_results/*
OnlineRetraining/
!OnlineRetraining/segm
WeakTr_results_coco
data/coco
data/voc12
mlruns/*
*results/*
OnlineRetraining/mlruns
OnlineRetraining/start*
OnlineRetraining/seg*
!OnlineRetraining/segm
44 changes: 44 additions & 0 deletions OnlineRetraining/segm/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,50 @@ model:
n_heads: 16
n_layers: 24
normalization: vit

# dino
dino_small_patch16_224:
image_size: 224
patch_size: 16
d_model: 384
n_heads: 6
n_layers: 12
normalization: deit
distilled: false
dinov2_small_patch16_224:
image_size: 224
patch_size: 16
d_model: 384
n_heads: 6
n_layers: 12
normalization: deit
distilled: false
dinov2_small_patch14_224:
image_size: 224
patch_size: 14
d_model: 384
n_heads: 6
n_layers: 12
normalization: deit
distilled: false
# eva
eva02_tiny_patch16_224:
image_size: 224
patch_size: 16
d_model: 192
n_heads: 3
n_layers: 12
normalization: eva02
distilled: false
# eva
eva02_small_patch16_224:
image_size: 224
patch_size: 16
d_model: 384
n_heads: 6
n_layers: 12
normalization: eva02
distilled: false
decoder:
linear: {}
deeplab_dec:
Expand Down
1 change: 1 addition & 0 deletions OnlineRetraining/segm/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
STATS = {
"vit": {"mean": (0.5, 0.5, 0.5), "std": (0.5, 0.5, 0.5)},
"deit": {"mean": (0.485, 0.456, 0.406), "std": (0.229, 0.224, 0.225)},
"eva02": {"mean": (0.48145466, 0.4578275, 0.40821073), "std": (0.26862954, 0.26130258, 0.27577711)},
}


Expand Down
95 changes: 92 additions & 3 deletions OnlineRetraining/segm/model/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@
import torch.nn.functional as F

from timm.models.layers import DropPath
from torch import Tensor
from typing import Union

import xformers.ops as xops

class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout, out_dim=None):
Expand Down Expand Up @@ -78,10 +81,10 @@ def forward(self, x, mask=None):


class Block(nn.Module):
def __init__(self, dim, heads, mlp_dim, dropout, drop_path):
def __init__(self, dim, heads, mlp_dim, dropout, drop_path, norm_layer=nn.LayerNorm):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.norm1 = norm_layer(dim)
self.norm2 = norm_layer(dim)
self.attn = Attention(dim, heads, dropout)
self.mlp = FeedForward(dim, mlp_dim, dropout)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
Expand All @@ -93,3 +96,89 @@ def forward(self, x, mask=None, return_attention=False):
x = x + self.drop_path(y)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x



class LayerScale(nn.Module):
def __init__(
self,
dim: int,
init_values: Union[float, Tensor] = 1e-5,
inplace: bool = False,
) -> None:
super().__init__()
self.inplace = inplace
self.gamma = nn.Parameter(init_values * torch.ones(dim))

def forward(self, x: Tensor) -> Tensor:
return x.mul_(self.gamma) if self.inplace else x * self.gamma


class DINOBlock(Block):
def __init__(self, dim, heads, mlp_dim, dropout, drop_path,
init_values=1e-5, use_bn=False):
super().__init__(dim, heads, mlp_dim, dropout, drop_path)
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()

def forward(self, x, mask=None, return_attention=False):
y, attn = self.attn(self.norm1(x), mask)
y = self.ls1(y)
if return_attention:
return attn
x = x + self.drop_path(y)
x = x + self.drop_path(self.ls2(self.mlp(self.norm2(x))))
return x


class EVA02Attention(Attention):
def __init__(self, dim, *args, rope=None):
super().__init__(dim, *args)
self.qkv = nn.Linear(dim, dim * 3, bias=False)
self.q_bias = nn.Parameter(torch.zeros(dim))
self.v_bias = nn.Parameter(torch.zeros(dim))

self.rope = rope

def forward(self, x, mask=None):
B, N, C = x.shape
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
qkv = (
(self.qkv(x) + qkv_bias)
.reshape(B, N, 3, self.heads, C // self.heads)
.permute(2, 0, 3, 1, 4)
)
q, k, v = (
qkv[0],
qkv[1],
qkv[2],
)

if self.rope:
q_t = q[:, :, 1:, :]
ro_q_t = self.rope(q_t)
q = torch.cat((q[:, :, :1, :], ro_q_t), -2).type_as(v)

k_t = k[:, :, 1:, :]
ro_k_t = self.rope(k_t)
k = torch.cat((k[:, :, :1, :], ro_k_t), -2).type_as(v)

attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)

x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)

return x, attn

class EVA02Block(Block):
def __init__(self, dim, heads, mlp_dim, dropout, drop_path,
init_values=1e-5, use_bn=False, rope=None, norm_layer=nn.LayerNorm):
super().__init__(dim, heads, mlp_dim, dropout, drop_path, norm_layer=norm_layer)
self.attn = EVA02Attention(dim, heads, dropout, rope=rope)
self.mlp = xops.SwiGLU(
in_features=dim,
hidden_features=mlp_dim
) # hidden_features: 2/3
Loading

0 comments on commit 87a3470

Please sign in to comment.