Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add "bare_vocab" fake positional embedding without positions #747

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions xformers/components/positional_embedding/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,13 @@ class MyEncoding(PositionEncoding):
from .rotary import RotaryEmbedding # noqa
from .sine import SinePositionalEmbedding # type: ignore # noqa
from .vocab import VocabEmbedding # noqa
from .bare_vocab import BareVocabEmbedding # noqa

__all__ = [
"RotaryEmbedding",
"SinePositionalEmbedding",
"VocabEmbedding",
"BareVocabEmbedding",
"build_positional_embedding",
"register_positional_embedding",
]
Expand Down
53 changes: 53 additions & 0 deletions xformers/components/positional_embedding/bare_vocab.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass

import torch.nn

from xformers.components.positional_embedding import (
PositionEmbedding,
register_positional_embedding,
)


@dataclass
class VocabEmbeddingConfig:
name: str
dim_model: int
vocab_size: int
dropout: float
init_std: float


@register_positional_embedding("bare_vocab", VocabEmbeddingConfig)
class BareVocabEmbedding(PositionEmbedding):
"""Vocabulary embedding without positional information. Required for ALiBi-like positioning."""

def __init__(
self,
dim_model: int,
vocab_size: int,
dropout: float = 0.0,
init_std: float = 0.02,
*args,
**kwargs,
):
super().__init__()

self.vocab_size = vocab_size
self.dim_model = dim_model
self.init_std = init_std

self.dropout = torch.nn.Dropout(p=dropout)
self.word_embeddings = torch.nn.Embedding(self.vocab_size, self.dim_model)

self.init_weights()

def init_weights(self, gain: float = 1.0):
torch.nn.init.normal_(self.word_embeddings.weight, std=self.init_std * gain)

def forward(self, x: torch.Tensor):
y = self.dropout(self.word_embeddings(x))
return y