Skip to content

Commit

Permalink
Finish 2d implementation
Browse files Browse the repository at this point in the history
Note: this is going to be reverted since the idea doesn't work well and
the code has become too complex but a lot of effort has been put into
these changes. I want them to be at least committed such that it can be
eventually helpful in the future when I plan to revisit these ideas.

Signed-off-by: Steven Lang <steven.lang.mz@gmail.com>
  • Loading branch information
braun-steven committed Jan 7, 2022
1 parent 6c0a0ca commit ef37d62
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 47 deletions.
57 changes: 41 additions & 16 deletions simple_einet/einsum_layer.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from typing import Sequence, Tuple, Union
import torch.nn.functional as F
from typing import Tuple, Union

import numpy as np
from torch import nn
import torch
import torch.nn.functional as F
from torch import nn

from .utils import SamplingContext
from .type_checks import check_valid
from .layers import AbstractLayer
from .type_checks import check_valid
from .utils import SamplingContext, invert_permutation


class EinsumLayer(AbstractLayer):
Expand All @@ -33,14 +34,25 @@ def __init__(

# Compute correct output dimensions for width/height
if split_dim == "h":
out_height = self.in_shape[0] // self.cardinality
out_height = self.in_shape[0] // self.cardinality
out_width = self.in_shape[1]
else:
out_height = self.in_shape[0]
out_width = self.in_shape[1] // self.cardinality

self.out_shape = [out_height, out_width]

# Construct left/right partition indices
if split_dim == "h":
indices = torch.arange(self.in_shape[0])
else:
indices = torch.arange(self.in_shape[1])

self.left_idx = indices[0::2]
self.right_idx = indices[1::2]

self.inverse_idx = invert_permutation(torch.cat((self.left_idx, self.right_idx)))

# Weights, such that each sumnode has its own weights
weights = torch.randn(
out_height,
Expand Down Expand Up @@ -90,14 +102,20 @@ def forward(self, x: torch.Tensor):
assert self.num_features == H * W

# left/right shape: [n, h/2, w, c, r] or [n, h, w/2, c, r]
# if self.split_dim == "h":
# split_val = self.out_shape[0]
# left = x[:, split_val:]
# right = x[:, :split_val]
# else:
# split_val = self.out_shape[1]
# left = x[:, :, split_val:]
# right = x[:, :, :split_val]
if self.split_dim == "h":
split_val = self.out_shape[0]
left = x[:, split_val:]
right = x[:, :split_val]
left = x[:, self.left_idx]
right = x[:, self.right_idx]
else:
split_val = self.out_shape[1]
left = x[:, :, split_val:]
right = x[:, :, :split_val]
left = x[:, :, self.left_idx]
right = x[:, :, self.right_idx]

left_max = torch.max(left, dim=3, keepdim=True)[0]
left_prob = torch.exp(left - left_max)
Expand Down Expand Up @@ -156,7 +174,6 @@ def sample(
weights = weights.gather(dim=-1, index=r_idxs)
weights = weights.squeeze(-1)


# Index with parent indices
p_idxs = context.indices_out.view(-1, height, width, 1, 1, 1)
p_idxs = p_idxs.expand(-1, -1, -1, num_sums_in, num_sums_in, -1)
Expand Down Expand Up @@ -210,14 +227,22 @@ def sample(
# [[0, 1]
# [1, 0]]
#

# Map indices from product output tensor to partition tensors (last dim is now of size 2)
indices = self.unraveled_channel_indices[indices]
if self.split_dim == "h":
# Move partition dimension after the dimension at which we have split
indices = indices.permute(0, 1, 3, 2)
# Interleave tensors
indices = indices.reshape(num_samples, 2 * height, width)
else:
indices = indices.permute(0, 3, 2, 1)
indices = indices.permute(0, 1, 2, 3)
# Interleave tensors
indices = indices.reshape(num_samples, height, 2 * width)


assert list(indices.shape[-2:]) == self.in_shape

context.indices_out = indices
return context

Expand All @@ -243,8 +268,8 @@ def _disable_input_cache(self):
self._input_cache_right = None

def __repr__(self):
return "EinsumLayer(in_shape={}, out_shape={}, num_sums_in={}, num_sums_out={}, num_repetitions={})".format(
self.in_shape, self.out_shape, self.num_sums_in, self.num_sums_out, self.num_repetitions
return "EinsumLayer(in_shape={}, out_shape={}, num_sums_in={}, num_sums_out={}, num_repetitions={}, split={})".format(
self.in_shape, self.out_shape, self.num_sums_in, self.num_sums_out, self.num_repetitions, self.split_dim
)


Expand Down
74 changes: 43 additions & 31 deletions simple_einet/factorized_leaf_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def forward(self, x: torch.Tensor, marginalized_scopes: List[int]):
def sample(self, num_samples: int = None, context: SamplingContext = None) -> torch.Tensor:
# Save original indices_out and set context indices_out to none, such that the out_channel
# are not filtered in the base_leaf sampling procedure
__import__("pdb").set_trace()
indices_out = context.indices_out
context.indices_out = None
samples = self.base_leaf.sample(context=context)
Expand Down Expand Up @@ -138,38 +139,49 @@ def sample(self, num_samples: int = None, context: SamplingContext = None) -> to
# Interpretation: For each scope i,j in indices_out, the element (index) at position i,j
# is the index into the leaf sample output channel (0, ..., num_leaves - 1)
indices_out_i = indices_out[sample_idx]
assert list(indices_out_i.shape) == self.out_shape

# Get scope for the current repetition
upsize = self.scopes_w.shape[0] // self.scopes_w.shape[1]
upsample = nn.ConvTranspose2d(1, 1, upsize, stride=upsize)
upsample.weight[:] = 1.0
up = upsample(indices_out_i.view(1, 1, *indices_out_i.shape).float())
up = up.round().long()[0, 0]
scope_h = self.scopes_h[:, :, rep] # Which height scopes get merged
scope_w = self.scopes_w[:, :, rep] # Which width scopes get merged

# Turn one-hot encoded in-feature -> out-feature mapping into a linear index
rnge_in_h = torch.arange(self.out_shape[0], device=samples.device)
rnge_in_w = torch.arange(self.out_shape[1], device=samples.device)

# Mapping from in-scope to out-scope
# Read: element i in the following list is an index j, where
# i is an index into the in_shape and j is the corresponding index into the out_shape
scope_h = (scope_h * rnge_in_h).sum(-1).long()
scope_w = (scope_w * rnge_in_w).sum(-1).long()
assert scope_h.shape[0] == self.in_shape[0]
assert scope_w.shape[0] == self.in_shape[1]


# # Map indices_out from original "out_shape" view to "in_shape" view
scope_h = scope_h.view(-1, 1).expand(-1, indices_out_i.shape[0])
scope_w = scope_w.view(1, -1).expand(self.in_shape[1], -1)
indices_in = indices_out_i.gather(dim=0, index=scope_h)
indices_in = indices_in.gather(dim=1, index=scope_w)
# assert list(indices_out_i.shape) == self.out_shape

# # Get scope for the current repetition
# # upsize = self.scopes_w.shape[0] // self.scopes_w.shape[1]
# # upsample = nn.ConvTranspose2d(1, 1, upsize, stride=upsize)
# # upsample.weight[:] = 1.0
# # up = upsample(indices_out_i.view(1, 1, *indices_out_i.shape).float())
# # up = up.round().long()[0, 0]
# scope_h = self.scopes_h[:, :, rep] # Which height scopes get merged
# scope_w = self.scopes_w[:, :, rep] # Which width scopes get merged

# # Turn one-hot encoded in-feature -> out-feature mapping into a linear index
# rnge_in_h = torch.arange(self.out_shape[0], device=samples.device)
# rnge_in_w = torch.arange(self.out_shape[1], device=samples.device)

# # Mapping from in-scope to out-scope
# # Read: element i in the following list is an index j, where
# # i is an index into the in_shape and j is the corresponding index into the out_shape
# scope_h = (scope_h * rnge_in_h).sum(-1).long()
# scope_w = (scope_w * rnge_in_w).sum(-1).long()
# assert scope_h.shape[0] == self.in_shape[0]
# assert scope_w.shape[0] == self.in_shape[1]

# ic(indices_out_i.shape)
# ic(scope_h.shape)
# ic(scope_w.shape)

# # # Map indices_out from original "out_shape" view to "in_shape" view
# scope_h = scope_h.view(-1, 1).expand(-1, indices_out_i.shape[0])
# scope_w = scope_w.view(1, -1).expand(self.in_shape[1], -1)
# indices_in = indices_out_i.gather(dim=0, index=scope_h)
# indices_in = indices_in.gather(dim=1, index=scope_w)

# Note: the following is simpler but only works for the input dimension is a power of 2
repeat_h = self.scopes_h.shape[0] // self.scopes_h.shape[1]
repeat_w = self.scopes_w.shape[0] // self.scopes_w.shape[1]
indices_in = indices_out_i.repeat_interleave(repeat_h, dim=0).repeat_interleave(repeat_w, dim=1)

# assert (indices_in == indices_in_2).all()

# assert (indices_in == up).all()
indices_in.fill_(sample_idx % samples.shape[-1])
warnings.warn("Sampling indices fixed in factorizedleaf")
# indices_in.fill_(sample_idx % samples.shape[-1])
# warnings.warn("Sampling indices fixed in factorizedleaf")


# TODO: This is not yet working - something is off with the indexing
Expand Down

0 comments on commit ef37d62

Please sign in to comment.