forked from city-super/Octree-GS
-
Notifications
You must be signed in to change notification settings - Fork 0
/
embedding.py
executable file
·79 lines (62 loc) · 2.31 KB
/
embedding.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import torch
from abc import abstractmethod
from typing import Optional
from jaxtyping import Shaped
from torch import Tensor, nn
class FieldComponent(nn.Module):
"""Field modules that can be combined to store and compute the fields.
Args:
in_dim: Input dimension to module.
out_dim: Output dimension to module.
"""
def __init__(self, in_dim: Optional[int] = None, out_dim: Optional[int] = None) -> None:
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
def build_nn_modules(self) -> None:
"""Function instantiates any torch.nn members within the module.
If none exist, do nothing."""
def set_in_dim(self, in_dim: int) -> None:
"""Sets input dimension of encoding
Args:
in_dim: input dimension
"""
if in_dim <= 0:
raise ValueError("Input dimension should be greater than zero")
self.in_dim = in_dim
def get_out_dim(self) -> int:
"""Calculates output dimension of encoding."""
if self.out_dim is None:
raise ValueError("Output dimension has not been set")
return self.out_dim
@abstractmethod
def forward(self, in_tensor: Shaped[Tensor, "*bs input_dim"]) -> Shaped[Tensor, "*bs output_dim"]:
"""
Returns processed tensor
Args:
in_tensor: Input tensor to process
"""
raise NotImplementedError
class Embedding(FieldComponent):
"""Index into embeddings.
# TODO: add different types of initializations
Args:
in_dim: Number of embeddings
out_dim: Dimension of the embedding vectors
"""
def __init__(self, in_dim: int, out_dim: int) -> None:
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.build_nn_modules()
def build_nn_modules(self) -> None:
self.embedding = torch.nn.Embedding(self.in_dim, self.out_dim)
def mean(self, dim=0):
"""Return the mean of the embedding weights along a dim."""
return self.embedding.weight.mean(dim)
def forward(self, in_tensor: Shaped[Tensor, "*batch input_dim"]) -> Shaped[Tensor, "*batch output_dim"]:
"""Call forward
Args:
in_tensor: input tensor to process
"""
return self.embedding(in_tensor)