Skip to content
This repository has been archived by the owner on Jan 4, 2024. It is now read-only.

Commit

Permalink
refactor: cleanup and fix integration tests
Browse files Browse the repository at this point in the history
  • Loading branch information
decahedron1 committed Apr 7, 2023
1 parent 1f373a3 commit a835118
Show file tree
Hide file tree
Showing 13 changed files with 143 additions and 219 deletions.
7 changes: 3 additions & 4 deletions examples/stable-diffusion-interactive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ use std::{cell::RefCell, env};

use kdam::{tqdm, BarExt};
use pyke_diffusers::{
ArenaExtendStrategy, CUDADeviceOptions, DPMSolverMultistepScheduler, DiffusionDevice, DiffusionDeviceControl, EulerDiscreteScheduler, OrtEnvironment,
SchedulerOptimizedDefaults, StableDiffusionOptions, StableDiffusionPipeline, StableDiffusionTxt2ImgOptions
ArenaExtendStrategy, CUDADeviceOptions, DPMSolverMultistepScheduler, DiffusionDevice, DiffusionDeviceControl, OrtEnvironment, SchedulerOptimizedDefaults,
StableDiffusionOptions, StableDiffusionPipeline, StableDiffusionTxt2ImgOptions
};
use requestty::Question;
use show_image::{ImageInfo, ImageView, WindowOptions};
Expand Down Expand Up @@ -46,8 +46,7 @@ fn main() -> anyhow::Result<()> {
})
),
..Default::default()
},
lpw: true
}
}
)?;

Expand Down
3 changes: 1 addition & 2 deletions examples/stable-diffusion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@ fn main() -> anyhow::Result<()> {
})
),
..Default::default()
},
..Default::default()
}
}
)?;

Expand Down
38 changes: 10 additions & 28 deletions src/clip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use std::path::PathBuf;

use ndarray::Array2;
use serde::{Deserialize, Serialize};
use tokenizers::{models::bpe::BPE, EncodeInput, PaddingParams, PaddingStrategy, Tokenizer, TruncationParams};
use tokenizers::{models::bpe::BPE, EncodeInput, Tokenizer};

#[derive(Serialize, Deserialize)]
pub struct CLIPStandardTokenizerWrapper {
Expand All @@ -33,7 +33,7 @@ pub struct CLIPStandardTokenizerWrapper {
///
/// CLIP is used by many diffusion models, including Stable Diffusion, for prompt tokenization and feature extraction.
pub struct CLIPStandardTokenizer {
pub tokenizer: Tokenizer,
pub inner: Tokenizer,
model_max_length: usize,
bos_token_id: u32,
eos_token_id: u32
Expand All @@ -44,35 +44,17 @@ unsafe impl Sync for CLIPStandardTokenizer {}

impl CLIPStandardTokenizer {
/// Loads a CLIP tokenizer from a file.
pub fn new(path: impl Into<PathBuf>, reconfigure: bool, model_max_length: usize, bos_token_id: u32, eos_token_id: u32) -> anyhow::Result<Self> {
pub fn new(path: impl Into<PathBuf>, model_max_length: usize, bos_token_id: u32, eos_token_id: u32) -> anyhow::Result<Self> {
let path = path.into();
let bytes = std::fs::read(path)?;
Self::from_bytes(bytes, reconfigure, model_max_length, bos_token_id, eos_token_id)
Self::from_bytes(bytes, model_max_length, bos_token_id, eos_token_id)
}

/// Loads a CLIP tokenizer from a byte array.
pub fn from_bytes<B: AsRef<[u8]>>(bytes: B, reconfigure: bool, model_max_length: usize, bos_token_id: u32, eos_token_id: u32) -> anyhow::Result<Self> {
let mut tokenizer: Tokenizer = serde_json::from_slice(bytes.as_ref())?;
// `reconfigure` is disabled in long prompt weighting; LPW has its own padding and truncation strategy that would
// conflict with this configuration.
if reconfigure {
// For some reason, CLIP tokenizers lose their padding and truncation config when converting from the old HF tokenizers
// format, so we have to add them back here.
tokenizer
.with_padding(Some(PaddingParams {
strategy: PaddingStrategy::Fixed(model_max_length),
// `clip-vit-base-patch32` and (maybe) all Stable Diffusion models use `"pad_token": "<|endoftext|>"`
// This info is also lost in translation in HF tokenizers.
pad_id: eos_token_id,
..Default::default()
}))
.with_truncation(Some(TruncationParams {
max_length: model_max_length,
..Default::default()
}));
}
pub fn from_bytes<B: AsRef<[u8]>>(bytes: B, model_max_length: usize, bos_token_id: u32, eos_token_id: u32) -> anyhow::Result<Self> {
let tokenizer: Tokenizer = serde_json::from_slice(bytes.as_ref())?;
Ok(Self {
tokenizer,
inner: tokenizer,
model_max_length,
bos_token_id,
eos_token_id
Expand All @@ -87,7 +69,7 @@ impl CLIPStandardTokenizer {
/// tokenizer, which should be impossible).
#[allow(dead_code)]
pub fn model(&self) -> &BPE {
match self.tokenizer.get_model() {
match self.inner.get_model() {
tokenizers::ModelWrapper::BPE(ref bpe) => bpe,
_ => unreachable!()
}
Expand Down Expand Up @@ -117,7 +99,7 @@ impl CLIPStandardTokenizer {
E: Into<EncodeInput<'s>> + Send
{
Ok(self
.tokenizer
.inner
.encode_batch(enc, true)
.map_err(|e| anyhow::anyhow!("{e:?}"))?
.iter()
Expand All @@ -133,7 +115,7 @@ impl CLIPStandardTokenizer {
let batch_size = enc.len();
Ok(Array2::from_shape_vec(
(batch_size, self.len()),
self.tokenizer
self.inner
.encode_batch(enc, true)
.map_err(|e| anyhow::anyhow!("{e:?}"))?
.iter()
Expand Down
7 changes: 0 additions & 7 deletions src/pipelines/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,6 @@ cfg_if::cfg_if! {
}
}

cfg_if::cfg_if! {
if #[cfg(feature = "safe-stable-diffusion")] {
mod safe_stable_diffusion;
pub use self::safe_stable_diffusion::*;
}
}

/// Text prompt(s) used as input in diffusion pipelines.
///
/// Can be converted from one or more prompts:
Expand Down
118 changes: 0 additions & 118 deletions src/pipelines/safe_stable_diffusion.rs

This file was deleted.

38 changes: 33 additions & 5 deletions src/pipelines/stable_diffusion/impl_main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,42 @@ pub struct StableDiffusionPipeline {
vae_encoder: Option<Session>,
vae_decoder: Session,
text_encoder: Session,
text_embeddings: TextEmbeddings,
/// The [text embeddings](TextEmbeddings) used by the text encoder. This can be used to add textual inversion
/// weights.
pub text_embeddings: TextEmbeddings,
pub(crate) unet: Session,
safety_checker: Option<Session>,
#[allow(dead_code)]
feature_extractor: Option<()>
}

impl StableDiffusionPipeline {
/// A recommended 'safety concept' for original Stable Diffusion models. This prompt is designed to be used as a
/// negative prompt to prevent the model from generating potentially harmful content.
///
/// ```
/// # fn main() -> anyhow::Result<()> {
/// use pyke_diffusers::{
/// EulerDiscreteScheduler, OrtEnvironment, SchedulerOptimizedDefaults, StableDiffusionOptions,
/// StableDiffusionPipeline, StableDiffusionTxt2ImgOptions
/// };
///
/// let environment = OrtEnvironment::default().into_arc();
/// let mut scheduler = EulerDiscreteScheduler::stable_diffusion_v1_optimized_default()?;
/// let pipeline =
/// StableDiffusionPipeline::new(&environment, "tests/stable-diffusion", StableDiffusionOptions::default())?;
///
/// let imgs = StableDiffusionTxt2ImgOptions::default()
/// .with_prompts("photo of a red fox", Some(StableDiffusionPipeline::SAFETY_CONCEPT.into()))
/// .run(&pipeline, &mut scheduler)?;
/// # Ok(())
/// # }
/// ```
///
/// This is not recommended for fine-tuned models, e.g. Waifu Diffusion or AnythingV3 (a negative prompt of simply
/// `(nsfw:1.05)` would probably work better for these models)
pub const SAFETY_CONCEPT: &str = "an image showing hate, harassment, violence, suffering, humiliation, harm, suicide, sexual, nudity, bodily fluids, blood, obscene gestures, illegal activity, drug use, theft, vandalism, weapons, child abuse, brutality, cruelty";

/// Creates a new Stable Diffusion pipeline, loading models from `root`.
///
/// ```
Expand Down Expand Up @@ -99,7 +127,7 @@ impl StableDiffusionPipeline {
model_max_length,
bos_token,
eos_token
} => CLIPStandardTokenizer::new(root.join(path.clone()), !options.lpw, *model_max_length, *bos_token, *eos_token)?,
} => CLIPStandardTokenizer::new(root.join(path.clone()), *model_max_length, *bos_token, *eos_token)?,
#[allow(unreachable_patterns)]
_ => anyhow::bail!("not a clip tokenizer")
};
Expand Down Expand Up @@ -208,7 +236,7 @@ impl StableDiffusionPipeline {
model_max_length,
bos_token,
eos_token
} => CLIPStandardTokenizer::new(new_root.join(path.clone()), !options.lpw, *model_max_length, *bos_token, *eos_token)?,
} => CLIPStandardTokenizer::new(new_root.join(path.clone()), *model_max_length, *bos_token, *eos_token)?,
#[allow(unreachable_patterns)]
_ => anyhow::bail!("not a clip tokenizer")
};
Expand All @@ -233,7 +261,7 @@ impl StableDiffusionPipeline {
/// ```no_run
/// # fn main() -> anyhow::Result<()> {
/// # use pyke_diffusers::{OrtEnvironment, StableDiffusionOptions, StableDiffusionPipeline};
/// let environment = OrtEnvironment::default().into_arc();
/// # let environment = OrtEnvironment::default().into_arc();
/// let mut pipeline =
/// StableDiffusionPipeline::new(&environment, "./stable-diffusion-v1-5/", StableDiffusionOptions::default())?;
/// pipeline.replace_unet("./anything/unet.onnx")?;
Expand Down Expand Up @@ -268,7 +296,7 @@ impl StableDiffusionPipeline {
/// ```no_run
/// # fn main() -> anyhow::Result<()> {
/// # use pyke_diffusers::{StableDiffusionOptions, StableDiffusionPipeline, OrtEnvironment};
/// let environment = OrtEnvironment::default().into_arc();
/// # let environment = OrtEnvironment::default().into_arc();
/// let mut pipeline =
/// StableDiffusionPipeline::new(&environment, "./stable-diffusion-v1-5/", StableDiffusionOptions::default())?;
/// pipeline.replace_vae("./anything/vae-decoder.onnx", Some("./anything/vae-encoder.onnx"))?;
Expand Down
26 changes: 19 additions & 7 deletions src/pipelines/stable_diffusion/lpw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ fn pad_tokens_and_weights(
}

pub fn get_unweighted_text_embeddings(
embeddings: &TextEmbeddings,
#[cfg_attr(test, allow(unused))] embeddings: &TextEmbeddings,
text_encoder: &Session,
text_input: Array2<i32>,
chunk_length: usize,
Expand All @@ -202,10 +202,16 @@ pub fn get_unweighted_text_embeddings(
text_input_chunk.slice_mut(s![.., 0]).assign(&text_input.slice(s![0, 0]));
text_input_chunk.slice_mut(s![.., -1]).assign(&text_input.slice(s![0, -1]));

let text_input_chunk = text_input_chunk.into_raw_vec();
let text_input_chunk = embeddings.embed(text_input_chunk.iter().map(|f| *f as u32).collect());
let text_input_chunk = if embeddings.is_empty() {
// no external embeds
InputTensor::from_array(text_input_chunk.into_dyn())
} else {
// pre-embed
let text_input_chunk = text_input_chunk.into_raw_vec();
InputTensor::from_array(embeddings.embed(text_input_chunk.iter().map(|f| *f as u32).collect()).into_dyn())
};

let chunk_embeddings = text_encoder.run(vec![InputTensor::from_array(text_input_chunk.into_dyn())])?;
let chunk_embeddings = text_encoder.run(vec![text_input_chunk])?;
let chunk_embeddings: Array3<f32> = chunk_embeddings[0].try_extract()?.view().to_owned().into_dimensionality().unwrap();

#[allow(clippy::reversed_empty_ranges)]
Expand All @@ -230,10 +236,16 @@ pub fn get_unweighted_text_embeddings(
}
Ok(x1)
} else {
let text_input = text_input.into_raw_vec();
let text_input = embeddings.embed(text_input.iter().map(|f| *f as u32).collect());
let text_input = if embeddings.is_empty() {
// no external embeds
InputTensor::from_array(text_input.into_dyn())
} else {
// pre-embed
let text_input = text_input.into_raw_vec();
InputTensor::from_array(embeddings.embed(text_input.iter().map(|f| *f as u32).collect()).into_dyn())
};

let text_embeddings = text_encoder.run(vec![InputTensor::from_array(text_input.into_dyn())])?;
let text_embeddings = text_encoder.run(vec![text_input])?;
Ok(text_embeddings[0].try_extract()?.view().to_owned().into_dimensionality().unwrap())
}
}
Expand Down
Loading

0 comments on commit a835118

Please sign in to comment.