Skip to content

Commit

Permalink
update model
Browse files Browse the repository at this point in the history
  • Loading branch information
Dai Zuozhuo committed Dec 14, 2023
1 parent 86bb6c3 commit 8bee17e
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 21 deletions.
15 changes: 8 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ A girl is talking.
| ![Input image](docs/fish.jpg) | ![](docs/fish_mask.png) | ![](docs/fish.gif) The fish and tadpoles are playing.|



## Getting Started
This repository is based on [Text-To-Video-Finetuning](https://github.com/ExponentialML/Text-To-Video-Finetuning.git).

Expand All @@ -33,13 +32,15 @@ pip install -r requirements.txt
```

### Pretrained models
Download pretrained [motion mask and motion strength model](https://cloudbook-public-production.oss-cn-shanghai.aliyuncs.com/animation/mask_motion_v1.tar) and unzip it in the directory output/latent/mask_moition_v1

| Resolution | Model Path | Description |
| ------------- | ------------- | -------|
| 384x384 | [animate_anything_384](https://cloudbook-public-production.oss-cn-shanghai.aliyuncs.com/animation/aimate_anything_384_v1.01.tar) | Finetuned on 60K clips, 2s in 8fps |
| 512x512 | [animate_anything_512](https://cloudbook-public-production.oss-cn-shanghai.aliyuncs.com/animation/aimate_anything_512_v1.01.tar) | Finetuned on 60K clips, 2s in 8fps

## Running inference
Please download the checkpoints to output/latent, then run the following command:
Please download the pretrained models to output/latent, then run the following command. Please replace the {download_model} to your download model name:
```bash
python train.py --config output/latent/mask_motion_v1/config.yaml --eval validation_data.prompt_image=example/barbie2.jpg validation_data.prompt='A cartoon girl is talking.'
python train.py --config output/latent/{download_model}/config.yaml --eval validation_data.prompt_image=example/barbie2.jpg validation_data.prompt='A cartoon girl is talking.'
```

To control the motion area, we can use the labelme to generate a binary mask. First, we use labelme to drag the polygon the reference image.
Expand All @@ -55,14 +56,14 @@ labelme_json_to_dataset qingming2.json

Then run the following command for inference:
```bash
python train.py --config output/latent/mask_motion_v1/config.yaml --eval validation_data.prompt_image=example/qingming2.jpg validation_data.prompt='Peoples are walking on the street.' validation_data.mask=example/qingming2_label.jpg
python train.py --config output/latent/{download_model}/config.yaml --eval validation_data.prompt_image=example/qingming2.jpg validation_data.prompt='Peoples are walking on the street.' validation_data.mask=example/qingming2_label.jpg
```
![](docs/qingming2.gif)


User can ajust the motion strength by using the mask motion model:
```bash
python train.py --config output/latent/mask_motion_v1/
python train.py --config output/latent/{download_model}/
config.yaml --eval validation_data.prompt_image=example/qingming2.jpg validation_data.prompt='Peoples are walking on the street.' validation_data.mask=example/qingming2_label.jpg validation_data.strength=5
```
## Video super resolution
Expand Down
22 changes: 8 additions & 14 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -1038,21 +1038,17 @@ def eval(pipeline, vae_processor, validation_data, out_file, index, forward_t=25
mask = mask.resize((validation_data.width, validation_data.height))
np_mask = np.array(mask)
np_mask[np_mask!=0]=255
out_mask_path = out_file.replace(".gif", "_mask.jpg")
Image.fromarray(np_mask).save(out_mask_path)
else:
mask = generate_center_mask(input_image[0])
np_mask = mask[0]
np_mask[:] = 255
out_mask_path = out_file.replace(".gif", "_mask.jpg")
Image.fromarray(np_mask).save(out_mask_path)
np_mask = np.ones([validation_data.height, validation_data.width], dtype=np.uint8)*255
out_mask_path = os.path.splitext(out_file)[0] + "_mask.jpg"
Image.fromarray(np_mask).save(out_mask_path)

initial_latents, timesteps = DDPM_forward_timesteps(input_image_latents, forward_t, validation_data.num_frames, diffusion_scheduler)
mask = T.ToTensor()(np_mask).to(dtype).to(device)
b, c, f, h, w = initial_latents.shape
mask = T.Resize([h, w], antialias=False)(mask)
mask = rearrange(mask, 'b h w -> b 1 1 h w')
motion_strength = validation_data.get('strength', index+2)
motion_strength = validation_data.get("strength", index+2)
with torch.no_grad():
video_frames, video_latents = pipeline(
prompt=prompt,
Expand All @@ -1064,14 +1060,12 @@ def eval(pipeline, vae_processor, validation_data, out_file, index, forward_t=25
guidance_scale=validation_data.guidance_scale,
condition_latent=input_image_latents,
mask=mask,
motion=motion_strength,
motion=[motion_strength],
return_dict=False,
timesteps=timesteps,
)
#export_to_video(video_frames, out_file, train_data.get('fps', 8))
if preview:
imageio.mimwrite(out_file, video_frames, duration=125, loop=0)
imageio.mimwrite(out_file.replace(".gif", ".mp4"), video_frames, fps=8, quality=9)
imageio.mimwrite(out_file, video_frames, fps=validation_data.get('fps', 8))
real_motion_strength = calculate_latent_motion_score(video_latents).cpu().numpy()[0]
precision = calculate_motion_precision(video_frames, np_mask)
print(f"save file {out_file}, motion strength {motion_strength} -> {real_motion_strength}, motion precision {precision}")
Expand Down Expand Up @@ -1110,13 +1104,13 @@ def batch_eval(unet, text_encoder, vae, vae_processor, lora_manager, pretrained_

motion_errors = []
motion_precisions = []
iters = 10
iters = 5
motion_precision = 0
for t in range(iters):
name= os.path.basename(validation_data.prompt_image)
out_file_dir = f"{output_dir}/{name.split('.')[0]}"
os.makedirs(out_file_dir, exist_ok=True)
out_file = f"{out_file_dir}/{global_step+t}.gif"
out_file = f"{out_file_dir}/{global_step+t}.mp4"
precision = eval(pipeline, vae_processor,
validation_data, out_file, t, forward_t=validation_data.num_inference_steps, preview=preview)
motion_precision += precision
Expand Down

0 comments on commit 8bee17e

Please sign in to comment.