Skip to content

Commit

Permalink
[Feature] Support uniform timesteps sampler for DDPM (#153)
Browse files Browse the repository at this point in the history
* support uniform timesteps sampler for DDPM

* fix known issuses
  • Loading branch information
LeoXing1996 committed Nov 29, 2021
1 parent 5b9c1bb commit eb023d3
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 0 deletions.
3 changes: 3 additions & 0 deletions mmgen/models/diffusions/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .sampler import UniformTimeStepSampler

__all__ = ['UniformTimeStepSampler']
36 changes: 36 additions & 0 deletions mmgen/models/diffusions/sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import numpy as np
import torch

from ..builder import MODULES


@MODULES.register_module()
class UniformTimeStepSampler:
"""Timestep sampler for DDPM-based models. This sampler sample all
timesteps with the same probabilistic.
Args:
num_timesteps (int): Total timesteps of the diffusion process.
"""

def __init__(self, num_timesteps):
self.num_timesteps = num_timesteps
self.prob = [1 / self.num_timesteps for _ in range(self.num_timesteps)]

def sample(self, batch_size):
"""Sample timesteps.
Args:
batch_size (int): The desired batch size of the sampled timesteps.
Returns:
torch.Tensor: Sampled timesteps.
"""
# use numpy to make sure our implementation is consistent with the
# official ones.
return torch.from_numpy(
np.random.choice(
self.num_timesteps, size=(batch_size, ), p=self.prob))

def __call__(self, batch_size):
"""Return sampled results."""
return self.sample(batch_size)
18 changes: 18 additions & 0 deletions tests/test_models/test_base_ddpm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import torch

from mmgen.models.diffusions import UniformTimeStepSampler


def test_uniform_sampler():
sampler = UniformTimeStepSampler(10)
timesteps = sampler(2)
assert timesteps.shape == torch.Size([
2,
])
assert timesteps.max() < 10 and timesteps.min() >= 0

timesteps = sampler.__call__(2)
assert timesteps.shape == torch.Size([
2,
])
assert timesteps.max() < 10 and timesteps.min() >= 0

0 comments on commit eb023d3

Please sign in to comment.