Skip to content

Commit

Permalink
fine tune binary
Browse files Browse the repository at this point in the history
  • Loading branch information
l-z-l committed Sep 28, 2023
1 parent 4f9c35d commit 8d34bff
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions enformer_pytorch/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,9 @@ def __init__(
bottleneck_num_codebooks = 4,
bottleneck_decay = 0.9,
transformer_embed_fn: nn.Module = nn.Identity(),
auto_set_target_length = True
auto_set_target_length = True,
layer_size=16,
target_length=200,
):
super().__init__()
assert isinstance(enformer, Enformer)
Expand All @@ -124,6 +126,8 @@ def __init__(
decay = bottleneck_decay,
)

self.target_length = target_length

self.post_transformer_embed = post_transformer_embed

self.enformer = enformer
Expand All @@ -148,10 +152,11 @@ def __init__(
self.to_tracks = Sequential(
# NOTE: ZELUN Added nn.Flatten()
nn.Flatten(),
nn.Linear(enformer_hidden_dim * self.enformer.target_length, num_tracks)
nn.Linear(enformer_hidden_dim * target_length, layer_size),
nn.Linear(layer_size, num_tracks)
)
# NOTE: ZELUN
# print(f"this is the enformer target length {self.enformer.target_length} and this is the enformer hidden dim {enformer_hidden_dim} this is the product {enformer_hidden_dim * self.enformer.target_length}")
print(f"this is the layer_size {layer_size} this is the enformer target length {target_length} and this is the enformer hidden dim {enformer_hidden_dim} this is the product {enformer_hidden_dim * target_length}")

def forward(
self,
Expand All @@ -168,10 +173,12 @@ def forward(
#NOTE: ZELUN DEBUG
print(f"========================== TARGET {target} ==========================")
enformer_kwargs = dict(target_length = target.shape[-2])

# NOTE: this is trying to set the target length to a fix value regardless of
# how the target tensor is being encoded
# enformer_kwargs = dict(target_length = 2)

# NOTE: Setting the target length to a smaller value
enformer_kwargs = dict(target_length = self.target_length)
if self.discrete_key_value_bottleneck:
embeddings = self.enformer(seq, return_only_embeddings = True, **enformer_kwargs)
else:
Expand Down

0 comments on commit 8d34bff

Please sign in to comment.