Skip to content
This repository has been archived by the owner on Jan 4, 2024. It is now read-only.

Commit

Permalink
feat: prompting utils
Browse files Browse the repository at this point in the history
  • Loading branch information
decahedron1 committed Apr 10, 2023
1 parent f299dde commit dfe1fed
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/util/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@
// See the License for the specific language governing permissions and
// limitations under the License.

pub mod interpolation;
pub(crate) mod interpolation;
pub mod prompting;
80 changes: 80 additions & 0 deletions src/util/prompting.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
//! Utilities for cleaning and modifying prompts.

use regex::{Captures, Regex};

/// Cleans up a potetnailly dirty prompt. This removes extraneous parentheses and commas, and cleans up trailing commas
/// and whitespace.
///
/// ```
/// use pyke_diffusers::util::prompting::cleanup_prompt;
///
/// assert_eq!(
/// cleanup_prompt("(masterpiece,, best quality,:1.1)), 1girl,").as_str(),
/// "(masterpiece, best quality:1.1), 1girl"
/// );
/// ```
pub fn cleanup_prompt<S: AsRef<str>>(prompt: S) -> String {
let split_regex: Regex = Regex::new(r#"\(*?(?:\([^)(]*(?:\([^)(]*(?:\([^)(]*(?:\([^)(]*\)[^)(]*)*\)[^)(]*)*\)[^)(]*)*\))\)*?|\b[^,]+\b"#).unwrap();
let cleanup_emphasis_regex: Regex = Regex::new(r#"\(*?(\([^)(]*(?:\([^)(]*(?:\([^)(]*(?:\([^)(]*\)[^)(]*)*\)[^)(]*)*\)[^)(]*)*\))\)*"#).unwrap();
let emphasis_trailing_comma_regex: Regex = Regex::new(r#"(\(+)([^:]*?),+(:\d[^)]+)?(\)+)"#).unwrap();
let comma_regex: Regex = Regex::new(r#"\s*,+\s*"#).unwrap();
let whitespace_regex: Regex = Regex::new(r#"\s+"#).unwrap();
let trailing_leading_comma: Regex = Regex::new(r#"^,+\s*|,+\s*$"#).unwrap();

fn emphasis_trailing_comma(cap: &Captures<'_>) -> String {
cap.get(1).unwrap().as_str().to_owned() + cap.get(2).unwrap().as_str() + cap.get(3).unwrap().as_str() + cap.get(4).unwrap().as_str()
}
fn cleanup_emphasis(cap: &Captures<'_>) -> String {
cap.get(1).unwrap().as_str().to_string()
}
fn cleanup_concept(cap: &Captures<'_>) -> String {
cap.get(0).unwrap().as_str().trim().to_string()
}

let prompt = cleanup_emphasis_regex.replace_all(prompt.as_ref(), cleanup_emphasis);
let prompt = emphasis_trailing_comma_regex.replace_all(prompt.as_ref(), emphasis_trailing_comma);
let prompt = split_regex.replace_all(prompt.as_ref(), cleanup_concept);
let prompt = comma_regex.replace_all(prompt.as_ref(), ", ");
let prompt = whitespace_regex.replace_all(prompt.as_ref(), " ");
let prompt = trailing_leading_comma.replace_all(prompt.as_ref(), "");
prompt.trim().to_string()
}

/// Combines 2 concepts into one prompt.
///
/// The output prompt is only minimally cleaned (removing extraneous/trailing commas). You should pass the output prompt
/// into [`cleanup_prompt`] for best results.
///
/// ```
/// use pyke_diffusers::util::prompting::combine_concepts;
///
/// assert_eq!(
/// combine_concepts("masterpiece, best quality,,", "1girl, solo, blue hair, ").as_str(),
/// "masterpiece, best quality, 1girl, solo, blue hair"
/// );
/// ```
pub fn combine_concepts<A: AsRef<str>, B: AsRef<str>>(a: A, b: B) -> String {
let trailing_leading_comma: Regex = Regex::new(r#"^,+\s*|,+\s*$"#).unwrap();

let a = trailing_leading_comma.replace_all(a.as_ref(), "");
let b = trailing_leading_comma.replace_all(b.as_ref(), "");
a.trim().to_string() + ", " + b.trim()
}

#[cfg(test)]
mod tests {
use super::{cleanup_prompt, combine_concepts};

#[test]
fn test_cleanup_prompt() {
assert_eq!(
cleanup_prompt("(best quality,, masterpiece,:1.3)), 1girl, solo, blue hair, ").as_str(),
"(best quality, masterpiece:1.3), 1girl, solo, blue hair"
);
}

#[test]
fn test_combine_concepts() {
assert_eq!(combine_concepts("masterpiece, best quality,,", "1girl, solo, blue hair, ").as_str(), "masterpiece, best quality, 1girl, solo, blue hair");
}
}

0 comments on commit dfe1fed

Please sign in to comment.