Skip to content

Commit

Permalink
Update README
Browse files Browse the repository at this point in the history
Signed-off-by: Steven Lang <steven.lang.mz@gmail.com>
  • Loading branch information
braun-steven committed Dec 14, 2021
1 parent e676164 commit 28259af
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 21 deletions.
40 changes: 19 additions & 21 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@ pip install -e .
## Usage Example

```python

import torch
from simple_einet.clipper import DistributionClipper
from simple_einet.distributions import Normal
from simple_einet.distributions import RatNormal
from simple_einet.einet import Einet
from simple_einet.einet import EinetConfig

Expand All @@ -33,48 +34,45 @@ out_features = 3
x = torch.randn(batchsize, in_features)

# Construct Einet
einet = Einet(EinetConfig(
in_features=in_features,
D=2,
S=2,
I=2,
R=3,
C=out_features,
dropout=0.5,
leaf_base_class=Normal,
))
einet = Einet(EinetConfig(in_features=in_features, D=2, S=2, I=2, R=3, C=out_features, dropout=0.0, leaf_base_class=RatNormal, leaf_base_kwargs={"min_sigma": 1e-5, "max_sigma": 1.0},))

# Compute log-likelihoods
lls = einet(x)
print(f"lls={lls}")
print(f"lls.shape={lls.shape}")

# Construct samples
samples = einet.sample(2)
print(f"samples={samples}")
print(f"samples.shape={samples.shape}")

# Optimize Einet parameters (weights and leaf params)
optim = torch.optim.Adam(einet.parameters(), lr=0.001)
clipper = DistributionClipper()

for _ in range(1000):
optim.zero_grad()

# Forward pass: log-likelihoods
# Forward pass: compute log-likelihoods
lls = einet(x)

# Backprop NLL loss
# Backprop negative log-likelihood loss
nlls = -1 * lls.sum()
nlls.backward()

# Update weights
optim.step()

# Clip leaf distribution parameters (e.g. std > 0.0, etc.)
clipper(einet._leaf)
# Construct samples
samples = einet.sample(2)
print(f"samples={samples}")
print(f"samples.shape={samples.shape}")
```

## MNIST Samples
Some samples from the `[0, 1]` class-subset of MNIST [./mnist.py]:

**Samples**

![MNIST Samples]( ./res/mnist-0-1-samples.png )

**Reconstructions**

![MNIST Reconstructions]( ./res/mnist-0-1-samples.png )

## Citing EinsumNetworks

Expand Down
Binary file added res/mnist-0-1-rec.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added res/mnist-0-1-samples.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 28259af

Please sign in to comment.