Skip to content

Commit

Permalink
update config
Browse files Browse the repository at this point in the history
  • Loading branch information
MzeroMiko committed Mar 10, 2024
1 parent c42de5b commit e8f154d
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions classification/models/vmamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,7 @@ def cross_selective_scan(
CrossScan=CrossScan,
CrossMerge=CrossMerge,
no_einsum=False, # replace einsum with linear or conv1d to raise throughput
dt_low_rank=True,
):
# out_norm: whatever fits (B, L, C); LayerNorm; Sigmoid; Softmax(dim=1);...

Expand Down Expand Up @@ -528,13 +529,18 @@ def cross_selective_scan(
def selective_scan(u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=True):
return SelectiveScan.apply(u, delta, A, B, C, D, delta_bias, delta_softplus, nrows, backnrows, ssoflex)

xs = CrossScan.apply(x)

if no_einsum:
if (not dt_low_rank):
x_dbl = F.conv1d(x.view(B, -1, L), x_proj_weight.view(-1, D, 1), bias=(x_proj_bias.view(-1) if x_proj_bias is not None else None), groups=K)
dts, Bs, Cs = torch.split(x_dbl.view(B, -1, L), [D, 4 * N, 4 * N], dim=1)
xs = CrossScan.apply(x)
dts = CrossScan.apply(dts)
elif no_einsum:
xs = CrossScan.apply(x)
x_dbl = F.conv1d(xs.view(B, -1, L), x_proj_weight.view(-1, D, 1), bias=(x_proj_bias.view(-1) if x_proj_bias is not None else None), groups=K)
dts, Bs, Cs = torch.split(x_dbl.view(B, K, -1, L), [R, N, N], dim=2)
dts = F.conv1d(dts.contiguous().view(B, -1, L), dt_projs_weight.view(K * D, -1, 1), groups=K)
else:
xs = CrossScan.apply(x)
x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs, x_proj_weight)
if x_proj_bias is not None:
x_dbl = x_dbl + x_proj_bias.view(1, K, -1, 1)
Expand All @@ -544,8 +550,8 @@ def selective_scan(u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=Tr
xs = xs.view(B, -1, L)
dts = dts.contiguous().view(B, -1, L)
As = -torch.exp(A_logs.to(torch.float)) # (k * c, d_state)
Bs = Bs.contiguous()
Cs = Cs.contiguous()
Bs = Bs.contiguous().view(B, K, N, L)
Cs = Cs.contiguous().view(B, K, N, L)
Ds = Ds.to(torch.float) # (K * c)
delta_bias = dt_projs_bias.view(-1).to(torch.float)

Expand Down Expand Up @@ -944,7 +950,6 @@ def selective_scan(u, delta, A, B, C, D, delta_bias, delta_softplus):

us = CrossScanTriton.apply(us).contiguous().view(B, -1, L)
dts = CrossScanTriton.apply(dts)
# dts = torch.einsum("bkrl,kdr->bkdl", dts, dt_projs_weight).contiguous().view(B, -1, L)
dts = F.conv1d(dts.view(B, -1, L), dt_projs_weight.view(K * self.d_inner, self.dt_rank, 1), None, groups=K).contiguous().view(B, -1, L)
Bs, Cs = Bs.view(B, K, -1, L).contiguous(), Cs.view(B, K, -1, L).contiguous()

Expand Down

0 comments on commit e8f154d

Please sign in to comment.