Skip to content

Commit

Permalink
Extend support for MultivariateNormal
Browse files Browse the repository at this point in the history
Signed-off-by: Steven Lang <steven.lang.mz@gmail.com>
  • Loading branch information
braun-steven committed Dec 14, 2021
1 parent 53d115e commit e676164
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 49 deletions.
18 changes: 14 additions & 4 deletions mnist.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#!/usr/bin/env python3

import torch
from rich.traceback import install
install(suppress=[torch])
import argparse
import os
from typing import Tuple
Expand All @@ -11,7 +13,7 @@
install()

import torch
from distributions import Binomial, MultivariateNormal, Normal, RatNormal
from simple_einet.distributions import Binomial, MultivariateNormal, Normal, RatNormal
from simple_einet.einet import Einet, EinetConfig
from torchvision import datasets, transforms
import torchvision
Expand Down Expand Up @@ -108,7 +110,8 @@

device = torch.device(args.device)
# digits = [0, 1, 5, 8]
digits = list(range(10))
# digits = list(range(10))
digits = [0, 1]

# Construct Einet
num_classes = len(digits) if args.classification else 1
Expand All @@ -122,6 +125,10 @@
C=num_classes,
leaf_base_class=Binomial,
leaf_base_kwargs={"total_count": 255},
# leaf_base_class=MultivariateNormal,
# leaf_base_kwargs={"cardinality": 2, "min_sigma": 1e-5, "max_sigma": 0.1},
# leaf_base_class=RatNormal,
# leaf_base_kwargs={"min_sigma": 1e-5, "max_sigma": 1.0},
dropout=0.0,
)
model = Einet(config).to(device)
Expand Down Expand Up @@ -375,7 +382,7 @@ def test(model, device, loader, tag):
if has_gauss_dist:
test_x = preprocess(test_x, n_bits=n_bits, image_shape=(1, 28, 28), device=device, pertubate=False)
else:
test_x = (test_x * 255).long()
test_x = (test_x * 255)


grid = torchvision.utils.make_grid(
Expand Down Expand Up @@ -403,3 +410,6 @@ def test(model, device, loader, tag):
reconstructions = reconstructions.view(-1, 1, 28, 28)
grid = torchvision.utils.make_grid(reconstructions, **grid_kwargs)
torchvision.utils.save_image(grid, os.path.join(result_dir, "reconstructions.png"))

print(f"Result directory: {result_dir}")
print("Done.")
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
numpy==1.20.1
torch==1.10.0
torchvision==0.11.1
rich==10.16.0
103 changes: 64 additions & 39 deletions simple_einet/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,6 @@ def dist_sample(distribution: dist.Distribution, context: SamplingContext = None
-1
)

if type(distribution) in (dist.Binomial, dist.Bernoulli):
samples = samples.long()

return samples


Expand All @@ -121,7 +118,7 @@ class Leaf(AbstractLayer):
If the input at a specific position is NaN, the variable will be marginalized.
"""

def __init__(self, in_features: int, out_channels: int, num_repetitions: int = 1, dropout=0.0):
def __init__(self, in_features: int, out_channels: int, num_repetitions: int = 1, dropout=0.0, cardinality=1):
"""
Create the leaf layer.
Expand All @@ -130,11 +127,13 @@ def __init__(self, in_features: int, out_channels: int, num_repetitions: int = 1
out_channels: Number of parallel representations for each input feature.
num_repetitions: Number of parallel repetitions of this layer.
dropout: Dropout probability.
cardinality: Number of random variables covered by a single leaf.
"""
super().__init__(in_features=in_features, num_repetitions=num_repetitions)
self.in_features = check_valid(in_features, int, 1)
self.out_channels = check_valid(out_channels, int, 1)
self.num_repetitions = check_valid(num_repetitions, int, 1)
self.cardinality = check_valid(cardinality, int, 1)
dropout = check_valid(dropout, float, 0.0, 1.0)
self.dropout = nn.Parameter(torch.tensor(dropout), requires_grad=False)

Expand All @@ -157,7 +156,8 @@ def _apply_dropout(self, x: torch.Tensor) -> torch.Tensor:
def _marginalize_input(self, x: torch.Tensor, marginalized_scopes: List[int]) -> torch.Tensor:
# Marginalize nans set by user
if marginalized_scopes:
x[:, marginalized_scopes] = self.marginalization_constant
s = list(set(torch.tensor(marginalized_scopes).div(self.cardinality, rounding_mode="floor").tolist()))
x[:, s] = self.marginalization_constant
return x

def forward(self, x, marginalized_scopes: List[int]):
Expand Down Expand Up @@ -233,10 +233,19 @@ def _get_base_distribution(self):
probs_ratio = torch.sigmoid(self.probs)
return dist.Bernoulli(probs=probs_ratio)

class Binomial(Leaf):

def __init__(self, in_features: int, out_channels: int, total_count: int, num_repetitions: int = 1, dropout=0):
super().__init__(in_features, out_channels, num_repetitions=num_repetitions, dropout=dropout)
class Binomial(Leaf):
def __init__(
self,
in_features: int,
out_channels: int,
total_count: int,
num_repetitions: int = 1,
dropout=0,
):
super().__init__(
in_features, out_channels, num_repetitions=num_repetitions, dropout=dropout
)

self.total_count = check_valid(total_count, int, lower_bound=1)

Expand All @@ -245,13 +254,11 @@ def __init__(self, in_features: int, out_channels: int, total_count: int, num_re
# Learnable s
self.sigmoid_scale = nn.Parameter(torch.tensor(2.0))


def _get_base_distribution(self):
# Use sigmoid to ensure, that probs are in valid range
return dist.Binomial(self.total_count, probs=torch.sigmoid(self.probs * self.sigmoid_scale))



class MultivariateNormal(Leaf):
"""Multivariate Gaussian layer."""

Expand All @@ -262,8 +269,10 @@ def __init__(
cardinality: int,
num_repetitions: int = 1,
dropout=0.0,
min_sigma: float = 0.1,
max_sigma: float = 1.0,
):
"""Creat a gaussian layer.
"""Creat a multivariate gaussian layer.
Args:
out_channels: Number of parallel representations for each input feature.
Expand All @@ -273,13 +282,14 @@ def __init__(
"""
# TODO: Fix for num_repetitions
super().__init__(in_features, out_channels, num_repetitions, dropout)
# self.cardinality = check_valid(cardinality, int, 2, in_features + 1)
self.cardinality = cardinality
super().__init__(in_features, out_channels, num_repetitions, dropout, cardinality)
self._pad_value = in_features % cardinality
self.out_features = np.ceil(in_features / cardinality).astype(int)
self._n_dists = np.ceil(in_features / cardinality).astype(int)

self.min_sigma = check_valid(min_sigma, float, 0.0, max_sigma)
self.max_sigma = check_valid(max_sigma, float, min_sigma)

# Create gaussian means and covs
self.means = nn.Parameter(
torch.randn(out_channels * self._n_dists * self.num_repetitions, cardinality)
Expand All @@ -296,10 +306,12 @@ def __init__(

rand = rand + torch.randn_like(rand) * 1e-1

# Make matrices triangular
trils = rand.tril()
# Make matrices triangular and remove diagonal entries
cov_tril_wo_diag = rand.tril(diagonal=-1)
cov_tril_wi_diag = torch.rand(out_channels * self._n_dists * self.num_repetitions, cardinality, cardinality)

self.triangular = nn.Parameter(trils)
self.cov_tril_wo_diag = nn.Parameter(cov_tril_wo_diag)
self.cov_tril_wi_diag = nn.Parameter(cov_tril_wi_diag)
# self._mv = dist.MultivariateNormal(loc=self.means, scale_tril=self.triangular)
# Reassign means since mv __init__ creates a copy and thus would loose track for autograd
# self._mv.loc.requires_grad_(True)
Expand All @@ -308,23 +320,20 @@ def __init__(
self.out_shape = f"(N, {self.out_features}, {self.out_channels})"

def forward(self, x: torch.Tensor, marginalized_scopes: List[int]) -> torch.Tensor:
batch_size = x.shape[0]

# Pad dummy variable via reflection
if self._pad_value != 0:
x = F.pad(x, pad=[0, 0, 0, self._pad_value], mode="reflect")

# Make room for out_channels of layer
# Output shape: [n, 1, d]
batch_size = x.shape[0]
# Push repetitions into dim=1
x = x.permute(0, 2, 1) # [n, r, d]
# Make room for repetitions: [n, 1, d]
x = x.unsqueeze(1)

# Split features into groups
x = x.view(
batch_size, self.num_repetitions, 1, self._n_dists, self.cardinality
) # [n, r, 1, d/k, k]
x = x.view(batch_size, 1, 1, self._n_dists, self.cardinality) # [n, 1, 1, d/k, k]

# Repeat groups by number of output_channels
x = x.repeat(1, 1, self.out_channels, 1, 1) # [n, r, oc, d/k, k]
# Repeat groups by number of output_channels and number of repetitions
x = x.repeat(1, self.num_repetitions, self.out_channels, 1, 1) # [n, r, oc, d/k, k]

# Merge groups and repetitions
x = x.view(
Expand All @@ -341,7 +350,7 @@ def forward(self, x: torch.Tensor, marginalized_scopes: List[int]) -> torch.Tens
x = x.permute(0, 3, 2, 1) # [n, d/k, oc, r]

# Marginalize and apply dropout
x = self._marginalize_input(x)
x = self._marginalize_input(x, marginalized_scopes)
x = self._apply_dropout(x)

return x
Expand Down Expand Up @@ -378,24 +387,40 @@ def sample(self, num_samples: int = None, context: SamplingContext = None) -> to

samples = tmp # [n, oc, d/k, k]

# If parent index into out_channels are given

if context.parent_indices is not None:
indices = context.parent_indices.unsqueeze(1).unsqueeze(-1).repeat(1, 1, 1, cardinality)
# Choose only specific samples for each feature/scope
samples = torch.gather(samples, dim=1, index=indices).squeeze(-1)

samples = samples.squeeze(0) # Squeeze out_channels dim
samples = samples.view(
context.num_samples,
self._n_dists * self.cardinality,
self.out_channels,
self._n_dists,
self.cardinality,
)
samples = samples.permute(0, 2, 3, 1)

return samples

def _get_base_distribution(self):
triang = self.triangular.sigmoid().tril()
mv = dist.MultivariateNormal(loc=self.means, scale_tril=triang)

if self.min_sigma < self.max_sigma:
# scale diag to [min_sigma, max_sigma]
cov_diag = self.cov_tril_wi_diag
sigma_ratio = torch.sigmoid(cov_diag)
cov_diag = self.min_sigma + (self.max_sigma - self.min_sigma) * sigma_ratio
cov_diag = cov_diag.tril()

# scale tril to [-max_sigma, max_sigma]
cov_tril = self.cov_tril_wo_diag
sigma_ratio = torch.sigmoid(cov_tril)
cov_tril = -1 * self.max_sigma + 2 * self.max_sigma * sigma_ratio
cov_tril = cov_tril.tril(-1)

else:
cov_tril = self.cov_tril_wo_diag.tril(-1)
cov_diag = self.cov_tril_wi_diag.tril().sigmoid()

scale_tril = cov_tril + cov_diag
mv = dist.MultivariateNormal(loc=self.means, scale_tril=scale_tril)

# ic(cov_diag.mean(0))
# ic(cov_tril.mean(0))
return mv


Expand Down
2 changes: 1 addition & 1 deletion simple_einet/einet.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def _build_input_distribution(self):
)

return FactorizedLeaf(
in_features=self.config.F,
in_features=base_leaf.out_features,
out_features=2 ** self.config.D,
num_repetitions=self.config.R,
base_leaf=base_leaf,
Expand Down
35 changes: 30 additions & 5 deletions simple_einet/factorized_leaf_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,33 @@ def sample(self, num_samples: int = None, context: SamplingContext = None) -> to
samples = self.base_leaf.sample(context=context)

# Check that shapes match as expected
assert samples.shape == (context.num_samples, self.in_features, self.base_leaf.out_channels)
if samples.dim() == 3:
assert samples.shape == (
context.num_samples,
self.in_features,
self.base_leaf.out_channels,
)
elif samples.dim() == 4:
assert self.in_features == samples.shape[1]
assert hasattr(self.base_leaf, "cardinality")
assert samples.shape == (
context.num_samples,
self.base_leaf.out_features,
self.base_leaf.cardinality,
self.base_leaf.out_channels,
)

# Collect final samples in temporary tensor
if hasattr(self.base_leaf, "cardinality"):
cardinality = self.base_leaf.cardinality
else:
cardinality = 1
tmp = torch.zeros(
context.num_samples, self.in_features, device=samples.device, dtype=samples.dtype
context.num_samples,
self.in_features,
cardinality,
device=samples.device,
dtype=samples.dtype,
)
for sample_idx in range(context.num_samples):
# Get correct repetition
Expand All @@ -86,13 +108,16 @@ def sample(self, num_samples: int = None, context: SamplingContext = None) -> to
scope = (scope * rnge_in).sum(-1).long()

# Map parent_indices from original "out_features" view to "in_feautres" view
paren_indices_in = parent_indices_out[scope]
parent_indices_in = parent_indices_out[scope]

# Access base leaf samples based on
rnge_out = torch.arange(self.in_features, device=samples.device)
tmp[sample_idx] = samples[sample_idx, rnge_out, paren_indices_in]

samples = tmp
tmp[sample_idx] = samples[sample_idx, rnge_out, ..., parent_indices_in].view(
self.in_features, cardinality
)

samples = tmp.view(context.num_samples, -1)
return samples

def __repr__(self):
Expand Down

0 comments on commit e676164

Please sign in to comment.