Skip to content
This repository has been archived by the owner on Jan 4, 2024. It is now read-only.

Commit

Permalink
fix: batch generation
Browse files Browse the repository at this point in the history
  • Loading branch information
decahedron1 committed Apr 7, 2023
1 parent c7064aa commit f299dde
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 19 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
<div align=center>
<img src="https://parcel.pyke.io/v2/cdn/assetdelivery/diffusers/doc/diffusers.webp" width="100%" alt="pyke Diffusers">
<a href="https://parcel.pyke.io/v2/cdn/assetdelivery/diffusers/doc/gallery0.webp" target="_blank"><img src="https://parcel.pyke.io/v2/cdn/assetdelivery/diffusers/doc/gallery0.webp" width="100%" alt="Gallery of generated images"></a>
<a href="https://github.com/pykeio/diffusers/actions/workflows/test.yml"><img alt="GitHub Workflow Status" src="https://img.shields.io/github/actions/workflow/status/pykeio/diffusers/test.yml?branch=v2&style=for-the-badge"></a> <a href="https://crates.io/crates/pyke-diffusers" target="_blank"><img alt="Crates.io" src="https://img.shields.io/crates/d/ort?style=for-the-badge"></a> <a href="https://discord.gg/BAkXJ6VjCz"><img alt="Discord" src="https://img.shields.io/discord/1029216970027049072?style=for-the-badge&logo=discord&logoColor=white"></a>
<hr />
</div>

Expand Down Expand Up @@ -137,5 +138,5 @@ A combination of 256x256 image generation via `StableDiffusionMemoryOptimizedPip
- [ ] Rewrite scheduler system ([#16](https://github.com/pykeio/diffusers/issues/16))
- [x] Acceleration for M1 Macs ([#14](https://github.com/pykeio/diffusers/issues/14))
- [ ] Web interface
- [ ] Batch generation
- [x] Batch generation
- [ ] Explore other backends (pyke's DragonML, [tract](https://github.com/sonos/tract))
9 changes: 5 additions & 4 deletions src/pipelines/stable_diffusion/impl_txt2img.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use image::DynamicImage;
use ndarray::{concatenate, Array1, Array4, ArrayD, Axis, IxDyn};
use ndarray::{concatenate, s, Array1, Array4, ArrayD, Axis, IxDyn};
use ndarray_rand::rand_distr::StandardNormal;
use ndarray_rand::RandomExt;
use num_traits::ToPrimitive;
Expand Down Expand Up @@ -181,9 +181,10 @@ impl StableDiffusionTxt2ImgOptions {

let mut noise_pred: Array4<f32> = noise_pred.clone();
if do_classifier_free_guidance {
let mut noise_pred_chunks = noise_pred.axis_iter(Axis(0));
let (noise_pred_uncond, noise_pred_text) = (noise_pred_chunks.next().unwrap(), noise_pred_chunks.next().unwrap());
let (noise_pred_uncond, noise_pred_text) = (noise_pred_uncond.insert_axis(Axis(0)), noise_pred_text.insert_axis(Axis(0)));
assert!(noise_pred.shape()[0] % 2 == 0);
let split_len = (noise_pred.shape()[0] / 2) as isize;
let noise_pred_uncond = noise_pred.slice(s![..split_len, .., .., ..]);
let noise_pred_text = noise_pred.slice(s![split_len.., .., .., ..]);
noise_pred = &noise_pred_uncond + self.guidance_scale * (&noise_pred_text - &noise_pred_uncond);
}

Expand Down
10 changes: 4 additions & 6 deletions src/pipelines/stable_diffusion/lpw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,7 @@ pub fn get_unweighted_text_embeddings(
InputTensor::from_array(text_input_chunk.into_dyn())
} else {
// pre-embed
let text_input_chunk = text_input_chunk.into_raw_vec();
InputTensor::from_array(embeddings.embed(text_input_chunk.iter().map(|f| *f as u32).collect()).into_dyn())
InputTensor::from_array(embeddings.embed(text_input_chunk).into_dyn())
};

let chunk_embeddings = text_encoder.run(vec![text_input_chunk])?;
Expand Down Expand Up @@ -241,8 +240,7 @@ pub fn get_unweighted_text_embeddings(
InputTensor::from_array(text_input.into_dyn())
} else {
// pre-embed
let text_input = text_input.into_raw_vec();
InputTensor::from_array(embeddings.embed(text_input.iter().map(|f| *f as u32).collect()).into_dyn())
InputTensor::from_array(embeddings.embed(text_input).into_dyn())
};

let text_embeddings = text_encoder.run(vec![text_input])?;
Expand Down Expand Up @@ -300,7 +298,7 @@ pub fn get_weighted_text_embeddings(
.slice(s![.., .., NewAxis])
.to_owned();
let current_mean = text_embeddings.mean_axis(Axis(2)).unwrap().mean_axis(Axis(1)).unwrap();
let text_embeddings = text_embeddings * (previous_mean / current_mean);
let text_embeddings = text_embeddings * (previous_mean / current_mean).insert_axis(Axis(1)).insert_axis(Axis(2));

let uncond_embeddings = if let Some((uncond_tokens, uncond_weights)) = uncond_padded {
let uncond_embeddings = get_unweighted_text_embeddings(
Expand All @@ -316,7 +314,7 @@ pub fn get_weighted_text_embeddings(
.slice(s![.., .., NewAxis])
.to_owned();
let current_mean = uncond_embeddings.mean_axis(Axis(2)).unwrap().mean_axis(Axis(1)).unwrap();
Some(uncond_embeddings * (previous_mean / current_mean))
Some(uncond_embeddings * (previous_mean / current_mean).insert_axis(Axis(1)).insert_axis(Axis(2)))
} else {
None
};
Expand Down
21 changes: 13 additions & 8 deletions src/pipelines/stable_diffusion/text_embeddings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::{
};

use byteorder::{LittleEndian, ReadBytesExt};
use ndarray::{concatenate, Array2, Array3, Axis};
use ndarray::{concatenate, stack, Array2, Array3, Axis};

use crate::clip::CLIPStandardTokenizer;

Expand Down Expand Up @@ -89,14 +89,19 @@ impl TextEmbeddings {
AddedToken { tok, tid: token_id }
}

pub fn embed(&self, token_ids: Vec<u32>) -> Array3<f32> {
let mut embeds = Vec::with_capacity(token_ids.len());
for tok in token_ids {
let tok = self.tokens.get(&tok).unwrap();
embeds.push(tok.view());
pub fn embed(&self, token_ids: Array2<i32>) -> Array3<f32> {
let (batch_size, max_len) = (token_ids.shape()[0], token_ids.shape()[1]);
let mut embeds = Vec::with_capacity(batch_size);
for batch in token_ids.axis_iter(Axis(0)) {
let mut batch_embeds = Vec::with_capacity(max_len);
for tok_id in batch {
let tok = self.tokens.get(&(*tok_id as u32)).unwrap();
batch_embeds.push(tok.view());
}
embeds.push(concatenate(Axis(0), &batch_embeds).unwrap());
}
let embeds = concatenate(Axis(0), &embeds).unwrap();
embeds.insert_axis(Axis(0))

stack(Axis(0), &embeds.iter().map(|f| f.view()).collect::<Vec<_>>()).unwrap()
}

pub fn len(&self) -> usize {
Expand Down

0 comments on commit f299dde

Please sign in to comment.