Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feature] support normal sample without cfg #16

Merged
merged 5 commits into from
Feb 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions open_sora/modeling/dit.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def __init__(
self.proj = nn.Linear(in_features, embed_dim, bias=bias)

def drop_sample(self, x: torch.Tensor) -> torch.Tensor:
drop_ids = torch.rand(x.shape[0], device=x.device) < self.dropout_prob
drop_ids = torch.rand(x.shape[0], 1, 1, device=x.device) < self.dropout_prob
x = torch.where(drop_ids, torch.zeros_like(x), x)
return x

Expand Down Expand Up @@ -517,13 +517,18 @@ def forward_with_cfg(
# three channels by default. The standard approach to cfg applies it to all channels.
# This can be done by uncommenting the following line and commenting-out the line following that.
# eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
c = model_out.shape[2]
assert c == 2 * self.in_channels
eps, rest = model_out.chunk(2, dim=2)
if self.learn_sigma:
c = model_out.shape[2]
assert c == 2 * self.in_channels
eps, rest = model_out.chunk(2, dim=2)
else:
eps = model_out
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
eps = torch.cat([half_eps, half_eps], dim=0)
return torch.cat([eps, rest], dim=2)
if self.learn_sigma:
return torch.cat([eps, rest], dim=2)
return eps


#################################################################################
Expand Down
123 changes: 110 additions & 13 deletions open_sora/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
from colossalai.utils import get_current_device
from datasets import Dataset as HFDataset
from datasets import dataset_dict, load_from_disk
from diffusers.models import AutoencoderKL
from torch.utils.data import ConcatDataset, Dataset
from torchvision.io import read_video
from transformers import AutoModel

DatasetType = Union[Dataset, ConcatDataset, dataset_dict.Dataset]
PathType = Union[str, os.PathLike]
Expand Down Expand Up @@ -155,11 +157,116 @@ def unnormalize_video(video: torch.Tensor) -> torch.Tensor:
return (video + 0.5) * 255


class VideoCompressor:
t_factor: int
h_w_factor: int
out_channels: int

def encode(self, video: torch.Tensor) -> torch.Tensor:
"""Encode a video.

Args:
video (torch.Tensor): [T, H, W, C]

Returns:
torch.Tensor: [T, C, H, W]
"""
raise NotImplementedError

def decode(self, latent: torch.Tensor) -> torch.Tensor:
"""Decode a latent tensor.

Args:
latent (torch.Tensor): [T, C, H, W]

Returns:
torch.Tensor: [T, H, W, C]
"""
raise NotImplementedError


class RawVideoCompressor(VideoCompressor):
t_factor = 1
h_w_factor = 1
out_channels = 3

def encode(self, video: torch.Tensor) -> torch.Tensor:
# [T, H, W, C] -> [T, C, H, W]
return video.permute(0, 3, 1, 2).contiguous()

def decode(self, latent: torch.Tensor) -> torch.Tensor:
# [T, C, H, W] -> [T, H, W, C]
return latent.permute(0, 2, 3, 1).contiguous()


class VqvaeVideoCompressor(VideoCompressor):
t_factor = 2
h_w_factor = 4

def __init__(self, vqvae: nn.Module):
self.vqvae = vqvae
self.out_channels = vqvae.embedding_dim

def encode(self, video: torch.Tensor) -> torch.Tensor:
# [T, H, W, C] -> [B, C, T, H, W]
video = video.permute(3, 0, 1, 2).unsqueeze(0)
latent_indices, embeddings = self.vqvae.encode(video, include_embeddings=True)
# [B, C, T, H, W] -> [T, C, H, W]
return embeddings.squeeze(0).permute(1, 0, 2, 3)

def decode(self, latent: torch.Tensor) -> torch.Tensor:
# [T, C, H, W] -> [B, C, T, H, W]
latent = latent.permute(1, 0, 2, 3).unsqueeze(0)
video = self.vqvae.decode_from_embeddings(latent)
# [B, C, T, H, W] -> [T, H, W, C]
video = video.squeeze(0).permute(1, 2, 3, 0)
return video


class VaeVideoCompressor(VideoCompressor):
t_factor = 1
h_w_factor = 8
out_channels = 4

def __init__(self, vae: nn.Module):
self.vae = vae

def encode(self, video: torch.Tensor) -> torch.Tensor:
# [T, H, W, C] -> [T, C, H, W]
video = video.permute(0, 3, 1, 2)
return self.vae.encode(video).latent_dist.sample()

def decode(self, latent: torch.Tensor) -> torch.Tensor:
video = self.vae.decode(latent).sample
# [T, C, H, W] -> [T, H, W, C]
return video.permute(0, 2, 3, 1).contiguous()


def create_video_compressor(
compressor_type: str,
vqvae_path="hpcai-tech/vqvae",
vae_path="stabilityai/sd-vae-ft-mse",
) -> VideoCompressor:
if compressor_type == "raw":
return RawVideoCompressor()
if compressor_type == "vqvae":
vqvae = (
AutoModel.from_pretrained(vqvae_path, trust_remote_code=True)
.to(get_current_device())
.eval()
)
return VqvaeVideoCompressor(vqvae)
if compressor_type == "vae":
vae = AutoencoderKL.from_pretrained(vae_path).to(get_current_device()).eval()
return VaeVideoCompressor(vae)
raise ValueError(f"Unsupported video compressor type {compressor_type}")


@torch.no_grad()
def preprocess_batch(
batch: dict,
patch_size: int,
vqvae: Optional[nn.Module] = None,
video_compressor: VideoCompressor,
device=None,
use_cross_attn=True,
) -> dict:
Expand All @@ -169,18 +276,8 @@ def preprocess_batch(
for video in batch.pop("videos"):
video = video.to(device)
video = normalize_video(video)
if vqvae is not None:
# [T, H, W, C] -> [B, C, T, H, W]
video = video.permute(3, 0, 1, 2)
video = video.unsqueeze(0)
latent_indices, embeddings = vqvae.encode(video, include_embeddings=True)
# [B, C, T, H, W] -> [T, C, H, W]
embeddings = embeddings.squeeze(0).permute(1, 0, 2, 3)
videos.append(embeddings)
else:
# [T, H, W, C] -> [T, C, H, W]
video = video.permute(0, 3, 1, 2).contiguous()
videos.append(video)
video = video_compressor.encode(video)
videos.append(video)
video_latent_states, video_padding_mask = patchify_batch(videos, patch_size)
batch["video_latent_states"] = video_latent_states
batch["video_padding_mask"] = video_padding_mask
Expand Down
81 changes: 37 additions & 44 deletions sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,37 +15,26 @@

from colossalai.utils import get_current_device
from torchvision.io import write_video
from transformers import AutoModel, AutoTokenizer, CLIPTextModel
from transformers import AutoTokenizer, CLIPTextModel

from open_sora.diffusion import create_diffusion
from open_sora.modeling import DiT_models
from open_sora.utils.data import col2video, unnormalize_video
from open_sora.utils.data import col2video, create_video_compressor, unnormalize_video


def main(args):
# Setup PyTorch:
torch.manual_seed(args.seed)
torch.set_grad_enabled(False)
device = get_current_device()
if len(args.vqvae) > 0:
vqvae = (
AutoModel.from_pretrained(args.vqvae, trust_remote_code=True)
.to(device)
.eval()
)
in_channels = vqvae.embedding_dim
w_h_factor = 4
t_factor = 2
else:
# disable VQ-VAE if not provided, just use raw video frames
vqvae = None
in_channels = 3
w_h_factor = 1
t_factor = 1

video_compressor = create_video_compressor(args.compressor)
model_kwargs = {"in_channels": video_compressor.out_channels}

text_model = CLIPTextModel.from_pretrained(args.text_model).to(device).eval()
tokenizer = AutoTokenizer.from_pretrained(args.text_model)

model = DiT_models[args.model](in_channels=in_channels).to(device).eval()
model = DiT_models[args.model](**model_kwargs).to(device).eval()
patch_size = model.patch_size
model.load_state_dict(torch.load(args.ckpt))
diffusion = create_diffusion(str(args.num_sampling_steps))
Expand All @@ -58,55 +47,56 @@ def main(args):
num_frames = args.fps * args.sec
z = torch.randn(
1,
(args.height // patch_size // w_h_factor)
* (args.width // patch_size // w_h_factor)
* (num_frames // t_factor),
in_channels,
(args.height // patch_size // video_compressor.h_w_factor)
* (args.width // patch_size // video_compressor.h_w_factor)
* (num_frames // video_compressor.t_factor),
video_compressor.out_channels,
patch_size,
patch_size,
device=device,
)

# Setup classifier-free guidance:
model_kwargs = {}
z = torch.cat([z, z], 0)
model_kwargs["text_latent_states"] = torch.cat(
[text_latent_states, torch.zeros_like(text_latent_states)], 0
)
model_kwargs["cfg_scale"] = args.cfg_scale
if not args.disable_cfg:
z = torch.cat([z, z], 0)
model_kwargs["text_latent_states"] = torch.cat(
[text_latent_states, torch.zeros_like(text_latent_states)], 0
)
model_kwargs["cfg_scale"] = args.cfg_scale
else:
model_kwargs["text_latent_states"] = text_latent_states
model_kwargs["attention_mask"] = torch.ones(
2, 1, z.shape[1], text_latent_states.shape[1], device=device, dtype=torch.int
z.shape[0],
1,
z.shape[1],
text_latent_states.shape[1],
device=device,
dtype=torch.int,
)

# Sample images:
samples = diffusion.p_sample_loop(
model.forward_with_cfg,
model if args.disable_cfg else model.forward_with_cfg,
z.shape,
z,
clip_denoised=False,
model_kwargs=model_kwargs,
progress=True,
device=device,
)
samples, _ = samples.chunk(2, dim=0) # Remove null class samples
if not args.disable_cfg:
samples, _ = samples.chunk(2, dim=0) # Remove null class samples
samples = col2video(
samples.squeeze(),
(
num_frames // t_factor,
in_channels,
args.height // w_h_factor,
args.width // w_h_factor,
num_frames // video_compressor.t_factor,
video_compressor.out_channels,
args.height // video_compressor.h_w_factor,
args.width // video_compressor.h_w_factor,
),
)
if vqvae is not None:
# [T, C, H, W] -> [B, C, T, H, W]
samples = samples.permute(1, 0, 2, 3).unsqueeze(0)
samples = vqvae.decode_from_embeddings(samples)
# [B, C, T, H, W] -> [T, H, W, C]
samples = samples.squeeze(0).permute(1, 2, 3, 0)
else:
# [T, C, H, W] -> [T, H, W, C]
samples = samples.permute(0, 2, 3, 1)
samples = video_compressor.decode(samples)
samples = unnormalize_video(samples).to(torch.uint8)

write_video("sample.mp4", samples.cpu(), args.fps)
Expand All @@ -131,13 +121,16 @@ def main(args):
required=True,
help="Optional path to a DiT checkpoint (default: auto-download a pre-trained DiT-XL/2 model).",
)
parser.add_argument("--vqvae", default="hpcai-tech/vqvae")
parser.add_argument(
"-c", "--compressor", choices=["raw", "vqvae", "vae"], default="raw"
)
parser.add_argument(
"--text_model", type=str, default="openai/clip-vit-base-patch32"
)
parser.add_argument("--width", type=int, default=480)
parser.add_argument("--height", type=int, default=320)
parser.add_argument("--fps", type=int, default=15)
parser.add_argument("--sec", type=int, default=8)
parser.add_argument("--disable-cfg", action="store_true", default=False)
args = parser.parse_args()
main(args)
Loading