Skip to content
This repository has been archived by the owner on Jan 4, 2024. It is now read-only.

Commit

Permalink
chore: cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
decahedron1 committed May 8, 2023
1 parent 70374a1 commit 34c6d84
Show file tree
Hide file tree
Showing 11 changed files with 53 additions and 52 deletions.
5 changes: 2 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,8 @@ tracing = "0.1"
regex = "1.7"
once_cell = "1.17"
image = { version = "0.24", default-features = false }
rand = "0.8"
cfg-if = "1.0"
ort = { version = "1.14", default-features = false }
ort = { git = "https://github.com/pykeio/ort", branch = "io-rework", default-features = false }
ndarray_einsum_beta = "0.7"
byteorder = "1"

Expand All @@ -43,7 +42,7 @@ tokenizers = { version = "0.13", default-features = false, features = [ "onig" ]
[dev-dependencies]
tokio = { version = "1.0", features = [ "full" ] }
image = { version = "0.24", default-features = false, features = [ "png" ] }
ort = { version = "1.14", default-features = false, features = [ "download-binaries" ] }
ort = { git = "https://github.com/pykeio/ort", branch = "io-rework", default-features = false, features = [ "download-binaries" ] }
tracing-subscriber = "0.3"

requestty = "0.5"
Expand Down
19 changes: 9 additions & 10 deletions src/pipelines/stable_diffusion/impl_main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,22 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::path::Path;
use std::{fs, path::PathBuf, sync::Arc};
use std::{
fs,
path::{Path, PathBuf},
sync::Arc
};

use image::{DynamicImage, Rgb32FImage};
use ndarray::{concatenate, Array2, Array4, ArrayD, ArrayView4, Axis, IxDyn};
use ndarray_einsum_beta::einsum;
use ort::{
tensor::{FromArray, InputTensor, OrtOwnedTensor},
Environment, OrtResult, Session, SessionBuilder
};
use ort::{tensor::OrtOwnedTensor, Environment, OrtResult, Session, SessionBuilder};

use super::{StableDiffusionOptions, StableDiffusionTxt2ImgOptions};
use crate::text_embeddings::TextEmbeddings;
use crate::{
clip::CLIPStandardTokenizer,
config::{DiffusionFramework, DiffusionPipeline, StableDiffusionConfig, TokenizerConfig},
schedulers::DiffusionScheduler,
pipelines::StableDiffusionOptions,
text_embeddings::TextEmbeddings,
Prompt
};

Expand Down Expand Up @@ -402,7 +401,7 @@ impl StableDiffusionPipeline {
let mut images = Vec::new();
for latent_chunk in latents.axis_iter(Axis(0)) {
let latent_chunk = latent_chunk.into_dyn().insert_axis(Axis(0));
let image = self.vae_decoder.run(vec![InputTensor::from_array(latent_chunk.to_owned())])?;
let image = self.vae_decoder.run(&[latent_chunk.to_owned().into()])?;
let image: OrtOwnedTensor<'_, f32, IxDyn> = image[0].try_extract()?;
let f_image: Array4<f32> = image.view().to_owned().into_dimensionality()?;
let f_image = f_image.permuted_axes([0, 2, 3, 1]) / 2.0 + 0.5;
Expand Down
19 changes: 9 additions & 10 deletions src/pipelines/stable_diffusion/impl_txt2img.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
use image::DynamicImage;
use ndarray::{concatenate, s, Array1, Array4, ArrayD, Axis, IxDyn};
use ndarray_rand::rand_distr::StandardNormal;
use ndarray_rand::RandomExt;
use ndarray_rand::{
rand::{self, rngs::StdRng, Rng, SeedableRng},
rand_distr::StandardNormal,
RandomExt
};
use num_traits::ToPrimitive;
use ort::tensor::{FromArray, InputTensor, OrtOwnedTensor};
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use ort::tensor::OrtOwnedTensor;

use crate::{DiffusionScheduler, Prompt, StableDiffusionCallback, StableDiffusionPipeline};

Expand Down Expand Up @@ -275,11 +276,9 @@ impl StableDiffusionTxt2ImgOptions {
let timestep: ArrayD<f32> = Array1::from_iter([t.to_f32().unwrap()]).into_dyn();
let encoder_hidden_states: ArrayD<f32> = text_embeddings.clone().into_dyn();

let noise_pred = session.unet.run(vec![
InputTensor::from_array(latent_model_input),
InputTensor::from_array(timestep),
InputTensor::from_array(encoder_hidden_states),
])?;
let noise_pred = session
.unet
.run(&[latent_model_input.into(), timestep.into(), encoder_hidden_states.into()])?;
let noise_pred: OrtOwnedTensor<'_, f32, IxDyn> = noise_pred[0].try_extract()?;
let noise_pred: Array4<f32> = noise_pred.view().to_owned().into_dimensionality()?;

Expand Down
17 changes: 7 additions & 10 deletions src/pipelines/stable_diffusion/lpw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,7 @@ use std::num::ParseFloatError;

use ndarray::{s, Array2, Array3, Axis, NewAxis};
use once_cell::sync::Lazy;
use ort::{
tensor::{FromArray, InputTensor},
OrtResult, Session
};
use ort::{OrtResult, Session};
use regex::Regex;

use crate::{text_embeddings::TextEmbeddings, Prompt};
Expand Down Expand Up @@ -205,13 +202,13 @@ pub fn get_unweighted_text_embeddings(

let text_input_chunk = if embeddings.is_empty() {
// no external embeds
InputTensor::from_array(text_input_chunk.into_dyn())
text_input_chunk.into_dyn().into()
} else {
// pre-embed
InputTensor::from_array(embeddings.embed(text_input_chunk).into_dyn())
embeddings.embed(text_input_chunk).into_dyn().into()
};

let chunk_embeddings = text_encoder.run(vec![text_input_chunk])?;
let chunk_embeddings = text_encoder.run(&[text_input_chunk])?;
let chunk_embeddings: Array3<f32> = chunk_embeddings[0].try_extract()?.view().to_owned().into_dimensionality().unwrap();

#[allow(clippy::reversed_empty_ranges)]
Expand All @@ -238,13 +235,13 @@ pub fn get_unweighted_text_embeddings(
} else {
let text_input = if embeddings.is_empty() {
// no external embeds
InputTensor::from_array(text_input.into_dyn())
text_input.into_dyn().into()
} else {
// pre-embed
InputTensor::from_array(embeddings.embed(text_input).into_dyn())
embeddings.embed(text_input).into_dyn().into()
};

let text_embeddings = text_encoder.run(vec![text_input])?;
let text_embeddings = text_encoder.run(&[text_input])?;
Ok(text_embeddings[0].try_extract()?.view().to_owned().into_dimensionality().unwrap())
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/pipelines/stable_diffusion/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ pub(crate) mod text_embeddings;
pub use self::impl_img2img::{ImagePreprocessing, StableDiffusionImg2ImgOptions};
pub use self::impl_main::StableDiffusionPipeline;
pub use self::impl_txt2img::StableDiffusionTxt2ImgOptions;
use crate::{DiffusionDeviceControl, Prompt};
use crate::DiffusionDeviceControl;

/// Options for the Stable Diffusion pipeline. This includes options like device control and long prompt weighting.
#[derive(Default, Debug, Clone)]
Expand Down
9 changes: 5 additions & 4 deletions src/schedulers/ddim.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@
// limitations under the License.

use ndarray::{s, Array1, Array4, ArrayView4};
use ndarray_rand::{rand_distr::StandardNormal, RandomExt};
use rand::Rng;
use ndarray_rand::{rand::Rng, rand_distr::StandardNormal, RandomExt};

use super::{betas_for_alpha_bar, BetaSchedule, DiffusionScheduler, SchedulerStepOutput};
use crate::{SchedulerOptimizedDefaults, SchedulerPredictionType};
use crate::{
schedulers::{betas_for_alpha_bar, BetaSchedule, DiffusionScheduler, SchedulerStepOutput},
SchedulerOptimizedDefaults, SchedulerPredictionType
};

/// Additional configuration for the [`DDIMScheduler`].
#[derive(Debug, Clone)]
Expand Down
3 changes: 1 addition & 2 deletions src/schedulers/ddpm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@

use anyhow::Context;
use ndarray::{s, Array1, Array4, ArrayView4};
use ndarray_rand::{rand_distr::StandardNormal, RandomExt};
use rand::Rng;
use ndarray_rand::{rand::Rng, rand_distr::StandardNormal, RandomExt};

use super::{betas_for_alpha_bar, BetaSchedule, DiffusionScheduler, SchedulerStepOutput};
use crate::{SchedulerOptimizedDefaults, SchedulerPredictionType};
Expand Down
8 changes: 5 additions & 3 deletions src/schedulers/dpm_solver_multistep.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@ use std::collections::VecDeque;

use anyhow::Context;
use ndarray::{Array1, Array4, ArrayView4};
use rand::Rng;
use ndarray_rand::rand::Rng;

use super::{betas_for_alpha_bar, BetaSchedule, DiffusionScheduler, SchedulerStepOutput};
use crate::{SchedulerOptimizedDefaults, SchedulerPredictionType};
use crate::{
schedulers::{betas_for_alpha_bar, BetaSchedule, DiffusionScheduler, SchedulerStepOutput},
SchedulerOptimizedDefaults, SchedulerPredictionType
};

/// The algorithm type for the solver.
///
Expand Down
10 changes: 6 additions & 4 deletions src/schedulers/euler_ancestral_discrete.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@

use anyhow::{anyhow, Context};
use ndarray::{concatenate, s, Array1, Array4, ArrayView4, Axis, Zip};
use ndarray_rand::{rand_distr::StandardNormal, RandomExt};
use rand::Rng;
use ndarray_rand::{rand::Rng, rand_distr::StandardNormal, RandomExt};

use super::{BetaSchedule, DiffusionScheduler, SchedulerStepOutput};
use crate::{util::interpolation::LinearInterpolatorAccelerated, SchedulerOptimizedDefaults};
use crate::{
schedulers::{BetaSchedule, DiffusionScheduler, SchedulerStepOutput},
util::interpolation::LinearInterpolatorAccelerated,
SchedulerOptimizedDefaults
};

/// Ancestral sampling with Euler method steps.
///
Expand Down
11 changes: 7 additions & 4 deletions src/schedulers/euler_discrete.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@

use anyhow::{anyhow, Context};
use ndarray::{concatenate, s, Array1, Array4, ArrayView4, Axis, Zip};
use ndarray_rand::{rand_distr::StandardNormal, RandomExt};
use rand::Rng;
use ndarray_rand::{rand::Rng, rand_distr::StandardNormal, RandomExt};

use super::{BetaSchedule, DiffusionScheduler, SchedulerStepOutput};
use crate::{util::interpolation::LinearInterpolatorAccelerated, SchedulerOptimizedDefaults};
use crate::{
schedulers::{betas_for_alpha_bar, BetaSchedule, DiffusionScheduler, SchedulerStepOutput},
util::interpolation::LinearInterpolatorAccelerated,
SchedulerOptimizedDefaults
};

/// Euler scheduler (Algorithm 2) from [Karras et al. (2022)](https://arxiv.org/abs/2206.00364).
///
Expand Down Expand Up @@ -77,6 +79,7 @@ impl EulerDiscreteScheduler {
betas.par_map_inplace(|f| *f = f.powi(2));
betas
}
BetaSchedule::SquaredcosCapV2 => betas_for_alpha_bar(num_train_timesteps, 0.999),
_ => anyhow::bail!("{beta_schedule:?} not implemented for EulerDiscreteScheduler")
};

Expand Down
2 changes: 1 addition & 1 deletion src/schedulers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
//! exceptionally creative and can produce high quality results in as few as 20 steps.

use ndarray::{Array1, Array4, ArrayBase, ArrayView1, ArrayView4};
use ndarray_rand::rand::Rng;
use num_traits::ToPrimitive;
use rand::Rng;

cfg_if::cfg_if! {
if #[cfg(feature = "scheduler-euler")] {
Expand Down

0 comments on commit 34c6d84

Please sign in to comment.