Skip to content
This repository has been archived by the owner on Jan 4, 2024. It is now read-only.

Commit

Permalink
chore: update to latest ort, fixes #32
Browse files Browse the repository at this point in the history
  • Loading branch information
decahedron1 committed Oct 30, 2023
1 parent c7c392b commit 0efe380
Show file tree
Hide file tree
Showing 8 changed files with 46 additions and 180 deletions.
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ regex = "1.7"
once_cell = "1.17"
image = { version = "0.24", default-features = false }
cfg-if = "1.0"
ort = { git = "https://github.com/pykeio/ort", rev = "5d064ceb675104c6f7e72770b3e48d43fdc450b2", default-features = false }
ort = { git = "https://github.com/pykeio/ort", rev = "965712dbf4d1cce4deff5f1655144e9e7621e4ea", default-features = false }
ndarray_einsum_beta = "0.7"
byteorder = "1"

Expand All @@ -42,7 +42,7 @@ tokenizers = { version = "0.13", default-features = false, features = [ "onig" ]
[dev-dependencies]
tokio = { version = "1.0", features = [ "full" ] }
image = { version = "0.24", default-features = false, features = [ "png" ] }
ort = { git = "https://github.com/pykeio/ort", rev = "5d064ceb675104c6f7e72770b3e48d43fdc450b2", default-features = false, features = [ "download-binaries" ] }
ort = { git = "https://github.com/pykeio/ort", rev = "965712dbf4d1cce4deff5f1655144e9e7621e4ea", default-features = false, features = [ "download-binaries" ] }
tracing-subscriber = "0.3"

requestty = "0.5"
Expand Down
8 changes: 4 additions & 4 deletions examples/stable-diffusion-interactive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ use std::{cell::RefCell, env};

use kdam::{tqdm, BarExt};
use pyke_diffusers::{
ArenaExtendStrategy, CUDADeviceOptions, DPMSolverMultistepScheduler, DiffusionDevice, DiffusionDeviceControl, OrtEnvironment, SchedulerOptimizedDefaults,
StableDiffusionOptions, StableDiffusionPipeline, StableDiffusionTxt2ImgOptions
ArenaExtendStrategy, CUDAExecutionProviderOptions, DPMSolverMultistepScheduler, DiffusionDevice, DiffusionDeviceControl, OrtEnvironment,
SchedulerOptimizedDefaults, StableDiffusionOptions, StableDiffusionPipeline, StableDiffusionTxt2ImgOptions
};
use requestty::Question;
use show_image::{ImageInfo, ImageView, WindowOptions};
Expand All @@ -39,8 +39,8 @@ fn main() -> anyhow::Result<()> {
devices: DiffusionDeviceControl {
unet: DiffusionDevice::CUDA(
0,
Some(CUDADeviceOptions {
memory_limit: Some(3500000000),
Some(CUDAExecutionProviderOptions {
gpu_mem_limit: Some(3500000000),
arena_extend_strategy: Some(ArenaExtendStrategy::SameAsRequested),
..Default::default()
})
Expand Down
8 changes: 4 additions & 4 deletions examples/stable-diffusion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
// limitations under the License.

use pyke_diffusers::{
ArenaExtendStrategy, CUDADeviceOptions, DiffusionDevice, DiffusionDeviceControl, EulerDiscreteScheduler, OrtEnvironment, SchedulerOptimizedDefaults,
StableDiffusionOptions, StableDiffusionPipeline, StableDiffusionTxt2ImgOptions
ArenaExtendStrategy, CUDAExecutionProviderOptions, DiffusionDevice, DiffusionDeviceControl, EulerDiscreteScheduler, OrtEnvironment,
SchedulerOptimizedDefaults, StableDiffusionOptions, StableDiffusionPipeline, StableDiffusionTxt2ImgOptions
};

fn main() -> anyhow::Result<()> {
Expand All @@ -27,8 +27,8 @@ fn main() -> anyhow::Result<()> {
devices: DiffusionDeviceControl {
unet: DiffusionDevice::CUDA(
0,
Some(CUDADeviceOptions {
memory_limit: Some(3500000000),
Some(CUDAExecutionProviderOptions {
gpu_mem_limit: Some(3500000000),
arena_extend_strategy: Some(ArenaExtendStrategy::SameAsRequested),
..Default::default()
})
Expand Down
155 changes: 12 additions & 143 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,146 +59,19 @@ pub mod pipelines;
pub mod schedulers;
pub(crate) mod util;

use ort::execution_providers::ArenaExtendStrategy as OrtArenaExtendStrategy;
use ort::execution_providers::CPUExecutionProviderOptions;
use ort::execution_providers::CUDAExecutionProviderCuDNNConvAlgoSearch;
use ort::execution_providers::CUDAExecutionProviderOptions;
use ort::execution_providers::CoreMLExecutionProviderOptions;
use ort::execution_providers::DirectMLExecutionProviderOptions;
use ort::execution_providers::OneDNNExecutionProviderOptions;
use ort::execution_providers::ROCmExecutionProviderOptions;
use ort::CPUExecutionProviderOptions;
use ort::CoreMLExecutionProviderOptions;
use ort::DirectMLExecutionProviderOptions;
pub use ort::Environment as OrtEnvironment;
use ort::ExecutionProvider;
use ort::OneDNNExecutionProviderOptions;
use ort::ROCmExecutionProviderOptions;
pub use ort::{ArenaExtendStrategy, CUDAExecutionProviderCuDNNConvAlgoSearch, CUDAExecutionProviderOptions};

pub use self::pipelines::*;
pub use self::schedulers::*;
pub use self::util::prompting;

/// The strategy to use for extending the device memory arena.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ArenaExtendStrategy {
/// Subsequent memory allocations extend by larger amounts (multiplied by powers of two)
PowerOfTwo,
/// Memory allocations extend only by the requested amount.
SameAsRequested
}

impl Default for ArenaExtendStrategy {
fn default() -> Self {
Self::PowerOfTwo
}
}

impl From<ArenaExtendStrategy> for String {
fn from(val: ArenaExtendStrategy) -> Self {
match val {
ArenaExtendStrategy::PowerOfTwo => "kNextPowerOfTwo".to_string(),
ArenaExtendStrategy::SameAsRequested => "kSameAsRequested".to_string()
}
}
}

impl From<OrtArenaExtendStrategy> for ArenaExtendStrategy {
fn from(val: OrtArenaExtendStrategy) -> Self {
match val {
OrtArenaExtendStrategy::SameAsRequested => ArenaExtendStrategy::SameAsRequested,
OrtArenaExtendStrategy::NextPowerOfTwo => ArenaExtendStrategy::PowerOfTwo
}
}
}

impl From<ArenaExtendStrategy> for OrtArenaExtendStrategy {
fn from(val: ArenaExtendStrategy) -> Self {
match val {
ArenaExtendStrategy::SameAsRequested => OrtArenaExtendStrategy::SameAsRequested,
ArenaExtendStrategy::PowerOfTwo => OrtArenaExtendStrategy::NextPowerOfTwo
}
}
}

/// The type of search done for cuDNN convolution algorithms.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum CuDNNConvolutionAlgorithmSearch {
/// Exhaustive kernel search. Will spend more time and memory to find the most optimal kernel for this GPU.
/// This is the **default** value set by ONNX Runtime.
Exhaustive,
/// Heuristic kernel search. Will spend a small amount of time and memory to find an optimal kernel for this
/// GPU.
Heuristic,
/// Uses the default cuDNN kernels that may not be optimized for this GPU. **This is NOT the actual default
/// value set by ONNX Runtime, the default is set to `Exhaustive`.**
Default
}

impl Default for CuDNNConvolutionAlgorithmSearch {
fn default() -> Self {
Self::Exhaustive
}
}

impl From<CuDNNConvolutionAlgorithmSearch> for String {
fn from(val: CuDNNConvolutionAlgorithmSearch) -> Self {
match val {
CuDNNConvolutionAlgorithmSearch::Exhaustive => "EXHAUSTIVE".to_string(),
CuDNNConvolutionAlgorithmSearch::Heuristic => "HEURISTIC".to_string(),
CuDNNConvolutionAlgorithmSearch::Default => "DEFAULT".to_string()
}
}
}

impl From<CuDNNConvolutionAlgorithmSearch> for CUDAExecutionProviderCuDNNConvAlgoSearch {
fn from(val: CuDNNConvolutionAlgorithmSearch) -> Self {
match val {
CuDNNConvolutionAlgorithmSearch::Exhaustive => CUDAExecutionProviderCuDNNConvAlgoSearch::Exhaustive,
CuDNNConvolutionAlgorithmSearch::Heuristic => CUDAExecutionProviderCuDNNConvAlgoSearch::Heuristic,
CuDNNConvolutionAlgorithmSearch::Default => CUDAExecutionProviderCuDNNConvAlgoSearch::Default
}
}
}

/// Device options for the CUDA execution provider.
///
/// For low-VRAM devices running Stable Diffusion v1, it's best to use a float16 model with the following parameters:
/// ```
/// # use pyke_diffusers::{ArenaExtendStrategy, CUDADeviceOptions};
/// let options = CUDADeviceOptions {
/// memory_limit: Some(3000000000),
/// arena_extend_strategy: Some(ArenaExtendStrategy::SameAsRequested),
/// ..Default::default()
/// };
/// ```
#[derive(Default, Debug, Clone, PartialEq, Eq)]
pub struct CUDADeviceOptions {
/// The strategy to use for extending the device memory arena. See [`ArenaExtendStrategy`] for more info.
pub arena_extend_strategy: Option<ArenaExtendStrategy>,
/// Per-session (aka per-model) memory limit. Models may use all available VRAM if a memory limit is not set.
/// VRAM usage may be higher than the memory limit (though typically not by much).
pub memory_limit: Option<usize>,
/// The type of search done for cuDNN convolution algorithms. See [`CuDNNConvolutionAlgorithmSearch`] for
/// more info.
///
/// **NOTE**: Setting this to any value other than `Exhaustive` seems to break float16 models!
pub cudnn_conv_algorithm_search: Option<CuDNNConvolutionAlgorithmSearch>
}

impl From<CUDADeviceOptions> for CUDAExecutionProviderOptions {
fn from(val: CUDADeviceOptions) -> Self {
let defs = CUDAExecutionProviderOptions::default();
Self {
gpu_mem_limit: val.memory_limit.unwrap_or(defs.gpu_mem_limit),
arena_extend_strategy: val.arena_extend_strategy.map(|x| x.into()).unwrap_or(defs.arena_extend_strategy),
cudnn_conv_algo_search: val.cudnn_conv_algorithm_search.map(|x| x.into()).unwrap_or(defs.cudnn_conv_algo_search),
..Default::default()
}
}
}

impl From<CUDADeviceOptions> for ExecutionProvider {
fn from(val: CUDADeviceOptions) -> Self {
ExecutionProvider::CUDA(val.into())
}
}

/// A device on which to place a diffusion model on.
///
/// If a device is not specified, or a configured execution provider is not available, the model will be placed on the
Expand All @@ -214,7 +87,7 @@ pub enum DiffusionDevice {
/// provider parameters. These options can be fine tuned for inference on low-VRAM GPUs
/// (~3 GB free seems to be a good number for the Stable Diffusion v1 float16 UNet at 512x512 resolution); see
/// [`CUDADeviceOptions`] for an example.
CUDA(u32, Option<CUDADeviceOptions>),
CUDA(u32, Option<CUDAExecutionProviderOptions>),
/// Use NVIDIA TensorRT as a device. Requires an NVIDIA Kepler GPU or later.
TensorRT,
/// Use Windows DirectML as a device. Requires a DirectX 12 compatible GPU.
Expand All @@ -239,21 +112,17 @@ impl From<DiffusionDevice> for ExecutionProvider {
DiffusionDevice::CPU => ExecutionProvider::CPU(CPUExecutionProviderOptions::default()),
DiffusionDevice::CUDA(device, options) => {
let options = options.unwrap_or_default();
let op = CUDAExecutionProviderOptions { device_id: device, ..options.into() };
let op = CUDAExecutionProviderOptions {
device_id: Some(device),
..options.into()

Check warning on line 117 in src/lib.rs

View workflow job for this annotation

GitHub Actions / all-schedulers

useless conversion to the same type: `ort::CUDAExecutionProviderOptions`

warning: useless conversion to the same type: `ort::CUDAExecutionProviderOptions` --> src/lib.rs:117:8 | 117 | ..options.into() | ^^^^^^^^^^^^^^ help: consider removing `.into()`: `options` | = help: for further information visit https://rust-lang.github.io/rust-clippy/master/index.html#useless_conversion note: the lint level is defined here --> src/lib.rs:52:50 | 52 | #![warn(clippy::correctness, clippy::suspicious, clippy::complexity, clippy::perf, clippy::style)] | ^^^^^^^^^^^^^^^^^^ = note: `#[warn(clippy::useless_conversion)]` implied by `#[warn(clippy::complexity)]`

Check warning on line 117 in src/lib.rs

View workflow job for this annotation

GitHub Actions / all-schedulers

useless conversion to the same type: `ort::CUDAExecutionProviderOptions`

warning: useless conversion to the same type: `ort::CUDAExecutionProviderOptions` --> src/lib.rs:117:8 | 117 | ..options.into() | ^^^^^^^^^^^^^^ help: consider removing `.into()`: `options` | = help: for further information visit https://rust-lang.github.io/rust-clippy/master/index.html#useless_conversion note: the lint level is defined here --> src/lib.rs:52:50 | 52 | #![warn(clippy::correctness, clippy::suspicious, clippy::complexity, clippy::perf, clippy::style)] | ^^^^^^^^^^^^^^^^^^ = note: `#[warn(clippy::useless_conversion)]` implied by `#[warn(clippy::complexity)]`

Check warning on line 117 in src/lib.rs

View workflow job for this annotation

GitHub Actions / default

useless conversion to the same type: `ort::CUDAExecutionProviderOptions`

warning: useless conversion to the same type: `ort::CUDAExecutionProviderOptions` --> src/lib.rs:117:8 | 117 | ..options.into() | ^^^^^^^^^^^^^^ help: consider removing `.into()`: `options` | = help: for further information visit https://rust-lang.github.io/rust-clippy/master/index.html#useless_conversion note: the lint level is defined here --> src/lib.rs:52:50 | 52 | #![warn(clippy::correctness, clippy::suspicious, clippy::complexity, clippy::perf, clippy::style)] | ^^^^^^^^^^^^^^^^^^ = note: `#[warn(clippy::useless_conversion)]` implied by `#[warn(clippy::complexity)]`

Check warning on line 117 in src/lib.rs

View workflow job for this annotation

GitHub Actions / default

useless conversion to the same type: `ort::CUDAExecutionProviderOptions`

warning: useless conversion to the same type: `ort::CUDAExecutionProviderOptions` --> src/lib.rs:117:8 | 117 | ..options.into() | ^^^^^^^^^^^^^^ help: consider removing `.into()`: `options` | = help: for further information visit https://rust-lang.github.io/rust-clippy/master/index.html#useless_conversion note: the lint level is defined here --> src/lib.rs:52:50 | 52 | #![warn(clippy::correctness, clippy::suspicious, clippy::complexity, clippy::perf, clippy::style)] | ^^^^^^^^^^^^^^^^^^ = note: `#[warn(clippy::useless_conversion)]` implied by `#[warn(clippy::complexity)]`

Check warning on line 117 in src/lib.rs

View workflow job for this annotation

GitHub Actions / stable-diffusion

useless conversion to the same type: `ort::CUDAExecutionProviderOptions`

warning: useless conversion to the same type: `ort::CUDAExecutionProviderOptions` --> src/lib.rs:117:8 | 117 | ..options.into() | ^^^^^^^^^^^^^^ help: consider removing `.into()`: `options` | = help: for further information visit https://rust-lang.github.io/rust-clippy/master/index.html#useless_conversion note: the lint level is defined here --> src/lib.rs:52:50 | 52 | #![warn(clippy::correctness, clippy::suspicious, clippy::complexity, clippy::perf, clippy::style)] | ^^^^^^^^^^^^^^^^^^ = note: `#[warn(clippy::useless_conversion)]` implied by `#[warn(clippy::complexity)]`

Check warning on line 117 in src/lib.rs

View workflow job for this annotation

GitHub Actions / stable-diffusion

useless conversion to the same type: `ort::CUDAExecutionProviderOptions`

warning: useless conversion to the same type: `ort::CUDAExecutionProviderOptions` --> src/lib.rs:117:8 | 117 | ..options.into() | ^^^^^^^^^^^^^^ help: consider removing `.into()`: `options` | = help: for further information visit https://rust-lang.github.io/rust-clippy/master/index.html#useless_conversion note: the lint level is defined here --> src/lib.rs:52:50 | 52 | #![warn(clippy::correctness, clippy::suspicious, clippy::complexity, clippy::perf, clippy::style)] | ^^^^^^^^^^^^^^^^^^ = note: `#[warn(clippy::useless_conversion)]` implied by `#[warn(clippy::complexity)]`
};
ExecutionProvider::CUDA(op)
}
DiffusionDevice::TensorRT => ExecutionProvider::TensorRT(Default::default()),
DiffusionDevice::DirectML(device) => ExecutionProvider::DirectML(DirectMLExecutionProviderOptions { device_id: device }),
DiffusionDevice::ROCm(device) => ExecutionProvider::ROCm(ROCmExecutionProviderOptions {
device_id: device,
miopen_conv_exhaustive_search: 0,
gpu_mem_limit: 0,
arena_extend_strategy: 0,
do_copy_in_default_stream: 0,
has_user_compute_stream: 0,
user_compute_stream: std::ptr::null_mut(),
default_memory_arena_cfg: std::ptr::null_mut(),
tunable_op_enabled: 0
..Default::default()
}),
DiffusionDevice::OneDNN => ExecutionProvider::OneDNN(OneDNNExecutionProviderOptions::default()),
DiffusionDevice::CoreML => ExecutionProvider::CoreML(CoreMLExecutionProviderOptions::default()),
Expand Down
11 changes: 5 additions & 6 deletions src/pipelines/stable_diffusion/impl_main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ use std::{
};

use image::{DynamicImage, Rgb32FImage};
use ndarray::{concatenate, Array2, Array4, ArrayD, ArrayView4, Axis, IxDyn};
use ndarray::{concatenate, Array2, Array4, ArrayD, ArrayView4, Axis};
use ndarray_einsum_beta::einsum;
use ort::{tensor::OrtOwnedTensor, Environment, OrtResult, Session, SessionBuilder};
use ort::{Environment, OrtOwnedTensor, OrtResult, Session, SessionBuilder};

use crate::{
clip::CLIPStandardTokenizer,
Expand Down Expand Up @@ -93,7 +93,7 @@ impl StableDiffusionPipeline {
///
/// This is not recommended for fine-tuned models, e.g. Waifu Diffusion or AnythingV3 (a negative prompt of simply
/// `(nsfw:1.05)` would probably work better for these models)
pub const SAFETY_CONCEPT: &str = "an image showing hate, harassment, violence, suffering, humiliation, harm, suicide, sexual, nudity, bodily fluids, blood, obscene gestures, illegal activity, drug use, theft, vandalism, weapons, child abuse, brutality, cruelty";
pub const SAFETY_CONCEPT: &'static str = "an image showing hate, harassment, violence, suffering, humiliation, harm, suicide, sexual, nudity, bodily fluids, blood, obscene gestures, illegal activity, drug use, theft, vandalism, weapons, child abuse, brutality, cruelty";

/// Creates a new Stable Diffusion pipeline, loading models from `root`.
///
Expand Down Expand Up @@ -400,9 +400,8 @@ impl StableDiffusionPipeline {

let mut images = Vec::new();
for latent_chunk in latents.axis_iter(Axis(0)) {
let latent_chunk = latent_chunk.into_dyn().insert_axis(Axis(0));
let image = self.vae_decoder.run(&[latent_chunk.to_owned().into()])?;
let image: OrtOwnedTensor<'_, f32, IxDyn> = image[0].try_extract()?;
let image = self.vae_decoder.run(ort::inputs![latent_chunk.insert_axis(Axis(0))]?)?;
let image: OrtOwnedTensor<f32> = image[0].extract_tensor()?;
let f_image: Array4<f32> = image.view().to_owned().into_dimensionality()?;
let f_image = f_image.permuted_axes([0, 2, 3, 1]) / 2.0 + 0.5;

Expand Down
16 changes: 7 additions & 9 deletions src/pipelines/stable_diffusion/impl_txt2img.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
use image::DynamicImage;
use ndarray::{concatenate, s, Array1, Array4, ArrayD, Axis, IxDyn};
use ndarray::{concatenate, s, Array1, Array4, Axis, CowArray, IxDyn};
use ndarray_rand::{
rand::{self, rngs::StdRng, Rng, SeedableRng},
rand_distr::StandardNormal,
RandomExt
};
use num_traits::ToPrimitive;
use ort::tensor::OrtOwnedTensor;
use ort::OrtOwnedTensor;

use crate::{DiffusionScheduler, Prompt, StableDiffusionCallback, StableDiffusionPipeline};

Expand Down Expand Up @@ -276,14 +276,12 @@ impl StableDiffusionTxt2ImgOptions {
latents.clone()
};
let latent_model_input = scheduler.scale_model_input(latent_model_input.view(), *t);
let latent_model_input: ArrayD<f32> = latent_model_input.into_dyn();
let timestep: ArrayD<f32> = Array1::from_iter([t.to_f32().unwrap()]).into_dyn();
let encoder_hidden_states: ArrayD<f32> = text_embeddings.clone().into_dyn();
let latent_model_input: CowArray<f32, IxDyn> = CowArray::from(latent_model_input.into_dyn());
let timestep: CowArray<f32, IxDyn> = CowArray::from(Array1::from_iter([t.to_f32().unwrap()]).into_dyn());
let encoder_hidden_states: CowArray<f32, IxDyn> = CowArray::from(text_embeddings.clone().into_dyn());

let noise_pred = session
.unet
.run(&[latent_model_input.into(), timestep.into(), encoder_hidden_states.into()])?;
let noise_pred: OrtOwnedTensor<'_, f32, IxDyn> = noise_pred[0].try_extract()?;
let noise_pred = session.unet.run(ort::inputs![&latent_model_input, &timestep, &encoder_hidden_states]?)?;
let noise_pred: OrtOwnedTensor<f32> = noise_pred[0].extract_tensor()?;
let noise_pred: Array4<f32> = noise_pred.view().to_owned().into_dimensionality()?;

let mut noise_pred: Array4<f32> = noise_pred.clone();
Expand Down
22 changes: 11 additions & 11 deletions src/pipelines/stable_diffusion/lpw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use std::num::ParseFloatError;

use ndarray::{s, Array2, Array3, Axis, NewAxis};
use once_cell::sync::Lazy;
use ort::{OrtResult, Session};
use ort::{OrtResult, Session, Value};
use regex::Regex;

use crate::{text_embeddings::TextEmbeddings, Prompt};
Expand Down Expand Up @@ -202,14 +202,14 @@ pub fn get_unweighted_text_embeddings(

let text_input_chunk = if embeddings.is_empty() {
// no external embeds
text_input_chunk.into_dyn().into()
Value::from_array(text_input_chunk)
} else {
// pre-embed
embeddings.embed(text_input_chunk).into_dyn().into()
};
Value::from_array(embeddings.embed(text_input_chunk))
}?;

let chunk_embeddings = text_encoder.run(&[text_input_chunk])?;
let chunk_embeddings: Array3<f32> = chunk_embeddings[0].try_extract()?.view().to_owned().into_dimensionality().unwrap();
let chunk_embeddings = text_encoder.run(ort::inputs![text_input_chunk]?)?;
let chunk_embeddings: Array3<f32> = chunk_embeddings[0].extract_tensor()?.view().to_owned().into_dimensionality().unwrap();

#[allow(clippy::reversed_empty_ranges)]
let view = if no_boseos_middle {
Expand All @@ -235,14 +235,14 @@ pub fn get_unweighted_text_embeddings(
} else {
let text_input = if embeddings.is_empty() {
// no external embeds
text_input.into_dyn().into()
Value::from_array(text_input)
} else {
// pre-embed
embeddings.embed(text_input).into_dyn().into()
};
Value::from_array(embeddings.embed(text_input))
}?;

let text_embeddings = text_encoder.run(&[text_input])?;
Ok(text_embeddings[0].try_extract()?.view().to_owned().into_dimensionality().unwrap())
let text_embeddings = text_encoder.run(ort::inputs![text_input]?)?;
Ok(text_embeddings[0].extract_tensor()?.view().to_owned().into_dimensionality().unwrap())
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/pipelines/stable_diffusion/text_embeddings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::{
collections::HashMap,
fs::File,
io::{self, BufRead, BufReader},
path::{Path, PathBuf}
path::Path
};

use byteorder::{LittleEndian, ReadBytesExt};
Expand Down

0 comments on commit 0efe380

Please sign in to comment.