Skip to content

Commit

Permalink
generate-cypher is working with llm configurations, can upload direct…
Browse files Browse the repository at this point in the history
…ories or files.
  • Loading branch information
njfio committed Jul 22, 2024
1 parent 30428c9 commit 30c5853
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 78 deletions.
216 changes: 139 additions & 77 deletions crates/fluent-cli/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@ use fluent_core::spinner_configuration::SpinnerConfig;
use fluent_core::traits::Engine;
use fluent_core::types::Request;
use fluent_engines::anthropic::AnthropicEngine;
use fluent_engines::cohere::CohereEngine;
use fluent_engines::google_gemini::GoogleGeminiEngine;
use fluent_engines::groqlpu::GroqLPUEngine;
use fluent_engines::openai::OpenAIEngine;
use fluent_engines::perplexity::PerplexityEngine;

pub mod cli {
use std::pin::Pin;
Expand All @@ -36,7 +40,7 @@ pub mod cli {
use std::io::IsTerminal;
use indicatif::{ProgressBar, ProgressStyle};

use log::{debug, error};
use log::{debug, error, info};
use serde_json::Value;
use tokio::io::AsyncReadExt;
use tokio::time::Instant;
Expand All @@ -54,7 +58,7 @@ pub mod cli {
use fluent_engines::leonardoai::LeonardoAIEngine;
use fluent_engines::mistral::MistralEngine;
use fluent_engines::stabilityai::StabilityAIEngine;
use crate::{create_llm_engine, generate_and_execute_cypher};
use crate::{create_engine, create_llm_engine, generate_and_execute_cypher};


fn parse_key_value_pair(s: &str) -> Option<(String, String)> {
Expand Down Expand Up @@ -246,6 +250,7 @@ pub mod cli {
let spinner_config = config.engines[0].spinner.clone().unwrap_or_default();
let pb = ProgressBar::new_spinner();
let engine_config = &config.engines[0];
let start_time = Instant::now();

let spinner_style = ProgressStyle::default_spinner()
.tick_chars(&spinner_config.frames)
Expand All @@ -255,86 +260,135 @@ pub mod cli {
pb.set_style(spinner_style);
pb.set_message(format!("Processing {} request...", engine_name));
pb.enable_steady_tick(Duration::from_millis(spinner_config.interval));
pb.set_length(100);

if matches.get_one::<String>("generate-cypher").is_some() {
let neo4j_config = config.engines.iter()
.find(|e| e.engine == "neo4j")
.ok_or_else(|| anyhow!("No Neo4j engine configured"))?;
let query_llm = neo4j_config.neo4j.as_ref()
.and_then(|c| c.query_llm.as_ref())
.ok_or_else(|| anyhow!("No query LLM configured"))?;
debug!("config: {:?}", config.engines);
debug!("neo4j_config: {:?}", neo4j_config);
debug!("query_llm: {:?}", query_llm);
if let Some(cypher_query) = matches.get_one::<String>("generate-cypher") {
let neo4j_config = engine_config.neo4j.as_ref()
.ok_or_else(|| anyhow!("Neo4j configuration not found in the engine config"))?;

let query_llm_name = neo4j_config.query_llm.as_ref()
.ok_or_else(|| anyhow!("No query LLM specified for Neo4j"))?;

let query_llm_full_config = load_config(&config_path, query_llm, &overrides)?;
// Load the configuration for the query LLM
let query_llm_config = load_config(&config_path, query_llm_name, &HashMap::new())?;
let query_llm_engine_config = &query_llm_config.engines[0];

let llm_config = query_llm_full_config.engines.iter()
.find(|e| e.name == query_llm.to_string())
.ok_or_else(|| anyhow!("No configuration found for query LLM: {}", query_llm))?;

let engine = create_llm_engine(llm_config).await?;
debug!("llm_config: {:?}", llm_config);

let cypher_query = matches.get_one::<String>("generate-cypher").unwrap();
debug!("cypher_query: {:?}", cypher_query);
let neo4j_config = config.engines.iter()
.find(|e| e.engine == "neo4j")
.and_then(|e| e.neo4j.as_ref())
.ok_or_else(|| anyhow!("Neo4j configuration not found"))?;
let query_llm_engine = create_llm_engine(query_llm_engine_config).await?;

let cypher_result = generate_and_execute_cypher(
neo4j_config,
llm_config,
query_llm_engine_config,
cypher_query,
&*engine
&*query_llm_engine
).await?;
println!("{}", cypher_result);
}

if engine_config.engine == "neo4j" {
println!("{}", cypher_result);
} else {
let engine: Box<dyn Engine> = create_engine(engine_config).await?;

let max_tokens = engine_config.parameters.get("max_tokens")
.and_then(|v| v.as_i64())
.unwrap_or(-1);

let user_request = matches.get_one::<String>("request")
.map(|s| s.to_string())
.unwrap_or_else(String::new);

let mut combined_request = format!(
"Cypher query: {}\n\nCypher result:\n{}\n\nBased on the above Cypher query and its result, please provide an analysis or answer the following question: {}",
cypher_query, cypher_result, user_request
);

// Truncate the combined request if it exceeds the max tokens
if max_tokens > 0 && combined_request.len() > max_tokens as usize {
combined_request.truncate(max_tokens as usize);
combined_request += "... [truncated]";
}
info!("Combined request: {}", combined_request);
let request = Request {
flowname: engine_name.to_string(),
payload: combined_request,
};

let response = Pin::from(engine.execute(&request)).await?;
let mut output = response.content.clone();

if let Some(download_dir) = matches.get_one::<String>("download-media") {
let download_path = PathBuf::from(download_dir);
OutputProcessor::download_media_files(&response.content, &download_path).await?;
}

if matches.get_flag("parse-code") {
debug!("Parsing code blocks");
let code_blocks = OutputProcessor::parse_code(&output);
debug!("Code blocks: {:?}", code_blocks);
output = code_blocks.join("\n\n");
}

if matches.get_flag("execute-output") {
debug!("Executing output code");
debug!("Attempting to execute : {}", output);
output = OutputProcessor::execute_code(&output).await?;
}

if matches.get_flag("markdown") {
debug!("Formatting output as markdown");
//output = format_markdown(&output);
}

let response_time = start_time.elapsed().as_secs_f64();

if let Some(neo4j_client) = engine.get_neo4j_client() {
let session_id = engine.get_session_id()
.unwrap_or_else(|| Uuid::new_v4().to_string());

let stats = InteractionStats {
prompt_tokens: response.usage.prompt_tokens,
completion_tokens: response.usage.completion_tokens,
total_tokens: response.usage.total_tokens,
response_time,
finish_reason: response.finish_reason.clone().unwrap_or_else(|| "unknown".to_string()),
};

debug!("Attempting to create interaction in Neo4j");
debug!("Using session ID: {}", session_id);
match neo4j_client.create_interaction(
&session_id,
&request.payload,
&response.content,
&response.model,
&stats
).await {
Ok(interaction_id) => debug!("Successfully created interaction with id: {}", interaction_id),
Err(e) => error!("Failed to create interaction in Neo4j: {:?}", e),
}
} else {
debug!("Neo4j client not available, skipping interaction logging");
}

if matches.get_flag("upsert") {
pb.finish_and_clear();
eprintln!();
println!("{}", output);

let use_colors = std::io::stderr().is_terminal();
let response_time_str = format!("{:.2}s", response_time);

eprintln!(
"{} | {} | Time: {} | Usage: {}↑ {}↓ {}Σ | {}\n",
spinner_config.success_symbol,
if use_colors { response.model.cyan().to_string() } else { response.model },
if use_colors { response_time_str.bright_blue().to_string() } else { response_time_str },
if use_colors { response.usage.prompt_tokens.to_string().yellow().to_string() } else { response.usage.prompt_tokens.to_string() },
if use_colors { response.usage.completion_tokens.to_string().yellow().to_string() } else { response.usage.completion_tokens.to_string() },
if use_colors { response.usage.total_tokens.to_string().yellow().to_string() } else { response.usage.total_tokens.to_string() },
if use_colors { response.finish_reason.as_deref().unwrap_or("No finish reason").italic().to_string() } else { response.finish_reason.as_deref().unwrap_or("No finish reason").to_string() }
);
}
} else if matches.get_flag("upsert") {
debug!("Upsert mode enabled");
handle_upsert(engine_config, &matches).await?;
} else if matches.get_one::<String>("generate-cypher").is_some() && engine_config.engine == "neo4j" {
debug!("Generate Cypher mode enabled");
let neo4j_config = config.engines.iter()
.find(|e| e.engine == "neo4j")
.ok_or_else(|| anyhow!("No Neo4j engine configured"))?;
let query_llm = neo4j_config.neo4j.as_ref()
.and_then(|c| c.query_llm.as_ref())
.ok_or_else(|| anyhow!("No query LLM configured"))?;
debug!("config: {:?}", config.engines);
debug!("neo4j_config: {:?}", neo4j_config);
debug!("query_llm: {:?}", query_llm);


let query_llm_full_config = load_config(&config_path, query_llm, &overrides)?;

let llm_config = query_llm_full_config.engines.iter()
.find(|e| e.name == query_llm.to_string())
.ok_or_else(|| anyhow!("No configuration found for query LLM: {}", query_llm))?;

let engine = create_llm_engine(llm_config).await?;
debug!("llm_config: {:?}", llm_config);

let cypher_query = matches.get_one::<String>("generate-cypher").unwrap();
debug!("cypher_query: {:?}", cypher_query);
let neo4j_config = config.engines.iter()
.find(|e| e.engine == "neo4j")
.and_then(|e| e.neo4j.as_ref())
.ok_or_else(|| anyhow!("Neo4j configuration not found"))?;

let cypher_result = generate_and_execute_cypher(
neo4j_config,
llm_config,
cypher_query,
&*engine
).await?;
println!("{}", cypher_result);
return Ok(());
} else {
} else {
let request = matches.get_one::<String>("request").unwrap();

let engine: Box<dyn Engine> = match engine_config.engine.as_str() {
Expand Down Expand Up @@ -399,7 +453,6 @@ pub mod cli {
};
debug!("Combined Request: {:?}", request);

let start_time = Instant::now();

let response = if let Some(file_path) = matches.get_one::<String>("upload-image-file") {
debug!("Processing request with file: {}", file_path);
Expand Down Expand Up @@ -571,13 +624,6 @@ async fn get_neo4j_query_llm(config: &Config) -> Option<Box<dyn Engine>> {
create_llm_engine(llm_config).await.ok()
}

async fn create_llm_engine(engine_config: &EngineConfig) -> Result<Box<dyn Engine>, Error> {
match engine_config.engine.as_str() {
"openai" => Ok(Box::new(OpenAIEngine::new(engine_config.clone()).await?)),
"anthropic" => Ok(Box::new(AnthropicEngine::new(engine_config.clone()).await?)),
_ => Err(anyhow!("Unsupported LLM engine: {}", engine_config.engine)),
}
}
async fn generate_and_execute_cypher(
neo4j_config: &Neo4jConfig,
llm_config: &EngineConfig,
Expand Down Expand Up @@ -653,3 +699,19 @@ fn format_as_csv(result: &Value) -> String {
// For now, we'll just return the JSON as a string
result.to_string()
}

async fn create_engine(engine_config: &EngineConfig) -> Result<Box<dyn Engine>, Error> {
match engine_config.engine.as_str() {
"openai" => Ok(Box::new(OpenAIEngine::new(engine_config.clone()).await?)),
"anthropic" => Ok(Box::new(AnthropicEngine::new(engine_config.clone()).await?)),
"cohere" => Ok(Box::new(CohereEngine::new(engine_config.clone()).await?)),
"google_gemini" => Ok(Box::new(GoogleGeminiEngine::new(engine_config.clone()).await?)),
"perplexity" => Ok(Box::new(PerplexityEngine::new(engine_config.clone()).await?)),
"groq_lpu" => Ok(Box::new(GroqLPUEngine::new(engine_config.clone()).await?)),
_ => Err(anyhow!("Unsupported engine: {}", engine_config.engine)),
}
}

async fn create_llm_engine(engine_config: &EngineConfig) -> Result<Box<dyn Engine>, Error> {
create_engine(engine_config).await
}
2 changes: 1 addition & 1 deletion default_config_test.json
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
"api_key": "AMBER_FLUENT_VOYAGE_AI_KEY",
"model": "voyage-large-2"
},
"query_llm": "anthropic",
"query_llm": "sonnet3.5",
"parameters": {

}
Expand Down

0 comments on commit 30c5853

Please sign in to comment.