forked from Helixform/CodeCursor
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
WIP: extract the request logic of the generated code
- Loading branch information
Showing
7 changed files
with
364 additions
and
345 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,176 @@ | ||
mod models; | ||
|
||
use crate::GenerateInput; | ||
use futures::StreamExt; | ||
use models::*; | ||
use node_bridge::http_client::{HttpMethod, HttpRequest}; | ||
use wasm_bindgen::prelude::*; | ||
|
||
// Split the code into chunks of 20 line blocks. | ||
fn split_code_into_blocks(code: &str) -> Vec<String> { | ||
let lines = code.split("\n"); | ||
let mut blocks = vec![]; | ||
let mut current_block = vec![]; | ||
for line in lines { | ||
current_block.push(line.to_string()); | ||
if current_block.len() >= 20 { | ||
blocks.push(current_block.join("\n")); | ||
current_block = vec![]; | ||
} | ||
} | ||
if current_block.len() > 0 { | ||
blocks.push(current_block.join("\n")); | ||
} | ||
blocks | ||
} | ||
|
||
#[wasm_bindgen(js_name = generateCode)] | ||
pub async fn generate_code(input: &GenerateInput) -> Result<(), JsValue> { | ||
let file_path = input.file_path(); | ||
let file_dir = file_path | ||
.split("/") | ||
.take(file_path.split("/").count() - 1) | ||
.collect::<Vec<&str>>() | ||
.join("/"); | ||
node_bridge::bindings::console::log_str(&format!("file_dir: {}", file_dir)); | ||
let workspace_directory = input.workspace_directory(); | ||
let selection = input.selection_range(); | ||
let document_text_utf16: Vec<u16> = input.document_text().encode_utf16().collect(); | ||
|
||
let selection_text = if selection.length() > 0 { | ||
Some(String::from_utf16_lossy( | ||
&document_text_utf16[selection.offset()..selection.offset() + selection.length()], | ||
)) | ||
} else { | ||
None | ||
}; | ||
let preceding_code = String::from_utf16_lossy(&document_text_utf16[0..selection.offset()]); | ||
let following_code = | ||
String::from_utf16_lossy(&document_text_utf16[selection.offset() + selection.length()..]); | ||
|
||
let message_type = if selection_text.is_some() { | ||
MessageType::Edit | ||
} else { | ||
MessageType::Generate | ||
}; | ||
|
||
let prompt = input.prompt(); | ||
|
||
let user_request = UserRequest::new( | ||
prompt, | ||
file_dir, | ||
file_path.to_owned(), | ||
input.document_text(), | ||
split_code_into_blocks(&preceding_code), | ||
split_code_into_blocks(&following_code), | ||
selection_text, | ||
message_type, | ||
); | ||
let mut request_body = RequestBody::new(user_request, vec![], workspace_directory); | ||
|
||
let result_stream = input.result_stream(); | ||
|
||
// A Boolean value indicating whether the conversation is finished. | ||
let mut finished = false; | ||
// If the conversation was interrupted, we need to send a "continue" request. | ||
let mut interrupted = false; | ||
// Handle the SSE stream. | ||
let mut message_started = false; | ||
let mut first_newline_dropped = false; | ||
|
||
let mut conversation_id: Option<String> = None; | ||
// The last message received from the server. | ||
let mut previous_message: String = "".to_owned(); | ||
let mut last_token = "".to_owned(); | ||
|
||
while !finished { | ||
if interrupted { | ||
// Generate an UUID as conversation ID. | ||
if conversation_id.is_none() { | ||
conversation_id = Some(node_bridge::bindings::uuid::uuid_v4()); | ||
} | ||
let bot_message = BotMessage::new( | ||
conversation_id.clone().unwrap(), | ||
message_type, | ||
previous_message.clone(), | ||
last_token.clone(), | ||
file_path.to_owned(), | ||
); | ||
request_body.bot_messages = vec![bot_message]; | ||
} | ||
|
||
node_bridge::bindings::console::log_str(&serde_json::to_string(&request_body).unwrap()); | ||
|
||
let request = HttpRequest::new(&format!( | ||
"https://aicursor.com/{}", | ||
if interrupted { | ||
"continue" | ||
} else { | ||
"conversation" | ||
} | ||
)) | ||
.set_method(HttpMethod::Post) | ||
.set_body(serde_json::to_string(&request_body).unwrap()) | ||
.add_header("authority", "aicursor.com") | ||
.add_header("accept", "*/*") | ||
.add_header("content-type", "application/json") | ||
.add_header("user-agent", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Cursor/0.1.0 Chrome/108.0.5359.62 Electron/22.0.0 Safari/537.36"); | ||
|
||
let mut response = request.send().await?; | ||
|
||
let body = response.body(); | ||
while let Some(chunk) = body.next().await { | ||
let chunk = chunk.to_string("utf-8"); | ||
node_bridge::bindings::console::log_str(&chunk); | ||
let lines = chunk.split("\n").filter(|l| l.len() > 0); | ||
let mut message_ended = false; | ||
for line in lines { | ||
if !line.starts_with("data: ") { | ||
continue; | ||
} | ||
// A string can be JSON to parse. | ||
let data_str = &line["data: ".len()..]; | ||
let mut data = serde_json::from_str::<String>(data_str).unwrap(); | ||
if data == "<|BEGIN_message|>" { | ||
message_started = true; | ||
continue; | ||
} else if data.contains("<|END_interrupt|>") { | ||
interrupted = true; | ||
last_token = data.clone(); | ||
// `END_interrupt` is included in valid data, | ||
// we cannot discard it. | ||
data = data.replace("<|END_interrupt|>", ""); | ||
} else if data == "<|END_message|>" { | ||
if !interrupted { | ||
finished = true; | ||
} | ||
// We cannot exit the loop here because we're in a nested loop. | ||
message_ended = true; | ||
break; | ||
} | ||
|
||
if message_started { | ||
// Server may produce newlines at the head of response, we need | ||
// to do this trick to ignore them in the final edit. | ||
if !first_newline_dropped && data.trim().len() == 0 { | ||
first_newline_dropped = true; | ||
continue; | ||
} | ||
previous_message.push_str(&data); | ||
result_stream.write(&data); | ||
} | ||
} | ||
// If we've reached the end of the message, break out of the loop. | ||
if message_ended { | ||
break; | ||
} | ||
} | ||
|
||
response.await?; | ||
} | ||
|
||
node_bridge::bindings::console::log_str("done"); | ||
|
||
result_stream.end(); | ||
Ok(()) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
use serde::Serialize; | ||
|
||
use super::{random, request_body::MessageType}; | ||
|
||
#[derive(Debug, Serialize, Clone)] | ||
pub struct BotMessage { | ||
#[serde(rename = "sender")] | ||
pub sender: String, | ||
|
||
#[serde(rename = "sendAt")] | ||
pub send_at: i64, | ||
|
||
#[serde(rename = "conversationId")] | ||
pub conversation_id: String, | ||
|
||
#[serde(rename = "type")] | ||
pub message_type: MessageType, | ||
|
||
#[serde(rename = "message")] | ||
pub message: String, | ||
|
||
#[serde(rename = "lastToken")] | ||
pub last_token: String, | ||
|
||
#[serde(rename = "finished")] | ||
pub finished: bool, | ||
|
||
#[serde(rename = "currentFile")] | ||
pub current_file: String, | ||
|
||
#[serde(rename = "interrupted")] | ||
pub interrupted: bool, | ||
|
||
#[serde(rename = "maxOrigLine")] | ||
pub max_original_line: i32, | ||
|
||
#[serde(rename = "hitTokenLimit")] | ||
pub hit_token_limit: bool, | ||
} | ||
|
||
impl BotMessage { | ||
pub fn new( | ||
conversation_id: String, | ||
message_type: MessageType, | ||
message: String, | ||
last_token: String, | ||
current_file: String, | ||
) -> Self { | ||
Self { | ||
sender: "bot".to_owned(), | ||
send_at: chrono::Utc::now().timestamp_millis(), | ||
conversation_id, | ||
message_type, | ||
message, | ||
last_token, | ||
finished: false, | ||
current_file, | ||
interrupted: true, | ||
max_original_line: random(), | ||
hit_token_limit: true, | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
pub(crate) mod bot_message; | ||
pub(crate) mod request_body; | ||
pub(crate) mod user_request; | ||
|
||
pub(crate) use bot_message::*; | ||
pub(crate) use request_body::*; | ||
pub(crate) use user_request::*; | ||
|
||
fn random() -> i32 { | ||
js_sys::Math::floor(js_sys::Math::random() * 1000.0) as i32 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
use serde::Serialize; | ||
|
||
use super::{bot_message::BotMessage, UserRequest}; | ||
|
||
#[derive(Debug, Clone, Copy, Serialize)] | ||
#[serde(rename_all = "snake_case")] | ||
pub enum MessageType { | ||
Edit, | ||
Generate, | ||
} | ||
|
||
#[derive(Debug, Serialize, Clone)] | ||
pub struct RequestBody { | ||
#[serde(rename = "userRequest")] | ||
pub user_request: UserRequest, | ||
|
||
#[serde(rename = "botMessages")] | ||
pub bot_messages: Vec<BotMessage>, | ||
|
||
#[serde(rename = "userMessages")] | ||
pub user_messages: Vec<String>, | ||
|
||
#[serde(rename = "contextType")] | ||
pub context_type: String, | ||
|
||
#[serde(rename = "rootPath")] | ||
pub root_path: Option<String>, | ||
} | ||
|
||
impl RequestBody { | ||
pub fn new( | ||
user_request: UserRequest, | ||
bot_messages: Vec<BotMessage>, | ||
root_path: Option<String>, | ||
) -> Self { | ||
Self { | ||
user_request, | ||
bot_messages, | ||
user_messages: vec![], | ||
context_type: "copilot".to_owned(), | ||
root_path, | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
use serde::Serialize; | ||
|
||
use super::{random, request_body::MessageType}; | ||
|
||
#[derive(Debug, Serialize, Clone)] | ||
pub struct UserRequest { | ||
pub message: String, | ||
|
||
#[serde(rename = "currentRootPath")] | ||
pub current_root_path: String, | ||
|
||
#[serde(rename = "currentFileName")] | ||
pub current_file_name: String, | ||
|
||
#[serde(rename = "currentFileContents")] | ||
pub current_file_contents: String, | ||
|
||
#[serde(rename = "precedingCode")] | ||
pub preceding_code: Vec<String>, | ||
|
||
#[serde(rename = "suffixCode")] | ||
pub suffix_code: Vec<String>, | ||
|
||
#[serde(rename = "currentSelection")] | ||
pub current_selection: Option<String>, | ||
|
||
#[serde(rename = "copilotCodeBlocks")] | ||
pub copilot_code_blocks: Vec<String>, | ||
|
||
#[serde(rename = "customCodeBlocks")] | ||
pub custom_code_blocks: Vec<String>, | ||
|
||
#[serde(rename = "codeBlockIdentifiers")] | ||
pub code_block_identifiers: Vec<String>, | ||
|
||
#[serde(rename = "msgType")] | ||
pub message_type: MessageType, | ||
|
||
#[serde(rename = "maxOrigLine")] | ||
pub max_original_line: i32, | ||
} | ||
|
||
impl UserRequest { | ||
pub fn new( | ||
message: String, | ||
current_root_path: String, | ||
current_file_name: String, | ||
current_file_contents: String, | ||
preceding_code: Vec<String>, | ||
suffix_code: Vec<String>, | ||
current_selection: Option<String>, | ||
message_type: MessageType, | ||
) -> Self { | ||
Self { | ||
message, | ||
current_root_path, | ||
current_file_name, | ||
current_file_contents, | ||
preceding_code, | ||
suffix_code, | ||
current_selection, | ||
copilot_code_blocks: vec![], | ||
custom_code_blocks: vec![], | ||
code_block_identifiers: vec![], | ||
message_type, | ||
max_original_line: random(), | ||
} | ||
} | ||
} |
Oops, something went wrong.