diff --git a/.cargo/config.toml b/.cargo/config.toml new file mode 100644 index 0000000..b2736b8 --- /dev/null +++ b/.cargo/config.toml @@ -0,0 +1,18 @@ +[profile.rust-analyzer] +inherits = "dev" +debug-assertions = true + +[profile.dev] +debug = true +debug-assertions = false +incremental = true + + +[target.x86_64-unknown-linux-musl] +linker = "rust-lld" + +[profile.release] +opt-level = "z" # Optimize for size +lto = true # Enable Link Time Optimization +codegen-units = 1 # Reduce codegen units to increase optimization +strip = "debuginfo" # Strip debug information diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile new file mode 100644 index 0000000..6f080a2 --- /dev/null +++ b/.devcontainer/Dockerfile @@ -0,0 +1,58 @@ +# Use the official Amazon Linux image +FROM amazonlinux:latest + +ARG USERNAME=amazonlinux +ARG USER_UID=1000 +ARG USER_GID=$USER_UID + +RUN yum update -y \ + && yum install -y shadow-utils sudo \ + && yum clean all + +# Create the user +RUN groupadd --gid $USER_GID $USERNAME \ + && useradd --uid $USER_UID --gid $USER_GID -m $USERNAME \ + && echo "$USERNAME ALL=(root) NOPASSWD:ALL" > /etc/sudoers.d/$USERNAME \ + && chmod 0440 /etc/sudoers.d/$USERNAME + +# Install development tools and dependencies +RUN yum groupinstall -y "Development Tools" \ + && yum install -y \ + git \ + vim \ + tar \ + gzip \ + openssl-devel \ + perl \ + perl-core \ + perl-IPC-Cmd \ + wget \ + glibc-langpack-en \ + && yum clean all + +# Switch to the created user +USER $USERNAME + +# Install Rust and Cargo tools as the amazonlinux user +RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y \ + && ~/.cargo/bin/cargo install cargo-lambda \ + && ~/.cargo/bin/cargo install just \ + && ~/.cargo/bin/cargo install --git https://github.com/fpco/amber \ + && ~/.cargo/bin/cargo install pqrs \ + && ~/.cargo/bin/cargo install --git https://github.com/helix-editor/helix helix-term --locked \ + && ~/.cargo/bin/rustup install nightly \ + && ~/.cargo/bin/rustup target add x86_64-unknown-linux-musl \ + && ~/.cargo/bin/cargo install cargo-udeps + +# Set environment variables +ENV PATH="/home/$USERNAME/.cargo/bin:${PATH}" +ENV LANG=en_US.UTF-8 +ENV LANGUAGE=en_US:en +ENV LC_ALL=en_US.UTF-8 +ENV OPENSSL_DIR=/usr \ + OPENSSL_LIB_DIR=/usr/lib64 \ + OPENSSL_INCLUDE_DIR=/usr/include \ + CC=gcc + +# Set the working directory +WORKDIR /workspace \ No newline at end of file diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json new file mode 100644 index 0000000..6721e4d --- /dev/null +++ b/.devcontainer/devcontainer.json @@ -0,0 +1,73 @@ +{ + "name": "Amazon Linux Rust Dev Container", + "context": "..", + "dockerFile": "Dockerfile", + "postCreateCommand": "rustup update", + "runArgs": [ + "--env-file", + ".devcontainer/.env" + ], + "mounts": [ + "source=${localWorkspaceFolder},target=/workspace,type=bind,consistency=cached", + "source=/home/ggraciani/.ssh/id_ed25519hit,target=/home/amazonlinux/.ssh/id_ed25519hit,type=bind,consistency=cached", + "source=/home/ggraciani/.ssh/known_hosts,target=/home/amazonlinux/.ssh/known_hosts,type=bind,consistency=cached" + ], + "customizations": { + "vscode": { + "extensions": [ + "amazonwebservices.aws-toolkit-vscode", + "arrterian.nix-env-selector", + "bbenoist.nix", + "bierner.markdown-preview-github-styles", + "bradlc.vscode-tailwindcss", + "brettm12345.nixfmt-vscode", + "dbaeumer.vscode-eslint", + "earthly.earthfile-syntax-highlighting", + "editorconfig.editorconfig", + "github.copilot", + "github.copilot-chat", + "github.vscode-github-actions", + "golang.go", + "hashicorp.terraform", + "hediet.vscode-drawio", + "janisdd.vscode-edit-csv", + "jeremyrajan.webpack", + "mark-hansen.hledger-vscode", + "marp-team.marp-vscode", + "mechatroner.rainbow-csv", + "mkhl.shfmt", + "ms-azuretools.vscode-docker", + "ms-playwright.playwright", + "ms-python.debugpy", + "ms-python.isort", + "ms-python.python", + "ms-python.vscode-pylance", + "ms-toolsai.jupyter", + "ms-toolsai.jupyter-keymap", + "ms-toolsai.jupyter-renderers", + "ms-toolsai.vscode-jupyter-cell-tags", + "ms-toolsai.vscode-jupyter-slideshow", + "ms-vscode-remote.remote-containers", + "ms-vscode.hexeditor", + "ms-vscode.vscode-speech", + "octref.vscode-ts-config-plugin", + "phu1237.vs-browser", + "polymeilex.wgsl", + "rangav.vscode-thunder-client", + "rebornix.ruby", + "redhat.vscode-yaml", + "rust-lang.rust-analyzer", + "skellock.just", + "tamasfe.even-better-toml", + "timonwong.shellcheck", + "tomoki1207.pdf", + "usernamehw.errorlens", + "vadimcn.vscode-lldb", + "wgsl-analyzer.wgsl-analyzer", + "wingrunr21.vscode-ruby", + "zxh404.vscode-proto3", + "fill-labs.dependi" + ] + } + } +} \ No newline at end of file diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 542ee51..8225c89 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -29,13 +29,15 @@ jobs: ./target key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} - name: Run tests - run: cargo test --verbose + run: cargo test --verbose --target x86_64-unknown-linux-gnu build: strategy: fail-fast: false matrix: include: + - TARGET: x86_64-unknown-linux-musl + OS: ubuntu-latest - TARGET: x86_64-unknown-linux-gnu OS: ubuntu-latest - TARGET: x86_64-apple-darwin @@ -74,6 +76,11 @@ jobs: sudo apt-get update sudo apt-get install -qq crossbuild-essential-arm64 crossbuild-essential-armhf fi + + - name: Add musl target + if: ${{ matrix.TARGET == 'x86_64-unknown-linux-musl' }} + run: sudo apt-get update && sudo apt-get install -y musl-dev musl-tools + - name: Run build run: cargo build --release --verbose --target $TARGET - name: Run tests diff --git a/.gitignore b/.gitignore index 15d9edc..00ac6a1 100644 --- a/.gitignore +++ b/.gitignore @@ -10,6 +10,8 @@ Cargo.lock # These are backup files generated by rustfmt **/*.rs.bk +**/.env + # MSVC Windows builds of rustc generate these, which store debugging information *.pdb .idea/vcs.xml diff --git a/crates/fluent-cli/Cargo.toml b/crates/fluent-cli/Cargo.toml index 9d11398..8425d4a 100644 --- a/crates/fluent-cli/Cargo.toml +++ b/crates/fluent-cli/Cargo.toml @@ -7,16 +7,15 @@ edition = "2021" clap = { version = "4.5.8", features = ["derive"] } fluent-core = { path = "../fluent-core" } fluent-engines = { path = "../fluent-engines" } -fluent-storage = { path = "../fluent-storage" } - tokio = { version = "1", features = ["full"] } anyhow = "1.0" log = "0.4.22" -atty = "0.2.14" uuid = { version = "1.9.1", features = ["v4"] } -clap_complete = "4.5.1" serde_json = "1.0.120" indicatif = "0.17.8" -owo-colors = "3.5.0" +owo-colors = "4.0.0" regex = "1.10.5" -serde_yaml = "0.9.34+deprecated" \ No newline at end of file +serde_yaml = "0.9.34+deprecated" +#fluent-storage = { path = "../fluent-storage" } # is not used +#clap_complete = "4.5.1" #is not used +#atty = "0.2.14" "use standard std::io::IsTerminal" diff --git a/crates/fluent-cli/src/lib.rs b/crates/fluent-cli/src/lib.rs index a1d6bef..e2b8413 100644 --- a/crates/fluent-cli/src/lib.rs +++ b/crates/fluent-cli/src/lib.rs @@ -1,16 +1,10 @@ - - use std::pin::Pin; use anyhow::{anyhow, Error}; - -use fluent_core::config::{ EngineConfig, Neo4jConfig}; +use fluent_core::config::{EngineConfig, Neo4jConfig}; use fluent_core::neo4j_client::Neo4jClient; -use log::{debug}; -use regex::Regex; -use serde_json::{ Value}; -use fluent_core::traits::{Engine}; +use fluent_core::traits::Engine; use fluent_core::types::Request; use fluent_engines::anthropic::AnthropicEngine; use fluent_engines::cohere::CohereEngine; @@ -18,33 +12,35 @@ use fluent_engines::google_gemini::GoogleGeminiEngine; use fluent_engines::groqlpu::GroqLPUEngine; use fluent_engines::openai::OpenAIEngine; use fluent_engines::perplexity::PerplexityEngine; +use log::debug; +use regex::Regex; +use serde_json::Value; pub mod cli { - use std::pin::Pin; - use std::fs; - use std::env; - use clap::{Command, Arg, ArgAction, ArgMatches}; + use anyhow::{anyhow, Error, Result}; + use clap::{Arg, ArgAction, ArgMatches, Command}; use fluent_core::config::{load_config, Config, EngineConfig}; - use fluent_engines::openai::OpenAIEngine; - use fluent_engines::anthropic::AnthropicEngine; use fluent_core::traits::Engine; use fluent_core::types::{Request, Response}; - use anyhow::{Result, anyhow, Error}; + use fluent_engines::anthropic::AnthropicEngine; + use fluent_engines::openai::OpenAIEngine; + use indicatif::{ProgressBar, ProgressStyle}; + use owo_colors::OwoColorize; use std::collections::{HashMap, HashSet}; + use std::fs; + use std::io::IsTerminal; use std::path::{Path, PathBuf}; + use std::pin::Pin; use std::time::Duration; - use owo_colors::OwoColorize; - use std::io::IsTerminal; - use indicatif::{ProgressBar, ProgressStyle}; + use std::{env, io}; use log::{debug, error, info}; use serde_json::Value; use tokio::io::AsyncReadExt; - use tokio::time::Instant; - use uuid::Uuid; + use crate::{create_engine, create_llm_engine, generate_and_execute_cypher}; use fluent_core::neo4j_client::{InteractionStats, Neo4jClient}; - use fluent_core::output_processor::{ OutputProcessor}; + use fluent_core::output_processor::OutputProcessor; use fluent_engines::cohere::CohereEngine; use fluent_engines::dalle::DalleEngine; use fluent_engines::flowise_chain::FlowiseChainEngine; @@ -52,14 +48,16 @@ pub mod cli { use fluent_engines::groqlpu::GroqLPUEngine; use fluent_engines::imagepro::ImagineProEngine; use fluent_engines::langflow::LangflowEngine; - use fluent_engines::perplexity::PerplexityEngine; - use fluent_engines::webhook::WebhookEngine; use fluent_engines::leonardoai::LeonardoAIEngine; use fluent_engines::mistral::MistralEngine; - use fluent_engines::pipeline_executor::{FileStateStore, Pipeline, PipelineExecutor, StateStore}; + use fluent_engines::perplexity::PerplexityEngine; + use fluent_engines::pipeline_executor::{ + FileStateStore, Pipeline, PipelineExecutor, StateStore, + }; use fluent_engines::stabilityai::StabilityAIEngine; - use crate::{create_engine, create_llm_engine, generate_and_execute_cypher}; - + use fluent_engines::webhook::WebhookEngine; + use tokio::time::Instant; + use uuid::Uuid; fn parse_key_value_pair(s: &str) -> Option<(String, String)> { let parts: Vec<&str> = s.splitn(2, '=').collect(); @@ -98,7 +96,11 @@ pub mod cli { Ok((engines, parameters)) } - pub async fn process_request_with_file(engine: &dyn Engine, request_content: &str, file_path: &str) -> Result { + pub async fn process_request_with_file( + engine: &dyn Engine, + request_content: &str, + file_path: &str, + ) -> Result { let file_id = Pin::from(engine.upload_file(Path::new(file_path))).await?; println!("File uploaded successfully. File ID: {}", file_id); @@ -110,14 +112,14 @@ pub mod cli { Pin::from(engine.execute(&request)).await } - pub async fn process_request(engine: &dyn Engine, request_content: &str) -> Result { let request = Request { flowname: "default".to_string(), payload: request_content.to_string(), }; - Pin::from(engine.execute(&request)).await } + Pin::from(engine.execute(&request)).await + } pub fn print_response(response: &Response, response_time: f64) { println!("Response: {}", response.content); @@ -132,124 +134,162 @@ pub mod cli { } } - pub fn build_cli() -> Command { Command::new("Fluent CLI") .version("2.0") .author("Your Name ") .about("A powerful CLI for interacting with various AI engines") - .arg(Arg::new("config") - .short('c') - .long("config") - .value_name("FILE") - .help("Sets a custom config file") - .required(false)) - .arg(Arg::new("engine") - .help("The engine to use (openai or anthropic)") - .required(true)) - .arg(Arg::new("request") - .help("The request to process") - .required(false)) - .arg(Arg::new("override") - .short('o') - .long("override") - .value_name("KEY=VALUE") - .help("Override configuration values") - .action(ArgAction::Append) - .num_args(1..)) - .arg(Arg::new("additional-context-file") - .long("additional-context-file") - .short('a') - .help("Specifies a file from which additional request context is loaded") - .action(ArgAction::Set) - .value_hint(clap::ValueHint::FilePath) - .required(false)) - .arg(Arg::new("upsert") - .long("upsert") - .help("Enables upsert mode") - .action(ArgAction::SetTrue) - .conflicts_with("request")) - .arg(Arg::new("input") - .long("input") - .short('i') - .value_name("FILE") - .help("Input file or directory to process (required for upsert)") - .required(false)) - .arg(Arg::new("metadata") - .long("metadata") - .short('t') - .value_name("TERMS") - .help("Comma-separated list of metadata terms (for upsert)") - .required(false)) - .arg(Arg::new("upload-image-file") - .short('l') - .long("upload_image_file") - .value_name("FILE") - .help("Upload a media file") - .action(ArgAction::Set) - .required(false)) - .arg(Arg::new("download-media") - .short('d') - .long("download-media") - .value_name("DIR") - .help("Download media files from the output") - .action(ArgAction::Set) - .required(false)) - .arg(Arg::new("parse-code") - .short('p') - .long("parse-code") - .help("Parse and display code blocks from the output") - .action(ArgAction::SetTrue)) - .arg(Arg::new("execute-output") - .short('x') - .long("execute-output") - .help("Execute code blocks from the output") - .action(ArgAction::SetTrue)) - .arg(Arg::new("markdown") - .short('m') - .long("markdown") - .help("Format output as markdown") - .action(ArgAction::SetTrue)) - .arg(Arg::new("generate-cypher") - .long("generate-cypher") - .value_name("QUERY") - .help("Generate and execute a Cypher query based on the given string") - .action(ArgAction::Set) - .required(false)) - .subcommand(Command::new("pipeline") - .about("Execute a pipeline") - .arg(Arg::new("file") - .short('f') - .long("file") - .help("The YAML file containing the pipeline definition") - .required(true)) - .arg(Arg::new("input") - .short('i') + .arg( + Arg::new("config") + .short('c') + .long("config") + .value_name("FILE") + .help("Sets a custom config file") + .required(false), + ) + .arg( + Arg::new("engine") + .help("The engine to use (openai or anthropic)") + .required(true), + ) + .arg( + Arg::new("request") + .help("The request to process") + .required(false), + ) + .arg( + Arg::new("override") + .short('o') + .long("override") + .value_name("KEY=VALUE") + .help("Override configuration values") + .action(ArgAction::Append) + .num_args(1..), + ) + .arg( + Arg::new("additional-context-file") + .long("additional-context-file") + .short('a') + .help("Specifies a file from which additional request context is loaded") + .action(ArgAction::Set) + .value_hint(clap::ValueHint::FilePath) + .required(false), + ) + .arg( + Arg::new("upsert") + .long("upsert") + .help("Enables upsert mode") + .action(ArgAction::SetTrue) + .conflicts_with("request"), + ) + .arg( + Arg::new("input") .long("input") - .help("The input for the pipeline") - .required(true)) - .arg(Arg::new("force_fresh") - .long("force-fresh") - .help("Force a fresh execution of the pipeline") - .action(ArgAction::SetTrue)) - .arg(Arg::new("run_id") - .long("run-id") - .help("Specify a run ID for the pipeline")) - .arg(Arg::new("json_output") - .long("json-output") - .help("Output only the JSON result, suppressing PrintOutput steps") - .action(ArgAction::SetTrue))) + .short('i') + .value_name("FILE") + .help("Input file or directory to process (required for upsert)") + .required(false), + ) + .arg( + Arg::new("metadata") + .long("metadata") + .short('t') + .value_name("TERMS") + .help("Comma-separated list of metadata terms (for upsert)") + .required(false), + ) + .arg( + Arg::new("upload-image-file") + .short('l') + .long("upload_image_file") + .value_name("FILE") + .help("Upload a media file") + .action(ArgAction::Set) + .required(false), + ) + .arg( + Arg::new("download-media") + .short('d') + .long("download-media") + .value_name("DIR") + .help("Download media files from the output") + .action(ArgAction::Set) + .required(false), + ) + .arg( + Arg::new("parse-code") + .short('p') + .long("parse-code") + .help("Parse and display code blocks from the output") + .action(ArgAction::SetTrue), + ) + .arg( + Arg::new("execute-output") + .short('x') + .long("execute-output") + .help("Execute code blocks from the output") + .action(ArgAction::SetTrue), + ) + .arg( + Arg::new("markdown") + .short('m') + .long("markdown") + .help("Format output as markdown") + .action(ArgAction::SetTrue), + ) + .arg( + Arg::new("generate-cypher") + .long("generate-cypher") + .value_name("QUERY") + .help("Generate and execute a Cypher query based on the given string") + .action(ArgAction::Set) + .required(false), + ) + .subcommand( + Command::new("pipeline") + .about("Execute a pipeline") + .arg( + Arg::new("file") + .short('f') + .long("file") + .help("The YAML file containing the pipeline definition") + .required(true), + ) + .arg( + Arg::new("input") + .short('i') + .long("input") + .help("The input for the pipeline") + .required(true), + ) + .arg( + Arg::new("force_fresh") + .long("force-fresh") + .help("Force a fresh execution of the pipeline") + .action(ArgAction::SetTrue), + ) + .arg( + Arg::new("run_id") + .long("run-id") + .help("Specify a run ID for the pipeline"), + ) + .arg( + Arg::new("json_output") + .long("json-output") + .help("Output only the JSON result, suppressing PrintOutput steps") + .action(ArgAction::SetTrue), + ), + ) } pub async fn get_neo4j_query_llm(config: &Config) -> Option<(Box, &EngineConfig)> { let neo4j_config = config.engines.iter().find(|e| e.engine == "neo4j")?; let query_llm = neo4j_config.neo4j.as_ref()?.query_llm.as_ref()?; - let llm_config = config.engines.iter().find(|e| e.name == query_llm.to_string())?; + let llm_config = config.engines.iter().find(|e| e.name == *query_llm)?; let engine = create_llm_engine(llm_config).await.ok()?; Some((engine, llm_config)) } - pub async fn run() -> Result<()> { let matches = build_cli().get_matches(); @@ -261,17 +301,26 @@ pub mod cli { let run_id = sub_matches.get_one::("run_id").cloned(); let json_output = sub_matches.get_flag("json_output"); - let pipeline: Pipeline = serde_yaml::from_str(&std::fs::read_to_string(pipeline_file)?)?; + let pipeline: Pipeline = + serde_yaml::from_str(&std::fs::read_to_string(pipeline_file)?)?; let state_store_dir = PathBuf::from("./pipeline_states"); tokio::fs::create_dir_all(&state_store_dir).await?; - let state_store = FileStateStore { directory: state_store_dir }; + let state_store = FileStateStore { + directory: state_store_dir, + }; let executor = PipelineExecutor::new(state_store.clone(), json_output); - executor.execute(&pipeline, input, force_fresh, run_id.clone()).await?; + executor + .execute(&pipeline, input, force_fresh, run_id.clone()) + .await?; if json_output { // Read the state file and print its contents to stdout - let state_key = format!("{}-{}", pipeline.name, run_id.unwrap_or_else(|| "unknown".to_string())); + let state_key = format!( + "{}-{}", + pipeline.name, + run_id.unwrap_or_else(|| "unknown".to_string()) + ); if let Some(state) = state_store.load_state(&state_key).await? { println!("{}", serde_json::to_string_pretty(&state)?); } else { @@ -281,7 +330,7 @@ pub mod cli { } std::process::exit(0); - }, + } // ... other commands ... _ => Ok(()), // Default case, do nothing }; @@ -293,7 +342,8 @@ pub mod cli { let engine_name = matches.get_one::("engine").unwrap(); - let overrides: HashMap = matches.get_many::("override") + let overrides: HashMap = matches + .get_many::("override") .map(|values| values.filter_map(|s| parse_key_value_pair(s)).collect()) .unwrap_or_default(); @@ -313,13 +363,15 @@ pub mod cli { pb.enable_steady_tick(Duration::from_millis(spinner_config.interval)); pb.set_length(100); - - if let Some(cypher_query) = matches.get_one::("generate-cypher") { - let neo4j_config = engine_config.neo4j.as_ref() + let neo4j_config = engine_config + .neo4j + .as_ref() .ok_or_else(|| anyhow!("Neo4j configuration not found in the engine config"))?; - let query_llm_name = neo4j_config.query_llm.as_ref() + let query_llm_name = neo4j_config + .query_llm + .as_ref() .ok_or_else(|| anyhow!("No query LLM specified for Neo4j"))?; // Load the configuration for the query LLM @@ -332,21 +384,25 @@ pub mod cli { neo4j_config, query_llm_engine_config, cypher_query, - &*query_llm_engine - ).await?; + &*query_llm_engine, + ) + .await?; if engine_config.engine == "neo4j" { println!("{}", cypher_result); } else { let engine: Box = create_engine(engine_config).await?; - let max_tokens = engine_config.parameters.get("max_tokens") + let max_tokens = engine_config + .parameters + .get("max_tokens") .and_then(|v| v.as_i64()) .unwrap_or(-1); - let user_request = matches.get_one::("request") + let user_request = matches + .get_one::("request") .map(|s| s.to_string()) - .unwrap_or_else(String::new); + .unwrap_or_default(); let mut combined_request = format!( "Cypher query: {}\n\nCypher result:\n{}\n\nBased on the above Cypher query and its result, please provide an analysis or answer the following question: {}", @@ -369,7 +425,8 @@ pub mod cli { if let Some(download_dir) = matches.get_one::("download-media") { let download_path = PathBuf::from(download_dir); - OutputProcessor::download_media_files(&response.content, &download_path).await?; + OutputProcessor::download_media_files(&response.content, &download_path) + .await?; } if matches.get_flag("parse-code") { @@ -393,7 +450,8 @@ pub mod cli { let response_time = start_time.elapsed().as_secs_f64(); if let Some(neo4j_client) = engine.get_neo4j_client() { - let session_id = engine.get_session_id() + let session_id = engine + .get_session_id() .unwrap_or_else(|| Uuid::new_v4().to_string()); let stats = InteractionStats { @@ -401,19 +459,28 @@ pub mod cli { completion_tokens: response.usage.completion_tokens, total_tokens: response.usage.total_tokens, response_time, - finish_reason: response.finish_reason.clone().unwrap_or_else(|| "unknown".to_string()), + finish_reason: response + .finish_reason + .clone() + .unwrap_or_else(|| "unknown".to_string()), }; debug!("Attempting to create interaction in Neo4j"); debug!("Using session ID: {}", session_id); - match neo4j_client.create_interaction( - &session_id, - &request.payload, - &response.content, - &response.model, - &stats - ).await { - Ok(interaction_id) => debug!("Successfully created interaction with id: {}", interaction_id), + match neo4j_client + .create_interaction( + &session_id, + &request.payload, + &response.content, + &response.model, + &stats, + ) + .await + { + Ok(interaction_id) => debug!( + "Successfully created interaction with id: {}", + interaction_id + ), Err(e) => error!("Failed to create interaction in Neo4j: {:?}", e), } } else { @@ -430,12 +497,55 @@ pub mod cli { eprintln!( "{} | {} | Time: {} | Usage: {}↑ {}↓ {}Σ | {}\n", spinner_config.success_symbol, - if use_colors { response.model.cyan().to_string() } else { response.model }, - if use_colors { response_time_str.bright_blue().to_string() } else { response_time_str }, - if use_colors { response.usage.prompt_tokens.to_string().yellow().to_string() } else { response.usage.prompt_tokens.to_string() }, - if use_colors { response.usage.completion_tokens.to_string().yellow().to_string() } else { response.usage.completion_tokens.to_string() }, - if use_colors { response.usage.total_tokens.to_string().yellow().to_string() } else { response.usage.total_tokens.to_string() }, - if use_colors { response.finish_reason.as_deref().unwrap_or("No finish reason").italic().to_string() } else { response.finish_reason.as_deref().unwrap_or("No finish reason").to_string() } + if use_colors { + response.model.cyan().to_string() + } else { + response.model + }, + if use_colors { + response_time_str.bright_blue().to_string() + } else { + response_time_str + }, + if use_colors { + response + .usage + .prompt_tokens + .to_string() + .yellow() + .to_string() + } else { + response.usage.prompt_tokens.to_string() + }, + if use_colors { + response + .usage + .completion_tokens + .to_string() + .yellow() + .to_string() + } else { + response.usage.completion_tokens.to_string() + }, + if use_colors { + response.usage.total_tokens.to_string().yellow().to_string() + } else { + response.usage.total_tokens.to_string() + }, + if use_colors { + response + .finish_reason + .as_deref() + .unwrap_or("No finish reason") + .italic() + .to_string() + } else { + response + .finish_reason + .as_deref() + .unwrap_or("No finish reason") + .to_string() + } ); } } else if matches.get_flag("upsert") { @@ -463,7 +573,7 @@ pub mod cli { engine.set_download_dir(download_dir.to_string()); } engine - }, + } "leonardo_ai" => Box::new(LeonardoAIEngine::new(engine_config.clone()).await?), "imagine_pro" => { let mut engine = Box::new(ImagineProEngine::new(engine_config.clone()).await?); @@ -477,7 +587,7 @@ pub mod cli { // Read context from stdin if available let mut context = String::new(); - if !atty::is(atty::Stream::Stdin) { + if io::stdin().is_terminal() { tokio::io::stdin().read_to_string(&mut context).await?; } @@ -497,7 +607,8 @@ pub mod cli { } // Add file contents if it's not empty if !file_contents.trim().is_empty() { - combined_request_parts.push(format!("Additional Context:\n{}", file_contents.trim())); + combined_request_parts + .push(format!("Additional Context:\n{}", file_contents.trim())); } // Join all parts with a separator let combined_request = combined_request_parts.join("\n\n----\n\n"); @@ -509,7 +620,6 @@ pub mod cli { }; debug!("Combined Request: {:?}", request); - let response = if let Some(file_path) = matches.get_one::("upload-image-file") { debug!("Processing request with file: {}", file_path); pb.set_message("Processing request with file..."); @@ -547,7 +657,8 @@ pub mod cli { let response_time = start_time.elapsed().as_secs_f64(); if let Some(neo4j_client) = engine.get_neo4j_client() { - let session_id = engine.get_session_id() + let session_id = engine + .get_session_id() .unwrap_or_else(|| Uuid::new_v4().to_string()); let stats = InteractionStats { @@ -555,19 +666,28 @@ pub mod cli { completion_tokens: response.usage.completion_tokens, total_tokens: response.usage.total_tokens, response_time, - finish_reason: response.finish_reason.clone().unwrap_or_else(|| "unknown".to_string()), + finish_reason: response + .finish_reason + .clone() + .unwrap_or_else(|| "unknown".to_string()), }; debug!("Attempting to create interaction in Neo4j"); debug!("Using session ID: {}", session_id); - match neo4j_client.create_interaction( - &session_id, - &request.payload, - &response.content, - &response.model, - &stats - ).await { - Ok(interaction_id) => debug!("Successfully created interaction with id: {}", interaction_id), + match neo4j_client + .create_interaction( + &session_id, + &request.payload, + &response.content, + &response.model, + &stats, + ) + .await + { + Ok(interaction_id) => debug!( + "Successfully created interaction with id: {}", + interaction_id + ), Err(e) => error!("Failed to create interaction in Neo4j: {:?}", e), } } else { @@ -584,33 +704,80 @@ pub mod cli { eprintln!( "{} | {} | Time: {} | Usage: {}↑ {}↓ {}Σ | {}\n", spinner_config.success_symbol, - if use_colors { response.model.cyan().to_string() } else { response.model }, - if use_colors { response_time_str.bright_blue().to_string() } else { response_time_str }, - if use_colors { response.usage.prompt_tokens.to_string().yellow().to_string() } else { response.usage.prompt_tokens.to_string() }, - if use_colors { response.usage.completion_tokens.to_string().yellow().to_string() } else { response.usage.completion_tokens.to_string() }, - if use_colors { response.usage.total_tokens.to_string().yellow().to_string() } else { response.usage.total_tokens.to_string() }, - if use_colors { response.finish_reason.as_deref().unwrap_or("No finish reason").italic().to_string() } else { response.finish_reason.as_deref().unwrap_or("No finish reason").to_string() } + if use_colors { + response.model.cyan().to_string() + } else { + response.model + }, + if use_colors { + response_time_str.bright_blue().to_string() + } else { + response_time_str + }, + if use_colors { + response + .usage + .prompt_tokens + .to_string() + .yellow() + .to_string() + } else { + response.usage.prompt_tokens.to_string() + }, + if use_colors { + response + .usage + .completion_tokens + .to_string() + .yellow() + .to_string() + } else { + response.usage.completion_tokens.to_string() + }, + if use_colors { + response.usage.total_tokens.to_string().yellow().to_string() + } else { + response.usage.total_tokens.to_string() + }, + if use_colors { + response + .finish_reason + .as_deref() + .unwrap_or("No finish reason") + .italic() + .to_string() + } else { + response + .finish_reason + .as_deref() + .unwrap_or("No finish reason") + .to_string() + } ); } Ok(()) } - async fn handle_upsert(engine_config: &EngineConfig, matches: &ArgMatches) -> Result<()> { if let Some(neo4j_config) = &engine_config.neo4j { let neo4j_client = Neo4jClient::new(neo4j_config).await?; - let input = matches.get_one::("input") + let input = matches + .get_one::("input") .ok_or_else(|| anyhow!("Input is required for upsert mode"))?; - let metadata = matches.get_one::("metadata") + let metadata = matches + .get_one::("metadata") .map(|s| s.split(',').map(String::from).collect::>()) .unwrap_or_default(); let input_path = Path::new(input); if input_path.is_file() { let document_id = neo4j_client.upsert_document(input_path, &metadata).await?; - eprintln!("Uploaded document with ID: {}. Embeddings and chunks created.", document_id); + eprintln!( + "Uploaded document with ID: {}. Embeddings and chunks created.", + document_id + ); } else if input_path.is_dir() { let mut uploaded_count = 0; for entry in fs::read_dir(input_path)? { @@ -618,11 +785,18 @@ pub mod cli { let path = entry.path(); if path.is_file() { let document_id = neo4j_client.upsert_document(&path, &metadata).await?; - eprintln!("Uploaded document {} with ID: {}. Embeddings and chunks created.", path.display(), document_id); + eprintln!( + "Uploaded document {} with ID: {}. Embeddings and chunks created.", + path.display(), + document_id + ); uploaded_count += 1; } } - eprintln!("Uploaded {} documents with embeddings and chunks", uploaded_count); + eprintln!( + "Uploaded {} documents with embeddings and chunks", + uploaded_count + ); } else { return Err(anyhow!("Input is neither a file nor a directory")); } @@ -641,12 +815,14 @@ pub mod cli { Ok(()) } - pub async fn generate_cypher_query(query: &str, config: &EngineConfig) -> Result { // Use the configured LLM to generate a Cypher query let llm_request = Request { flowname: "cypher_generation".to_string(), - payload: format!("Generate a Cypher query for Neo4j based on this request: {}", query), + payload: format!( + "Generate a Cypher query for Neo4j based on this request: {}", + query + ), }; debug!("Sending request to LLM engine: {:?}", llm_request); let llm_engine: Box = match config.engine.as_str() { @@ -656,21 +832,18 @@ pub mod cli { _ => return Err(anyhow!("Unsupported LLM engine for Cypher generation")), }; - let response = Pin::from(llm_engine.execute(&llm_request)).await?; - debug!("Response from LLM engine: {:?}", response); Ok(response.content) } } - async fn generate_and_execute_cypher( neo4j_config: &Neo4jConfig, _llm_config: &EngineConfig, query_string: &str, - llm_engine: &dyn Engine + llm_engine: &dyn Engine, ) -> Result { debug!("Generating Cypher query using LLM"); debug!("Neo4j configuration: {:#?}", neo4j_config); @@ -701,7 +874,6 @@ async fn generate_and_execute_cypher( Ok(format_as_csv(&cypher_result)) } - fn extract_cypher_query(content: &str) -> Result { // First, try to extract content between triple backticks let backtick_re = Regex::new(r"```(?:cypher)?\s*([\s\S]*?)\s*```").unwrap(); @@ -715,7 +887,8 @@ fn extract_cypher_query(content: &str) -> Result { } // If not found, look for common Cypher keywords to identify the query - let cypher_re = Regex::new(r"(?i)(MATCH|CREATE|MERGE|DELETE|REMOVE|SET|RETURN)[\s\S]+").unwrap(); + let cypher_re = + Regex::new(r"(?i)(MATCH|CREATE|MERGE|DELETE|REMOVE|SET|RETURN)[\s\S]+").unwrap(); if let Some(captures) = cypher_re.captures(content) { if let Some(query) = captures.get(0) { let extracted = query.as_str().trim(); @@ -731,8 +904,12 @@ fn extract_cypher_query(content: &str) -> Result { fn is_valid_cypher(query: &str) -> bool { // Basic validation: check if the query contains common Cypher clauses - let valid_clauses = ["MATCH", "CREATE", "MERGE", "DELETE", "REMOVE", "SET", "RETURN", "WITH", "WHERE"]; - valid_clauses.iter().any(|&clause| query.to_uppercase().contains(clause)) + let valid_clauses = [ + "MATCH", "CREATE", "MERGE", "DELETE", "REMOVE", "SET", "RETURN", "WITH", "WHERE", + ]; + valid_clauses + .iter() + .any(|&clause| query.to_uppercase().contains(clause)) } fn format_as_csv(result: &Value) -> String { @@ -746,8 +923,12 @@ async fn create_engine(engine_config: &EngineConfig) -> Result, "openai" => Ok(Box::new(OpenAIEngine::new(engine_config.clone()).await?)), "anthropic" => Ok(Box::new(AnthropicEngine::new(engine_config.clone()).await?)), "cohere" => Ok(Box::new(CohereEngine::new(engine_config.clone()).await?)), - "google_gemini" => Ok(Box::new(GoogleGeminiEngine::new(engine_config.clone()).await?)), - "perplexity" => Ok(Box::new(PerplexityEngine::new(engine_config.clone()).await?)), + "google_gemini" => Ok(Box::new( + GoogleGeminiEngine::new(engine_config.clone()).await?, + )), + "perplexity" => Ok(Box::new( + PerplexityEngine::new(engine_config.clone()).await?, + )), "groq_lpu" => Ok(Box::new(GroqLPUEngine::new(engine_config.clone()).await?)), _ => Err(anyhow!("Unsupported engine: {}", engine_config.engine)), } @@ -756,5 +937,3 @@ async fn create_engine(engine_config: &EngineConfig) -> Result, async fn create_llm_engine(engine_config: &EngineConfig) -> Result, Error> { create_engine(engine_config).await } - - diff --git a/crates/fluent-core/Cargo.toml b/crates/fluent-core/Cargo.toml index 9f5a0bb..89bf422 100644 --- a/crates/fluent-core/Cargo.toml +++ b/crates/fluent-core/Cargo.toml @@ -5,6 +5,10 @@ version = "0.1.0" edition = "2021" [dependencies] +reqwest = { version = "0.12.5", default-features = false, features = [ + "json", + "rustls-tls", +] } serde = { version = "1.0", features = ["derive"] } anyhow = "1.0" serde_json = "1.0.120" @@ -13,25 +17,20 @@ log = "0.4.22" chrono = "0.4.38" uuid = { version = "1.9.1", features = ["v4"] } neo4rs = "0.7.1" -reqwest = { version = "0.12.5", features = ["json"] } unicode-segmentation = "1.11.0" - - -#rust-bert = { version = "0.18.0" } -tokenizers = "0.14.0" rust-stemmers = "1.2.0" stop-words = "0.8.0" base64 = "0.22.1" -tokio = "1.38.0" -regex = "1.10.5" +tokio = "1.39.2" +regex = "1.10.6" url = "2.5.2" -termimad = "0.29.2" -crossterm = "0.27.0" +termimad = "0.29.4" syntect = "5.2.0" -indicatif = "0.17.8" -owo-colors = "3.5.0" +owo-colors = "4.0.0" pdf-extract = "0.7.7" -[patch.crates-io] -jetscii = "0.5.1" \ No newline at end of file +#rust-bert = { version = "0.18.0" } #Is not used +#indicatif = "0.17.8" #Is not used +#jetscii = "0.5.3" #Is not used +#tokenizers = "0.19.1" #Is not used diff --git a/crates/fluent-core/src/config.rs b/crates/fluent-core/src/config.rs index 0c5abea..e725b33 100644 --- a/crates/fluent-core/src/config.rs +++ b/crates/fluent-core/src/config.rs @@ -1,14 +1,13 @@ -use anyhow::{Result, anyhow}; -use serde::{Deserialize, Serialize}; -use std::collections::HashMap; -use std::{env, fs}; +use crate::neo4j_client::VoyageAIConfig; +use crate::spinner_configuration::SpinnerConfig; +use anyhow::{anyhow, Result}; use log::{debug, info}; +use serde::{Deserialize, Serialize}; use serde_json::Value; +use std::collections::HashMap; use std::process::Command; use std::sync::Arc; -use crate::neo4j_client::VoyageAIConfig; -use crate::spinner_configuration::SpinnerConfig; - +use std::{env, fs}; #[derive(Debug, Deserialize, Serialize, Clone)] pub struct EngineConfig { @@ -16,10 +15,9 @@ pub struct EngineConfig { pub engine: String, pub connection: ConnectionConfig, pub parameters: HashMap, - pub session_id: Option, // New field for sessionID + pub session_id: Option, // New field for sessionID pub neo4j: Option, pub spinner: Option, - } #[derive(Debug, Deserialize, Serialize, Clone)] @@ -33,7 +31,6 @@ pub struct Neo4jConfig { pub parameters: Option>, } - #[derive(Debug, Deserialize, Serialize, Clone)] pub struct ConnectionConfig { pub protocol: String, @@ -42,7 +39,7 @@ pub struct ConnectionConfig { pub request_path: String, } -#[derive( Clone)] +#[derive(Clone)] pub struct Config { pub engines: Vec, _env_guard: Arc, // This keeps the guard alive @@ -56,7 +53,11 @@ impl Config { } } } -pub fn load_config(config_path: &str, engine_name: &str, overrides: &HashMap) -> Result { +pub fn load_config( + config_path: &str, + engine_name: &str, + overrides: &HashMap, +) -> Result { let config_content = fs::read_to_string(config_path)?; let mut config: Value = serde_json::from_str(&config_content)?; @@ -80,15 +81,20 @@ pub fn load_config(config_path: &str, engine_name: &str, overrides: &HashMap value.parse::().map(Value::from).unwrap_or(Value::String(value.clone())), - Value::Bool(_) => value.parse::().map(Value::from).unwrap_or(Value::String(value.clone())), + Value::Number(_) => value + .parse::() + .map(Value::from) + .unwrap_or(Value::String(value.clone())), + Value::Bool(_) => value + .parse::() + .map(Value::from) + .unwrap_or(Value::String(value.clone())), _ => Value::String(value.clone()), }; parameters.insert(key.clone(), parsed_value); } } - debug!("Loaded and processed config for engine: {}", engine_name); let engine_config: EngineConfig = serde_json::from_value(engine_config.clone())?; @@ -113,18 +119,16 @@ pub fn apply_overrides(config: &mut EngineConfig, overrides: &[(String, String)] Ok(()) } +#[derive(Default)] pub struct AmberEnvVarGuard { keys: Vec, } impl AmberEnvVarGuard { pub fn new() -> Self { - AmberEnvVarGuard { - keys: Vec::new(), - } + AmberEnvVarGuard::default() } - fn decrypt_amber_keys_in_value(&mut self, value: &mut Value) -> Result<()> { match value { Value::String(s) if s.starts_with("AMBER_") => { @@ -132,20 +136,20 @@ impl AmberEnvVarGuard { self.set_env_var_from_amber(s, &decrypted)?; *s = decrypted; Ok(()) - }, + } Value::Object(map) => { for (_, v) in map.iter_mut() { self.decrypt_amber_keys_in_value(v)?; } Ok(()) - }, + } Value::Array(arr) => { for item in arr.iter_mut() { self.decrypt_amber_keys_in_value(item)?; } Ok(()) - }, - _ => Ok(()) + } + _ => Ok(()), } } @@ -156,11 +160,8 @@ impl AmberEnvVarGuard { Ok(()) } - fn get_amber_value(&self, key: &str) -> Result { - let output = Command::new("amber") - .arg("print") - .output()?; + let output = Command::new("amber").arg("print").output()?; if !output.status.success() { return Err(anyhow!("Failed to run amber print command")); @@ -202,24 +203,22 @@ pub fn replace_with_env_var(value: &mut Value) { Ok(env_value) => { debug!("Environment value found for: {}", env_key); *s = env_value; - }, + } Err(e) => { debug!("Failed to find environment variable '{}': {}", env_key, e); } } - }, + } Value::Object(map) => { for (_, v) in map.iter_mut() { replace_with_env_var(v); } - }, + } Value::Array(arr) => { for item in arr.iter_mut() { replace_with_env_var(item); } - }, + } _ => {} } } - - diff --git a/crates/fluent-core/src/neo4j_client.rs b/crates/fluent-core/src/neo4j_client.rs index 9d92df6..748c274 100644 --- a/crates/fluent-core/src/neo4j_client.rs +++ b/crates/fluent-core/src/neo4j_client.rs @@ -1,30 +1,30 @@ -use anyhow::{Result, anyhow, Error}; -use neo4rs::{Graph, query, ConfigBuilder, BoltMap, BoltList, BoltString, BoltType, BoltInteger, BoltFloat, Database, BoltNull, Row }; +use anyhow::{anyhow, Error, Result}; +use neo4rs::{ + query, BoltFloat, BoltInteger, BoltList, BoltMap, BoltNull, BoltString, BoltType, + ConfigBuilder, Database, Graph, Row, +}; use chrono::Duration as ChronoDuration; use chrono::{DateTime, Utc}; +use log::{debug, error, info, warn}; +use pdf_extract::extract_text; use serde_json::{json, Value}; -use uuid::Uuid; use std::collections::{HashMap, HashSet}; use std::path::Path; -use log::{debug, error, info, warn}; -use std::sync::{ RwLock}; -use pdf_extract::extract_text; - +use std::sync::RwLock; +use uuid::Uuid; use rust_stemmers::{Algorithm, Stemmer}; use serde::{Deserialize, Serialize}; - use crate::config::Neo4jConfig; +use crate::traits::{DocumentProcessor, DocxProcessor}; use crate::types::DocumentStatistics; use crate::utils::chunking::chunk_document; -use crate::voyageai_client::{EMBEDDING_DIMENSION, get_voyage_embedding}; +use crate::voyageai_client::{get_voyage_embedding, EMBEDDING_DIMENSION}; use tokio::fs::File; use tokio::io::AsyncReadExt; -use crate::traits::{DocumentProcessor, DocxProcessor}; - #[derive(Debug, Deserialize, Serialize, Clone)] pub struct VoyageAIConfig { @@ -39,13 +39,29 @@ pub struct Neo4jClient { voyage_ai_config: Option, query_llm: Option, } +impl Neo4jClient { + pub fn get_document_count(&self) -> usize { + *self.document_count.read().unwrap() + } + pub fn get_word_document_count_for_word(&self, word: &str) -> usize { + *self + .word_document_count + .read() + .unwrap() + .get(word) + .unwrap_or(&0) + } + pub fn get_query_llm(&self) -> Option<&String> { + self.query_llm.as_ref() + } +} #[derive(Debug, Clone)] pub struct InteractionStats { pub prompt_tokens: u32, pub completion_tokens: u32, pub total_tokens: u32, - pub response_time: f64, // in seconds + pub response_time: f64, // in seconds pub finish_reason: String, } @@ -56,8 +72,6 @@ pub struct Embedding { pub model: String, } - - #[derive(Debug)] pub struct EnrichmentConfig { pub themes_keywords_interval: ChronoDuration, @@ -72,18 +86,13 @@ pub struct EnrichmentStatus { pub last_sentiment_update: Option>, } - - - - - impl Neo4jClient { pub async fn new(config: &Neo4jConfig) -> Result { let graph_config = ConfigBuilder::default() .uri(&config.uri) .user(&config.user) .password(&config.password) - .db(Database::from(config.database.as_str())) // Convert string to Database instance + .db(Database::from(config.database.as_str())) // Convert string to Database instance .build()?; let graph = Graph::connect(graph_config).await?; @@ -97,7 +106,6 @@ impl Neo4jClient { }) } - pub async fn ensure_indexes(&self) -> Result<()> { let index_queries = vec![ "CREATE INDEX IF NOT EXISTS FOR (s:Session) ON (s.id)", @@ -115,10 +123,8 @@ impl Neo4jClient { for query_str in index_queries { debug!("Executing index creation query: {}", query_str); let _ = self.graph.execute(query(query_str)).await?; - } - // Create a vector index for embeddings let vector_index_query = format!( "CALL db.index.vector.createNodeIndex( @@ -136,7 +142,6 @@ impl Neo4jClient { Err(e) => warn!("Failed to create vector index for Document Embedding nodes: {}. This might be normal if the index already exists.", e), } - // Optionally, we can also create full-text indexes for content fields if needed let fulltext_index_queries = vec![ "CALL db.index.fulltext.createNodeIndex('questionContentIndex', ['Question'], ['content'])", @@ -170,14 +175,18 @@ impl Neo4jClient { RETURN s.id as session_id "#; - let mut result = self.graph.execute(query(query_str) - .param("id", session.id.to_string()) - .param("start_time", session.start_time.to_rfc3339()) - .param("end_time", session.end_time.to_rfc3339()) - .param("context", session.context.to_string()) - .param("session_id", session.session_id.to_string()) - .param("user_id", session.user_id.to_string()) - ).await?; + let mut result = self + .graph + .execute( + query(query_str) + .param("id", session.id.to_string()) + .param("start_time", session.start_time.to_rfc3339()) + .param("end_time", session.end_time.to_rfc3339()) + .param("context", session.context.to_string()) + .param("session_id", session.session_id.to_string()) + .param("user_id", session.user_id.to_string()), + ) + .await?; if let Some(row) = result.next().await? { Ok(row.get("session_id")?) @@ -186,14 +195,13 @@ impl Neo4jClient { } } - pub async fn create_interaction( &self, session_id: &str, request: &str, response: &str, model: &str, - stats: &InteractionStats + stats: &InteractionStats, ) -> Result { let query_str = r#" MERGE (s:Session {id: $session_id}) @@ -241,22 +249,56 @@ impl Neo4jClient { let stats_id = Uuid::new_v4().to_string(); let timestamp = Utc::now(); - let mut result = self.graph.execute(query(query_str) - .param("session_id", BoltType::String(BoltString::from(session_id))) - .param("id", BoltType::String(BoltString::from(interaction_id.as_str()))) - .param("question_id", BoltType::String(BoltString::from(question_id.as_str()))) - .param("response_id", BoltType::String(BoltString::from(response_id.as_str()))) - .param("stats_id", BoltType::String(BoltString::from(stats_id.as_str()))) - .param("timestamp", BoltType::String(BoltString::from(timestamp.to_rfc3339().as_str()))) - .param("request", BoltType::String(BoltString::from(request))) - .param("response", BoltType::String(BoltString::from(response))) - .param("model", BoltType::String(BoltString::from(model))) - .param("prompt_tokens", BoltType::Integer(BoltInteger::new(stats.prompt_tokens as i64))) - .param("completion_tokens", BoltType::Integer(BoltInteger::new(stats.completion_tokens as i64))) - .param("total_tokens", BoltType::Integer(BoltInteger::new(stats.total_tokens as i64))) - .param("response_time", BoltType::Float(BoltFloat::new(stats.response_time))) - .param("finish_reason", BoltType::String(BoltString::from(stats.finish_reason.as_str()))) - ).await?; + let mut result = self + .graph + .execute( + query(query_str) + .param("session_id", BoltType::String(BoltString::from(session_id))) + .param( + "id", + BoltType::String(BoltString::from(interaction_id.as_str())), + ) + .param( + "question_id", + BoltType::String(BoltString::from(question_id.as_str())), + ) + .param( + "response_id", + BoltType::String(BoltString::from(response_id.as_str())), + ) + .param( + "stats_id", + BoltType::String(BoltString::from(stats_id.as_str())), + ) + .param( + "timestamp", + BoltType::String(BoltString::from(timestamp.to_rfc3339().as_str())), + ) + .param("request", BoltType::String(BoltString::from(request))) + .param("response", BoltType::String(BoltString::from(response))) + .param("model", BoltType::String(BoltString::from(model))) + .param( + "prompt_tokens", + BoltType::Integer(BoltInteger::new(stats.prompt_tokens as i64)), + ) + .param( + "completion_tokens", + BoltType::Integer(BoltInteger::new(stats.completion_tokens as i64)), + ) + .param( + "total_tokens", + BoltType::Integer(BoltInteger::new(stats.total_tokens as i64)), + ) + .param( + "response_time", + BoltType::Float(BoltFloat::new(stats.response_time)), + ) + .param( + "finish_reason", + BoltType::String(BoltString::from(stats.finish_reason.as_str())), + ), + ) + .await?; if let Some(row) = result.next().await? { let interaction_id: String = row.get("interaction_id")?; @@ -272,9 +314,15 @@ impl Neo4jClient { if let Some(voyage_config) = &self.voyage_ai_config { debug!("Voyage AI config found, creating embeddings"); - match self.create_embeddings(request, response, &question_id, &response_id, voyage_config).await { + match self + .create_embeddings(request, response, &question_id, &response_id, voyage_config) + .await + { Ok(_) => debug!("Created embeddings for interaction {}", interaction_id), - Err(e) => warn!("Failed to create embeddings for interaction {}: {:?}", interaction_id, e), + Err(e) => warn!( + "Failed to create embeddings for interaction {}: {:?}", + interaction_id, e + ), } // Call enrich_document_incrementally for both question and response @@ -284,12 +332,18 @@ impl Neo4jClient { sentiment_interval: ChronoDuration::hours(1), }; - match self.enrich_document_incrementally(&question_id, "Question", &enrichment_config).await { + match self + .enrich_document_incrementally(&question_id, "Question", &enrichment_config) + .await + { Ok(_) => debug!("Enriched question {}", question_id), Err(e) => warn!("Failed to enrich question {}: {:?}", question_id, e), } - match self.enrich_document_incrementally(&response_id, "Response", &enrichment_config).await { + match self + .enrich_document_incrementally(&response_id, "Response", &enrichment_config) + .await + { Ok(_) => debug!("Enriched response {}", response_id), Err(e) => warn!("Failed to enrich response {}: {:?}", response_id, e), } @@ -303,14 +357,13 @@ impl Neo4jClient { } } - async fn create_embeddings( &self, request: &str, response: &str, question_id: &str, response_id: &str, - voyage_config: &VoyageAIConfig + voyage_config: &VoyageAIConfig, ) -> Result<()> { let question_embedding = get_voyage_embedding(request, voyage_config).await?; let response_embedding = get_voyage_embedding(response, voyage_config).await?; @@ -327,14 +380,20 @@ impl Neo4jClient { model: voyage_config.model.clone(), }; - self.create_embedding(&question_embedding_node, question_id, "Question").await?; - self.create_embedding(&response_embedding_node, response_id, "Response").await?; + self.create_embedding(&question_embedding_node, question_id, "Question") + .await?; + self.create_embedding(&response_embedding_node, response_id, "Response") + .await?; Ok(()) } - - pub async fn create_embedding(&self, embedding: &Embedding, parent_id: &str, parent_type: &str) -> Result { + pub async fn create_embedding( + &self, + embedding: &Embedding, + parent_id: &str, + parent_type: &str, + ) -> Result { let query_str = r#" MATCH (parent {id: $parent_id}) WHERE labels(parent)[0] = $parent_type @@ -355,13 +414,26 @@ impl Neo4jClient { vector_list.push(BoltType::Float(BoltFloat::new(value as f64))); } - let mut result = self.graph.execute(query(query_str) - .param("parent_id", BoltType::String(BoltString::from(parent_id))) - .param("parent_type", BoltType::String(BoltString::from(parent_type))) - .param("id", BoltType::String(BoltString::from(embedding.id.as_str()))) - .param("vector", BoltType::List(vector_list)) - .param("model", BoltType::String(BoltString::from(embedding.model.as_str()))) - ).await?; + let mut result = self + .graph + .execute( + query(query_str) + .param("parent_id", BoltType::String(BoltString::from(parent_id))) + .param( + "parent_type", + BoltType::String(BoltString::from(parent_type)), + ) + .param( + "id", + BoltType::String(BoltString::from(embedding.id.as_str())), + ) + .param("vector", BoltType::List(vector_list)) + .param( + "model", + BoltType::String(BoltString::from(embedding.model.as_str())), + ), + ) + .await?; if let Some(row) = result.next().await? { Ok(row.get("embedding_id")?) @@ -370,14 +442,14 @@ impl Neo4jClient { } } - pub async fn upsert_document(&self, file_path: &Path, metadata: &[String]) -> Result { debug!("Upserting document from file: {:?}", file_path); let content = self.extract_content(file_path).await?; let document_id = Uuid::new_v4().to_string(); - let query = query(" + let query = query( + " MERGE (d:Document {content: $content}) ON CREATE SET d.id = $id, @@ -387,11 +459,12 @@ impl Neo4jClient { d.metadata = d.metadata + $new_metadata, d.updated_at = datetime() RETURN d.id as document_id - ") - .param("id", document_id.clone()) - .param("content", content.clone()) // Clone here - .param("metadata", metadata) - .param("new_metadata", metadata); + ", + ) + .param("id", document_id.clone()) + .param("content", content.clone()) // Clone here + .param("metadata", metadata) + .param("new_metadata", metadata); let mut result = self.graph.execute(query).await?; @@ -407,41 +480,51 @@ impl Neo4jClient { sentiment_interval: ChronoDuration::hours(1), }; - let chunks = chunk_document(&content); // Now we can use content here - self.create_chunks_and_embeddings(&document_id, &chunks).await?; - self.enrich_document_incrementally(&document_id, "Document", &config).await?; + let chunks = chunk_document(&content); // Now we can use content here + self.create_chunks_and_embeddings(&document_id, &chunks) + .await?; + self.enrich_document_incrementally(&document_id, "Document", &config) + .await?; Ok(document_id) } async fn extract_content(&self, file_path: &Path) -> Result { - let extension = file_path.extension() + let extension = file_path + .extension() .and_then(|ext| ext.to_str()) .ok_or_else(|| anyhow!("Unable to determine file type"))?; match extension.to_lowercase().as_str() { "pdf" => { let path_buf = file_path.to_path_buf(); - Ok(tokio::task::spawn_blocking(move || { - extract_text(&path_buf) - }).await??) - }, - "txt" | "json" | "csv" | "tsv" | "md" | "html" | "xml" | "yml" | "yaml" | "json5" | "py" | "rb" | "rs" | "js" | "ts" | "php" | "java" | "c" | "cpp" | "go" | "sh" | "bat" | "ps1" | "psm1" | "psd1" | "ps1xml" | "psc1" | "pssc" | "pss1" | "psh" => { + Ok(tokio::task::spawn_blocking(move || extract_text(&path_buf)).await??) + } + "txt" | "json" | "csv" | "tsv" | "md" | "html" | "xml" | "yml" | "yaml" | "json5" + | "py" | "rb" | "rs" | "js" | "ts" | "php" | "java" | "c" | "cpp" | "go" | "sh" + | "bat" | "ps1" | "psm1" | "psd1" | "ps1xml" | "psc1" | "pssc" | "pss1" | "psh" => { let mut file = File::open(file_path).await?; let mut content = String::new(); file.read_to_string(&mut content).await?; Ok(content) - }, + } "docx" => { let processor = DocxProcessor; let (content, _metadata) = processor.process(file_path).await?; Ok(content) - }, + } // Add more file types here as needed _ => Err(anyhow!("Unsupported file type: {}", extension)), } } - async fn create_chunks_and_embeddings(&self, document_id: &str, chunks: &[String]) -> Result<()> { - debug!("Creating chunks and embeddings for document {}", document_id); + async fn create_chunks_and_embeddings( + &self, + document_id: &str, + chunks: &[String], + ) -> Result<()> { + debug!( + "Creating chunks and embeddings for document {}", + document_id + ); if let Some(voyage_config) = &self.voyage_ai_config { for (i, chunk) in chunks.iter().enumerate() { let embedding = get_voyage_embedding(chunk, voyage_config).await?; @@ -450,7 +533,8 @@ impl Neo4jClient { return Err(anyhow!("Embedding dimension mismatch")); } - let query = query(" + let query = query( + " MATCH (d:Document {id: $document_id}) MERGE (c:Chunk {content: $content}) ON CREATE SET @@ -467,18 +551,34 @@ impl Neo4jClient { MERGE (prev)-[:NEXT]->(c) ) RETURN c.id as chunk_id, e.id as embedding_id - ") - .param("document_id", BoltType::String(BoltString::from(document_id))) - .param("chunk_id", BoltType::String(BoltString::from(Uuid::new_v4().to_string()))) - .param("content", BoltType::String(BoltString::from(chunk.as_str()))) - .param("index", BoltType::Integer(BoltInteger::new(i as i64))) - .param("embedding_id", BoltType::String(BoltString::from(Uuid::new_v4().to_string()))) - .param("vector", embedding) - .param("prev_chunk_id", if i > 0 { + ", + ) + .param( + "document_id", + BoltType::String(BoltString::from(document_id)), + ) + .param( + "chunk_id", + BoltType::String(BoltString::from(Uuid::new_v4().to_string())), + ) + .param( + "content", + BoltType::String(BoltString::from(chunk.as_str())), + ) + .param("index", BoltType::Integer(BoltInteger::new(i as i64))) + .param( + "embedding_id", + BoltType::String(BoltString::from(Uuid::new_v4().to_string())), + ) + .param("vector", embedding) + .param( + "prev_chunk_id", + if i > 0 { BoltType::String(BoltString::from(chunks[i - 1].as_str())) } else { - BoltType::Null(BoltNull::default()) - }); + BoltType::Null(BoltNull) + }, + ); let mut result = self.graph.execute(query).await?; @@ -493,7 +593,8 @@ impl Neo4jClient { } pub async fn get_document_statistics(&self) -> Result { - let query = query(" + let query = query( + " MATCH (d:Document) OPTIONAL MATCH (d)-[:HAS_CHUNK]->(c) OPTIONAL MATCH (c)-[:HAS_EMBEDDING]->(e) @@ -502,7 +603,8 @@ impl Neo4jClient { avg(size(d.content)) as avg_content_length, count(DISTINCT c) as chunk_count, count(DISTINCT e) as embedding_count - "); + ", + ); let mut result = self.graph.execute(query).await?; @@ -518,41 +620,63 @@ impl Neo4jClient { } } - pub async fn enrich_document_incrementally(&self, node_id: &str, node_type: &str, config: &EnrichmentConfig) -> Result<()> { + pub async fn enrich_document_incrementally( + &self, + node_id: &str, + node_type: &str, + config: &EnrichmentConfig, + ) -> Result<()> { debug!("Enriching {} {}", node_type, node_id); let status = self.get_enrichment_status(node_id, node_type).await?; let now = Utc::now(); if let Some(voyage_config) = &self.voyage_ai_config { - if status.last_themes_keywords_update.map_or(true, |last| now - last > config.themes_keywords_interval) { - self.update_themes_and_keywords(node_id, node_type, voyage_config).await?; + if status + .last_themes_keywords_update + .map_or(true, |last| now - last > config.themes_keywords_interval) + { + self.update_themes_and_keywords(node_id, node_type, voyage_config) + .await?; } - if status.last_clustering_update.map_or(true, |last| now - last > config.clustering_interval) { + if status + .last_clustering_update + .map_or(true, |last| now - last > config.clustering_interval) + { self.update_clustering(node_id, node_type).await?; } - if status.last_sentiment_update.map_or(true, |last| now - last > config.sentiment_interval) { + if status + .last_sentiment_update + .map_or(true, |last| now - last > config.sentiment_interval) + { self.update_sentiment(node_id, node_type).await?; } - self.update_enrichment_status(node_id, node_type, &now).await?; + self.update_enrichment_status(node_id, node_type, &now) + .await?; Ok(()) } else { Err(anyhow!("VoyageAI configuration not found")) } } - async fn get_enrichment_status(&self, node_id: &str, node_type: &str) -> Result { + async fn get_enrichment_status( + &self, + node_id: &str, + node_type: &str, + ) -> Result { debug!("Getting enrichment status for {} {}", node_type, node_id); - let query = query(" + let query = query( + " MATCH (n) WHERE (n:Document OR n:Question OR n:Response) AND n.id = $node_id RETURN n.last_themes_keywords_update AS themes_keywords, n.last_clustering_update AS clustering, n.last_sentiment_update AS sentiment - ") - .param("node_id", BoltType::String(BoltString::from(node_id))); + ", + ) + .param("node_id", BoltType::String(BoltString::from(node_id))); let mut result = self.graph.execute(query).await?; if let Some(row) = result.next().await? { @@ -570,34 +694,61 @@ impl Neo4jClient { } } - - async fn update_themes_and_keywords(&self, node_id: &str, node_type: &str, voyage_config: &VoyageAIConfig) -> Result<()> { + async fn update_themes_and_keywords( + &self, + node_id: &str, + node_type: &str, + voyage_config: &VoyageAIConfig, + ) -> Result<()> { debug!("Updating themes and keywords for {} {}", node_type, node_id); let content = self.get_node_content(node_id, node_type).await?; - let (themes, keywords) = self.extract_themes_and_keywords(&content, voyage_config).await?; - self.create_theme_and_keyword_nodes(node_id, node_type, &themes, &keywords).await?; + let (themes, keywords) = self + .extract_themes_and_keywords(&content, voyage_config) + .await?; + self.create_theme_and_keyword_nodes(node_id, node_type, &themes, &keywords) + .await?; Ok(()) } - - async fn extract_sentiment(&self, content: &str) -> Result { // Define a simple sentiment lexicon let lexicon: HashMap<&str, f32> = [ - ("good", 1.0), ("great", 1.5), ("excellent", 2.0), ("amazing", 2.0), ("wonderful", 1.5), - ("bad", -1.0), ("terrible", -1.5), ("awful", -2.0), ("horrible", -2.0), ("poor", -1.0), - ("like", 0.5), ("love", 1.0), ("hate", -1.0), ("dislike", -0.5), - ("happy", 1.0), ("sad", -1.0), ("angry", -1.0), ("joyful", 1.5), - ("interesting", 0.5), ("boring", -0.5), ("exciting", 1.0), ("dull", -0.5) - ].iter().cloned().collect(); - - let words: Vec = content.to_lowercase() + ("good", 1.0), + ("great", 1.5), + ("excellent", 2.0), + ("amazing", 2.0), + ("wonderful", 1.5), + ("bad", -1.0), + ("terrible", -1.5), + ("awful", -2.0), + ("horrible", -2.0), + ("poor", -1.0), + ("like", 0.5), + ("love", 1.0), + ("hate", -1.0), + ("dislike", -0.5), + ("happy", 1.0), + ("sad", -1.0), + ("angry", -1.0), + ("joyful", 1.5), + ("interesting", 0.5), + ("boring", -0.5), + ("exciting", 1.0), + ("dull", -0.5), + ] + .iter() + .cloned() + .collect(); + + let words: Vec = content + .to_lowercase() .split_whitespace() .map(String::from) .collect(); let total_words = words.len() as f32; - let sentiment_sum: f32 = words.iter() + let sentiment_sum: f32 = words + .iter() .filter_map(|word| lexicon.get(word.as_str())) .sum(); @@ -608,17 +759,30 @@ impl Neo4jClient { Ok(sentiment.clamp(-1.0, 1.0)) } - async fn create_and_assign_sentiment(&self, node_id: &str, node_type: &str, sentiment: f32) -> Result<()> { - debug!("Creating and assigning sentiment node for {} {}", node_type, node_id); - let query = query(" + async fn create_and_assign_sentiment( + &self, + node_id: &str, + node_type: &str, + sentiment: f32, + ) -> Result<()> { + debug!( + "Creating and assigning sentiment node for {} {}", + node_type, node_id + ); + let query = query( + " MATCH (n) WHERE (n:Document OR n:Question OR n:Response) AND n.id = $node_id MERGE (s:Sentiment {value: $sentiment}) MERGE (n)-[:HAS_SENTIMENT]->(s) RETURN count(s) AS sentiment_count, s.value AS sentiment_value, n.id AS node_id - ") - .param("node_id", BoltType::String(BoltString::from(node_id))) - .param("sentiment", BoltType::Float(BoltFloat::new(sentiment as f64))); + ", + ) + .param("node_id", BoltType::String(BoltString::from(node_id))) + .param( + "sentiment", + BoltType::Float(BoltFloat::new(sentiment as f64)), + ); debug!("Executing query with sentiment: {}", sentiment); @@ -629,20 +793,42 @@ impl Neo4jClient { let sentiment_count: i64 = row.get("sentiment_count")?; let sentiment_value: f64 = row.get("sentiment_value")?; let db_node_id: String = row.get("node_id")?; - debug!("Created and assigned {} sentiment node with value {} for {} {}", - sentiment_count, sentiment_value, node_type, db_node_id); + debug!( + "Created and assigned {} sentiment node with value {} for {} {}", + sentiment_count, sentiment_value, node_type, db_node_id + ); if sentiment_count == 0 { - warn!("No sentiment was created or assigned for {} {}", node_type, node_id); - return Err(anyhow!("Failed to create or assign sentiment for {} {}", node_type, node_id)); + warn!( + "No sentiment was created or assigned for {} {}", + node_type, node_id + ); + return Err(anyhow!( + "Failed to create or assign sentiment for {} {}", + node_type, + node_id + )); } } else { - warn!("No result returned from sentiment creation and assignment query for {} {}", node_type, node_id); - return Err(anyhow!("No result returned from sentiment creation query for {} {}", node_type, node_id)); + warn!( + "No result returned from sentiment creation and assignment query for {} {}", + node_type, node_id + ); + return Err(anyhow!( + "No result returned from sentiment creation query for {} {}", + node_type, + node_id + )); } - }, + } Err(e) => { - error!("Error executing sentiment creation and assignment query for {} {}: {:?}", node_type, node_id, e); - return Err(anyhow!("Failed to create and assign sentiment node: {:?}", e)); + error!( + "Error executing sentiment creation and assignment query for {} {}: {:?}", + node_type, node_id, e + ); + return Err(anyhow!( + "Failed to create and assign sentiment node: {:?}", + e + )); } } @@ -653,11 +839,13 @@ impl Neo4jClient { } async fn verify_sentiment(&self, node_id: &str, expected_sentiment: f32) -> Result<()> { - let query = query(" + let query = query( + " MATCH (n {id: $node_id})-[:HAS_SENTIMENT]->(s:Sentiment) RETURN n.id as node_id, s.value as sentiment - ") - .param("node_id", BoltType::String(BoltString::from(node_id))); + ", + ) + .param("node_id", BoltType::String(BoltString::from(node_id))); let mut result = self.graph.execute(query).await?; if let Some(row) = result.next().await? { @@ -668,7 +856,10 @@ impl Neo4jClient { debug!("Sentiment in DB: {}", db_sentiment); if (db_sentiment as f32 - expected_sentiment).abs() > 1e-6 { - warn!("Sentiment mismatch for node {}: expected {}, found {}", db_node_id, expected_sentiment, db_sentiment); + warn!( + "Sentiment mismatch for node {}: expected {}, found {}", + db_node_id, expected_sentiment, db_sentiment + ); return Err(anyhow!("Sentiment mismatch for node {}", db_node_id)); } else { debug!("Sentiment verified successfully for node {}", db_node_id); @@ -680,14 +871,14 @@ impl Neo4jClient { Ok(()) } - - async fn get_all_documents(&self) -> Result> { - let query = query(" + let query = query( + " MATCH (n) WHERE (n:Document OR n:Question OR n:Response) RETURN n.content AS content - "); + ", + ); let mut result = self.graph.execute(query).await?; let mut documents = Vec::new(); @@ -710,30 +901,45 @@ impl Neo4jClient { let sentiment = self.extract_sentiment(&content).await?; // Create and assign sentiment to the node - match self.create_and_assign_sentiment(node_id, node_type, sentiment).await { + match self + .create_and_assign_sentiment(node_id, node_type, sentiment) + .await + { Ok(_) => { - debug!("Successfully created and assigned sentiment for {} {}", node_type, node_id); + debug!( + "Successfully created and assigned sentiment for {} {}", + node_type, node_id + ); Ok(()) - }, + } Err(e) => { - error!("Failed to create and assign sentiment for {} {}: {:?}", node_type, node_id, e); + error!( + "Failed to create and assign sentiment for {} {}: {:?}", + node_type, node_id, e + ); Err(e) } } } - - async fn update_enrichment_status(&self, node_id: &str, node_type: &str, now: &DateTime) -> Result<()> { + async fn update_enrichment_status( + &self, + node_id: &str, + node_type: &str, + now: &DateTime, + ) -> Result<()> { debug!("Updating enrichment status for {} {}", node_type, node_id); - let query = query(" + let query = query( + " MATCH (n) WHERE (n:Document OR n:Question OR n:Response) AND n.id = $node_id SET n.last_themes_keywords_update = $now, n.last_clustering_update = $now, n.last_sentiment_update = $now - ") - .param("node_id", BoltType::String(BoltString::from(node_id))) - .param("now", BoltType::String(BoltString::from(now.to_rfc3339()))); + ", + ) + .param("node_id", BoltType::String(BoltString::from(node_id))) + .param("now", BoltType::String(BoltString::from(now.to_rfc3339()))); let _ = self.graph.execute(query).await?; Ok(()) @@ -741,12 +947,14 @@ impl Neo4jClient { async fn get_node_content(&self, node_id: &str, node_type: &str) -> Result { debug!("Getting content for {} {}", node_type, node_type); - let query = query(" + let query = query( + " MATCH (n) WHERE (n:Document OR n:Question OR n:Response) AND n.id = $node_id RETURN n.content AS content - ") - .param("node_id", BoltType::String(BoltString::from(node_id))); + ", + ) + .param("node_id", BoltType::String(BoltString::from(node_id))); let mut result = self.graph.execute(query).await?; if let Some(row) = result.next().await? { @@ -756,10 +964,19 @@ impl Neo4jClient { } } - - async fn create_theme_and_keyword_nodes(&self, node_id: &str, node_type: &str, themes: &[String], keywords: &[String]) -> Result<()> { - debug!("Creating theme and keyword nodes for {} {}", node_type, node_id); - let query = query(" + async fn create_theme_and_keyword_nodes( + &self, + node_id: &str, + node_type: &str, + themes: &[String], + keywords: &[String], + ) -> Result<()> { + debug!( + "Creating theme and keyword nodes for {} {}", + node_type, node_id + ); + let query = query( + " MATCH (n) WHERE (n:Document OR n:Question OR n:Response) AND n.id = $node_id WITH n @@ -772,40 +989,63 @@ impl Neo4jClient { MERGE (n)-[:HAS_KEYWORD]->(k) WITH n, themes, collect(k) AS keywords RETURN size(themes) + size(keywords) AS total_count - ") - .param("node_id", BoltType::String(BoltString::from(node_id))) - .param("themes", themes) - .param("keywords", keywords); - - debug!("Executing query with themes: {:?} and keywords: {:?}", themes, keywords); + ", + ) + .param("node_id", BoltType::String(BoltString::from(node_id))) + .param("themes", themes) + .param("keywords", keywords); + + debug!( + "Executing query with themes: {:?} and keywords: {:?}", + themes, keywords + ); let result = self.graph.execute(query).await; match result { Ok(mut stream) => { if let Some(row) = stream.next().await? { let total_count: i64 = row.get("total_count")?; - debug!("Created {} theme and keyword nodes for {} {}", total_count, node_type, node_id); + debug!( + "Created {} theme and keyword nodes for {} {}", + total_count, node_type, node_id + ); if total_count == 0 { - warn!("No themes or keywords were created for {} {}", node_type, node_id); + warn!( + "No themes or keywords were created for {} {}", + node_type, node_id + ); } } else { - warn!("No result returned from theme and keyword creation query for {} {}", node_type, node_id); + warn!( + "No result returned from theme and keyword creation query for {} {}", + node_type, node_id + ); } - }, + } Err(e) => { - error!("Error executing theme and keyword creation query for {} {}: {:?}", node_type, node_id, e); + error!( + "Error executing theme and keyword creation query for {} {}: {:?}", + node_type, node_id, e + ); return Err(anyhow!("Failed to create theme and keyword nodes: {:?}", e)); } } // Verification step - self.verify_themes_and_keywords(node_id, themes, keywords).await?; + self.verify_themes_and_keywords(node_id, themes, keywords) + .await?; Ok(()) } - async fn verify_themes_and_keywords(&self, node_id: &str, themes: &[String], keywords: &[String]) -> Result<()> { - let query = query(" + async fn verify_themes_and_keywords( + &self, + node_id: &str, + themes: &[String], + keywords: &[String], + ) -> Result<()> { + let query = query( + " MATCH (n {id: $node_id}) OPTIONAL MATCH (n)-[:HAS_THEME]->(t:Theme) OPTIONAL MATCH (n)-[:HAS_KEYWORD]->(k:Keyword) @@ -815,8 +1055,9 @@ impl Neo4jClient { collect(distinct k.name) as keywords, count(distinct t) as theme_count, count(distinct k) as keyword_count - ") - .param("node_id", BoltType::String(BoltString::from(node_id))); + ", + ) + .param("node_id", BoltType::String(BoltString::from(node_id))); let mut result = self.graph.execute(query).await?; if let Some(row) = result.next().await? { @@ -828,22 +1069,59 @@ impl Neo4jClient { debug!("Verification for node {}", db_node_id); debug!("Themes in DB: {:?} (count: {})", db_themes, theme_count); - debug!("Keywords in DB: {:?} (count: {})", db_keywords, keyword_count); - - let missing_themes: Vec<_> = themes.iter().filter(|t| !db_themes.contains(t)).cloned().collect(); - let extra_themes: Vec<_> = db_themes.iter().filter(|t| !themes.contains(t)).cloned().collect(); - let missing_keywords: Vec<_> = keywords.iter().filter(|k| !db_keywords.contains(k)).cloned().collect(); - let extra_keywords: Vec<_> = db_keywords.iter().filter(|k| !keywords.contains(k)).cloned().collect(); - - if !missing_themes.is_empty() || !missing_keywords.is_empty() || !extra_themes.is_empty() || !extra_keywords.is_empty() { + debug!( + "Keywords in DB: {:?} (count: {})", + db_keywords, keyword_count + ); + + let missing_themes: Vec<_> = themes + .iter() + .filter(|t| !db_themes.contains(t)) + .cloned() + .collect(); + let extra_themes: Vec<_> = db_themes + .iter() + .filter(|t| !themes.contains(t)) + .cloned() + .collect(); + let missing_keywords: Vec<_> = keywords + .iter() + .filter(|k| !db_keywords.contains(k)) + .cloned() + .collect(); + let extra_keywords: Vec<_> = db_keywords + .iter() + .filter(|k| !keywords.contains(k)) + .cloned() + .collect(); + + if !missing_themes.is_empty() + || !missing_keywords.is_empty() + || !extra_themes.is_empty() + || !extra_keywords.is_empty() + { warn!("Discrepancies found for node {}:", db_node_id); - if !missing_themes.is_empty() { warn!("Missing themes: {:?}", missing_themes); } - if !extra_themes.is_empty() { warn!("Extra themes in DB: {:?}", extra_themes); } - if !missing_keywords.is_empty() { warn!("Missing keywords: {:?}", missing_keywords); } - if !extra_keywords.is_empty() { warn!("Extra keywords in DB: {:?}", extra_keywords); } - return Err(anyhow!("Discrepancies found in themes or keywords for node {}", db_node_id)); + if !missing_themes.is_empty() { + warn!("Missing themes: {:?}", missing_themes); + } + if !extra_themes.is_empty() { + warn!("Extra themes in DB: {:?}", extra_themes); + } + if !missing_keywords.is_empty() { + warn!("Missing keywords: {:?}", missing_keywords); + } + if !extra_keywords.is_empty() { + warn!("Extra keywords in DB: {:?}", extra_keywords); + } + return Err(anyhow!( + "Discrepancies found in themes or keywords for node {}", + db_node_id + )); } else { - debug!("All themes and keywords verified successfully for node {}", db_node_id); + debug!( + "All themes and keywords verified successfully for node {}", + db_node_id + ); } } else { warn!("No node found with ID: {}", node_id); @@ -864,14 +1142,24 @@ impl Neo4jClient { let clusters = self.extract_clusters(&content, &all_documents).await?; // Create and assign clusters to the node - self.create_and_assign_clusters(node_id, node_type, &clusters).await?; + self.create_and_assign_clusters(node_id, node_type, &clusters) + .await?; Ok(()) } - async fn create_and_assign_clusters(&self, node_id: &str, node_type: &str, clusters: &[String]) -> Result<()> { - debug!("Creating and assigning cluster nodes for {} {}", node_type, node_id); - let query = query(" + async fn create_and_assign_clusters( + &self, + node_id: &str, + node_type: &str, + clusters: &[String], + ) -> Result<()> { + debug!( + "Creating and assigning cluster nodes for {} {}", + node_type, node_id + ); + let query = query( + " MATCH (n) WHERE (n:Document OR n:Question OR n:Response) AND n.id = $node_id WITH n @@ -880,9 +1168,10 @@ impl Neo4jClient { MERGE (n)-[:BELONGS_TO]->(c) WITH n, collect(c) AS clusters RETURN size(clusters) AS total_count - ") - .param("node_id", BoltType::String(BoltString::from(node_id))) - .param("clusters", clusters); + ", + ) + .param("node_id", BoltType::String(BoltString::from(node_id))) + .param("clusters", clusters); debug!("Executing query with clusters: {:?}", clusters); @@ -891,17 +1180,32 @@ impl Neo4jClient { Ok(mut stream) => { if let Some(row) = stream.next().await? { let total_count: i64 = row.get("total_count")?; - debug!("Created and assigned {} cluster nodes for {} {}", total_count, node_type, node_id); + debug!( + "Created and assigned {} cluster nodes for {} {}", + total_count, node_type, node_id + ); if total_count == 0 { - warn!("No clusters were created or assigned for {} {}", node_type, node_id); + warn!( + "No clusters were created or assigned for {} {}", + node_type, node_id + ); } } else { - warn!("No result returned from cluster creation and assignment query for {} {}", node_type, node_id); + warn!( + "No result returned from cluster creation and assignment query for {} {}", + node_type, node_id + ); } - }, + } Err(e) => { - error!("Error executing cluster creation and assignment query for {} {}: {:?}", node_type, node_id, e); - return Err(anyhow!("Failed to create and assign cluster nodes: {:?}", e)); + error!( + "Error executing cluster creation and assignment query for {} {}: {:?}", + node_type, node_id, e + ); + return Err(anyhow!( + "Failed to create and assign cluster nodes: {:?}", + e + )); } } @@ -912,15 +1216,17 @@ impl Neo4jClient { } async fn verify_clusters(&self, node_id: &str, expected_clusters: &[String]) -> Result<()> { - let query = query(" + let query = query( + " MATCH (n {id: $node_id}) OPTIONAL MATCH (n)-[:BELONGS_TO]->(c:Cluster) RETURN n.id as node_id, collect(distinct c.name) as clusters, count(distinct c) as cluster_count - ") - .param("node_id", BoltType::String(BoltString::from(node_id))); + ", + ) + .param("node_id", BoltType::String(BoltString::from(node_id))); let mut result = self.graph.execute(query).await?; if let Some(row) = result.next().await? { @@ -929,16 +1235,34 @@ impl Neo4jClient { let cluster_count: i64 = row.get("cluster_count")?; debug!("Verification for node {}", db_node_id); - debug!("Clusters in DB: {:?} (count: {})", db_clusters, cluster_count); - - let missing_clusters: Vec<_> = expected_clusters.iter().filter(|c| !db_clusters.contains(c)).cloned().collect(); - let extra_clusters: Vec<_> = db_clusters.iter().filter(|c| !expected_clusters.contains(c)).cloned().collect(); + debug!( + "Clusters in DB: {:?} (count: {})", + db_clusters, cluster_count + ); + + let missing_clusters: Vec<_> = expected_clusters + .iter() + .filter(|c| !db_clusters.contains(c)) + .cloned() + .collect(); + let extra_clusters: Vec<_> = db_clusters + .iter() + .filter(|c| !expected_clusters.contains(c)) + .cloned() + .collect(); if !missing_clusters.is_empty() || !extra_clusters.is_empty() { warn!("Discrepancies found for node {}:", db_node_id); - if !missing_clusters.is_empty() { warn!("Missing clusters: {:?}", missing_clusters); } - if !extra_clusters.is_empty() { warn!("Extra clusters in DB: {:?}", extra_clusters); } - return Err(anyhow!("Discrepancies found in clusters for node {}", db_node_id)); + if !missing_clusters.is_empty() { + warn!("Missing clusters: {:?}", missing_clusters); + } + if !extra_clusters.is_empty() { + warn!("Extra clusters in DB: {:?}", extra_clusters); + } + return Err(anyhow!( + "Discrepancies found in clusters for node {}", + db_node_id + )); } else { debug!("All clusters verified successfully for node {}", db_node_id); } @@ -948,9 +1272,11 @@ impl Neo4jClient { Ok(()) } - - - pub async fn create_or_update_question(&self, question: &Neo4jQuestion, interaction_id: &str) -> Result { + pub async fn create_or_update_question( + &self, + question: &Neo4jQuestion, + interaction_id: &str, + ) -> Result { let query_str = r#" MERGE (q:Question {content: $content}) ON CREATE SET @@ -965,8 +1291,14 @@ impl Neo4jClient { "#; let mut props = BoltMap::new(); - props.put(BoltString::from("id"), BoltType::String(BoltString::from(question.id.as_str()))); - props.put(BoltString::from("content"), BoltType::String(BoltString::from(question.content.as_str()))); + props.put( + BoltString::from("id"), + BoltType::String(BoltString::from(question.id.as_str())), + ); + props.put( + BoltString::from("content"), + BoltType::String(BoltString::from(question.content.as_str())), + ); let mut vector_list = BoltList::new(); for &value in &question.vector { @@ -974,13 +1306,20 @@ impl Neo4jClient { } props.put(BoltString::from("vector"), BoltType::List(vector_list)); - props.put(BoltString::from("timestamp"), BoltType::String(BoltString::from(question.timestamp.to_rfc3339().as_str()))); + props.put( + BoltString::from("timestamp"), + BoltType::String(BoltString::from(question.timestamp.to_rfc3339().as_str())), + ); - let mut result = self.graph.execute(query(query_str) - .param("content", question.content.as_str()) - .param("props", BoltType::Map(props)) - .param("interaction_id", interaction_id) - ).await?; + let mut result = self + .graph + .execute( + query(query_str) + .param("content", question.content.as_str()) + .param("props", BoltType::Map(props)) + .param("interaction_id", interaction_id), + ) + .await?; if let Some(row) = result.next().await? { Ok(row.get("question_id")?) @@ -989,7 +1328,12 @@ impl Neo4jClient { } } - pub async fn create_response(&self, response: &Neo4jResponse, interaction_id: &str, model_id: &str) -> Result { + pub async fn create_response( + &self, + response: &Neo4jResponse, + interaction_id: &str, + model_id: &str, + ) -> Result { let query_str = r#" CREATE (r:Response { id: $id, @@ -1007,16 +1351,23 @@ impl Neo4jClient { RETURN r.id as response_id "#; - let mut result = self.graph.execute(query(query_str) - .param("id", response.id.clone()) - .param("content", response.content.clone()) - .param("vector", BoltType::List(response.vector.clone())) - .param("timestamp", response.timestamp.to_rfc3339()) - .param("confidence", response.confidence) - .param("llm_specific_data", serde_json::to_string(&response.llm_specific_data)?) - .param("interaction_id", interaction_id) - .param("model_id", model_id) - ).await?; + let mut result = self + .graph + .execute( + query(query_str) + .param("id", response.id.clone()) + .param("content", response.content.clone()) + .param("vector", BoltType::List(response.vector.clone())) + .param("timestamp", response.timestamp.to_rfc3339()) + .param("confidence", response.confidence) + .param( + "llm_specific_data", + serde_json::to_string(&response.llm_specific_data)?, + ) + .param("interaction_id", interaction_id) + .param("model_id", model_id), + ) + .await?; if let Some(row) = result.next().await? { Ok(row.get("response_id")?) @@ -1025,8 +1376,11 @@ impl Neo4jClient { } } - - async fn extract_themes_and_keywords(&self, content: &str, _config: &VoyageAIConfig) -> Result<(Vec, Vec)> { + async fn extract_themes_and_keywords( + &self, + content: &str, + _config: &VoyageAIConfig, + ) -> Result<(Vec, Vec)> { debug!("Extracting themes and keywords"); debug!("content: {}", content); let stemmer = Stemmer::create(Algorithm::English); @@ -1053,14 +1407,16 @@ impl Neo4jClient { let mut sorted_words: Vec<_> = word_freq.into_iter().collect(); sorted_words.sort_by(|a, b| b.1.cmp(&a.1).then_with(|| a.0.cmp(&b.0))); - let themes: Vec = sorted_words.iter() - .take(3) // Extract top 5 as themes + let themes: Vec = sorted_words + .iter() + .take(3) // Extract top 5 as themes .map(|(word, count)| format!("{}:{}", word, count)) .collect(); - let keywords: Vec = sorted_words.iter() + let keywords: Vec = sorted_words + .iter() .skip(5) - .take(3) // Extract next 10 as keywords + .take(3) // Extract next 10 as keywords .map(|(word, count)| format!("{}:{}", word, count)) .collect(); @@ -1070,12 +1426,16 @@ impl Neo4jClient { Ok((themes, keywords)) } - - - async fn extract_clusters(&self, content: &str, all_documents: &[String]) -> Result> { + async fn extract_clusters( + &self, + content: &str, + all_documents: &[String], + ) -> Result> { debug!("Extracting clusters"); let stemmer = Stemmer::create(Algorithm::English); - let stop_words: HashSet<_> = stop_words::get(stop_words::LANGUAGE::English).into_iter().collect(); + let stop_words: HashSet<_> = stop_words::get(stop_words::LANGUAGE::English) + .into_iter() + .collect(); // Function to tokenize and clean text let tokenize = |text: &str| -> Vec { @@ -1116,7 +1476,8 @@ impl Neo4jClient { } // Calculate TF-IDF - let mut tfidf: Vec<(String, f64)> = tf.into_iter() + let mut tfidf: Vec<(String, f64)> = tf + .into_iter() .map(|(word, tf_value)| { let df_value = df.get(&word).unwrap_or(&1.0); let idf = (n_docs / df_value).ln(); @@ -1128,7 +1489,8 @@ impl Neo4jClient { tfidf.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap()); // Extract top terms as clusters - let clusters: Vec = tfidf.into_iter() + let clusters: Vec = tfidf + .into_iter() .take(3) // Take top 5 terms as clusters .map(|(word, score)| format!("{}:{:.2}", word, score)) .collect(); @@ -1137,7 +1499,6 @@ impl Neo4jClient { Ok(clusters) } - pub async fn execute_cypher(&self, cypher_query: &str) -> Result { info!("Executing Cypher query: {}", cypher_query); @@ -1159,7 +1520,8 @@ impl Neo4jClient { } fn row_to_json(&self, row: &Row) -> Result { - row.to::().map_err(|e| anyhow!("Failed to convert row to JSON: {}", e)) + row.to::() + .map_err(|e| anyhow!("Failed to convert row to JSON: {}", e)) } pub async fn get_database_schema(&self) -> Result { @@ -1176,9 +1538,6 @@ impl Neo4jClient { Ok(schema_str) } - - - } // Define the necessary structs @@ -1235,4 +1594,4 @@ pub struct Neo4jTokenUsage { pub total_tokens: i32, } -// Implement other necessary structs and methods... \ No newline at end of file +// Implement other necessary structs and methods... diff --git a/crates/fluent-core/src/output_processor.rs b/crates/fluent-core/src/output_processor.rs index ed26f8b..4699951 100644 --- a/crates/fluent-core/src/output_processor.rs +++ b/crates/fluent-core/src/output_processor.rs @@ -1,26 +1,25 @@ -use std::env; -use anyhow::{Result, anyhow, Context}; -use regex::Regex; -use std::path::PathBuf; -use crossterm::style::Color; +use anyhow::{anyhow, Context, Result}; use log::{debug, info}; +use regex::Regex; use reqwest::Client; use serde_json::Value; +use std::env; +use std::path::Path; use syntect::easy::HighlightLines; use syntect::highlighting::{Style, ThemeSet}; use syntect::parsing::SyntaxSet; use syntect::util::{as_24_bit_terminal_escaped, LinesWithEndings}; +use termimad::crossterm::style::Color; +use termimad::{MadSkin, StyledChar}; use tokio::fs; use tokio::process::Command; use url::Url; use uuid::Uuid; -use termimad::{MadSkin, StyledChar}; pub struct OutputProcessor; impl OutputProcessor { - - pub async fn download_media_files(content: &str, directory: &PathBuf) -> Result<()> { + pub async fn download_media_files(content: &str, directory: &Path) -> Result<()> { debug!("Starting media file download process"); // Try to parse the content as JSON @@ -30,10 +29,12 @@ impl OutputProcessor { } else { debug!("Content is not valid JSON, proceeding with regex-based URL extraction"); // Corrected regex for URL matching, including query parameters - let url_regex = Regex::new(r#"(https?://[^\s"']+\.(?:jpg|jpeg|png|gif|bmp|svg|mp4|webm|ogg)(?:\?[^\s"']+)?)"#)?; + let url_regex = Regex::new( + r#"(https?://[^\s"']+\.(?:jpg|jpeg|png|gif|bmp|svg|mp4|webm|ogg)(?:\?[^\s"']+)?)"#, + )?; for cap in url_regex.captures_iter(content) { - let url = &cap[1]; // This includes the full URL with query parameters + let url = &cap[1]; // This includes the full URL with query parameters debug!("Found URL in content: {}", url); Self::download_file(url, directory, None).await?; } @@ -42,14 +43,21 @@ impl OutputProcessor { Ok(()) } - async fn download_file(url: &str, directory: &PathBuf, suggested_name: Option) -> Result<()> { + async fn download_file( + url: &str, + directory: &Path, + suggested_name: Option, + ) -> Result<()> { debug!("Attempting to download file from URL: {}", url); let client = Client::new(); let response = client.get(url).send().await?; if !response.status().is_success() { - return Err(anyhow!("Failed to download file: HTTP status {}", response.status())); + return Err(anyhow!( + "Failed to download file: HTTP status {}", + response.status() + )); } let file_name = if let Some(name) = suggested_name { @@ -66,12 +74,16 @@ impl OutputProcessor { let content = response.bytes().await?; fs::write(&file_path, &content).await?; - info!("Downloaded: {} ({} bytes)", file_path.display(), content.len()); + info!( + "Downloaded: {} ({} bytes)", + file_path.display(), + content.len() + ); Ok(()) } - async fn download_from_json(json_content: &Value, directory: &PathBuf) -> Result<()> { + async fn download_from_json(json_content: &Value, directory: &Path) -> Result<()> { if let Some(data) = json_content.get("data") { if let Some(data_array) = data.as_array() { for item in data_array { @@ -86,10 +98,10 @@ impl OutputProcessor { } fn extract_file_name_from_url(url: &str) -> Option { - Url::parse(url).ok()? + Url::parse(url) + .ok()? .path_segments()? .last()? - .to_string() .split('?') .next() .map(|s| s.to_string()) @@ -97,7 +109,8 @@ impl OutputProcessor { pub fn parse_code(content: &str) -> Vec { let code_block_regex = Regex::new(r"```(?:\w+)?\n([\s\S]*?)\n```").unwrap(); - code_block_regex.captures_iter(content) + code_block_regex + .captures_iter(content) .filter_map(|cap| cap.get(1)) .map(|m| m.as_str().trim().to_string()) .collect() @@ -115,7 +128,7 @@ impl OutputProcessor { output.push_str(&Self::execute_commands(&block).await?); } - output.push_str("\n"); + output.push('\n'); } Ok(output) } @@ -127,11 +140,17 @@ impl OutputProcessor { async fn execute_script(script: &str) -> Result { // Use a platform-agnostic way to get the temp directory let temp_dir = env::temp_dir(); - let file_name = format!("script_{}.{}", Uuid::new_v4(), if cfg!(windows) { "bat" } else { "sh" }); + let file_name = format!( + "script_{}.{}", + Uuid::new_v4(), + if cfg!(windows) { "bat" } else { "sh" } + ); let temp_file = temp_dir.join(file_name); // Write the script to the temporary file - fs::write(&temp_file, script).await.context("Failed to write script to temporary file")?; + fs::write(&temp_file, script) + .await + .context("Failed to write script to temporary file")?; // Set executable permissions on Unix-like systems #[cfg(unix)] @@ -139,25 +158,23 @@ impl OutputProcessor { use std::os::unix::fs::PermissionsExt; let mut perms = fs::metadata(&temp_file).await?.permissions(); perms.set_mode(0o755); - fs::set_permissions(&temp_file, perms).await.context("Failed to set file permissions")?; + fs::set_permissions(&temp_file, perms) + .await + .context("Failed to set file permissions")?; } // Execute the script let result = if cfg!(windows) { - Command::new("cmd") - .arg("/C") - .arg(&temp_file) - .output() - .await + Command::new("cmd").arg("/C").arg(&temp_file).output().await } else { - Command::new("sh") - .arg(&temp_file) - .output() - .await - }.context("Failed to execute script")?; + Command::new("sh").arg(&temp_file).output().await + } + .context("Failed to execute script")?; // Remove the temporary file - fs::remove_file(&temp_file).await.context("Failed to remove temporary file")?; + fs::remove_file(&temp_file) + .await + .context("Failed to remove temporary file")?; // Collect and format the output let stdout = String::from_utf8_lossy(&result.stdout); @@ -176,11 +193,7 @@ impl OutputProcessor { output.push_str(&format!("Executing: {}\n", trimmed)); - let result = Command::new("sh") - .arg("-c") - .arg(trimmed) - .output() - .await?; + let result = Command::new("sh").arg("-c").arg(trimmed).output().await?; let stdout = String::from_utf8_lossy(&result.stdout); let stderr = String::from_utf8_lossy(&result.stderr); @@ -189,12 +202,10 @@ impl OutputProcessor { if !stderr.is_empty() { output.push_str(&format!("Errors:\n{}\n", stderr)); } - output.push_str("\n"); + output.push('\n'); } Ok(output) } - - } pub struct MarkdownFormatter { @@ -203,20 +214,47 @@ pub struct MarkdownFormatter { theme_set: ThemeSet, } -impl MarkdownFormatter { - pub fn new() -> Self { +impl Default for MarkdownFormatter { + fn default() -> Self { let mut skin = MadSkin::default(); - skin.set_bg(Color::Rgb { r: 10, g: 10, b: 10 }); - skin.set_headers_fg(Color::Rgb { r: 255, g: 187, b: 0 }); - skin.bold.set_fg(Color::Rgb { r: 255, g: 215, b: 0 }); - skin.italic.set_fg(Color::Rgb { r: 173, g: 216, b: 230 }); - skin.paragraph.set_fg(Color::Rgb { r: 220, g: 220, b: 220 }); // Light grey for normal text + skin.set_bg(Color::Rgb { + r: 10, + g: 10, + b: 10, + }); + skin.set_headers_fg(Color::Rgb { + r: 255, + g: 187, + b: 0, + }); + skin.bold.set_fg(Color::Rgb { + r: 255, + g: 215, + b: 0, + }); + skin.italic.set_fg(Color::Rgb { + r: 173, + g: 216, + b: 230, + }); + skin.paragraph.set_fg(Color::Rgb { + r: 220, + g: 220, + b: 220, + }); // Light grey for normal text skin.bullet = StyledChar::from_fg_char(Color::Rgb { r: 0, g: 255, b: 0 }, '•'); // Set code block colors - skin.code_block.set_bg(Color::Rgb { r: 30, g: 30, b: 30 }); // Slightly lighter than main background - skin.code_block.set_fg(Color::Rgb { r: 255, g: 255, b: 255 }); // White text for code - + skin.code_block.set_bg(Color::Rgb { + r: 30, + g: 30, + b: 30, + }); // Slightly lighter than main background + skin.code_block.set_fg(Color::Rgb { + r: 255, + g: 255, + b: 255, + }); // White text for code MarkdownFormatter { skin, @@ -224,6 +262,11 @@ impl MarkdownFormatter { theme_set: ThemeSet::load_defaults(), } } +} +impl MarkdownFormatter { + pub fn new() -> Self { + Self::default() + } pub fn format(&self, content: &str) -> Result { debug!("Formatting markdown"); @@ -262,8 +305,12 @@ impl MarkdownFormatter { fn highlight_code(&self, lang: &str, code: &str) -> Result { debug!("highlight_code: {}", lang); - let syntax = self.syntax_set.find_syntax_by_extension(lang).unwrap_or_else(|| self.syntax_set.find_syntax_plain_text()); - let mut highlighter = HighlightLines::new(syntax, &self.theme_set.themes["base16-ocean.dark"]); + let syntax = self + .syntax_set + .find_syntax_by_extension(lang) + .unwrap_or_else(|| self.syntax_set.find_syntax_plain_text()); + let mut highlighter = + HighlightLines::new(syntax, &self.theme_set.themes["base16-ocean.dark"]); let mut output = String::new(); for line in LinesWithEndings::from(code) { @@ -294,4 +341,4 @@ pub fn format_markdown(content: &str) -> Result { debug!("format_markdown"); let formatter = MarkdownFormatter::new(); formatter.format(content) -} \ No newline at end of file +} diff --git a/crates/fluent-engines/Cargo.toml b/crates/fluent-engines/Cargo.toml index ba59f9e..549635b 100644 --- a/crates/fluent-engines/Cargo.toml +++ b/crates/fluent-engines/Cargo.toml @@ -6,7 +6,12 @@ edition = "2021" [dependencies] fluent-core = { path = "../fluent-core" } -reqwest = { version = "0.11", features = ["json", "stream", "multipart"] } +reqwest = { version = "0.12.5", default-features = false, features = [ + "json", + "stream", + "multipart", + "rustls-tls", +] } serde_json = "1.0.120" anyhow = "1.0.86" async-trait = "0.1.80" @@ -14,12 +19,12 @@ log = "0.4.22" tokio = "1.38.0" tokio-util = "0.7.11" base64 = "0.22.1" - mime_guess = "2.0.3" serde = { version = "1.0.204", features = ["derive"] } -indicatif = "0.17.8" uuid = "1.9.1" -clap = "4.5.8" futures-util = "0.3.30" tempfile = "3.10.1" -futures = "0.3.30" \ No newline at end of file +futures = "0.3.30" + +#indicatif = "0.17.8" +#clap = "4.5.8" diff --git a/crates/fluent-engines/src/cohere.rs b/crates/fluent-engines/src/cohere.rs index 54291ae..dd1fe54 100644 --- a/crates/fluent-engines/src/cohere.rs +++ b/crates/fluent-engines/src/cohere.rs @@ -1,19 +1,21 @@ +use anyhow::{anyhow, Context, Result}; +use async_trait::async_trait; +use base64::{engine::general_purpose::STANDARD, Engine as _}; +use fluent_core::config::EngineConfig; +use fluent_core::neo4j_client::Neo4jClient; +use fluent_core::traits::Engine; +use fluent_core::types::{ + ExtractedContent, Request, Response, UpsertRequest, UpsertResponse, Usage, +}; +use log::debug; +use reqwest::Client; +use serde_json::{json, Value}; use std::future::Future; use std::path::Path; use std::pin::Pin; use std::sync::Arc; -use anyhow::{Result, anyhow, Context}; -use async_trait::async_trait; -use serde_json::{json, Value}; use tokio::fs::File; use tokio::io::AsyncReadExt; -use base64::{Engine as _, engine::general_purpose::STANDARD}; -use fluent_core::types::{ExtractedContent, Request, Response, UpsertRequest, UpsertResponse, Usage}; -use fluent_core::neo4j_client::Neo4jClient; -use fluent_core::traits::Engine; -use fluent_core::config::EngineConfig; -use log::debug; -use reqwest::Client; pub struct CohereEngine { config: EngineConfig, @@ -39,13 +41,17 @@ impl CohereEngine { #[async_trait] impl Engine for CohereEngine { - fn execute<'a>(&'a self, request: &'a Request) -> Box> + Send + 'a> { + fn execute<'a>( + &'a self, + request: &'a Request, + ) -> Box> + Send + 'a> { Box::new(async move { - let url = format!("{}://{}:{}{}", - self.config.connection.protocol, - self.config.connection.hostname, - self.config.connection.port, - self.config.connection.request_path + let url = format!( + "{}://{}:{}{}", + self.config.connection.protocol, + self.config.connection.hostname, + self.config.connection.port, + self.config.connection.request_path ); let payload = json!({ @@ -71,11 +77,16 @@ impl Engine for CohereEngine { debug!("Cohere Payload: {:?}", payload); - let auth_token = self.config.parameters.get("bearer_token") + let auth_token = self + .config + .parameters + .get("bearer_token") .and_then(|v| v.as_str()) .ok_or_else(|| anyhow!("Bearer token not found in configuration"))?; - let response = self.client.post(&url) + let response = self + .client + .post(&url) .header("Authorization", format!("Bearer {}", auth_token)) .header("Content-Type", "application/json") .json(&payload) @@ -96,22 +107,31 @@ impl Engine for CohereEngine { .to_string(); let usage = Usage { - prompt_tokens: response.get("meta") + prompt_tokens: response + .get("meta") .and_then(|meta| meta.get("billed_units")) .and_then(|billed_units| billed_units.get("input_tokens")) .and_then(|input_tokens| input_tokens.as_u64()) .unwrap_or(0) as u32, - completion_tokens: response.get("meta") + completion_tokens: response + .get("meta") .and_then(|meta| meta.get("billed_units")) .and_then(|billed_units| billed_units.get("output_tokens")) .and_then(|output_tokens| output_tokens.as_u64()) .unwrap_or(0) as u32, - total_tokens: response.get("meta") + total_tokens: response + .get("meta") .and_then(|meta| meta.get("billed_units")) - .and_then(|billed_units| { - let input = billed_units.get("input_tokens").and_then(|t| t.as_u64()).unwrap_or(0); - let output = billed_units.get("output_tokens").and_then(|t| t.as_u64()).unwrap_or(0); - Some(input + output) + .map(|billed_units| { + let input = billed_units + .get("input_tokens") + .and_then(|t| t.as_u64()) + .unwrap_or(0); + let output = billed_units + .get("output_tokens") + .and_then(|t| t.as_u64()) + .unwrap_or(0); + input + output }) .unwrap_or(0) as u32, }; @@ -128,7 +148,10 @@ impl Engine for CohereEngine { }) } - fn upsert<'a>(&'a self, _request: &'a UpsertRequest) -> Box> + Send + 'a> { + fn upsert<'a>( + &'a self, + _request: &'a UpsertRequest, + ) -> Box> + Send + 'a> { Box::new(async move { // Cohere doesn't have a direct upsert functionality, so we return an empty response Ok(UpsertResponse { @@ -143,11 +166,16 @@ impl Engine for CohereEngine { } fn get_session_id(&self) -> Option { - self.config.parameters.get("sessionID").and_then(|v| v.as_str()).map(String::from) + self.config + .parameters + .get("sessionID") + .and_then(|v| v.as_str()) + .map(String::from) } fn extract_content(&self, value: &Value) -> Option { - value.get("text") + value + .get("text") .and_then(|text| text.as_str()) .map(|content| ExtractedContent { main_content: content.to_string(), @@ -158,26 +186,36 @@ impl Engine for CohereEngine { }) } - fn upload_file<'a>(&'a self, file_path: &'a Path) -> Box> + Send + 'a> { + fn upload_file<'a>( + &'a self, + file_path: &'a Path, + ) -> Box> + Send + 'a> { Box::new(async move { // Cohere doesn't have a direct file upload API, so we'll read the file and return its content as a base64 string let mut file = File::open(file_path).await.context("Failed to open file")?; let mut buffer = Vec::new(); - file.read_to_end(&mut buffer).await.context("Failed to read file")?; + file.read_to_end(&mut buffer) + .await + .context("Failed to read file")?; let base64_content = STANDARD.encode(&buffer); Ok(base64_content) }) } - fn process_request_with_file<'a>(&'a self, request: &'a Request, file_path: &'a Path) -> Box> + Send + 'a> { + fn process_request_with_file<'a>( + &'a self, + request: &'a Request, + file_path: &'a Path, + ) -> Box> + Send + 'a> { Box::new(async move { let base64_content = Pin::from(self.upload_file(file_path)).await?; - let url = format!("{}://{}:{}{}", - self.config.connection.protocol, - self.config.connection.hostname, - self.config.connection.port, - self.config.connection.request_path + let url = format!( + "{}://{}:{}{}", + self.config.connection.protocol, + self.config.connection.hostname, + self.config.connection.port, + self.config.connection.request_path ); let payload = json!({ @@ -192,11 +230,16 @@ impl Engine for CohereEngine { debug!("Cohere Payload with file: {:?}", payload); - let auth_token = self.config.parameters.get("bearer_token") + let auth_token = self + .config + .parameters + .get("bearer_token") .and_then(|v| v.as_str()) .ok_or_else(|| anyhow!("Bearer token not found in configuration"))?; - let response = self.client.post(&url) + let response = self + .client + .post(&url) .header("Authorization", format!("Bearer {}", auth_token)) .header("Content-Type", "application/json") .json(&payload) @@ -217,22 +260,31 @@ impl Engine for CohereEngine { .to_string(); let usage = Usage { - prompt_tokens: response.get("meta") + prompt_tokens: response + .get("meta") .and_then(|meta| meta.get("billed_units")) .and_then(|billed_units| billed_units.get("input_tokens")) .and_then(|input_tokens| input_tokens.as_u64()) .unwrap_or(0) as u32, - completion_tokens: response.get("meta") + completion_tokens: response + .get("meta") .and_then(|meta| meta.get("billed_units")) .and_then(|billed_units| billed_units.get("output_tokens")) .and_then(|output_tokens| output_tokens.as_u64()) .unwrap_or(0) as u32, - total_tokens: response.get("meta") + total_tokens: response + .get("meta") .and_then(|meta| meta.get("billed_units")) - .and_then(|billed_units| { - let input = billed_units.get("input_tokens").and_then(|t| t.as_u64()).unwrap_or(0); - let output = billed_units.get("output_tokens").and_then(|t| t.as_u64()).unwrap_or(0); - Some(input + output) + .map(|billed_units| { + let input = billed_units + .get("input_tokens") + .and_then(|t| t.as_u64()) + .unwrap_or(0); + let output = billed_units + .get("output_tokens") + .and_then(|t| t.as_u64()) + .unwrap_or(0); + input + output }) .unwrap_or(0) as u32, }; @@ -248,4 +300,4 @@ impl Engine for CohereEngine { }) }) } -} \ No newline at end of file +} diff --git a/crates/fluent-engines/src/flowise_chain.rs b/crates/fluent-engines/src/flowise_chain.rs index b35ef45..5c86196 100644 --- a/crates/fluent-engines/src/flowise_chain.rs +++ b/crates/fluent-engines/src/flowise_chain.rs @@ -1,22 +1,22 @@ -use std::collections::HashMap; -use std::future::Future; -use std::path::Path; -use std::sync::Arc; -use fluent_core::types::{ExtractedContent, Request, Response, UpsertRequest, UpsertResponse, Usage}; -use fluent_core::traits::{Engine, EngineConfigProcessor}; +use anyhow::{anyhow, Context, Result}; +use base64::engine::general_purpose::STANDARD as Base64; +use base64::Engine as Base64Engine; use fluent_core::config::EngineConfig; use fluent_core::neo4j_client::Neo4jClient; -use anyhow::{Result, anyhow, Context}; -use reqwest::Client; -use serde_json::{json, Value}; +use fluent_core::traits::{Engine, EngineConfigProcessor}; +use fluent_core::types::{ + ExtractedContent, Request, Response, UpsertRequest, UpsertResponse, Usage, +}; use log::{debug, warn}; use mime_guess::from_path; +use reqwest::Client; +use serde_json::{json, Value}; +use std::collections::HashMap; +use std::future::Future; +use std::path::Path; +use std::sync::Arc; use tokio::fs::File; use tokio::io::AsyncReadExt; -use base64::Engine as Base64Engine; -use base64::engine::general_purpose::STANDARD as Base64; - - pub struct FlowiseChainEngine { config: EngineConfig, @@ -39,18 +39,19 @@ impl FlowiseChainEngine { }) } - async fn create_upload_payload(file_path: &Path) -> Result { + async fn _create_upload_payload(file_path: &Path) -> Result { debug!("Creating upload payload for file: {}", file_path.display()); let mut file = File::open(file_path).await.context("Failed to open file")?; let mut buffer = Vec::new(); - file.read_to_end(&mut buffer).await.context("Failed to read file")?; + file.read_to_end(&mut buffer) + .await + .context("Failed to read file")?; let base64_image = Base64.encode(&buffer); - let mime_type = from_path(file_path) - .first_or_octet_stream() - .to_string(); + let mime_type = from_path(file_path).first_or_octet_stream().to_string(); - let file_name = file_path.file_name() + let file_name = file_path + .file_name() .and_then(|n| n.to_str()) .unwrap_or("unknown.file") .to_string(); @@ -67,7 +68,6 @@ impl FlowiseChainEngine { pub struct FlowiseChainConfigProcessor; impl EngineConfigProcessor for FlowiseChainConfigProcessor { - fn process_config(&self, config: &EngineConfig) -> Result { debug!("FlowiseConfigProcessor::process_config"); debug!("Config: {:#?}", config); @@ -84,7 +84,7 @@ impl EngineConfigProcessor for FlowiseChainConfigProcessor { // Handle nested objects (like openAIApiKey with multiple keys) let nested_config: HashMap = obj.clone().into_iter().collect(); payload["overrideConfig"][key] = json!(nested_config); - }, + } _ => { // For non-object values, add them directly payload["overrideConfig"][key] = value.clone(); @@ -104,30 +104,67 @@ impl Engine for FlowiseChainEngine { } fn get_session_id(&self) -> Option { - self.config.parameters.get("sessionID").and_then(|v| v.as_str()).map(String::from) + self.config + .parameters + .get("sessionID") + .and_then(|v| v.as_str()) + .map(String::from) } fn extract_content(&self, value: &Value) -> Option { let mut content = ExtractedContent::default(); if let Some(outputs) = value.get("outputs").and_then(|v| v.as_array()) { - if let Some(first_output) = outputs.get(0) { - content.main_content = first_output.get("output").and_then(|v| v.as_str()) - .unwrap_or_default().to_string(); - content.sentiment = first_output.get("sentiment").and_then(|v| v.as_str()).map(String::from); - content.clusters = first_output.get("clusters").and_then(|v| v.as_array()) - .map(|arr| arr.iter().filter_map(|v| v.as_str().map(String::from)).collect()); - content.themes = first_output.get("themes").and_then(|v| v.as_array()) - .map(|arr| arr.iter().filter_map(|v| v.as_str().map(String::from)).collect()); - content.keywords = first_output.get("keywords").and_then(|v| v.as_array()) - .map(|arr| arr.iter().filter_map(|v| v.as_str().map(String::from)).collect()); + if let Some(first_output) = outputs.first() { + content.main_content = first_output + .get("output") + .and_then(|v| v.as_str()) + .unwrap_or_default() + .to_string(); + content.sentiment = first_output + .get("sentiment") + .and_then(|v| v.as_str()) + .map(String::from); + content.clusters = + first_output + .get("clusters") + .and_then(|v| v.as_array()) + .map(|arr| { + arr.iter() + .filter_map(|v| v.as_str().map(String::from)) + .collect() + }); + content.themes = first_output + .get("themes") + .and_then(|v| v.as_array()) + .map(|arr| { + arr.iter() + .filter_map(|v| v.as_str().map(String::from)) + .collect() + }); + content.keywords = + first_output + .get("keywords") + .and_then(|v| v.as_array()) + .map(|arr| { + arr.iter() + .filter_map(|v| v.as_str().map(String::from)) + .collect() + }); } } - if content.main_content.is_empty() { None } else { Some(content) } + if content.main_content.is_empty() { + None + } else { + Some(content) + } } - fn upsert<'a>(&'a self, _request: &'a UpsertRequest) -> Box> + Send + 'a> { + fn upsert<'a>( + &'a self, + _request: &'a UpsertRequest, + ) -> Box> + Send + 'a> { Box::new(async move { // Implement FlowiseAI-specific upsert logic here if needed Ok(UpsertResponse { @@ -137,7 +174,10 @@ impl Engine for FlowiseChainEngine { }) } - fn execute<'a>(&'a self, request: &'a Request) -> Box> + Send + 'a> { + fn execute<'a>( + &'a self, + request: &'a Request, + ) -> Box> + Send + 'a> { Box::new(async move { let client = Client::new(); debug!("Config: {:?}", self.config); @@ -147,17 +187,15 @@ impl Engine for FlowiseChainEngine { // Add the user's request to the payload payload["question"] = json!(request.payload); - let url = format!("{}://{}:{}{}", - self.config.connection.protocol, - self.config.connection.hostname, - self.config.connection.port, - self.config.connection.request_path + let url = format!( + "{}://{}:{}{}", + self.config.connection.protocol, + self.config.connection.hostname, + self.config.connection.port, + self.config.connection.request_path ); - let res = client.post(&url) - .json(&payload) - .send() - .await?; + let res = client.post(&url).json(&payload).send().await?; let response_body = res.json::().await?; debug!("Response: {:?}", response_body); @@ -174,7 +212,7 @@ impl Engine for FlowiseChainEngine { // FlowiseAI doesn't provide token usage, so we'll estimate it based on content length let estimated_tokens = (content.len() as f32 / 4.0).ceil() as u32; let usage = Usage { - prompt_tokens: estimated_tokens / 2, // Rough estimate + prompt_tokens: estimated_tokens / 2, // Rough estimate completion_tokens: estimated_tokens / 2, // Rough estimate total_tokens: estimated_tokens, }; @@ -191,22 +229,34 @@ impl Engine for FlowiseChainEngine { }) } - fn upload_file<'a>(&'a self, _file_path: &'a Path) -> Box> + Send + 'a> { + fn upload_file<'a>( + &'a self, + _file_path: &'a Path, + ) -> Box> + Send + 'a> { Box::new(async move { - Err(anyhow!("File upload not implemented for Flowise Chain engine")) + Err(anyhow!( + "File upload not implemented for Flowise Chain engine" + )) }) } - fn process_request_with_file<'a>(&'a self, request: &'a Request, file_path: &'a Path) -> Box> + Send + 'a> { + fn process_request_with_file<'a>( + &'a self, + request: &'a Request, + file_path: &'a Path, + ) -> Box> + Send + 'a> { Box::new(async move { let mut file = File::open(file_path).await.context("Failed to open file")?; let mut buffer = Vec::new(); - file.read_to_end(&mut buffer).await.context("Failed to read file")?; + file.read_to_end(&mut buffer) + .await + .context("Failed to read file")?; let encoded_image = base64::engine::general_purpose::STANDARD.encode(&buffer); debug!("Encoded image length: {} bytes", encoded_image.len()); - let file_name = file_path.file_name() + let file_name = file_path + .file_name() .and_then(|n| n.to_str()) .unwrap_or("unknown.file") .to_string(); @@ -221,24 +271,33 @@ impl Engine for FlowiseChainEngine { }] }); - debug!("Data field prefix: {}", &payload["uploads"][0]["data"].as_str().unwrap_or("").split(',').next().unwrap_or("")); - debug!("Uploads array length: {}", payload["uploads"].as_array().map_or(0, |arr| arr.len())); + debug!( + "Data field prefix: {}", + &payload["uploads"][0]["data"] + .as_str() + .unwrap_or("") + .split(',') + .next() + .unwrap_or("") + ); + debug!( + "Uploads array length: {}", + payload["uploads"].as_array().map_or(0, |arr| arr.len()) + ); debug!("File name in payload: {}", &payload["uploads"][0]["name"]); let client = reqwest::Client::new(); - let url = format!("{}://{}:{}{}", - self.config.connection.protocol, - self.config.connection.hostname, - self.config.connection.port, - self.config.connection.request_path + let url = format!( + "{}://{}:{}{}", + self.config.connection.protocol, + self.config.connection.hostname, + self.config.connection.port, + self.config.connection.request_path ); debug!("Sending request to URL: {}", url); - let response = client.post(&url) - .json(&payload) - .send() - .await?; + let response = client.post(&url).json(&payload).send().await?; debug!("Response status: {}", response.status()); @@ -246,11 +305,17 @@ impl Engine for FlowiseChainEngine { debug!("FlowiseAI Response: {:?}", response_body); - if response_body.get("error").is_some() || response_body["text"].as_str().map_or(false, |s| s.contains("no image provided")) { - warn!("FlowiseAI did not process the image. Full response: {:?}", response_body); + if response_body.get("error").is_some() + || response_body["text"] + .as_str() + .map_or(false, |s| s.contains("no image provided")) + { + warn!( + "FlowiseAI did not process the image. Full response: {:?}", + response_body + ); } - let content = response_body["text"] .as_str() .ok_or_else(|| anyhow!("Failed to extract content from FlowiseAI response"))? @@ -259,7 +324,7 @@ impl Engine for FlowiseChainEngine { // FlowiseAI doesn't provide token usage, so we'll estimate it based on content length let estimated_tokens = (content.len() as f32 / 4.0).ceil() as u32; let usage = Usage { - prompt_tokens: estimated_tokens / 2, // Rough estimate + prompt_tokens: estimated_tokens / 2, // Rough estimate completion_tokens: estimated_tokens / 2, // Rough estimate total_tokens: estimated_tokens, }; @@ -275,4 +340,4 @@ impl Engine for FlowiseChainEngine { }) }) } -} \ No newline at end of file +} diff --git a/crates/fluent-engines/src/pipeline_executor.rs b/crates/fluent-engines/src/pipeline_executor.rs index 1de5ec8..a4e2fbb 100644 --- a/crates/fluent-engines/src/pipeline_executor.rs +++ b/crates/fluent-engines/src/pipeline_executor.rs @@ -1,4 +1,3 @@ - use serde::{Deserialize, Serialize}; use std::collections::HashMap; @@ -6,19 +5,17 @@ use std::future::Future; use std::path::{Path, PathBuf}; use std::pin::Pin; -use tokio::sync::{Mutex}; -use tokio::process::Command as TokioCommand; use std::io::Write; +use tokio::process::Command as TokioCommand; use std::time::{Duration, SystemTime, UNIX_EPOCH}; use tempfile; -use std::sync::{Arc}; use anyhow::{anyhow, Error}; -use tokio::process::Command; -use log::{info, error, warn, debug}; use async_trait::async_trait; - +use log::{debug, error, info, warn}; +use std::sync::Arc; +use tokio::process::Command; use tokio::task::JoinSet; use tokio::time::timeout; @@ -32,20 +29,74 @@ pub struct Pipeline { #[derive(Debug, Deserialize, Serialize, Clone)] pub enum PipelineStep { - Command { name: String, command: String, save_output: Option, retry: Option }, - ShellCommand { name: String, command: String, save_output: Option, retry: Option }, - Condition { name: String, condition: String, if_true: String, if_false: String }, - Loop { name: String, steps: Vec, condition: String }, - SubPipeline { name: String, pipeline: String, with: HashMap }, - Map { name: String, input: String, command: String, save_output: String }, - HumanInTheLoop { name: String, prompt: String, save_output: String }, - RepeatUntil { name: String, steps: Vec, condition: String }, - PrintOutput { name: String, value: String }, - ForEach { name: String, items: String, steps: Vec }, - TryCatch { name: String, try_steps: Vec, catch_steps: Vec, finally_steps: Vec }, - Parallel { name: String, steps: Vec }, - Timeout { name: String, duration: u64, step: Box }, - + Command { + name: String, + command: String, + save_output: Option, + retry: Option, + }, + ShellCommand { + name: String, + command: String, + save_output: Option, + retry: Option, + }, + Condition { + name: String, + condition: String, + if_true: String, + if_false: String, + }, + Loop { + name: String, + steps: Vec, + condition: String, + }, + SubPipeline { + name: String, + pipeline: String, + with: HashMap, + }, + Map { + name: String, + input: String, + command: String, + save_output: String, + }, + HumanInTheLoop { + name: String, + prompt: String, + save_output: String, + }, + RepeatUntil { + name: String, + steps: Vec, + condition: String, + }, + PrintOutput { + name: String, + value: String, + }, + ForEach { + name: String, + items: String, + steps: Vec, + }, + TryCatch { + name: String, + try_steps: Vec, + catch_steps: Vec, + finally_steps: Vec, + }, + Parallel { + name: String, + steps: Vec, + }, + Timeout { + name: String, + duration: u64, + step: Box, + }, } #[derive(Debug, Deserialize, Serialize, Clone)] @@ -70,30 +121,34 @@ pub trait StateStore { pub struct PipelineExecutor { // Change state to Arc> - state: Arc>, state_store: S, json_output: bool, } +type PipelineFuture<'a> = + Pin, Error>> + Send + 'a>>; + impl PipelineExecutor { pub fn new(state_store: S, _json_output: bool) -> Self { Self { - state: Arc::new(Mutex::new(PipelineState { - current_step: 0, - data: HashMap::new(), - run_id: "".to_string(), - start_time: 0, - })), state_store, json_output: false, } } - - pub async fn execute(&self, pipeline: &Pipeline, initial_input: &str, force_fresh: bool, provided_run_id: Option) -> Result { + pub async fn execute( + &self, + pipeline: &Pipeline, + initial_input: &str, + force_fresh: bool, + provided_run_id: Option, + ) -> Result { let run_id = provided_run_id.unwrap_or_else(|| Uuid::new_v4().to_string()); let state_key = format!("{}-{}", pipeline.name, run_id); - debug!("Executing pipeline {} with run_id {}", pipeline.name, run_id); + debug!( + "Executing pipeline {} with run_id {}", + pipeline.name, run_id + ); let mut state = if force_fresh { debug!("Forcing fresh state"); @@ -106,9 +161,12 @@ impl PipelineExec } else { match self.state_store.load_state(&state_key).await? { Some(saved_state) => { - debug!("Resuming from saved state at step {}", saved_state.current_step); + debug!( + "Resuming from saved state at step {}", + saved_state.current_step + ); saved_state - }, + } None => { debug!("No saved state found, starting fresh"); PipelineState { @@ -122,14 +180,18 @@ impl PipelineExec }; if state.current_step == 0 { - state.data.insert("input".to_string(), initial_input.to_string()); + state + .data + .insert("input".to_string(), initial_input.to_string()); } state.data.insert("run_id".to_string(), run_id.clone()); for (index, step) in pipeline.steps.iter().enumerate().skip(state.current_step) { debug!("Processing step {} (index {})", step.name(), index); - state.data.insert("step".to_string(), step.name().to_string()); + state + .data + .insert("step".to_string(), step.name().to_string()); state.current_step = index; match self.execute_step(step, &mut state).await { @@ -161,39 +223,58 @@ impl PipelineExec Ok(serde_json::to_string_pretty(&output)?) } - - - - fn execute_step<'a>(&'a self, step: &'a PipelineStep, state: &'a mut PipelineState) -> Pin, Error>> + Send + 'a>> { + fn execute_step<'a>( + &'a self, + step: &'a PipelineStep, + state: &'a mut PipelineState, + ) -> PipelineFuture<'a> { Box::pin(async move { match step { - - PipelineStep::Command { name, command, save_output, retry } => { + PipelineStep::Command { + name, + command, + save_output, + retry, + } => { debug!("Executing Command step: {}", name); debug!("Command: {}", command); let expanded_command = self.expand_variables(command, &state.data).await?; - self.execute_command(&expanded_command, save_output, retry).await + self.execute_command(&expanded_command, save_output, retry) + .await } - PipelineStep::ShellCommand { name, command, save_output, retry } => { + PipelineStep::ShellCommand { + name, + command, + save_output, + retry, + } => { debug!("Executing ShellCommand step: {}", name); debug!("Command: {}", command); let expanded_command = self.expand_variables(command, &state.data).await?; - self.execute_shell_command(&expanded_command, save_output, retry).await + self.execute_shell_command(&expanded_command, save_output, retry) + .await } - PipelineStep::Condition { name, condition, if_true, if_false } => { + PipelineStep::Condition { + name, + condition, + if_true, + if_false, + } => { debug!("Evaluating Condition step: {}", name); debug!("Condition: {}", condition); let expanded_condition = self.expand_variables(condition, &state.data).await?; if self.evaluate_condition(&expanded_condition).await? { debug!("Condition is true, executing: {}", if_true); let expanded_command = self.expand_variables(if_true, &state.data).await?; - self.execute_shell_command(&expanded_command, &None, &None).await + self.execute_shell_command(&expanded_command, &None, &None) + .await } else { debug!("Condition is false, executing: {}", if_false); let expanded_command = self.expand_variables(if_false, &state.data).await?; - self.execute_shell_command(&expanded_command, &None, &None).await + self.execute_shell_command(&expanded_command, &None, &None) + .await } } @@ -201,12 +282,17 @@ impl PipelineExec debug!("Executing PrintOutput step: {}", name); let expanded_value = self.expand_variables(value, &state.data).await?; if !self.json_output { - eprintln!("{}", expanded_value); // Print to stderr instead of stdout + eprintln!("{}", expanded_value); // Print to stderr instead of stdout } Ok(HashMap::new()) } - PipelineStep::Map { name, input, command, save_output } => { + PipelineStep::Map { + name, + input, + command, + save_output, + } => { debug!("Executing Map step: {}", name); debug!("Input: {}", input); debug!("Command: {}", command); @@ -220,13 +306,16 @@ impl PipelineExec let expanded_command = self.expand_variables(command, &state.data).await?; let item_command = expanded_command.replace("${ITEM}", item); debug!("Executing command: {}", item_command); - match self.execute_shell_command(&item_command, &None, &None).await { + match self + .execute_shell_command(&item_command, &None, &None) + .await + { Ok(output) => { let new_string = String::new(); let result = output.values().next().unwrap_or(&new_string); debug!("Command output: {}", result); results.push(result.to_string()); - }, + } Err(e) => { error!("Error executing command for item {}: {:?}", item, e); } @@ -238,7 +327,11 @@ impl PipelineExec Ok([(save_output.clone(), output)].into_iter().collect()) } - PipelineStep::HumanInTheLoop { name, prompt, save_output } => { + PipelineStep::HumanInTheLoop { + name, + prompt, + save_output, + } => { debug!("Executing HumanInTheLoop step: {}", name); debug!("Prompt: {}", prompt); let expanded_prompt = self.expand_variables(prompt, &state.data).await?; @@ -247,10 +340,16 @@ impl PipelineExec let mut input = String::new(); std::io::stdin().read_line(&mut input)?; - Ok([(save_output.clone(), input.trim().to_string())].into_iter().collect()) + Ok([(save_output.clone(), input.trim().to_string())] + .into_iter() + .collect()) } - PipelineStep::RepeatUntil { name, steps, condition } => { + PipelineStep::RepeatUntil { + name, + steps, + condition, + } => { debug!("Executing RepeatUntil step: {}", name); debug!("Steps: {:?}", steps); debug!("Condition: {}", condition); @@ -260,7 +359,8 @@ impl PipelineExec state.data.extend(step_result); } - let expanded_condition = self.expand_variables(condition, &state.data).await?; + let expanded_condition = + self.expand_variables(condition, &state.data).await?; if self.evaluate_condition(&expanded_condition).await? { break; } @@ -291,7 +391,12 @@ impl PipelineExec Ok(HashMap::from([(name.clone(), results.join(", "))])) } - PipelineStep::TryCatch { name, try_steps, catch_steps, finally_steps } => { + PipelineStep::TryCatch { + name, + try_steps, + catch_steps, + finally_steps, + } => { debug!("Executing TryCatch step: {}", name); debug!("Try Steps: {:?}", try_steps); debug!("Catch Steps: {:?}", catch_steps); @@ -303,7 +408,8 @@ impl PipelineExec state.data.extend(step_result); } Ok(()) as Result<(), Error> - }.await; + } + .await; match try_result { Ok(_) => { @@ -329,9 +435,13 @@ impl PipelineExec PipelineStep::Parallel { name, steps } => { self.execute_parallel_steps(name, steps, state).await - }, + } - PipelineStep::Timeout { name, duration, step } => { + PipelineStep::Timeout { + name, + duration, + step, + } => { debug!("Executing Timeout step: {}", name); debug!("Duration: {}", duration); debug!("Step: {:?}", step); @@ -344,27 +454,34 @@ impl PipelineExec let result = step_result?; Ok(result) } - Err(_) => { - Err(anyhow!("Step timed out after {} seconds", duration.as_secs())) - } + Err(_) => Err(anyhow!( + "Step timed out after {} seconds", + duration.as_secs() + )), } } - _ => { - Ok(HashMap::new()) - } + _ => Ok(HashMap::new()), } }) } - async fn execute_parallel_steps(&self, name: &str, steps: &[PipelineStep], state: &mut PipelineState) -> Result, Error> { + async fn execute_parallel_steps( + &self, + name: &str, + steps: &[PipelineStep], + state: &mut PipelineState, + ) -> Result, Error> { debug!("Executing Parallel step: {}", name); debug!("Steps: {:?}", steps); // Pre-expand variables for all steps let expanded_steps: Vec = futures::future::try_join_all( - steps.iter().map(|step| self.expand_variables_in_step(step, &state.data)) - ).await?; + steps + .iter() + .map(|step| self.expand_variables_in_step(step, &state.data)), + ) + .await?; let state_arc = Arc::new(tokio::sync::Mutex::new(state.clone())); let mut set = JoinSet::new(); @@ -385,10 +502,14 @@ impl PipelineExec combined_results.extend(step_result); } Ok(Err(e)) => { - combined_results.insert(format!("error_{}", combined_results.len()), e.to_string()); + combined_results + .insert(format!("error_{}", combined_results.len()), e.to_string()); } Err(e) => { - combined_results.insert(format!("join_error_{}", combined_results.len()), e.to_string()); + combined_results.insert( + format!("join_error_{}", combined_results.len()), + e.to_string(), + ); } } } @@ -398,12 +519,24 @@ impl PipelineExec state.data.extend(state_guard.data.clone()); state.data.extend(combined_results); - Ok(HashMap::from([(name.to_string(), "Parallel execution completed".to_string())])) + Ok(HashMap::from([( + name.to_string(), + "Parallel execution completed".to_string(), + )])) } - async fn expand_variables_in_step(&self, step: &PipelineStep, state_data: &HashMap) -> Result { + async fn expand_variables_in_step( + &self, + step: &PipelineStep, + state_data: &HashMap, + ) -> Result { match step { - PipelineStep::Command { name, command, save_output, retry } => { + PipelineStep::Command { + name, + command, + save_output, + retry, + } => { let expanded_command = self.expand_variables(command, state_data).await?; Ok(PipelineStep::Command { name: name.clone(), @@ -411,8 +544,13 @@ impl PipelineExec save_output: save_output.clone(), retry: retry.clone(), }) - }, - PipelineStep::ShellCommand { name, command, save_output, retry } => { + } + PipelineStep::ShellCommand { + name, + command, + save_output, + retry, + } => { let expanded_command = self.expand_variables(command, state_data).await?; Ok(PipelineStep::ShellCommand { name: name.clone(), @@ -420,25 +558,25 @@ impl PipelineExec save_output: save_output.clone(), retry: retry.clone(), }) - }, + } // For other step types, we can simply clone them as they are _ => Ok(step.clone()), } } - fn execute_single_step<'a>( step: &'a PipelineStep, state: &'a mut PipelineState, - ) -> Pin, Error>> + Send + 'a>> { + ) -> PipelineFuture<'a> { Box::pin(async move { match step { - PipelineStep::Command { name: _, command, save_output, retry: _ } => { - let output = Command::new("sh") - .arg("-c") - .arg(command) - .output() - .await?; + PipelineStep::Command { + name: _, + command, + save_output, + retry: _, + } => { + let output = Command::new("sh").arg("-c").arg(command).output().await?; let stdout = String::from_utf8(output.stdout)?; let mut result = HashMap::new(); @@ -447,12 +585,13 @@ impl PipelineExec } Ok(result) } - PipelineStep::ShellCommand { name: _, command, save_output, retry: _ } => { - let output = Command::new("sh") - .arg("-c") - .arg(command) - .output() - .await?; + PipelineStep::ShellCommand { + name: _, + command, + save_output, + retry: _, + } => { + let output = Command::new("sh").arg("-c").arg(command).output().await?; let stdout = String::from_utf8(output.stdout)?; let mut result = HashMap::new(); @@ -461,7 +600,12 @@ impl PipelineExec } Ok(result) } - PipelineStep::Condition { name, condition, if_true, if_false } => { + PipelineStep::Condition { + name, + condition, + if_true, + if_false, + } => { let condition_result = Command::new("sh") .arg("-c") .arg(condition) @@ -483,7 +627,11 @@ impl PipelineExec println!("{}", value); Ok(HashMap::new()) } - PipelineStep::RepeatUntil { name: _, steps, condition } => { + PipelineStep::RepeatUntil { + name: _, + steps, + condition, + } => { let result = HashMap::new(); loop { for sub_step in steps { @@ -507,7 +655,9 @@ impl PipelineExec PipelineStep::ForEach { name, items, steps } => { let mut result = Vec::new(); for item in items.split(',') { - state.data.insert("ITEM".to_string(), item.trim().to_string()); + state + .data + .insert("ITEM".to_string(), item.trim().to_string()); for sub_step in steps { let step_result = Self::execute_single_step(sub_step, state).await?; state.data.extend(step_result); @@ -517,7 +667,12 @@ impl PipelineExec state.data.remove("ITEM"); Ok(HashMap::from([(name.clone(), result.join(", "))])) } - PipelineStep::TryCatch { name: _, try_steps, catch_steps, finally_steps } => { + PipelineStep::TryCatch { + name: _, + try_steps, + catch_steps, + finally_steps, + } => { let mut result = HashMap::new(); let try_result = async { for sub_step in try_steps { @@ -525,7 +680,8 @@ impl PipelineExec state.data.extend(step_result); } Ok(()) as Result<(), Error> - }.await; + } + .await; match try_result { Ok(_) => { @@ -535,7 +691,8 @@ impl PipelineExec result.insert("try_result".to_string(), "failure".to_string()); result.insert("error".to_string(), e.to_string()); for sub_step in catch_steps { - let step_result = Self::execute_single_step(sub_step, state).await?; + let step_result = + Self::execute_single_step(sub_step, state).await?; state.data.extend(step_result); } } @@ -548,13 +705,21 @@ impl PipelineExec Ok(result) } - PipelineStep::Timeout { name: _, duration, step } => { + PipelineStep::Timeout { + name: _, + duration, + step, + } => { let duration = Duration::from_secs(*duration); - let timeout_result = timeout(duration, Self::execute_single_step(step, state)).await; + let timeout_result = + timeout(duration, Self::execute_single_step(step, state)).await; match timeout_result { Ok(step_result) => step_result, - Err(_) => Err(anyhow!("Step timed out after {} seconds", duration.as_secs())), + Err(_) => Err(anyhow!( + "Step timed out after {} seconds", + duration.as_secs() + )), } } PipelineStep::Parallel { name: _, steps } => { @@ -566,16 +731,22 @@ impl PipelineExec } Ok(result) } - _ => { - Err(anyhow!("Unknown step type")) - } + _ => Err(anyhow!("Unknown step type")), } }) } - async fn execute_command(&self, command: &str, save_output: &Option, retry: &Option) -> Result, Error> { + async fn execute_command( + &self, + command: &str, + save_output: &Option, + retry: &Option, + ) -> Result, Error> { debug!("Executing command: {}", command); - let retry_config = retry.clone().unwrap_or(RetryConfig { max_attempts: 1, delay_ms: 0 }); + let retry_config = retry.clone().unwrap_or(RetryConfig { + max_attempts: 1, + delay_ms: 0, + }); let mut attempts = 0; loop { @@ -588,17 +759,26 @@ impl PipelineExec Err(e) if attempts < retry_config.max_attempts => { attempts += 1; warn!("Attempt {} failed: {:?}. Retrying...", attempts, e); - tokio::time::sleep(std::time::Duration::from_millis(retry_config.delay_ms)).await; + tokio::time::sleep(std::time::Duration::from_millis(retry_config.delay_ms)) + .await; } Err(e) => { - error!("Command execution failed after {} attempts: {:?}", attempts + 1, e); + error!( + "Command execution failed after {} attempts: {:?}", + attempts + 1, + e + ); return Err(e); } } } } - async fn run_command(&self, command: &str, save_output: &Option) -> Result, Error> { + async fn run_command( + &self, + command: &str, + save_output: &Option, + ) -> Result, Error> { debug!("Running command: {}", command); let output = TokioCommand::new("sh") .arg("-c") @@ -609,7 +789,11 @@ impl PipelineExec if !output.status.success() { let stderr = String::from_utf8_lossy(&output.stderr); - return Err(anyhow!("Command failed with exit code {:?}. Stderr: {}", output.status.code(), stderr)); + return Err(anyhow!( + "Command failed with exit code {:?}. Stderr: {}", + output.status.code(), + stderr + )); } let stdout = String::from_utf8(output.stdout) @@ -626,18 +810,22 @@ impl PipelineExec Ok(result) } - - - - - async fn execute_shell_command(&self, command: &str, save_output: &Option, retry: &Option) -> Result, Error> { + async fn execute_shell_command( + &self, + command: &str, + save_output: &Option, + retry: &Option, + ) -> Result, Error> { debug!("Executing shell command: {}", command); // Create a temporary file let mut temp_file = tempfile::NamedTempFile::new()?; writeln!(temp_file.as_file_mut(), "{}", command)?; - let retry_config = retry.clone().unwrap_or(RetryConfig { max_attempts: 1, delay_ms: 0 }); + let retry_config = retry.clone().unwrap_or(RetryConfig { + max_attempts: 1, + delay_ms: 0, + }); let mut attempts = 0; loop { @@ -656,17 +844,21 @@ impl PipelineExec Err(e) if attempts < retry_config.max_attempts => { attempts += 1; warn!("Attempt {} failed: {:?}. Retrying...", attempts, e); - tokio::time::sleep(std::time::Duration::from_millis(retry_config.delay_ms)).await; + tokio::time::sleep(std::time::Duration::from_millis(retry_config.delay_ms)) + .await; } Err(e) => { - error!("Shell command execution failed after {} attempts: {:?}", attempts + 1, e); + error!( + "Shell command execution failed after {} attempts: {:?}", + attempts + 1, + e + ); return Err(e); } } } } - async fn run_shell_command(&self, script_path: &Path) -> Result { debug!("Running shell command from file: {:?}", script_path); let output = TokioCommand::new("bash") @@ -677,7 +869,11 @@ impl PipelineExec if !output.status.success() { let stderr = String::from_utf8_lossy(&output.stderr); - return Err(anyhow!("Shell command failed with exit code {:?}. Stderr: {}", output.status.code(), stderr)); + return Err(anyhow!( + "Shell command failed with exit code {:?}. Stderr: {}", + output.status.code(), + stderr + )); } let stdout = String::from_utf8(output.stdout) @@ -688,22 +884,29 @@ impl PipelineExec Ok(stdout.trim().to_string()) } - - async fn evaluate_condition(&self, condition: &str) -> Result { - let expanded_condition = self.expand_variables(condition, &Default::default()).await?; + let expanded_condition = self + .expand_variables(condition, &Default::default()) + .await?; debug!("Evaluating expanded condition: {}", expanded_condition); let output = TokioCommand::new("bash") .arg("-c") - .arg(format!("if {}; then exit 0; else exit 1; fi", expanded_condition)) + .arg(format!( + "if {}; then exit 0; else exit 1; fi", + expanded_condition + )) .output() .await?; Ok(output.status.success()) } - async fn expand_variables(&self, input: &str, state_data: &HashMap) -> Result { + async fn expand_variables( + &self, + input: &str, + state_data: &HashMap, + ) -> Result { debug!("Expanding variables in input: {}", input); let mut result = input.to_string(); for (key, value) in state_data { @@ -711,10 +914,8 @@ impl PipelineExec } Ok(result) } - } - impl PipelineStep { fn name(&self) -> &str { match self { @@ -728,7 +929,7 @@ impl PipelineStep { PipelineStep::RepeatUntil { name, .. } => name, PipelineStep::PrintOutput { name, .. } => name, PipelineStep::ForEach { name, .. } => name, - PipelineStep::TryCatch {name, .. } => name, + PipelineStep::TryCatch { name, .. } => name, PipelineStep::Parallel { name, .. } => name, PipelineStep::Timeout { name, .. } => name, } @@ -759,4 +960,4 @@ impl StateStore for FileStateStore { Ok(None) } } -} \ No newline at end of file +}