Skip to content

Commit

Permalink
WIP: extract the request logic of the generated code
Browse files Browse the repository at this point in the history
  • Loading branch information
ktiays committed Mar 22, 2023
1 parent 688a752 commit 411ac01
Show file tree
Hide file tree
Showing 7 changed files with 364 additions and 345 deletions.
176 changes: 176 additions & 0 deletions crates/cursor-core/src/generate/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
mod models;

use crate::GenerateInput;
use futures::StreamExt;
use models::*;
use node_bridge::http_client::{HttpMethod, HttpRequest};
use wasm_bindgen::prelude::*;

// Split the code into chunks of 20 line blocks.
fn split_code_into_blocks(code: &str) -> Vec<String> {
let lines = code.split("\n");
let mut blocks = vec![];
let mut current_block = vec![];
for line in lines {
current_block.push(line.to_string());
if current_block.len() >= 20 {
blocks.push(current_block.join("\n"));
current_block = vec![];
}
}
if current_block.len() > 0 {
blocks.push(current_block.join("\n"));
}
blocks
}

#[wasm_bindgen(js_name = generateCode)]
pub async fn generate_code(input: &GenerateInput) -> Result<(), JsValue> {
let file_path = input.file_path();
let file_dir = file_path
.split("/")
.take(file_path.split("/").count() - 1)
.collect::<Vec<&str>>()
.join("/");
node_bridge::bindings::console::log_str(&format!("file_dir: {}", file_dir));
let workspace_directory = input.workspace_directory();
let selection = input.selection_range();
let document_text_utf16: Vec<u16> = input.document_text().encode_utf16().collect();

let selection_text = if selection.length() > 0 {
Some(String::from_utf16_lossy(
&document_text_utf16[selection.offset()..selection.offset() + selection.length()],
))
} else {
None
};
let preceding_code = String::from_utf16_lossy(&document_text_utf16[0..selection.offset()]);
let following_code =
String::from_utf16_lossy(&document_text_utf16[selection.offset() + selection.length()..]);

let message_type = if selection_text.is_some() {
MessageType::Edit
} else {
MessageType::Generate
};

let prompt = input.prompt();

let user_request = UserRequest::new(
prompt,
file_dir,
file_path.to_owned(),
input.document_text(),
split_code_into_blocks(&preceding_code),
split_code_into_blocks(&following_code),
selection_text,
message_type,
);
let mut request_body = RequestBody::new(user_request, vec![], workspace_directory);

let result_stream = input.result_stream();

// A Boolean value indicating whether the conversation is finished.
let mut finished = false;
// If the conversation was interrupted, we need to send a "continue" request.
let mut interrupted = false;
// Handle the SSE stream.
let mut message_started = false;
let mut first_newline_dropped = false;

let mut conversation_id: Option<String> = None;
// The last message received from the server.
let mut previous_message: String = "".to_owned();
let mut last_token = "".to_owned();

while !finished {
if interrupted {
// Generate an UUID as conversation ID.
if conversation_id.is_none() {
conversation_id = Some(node_bridge::bindings::uuid::uuid_v4());
}
let bot_message = BotMessage::new(
conversation_id.clone().unwrap(),
message_type,
previous_message.clone(),
last_token.clone(),
file_path.to_owned(),
);
request_body.bot_messages = vec![bot_message];
}

node_bridge::bindings::console::log_str(&serde_json::to_string(&request_body).unwrap());

let request = HttpRequest::new(&format!(
"https://aicursor.com/{}",
if interrupted {
"continue"
} else {
"conversation"
}
))
.set_method(HttpMethod::Post)
.set_body(serde_json::to_string(&request_body).unwrap())
.add_header("authority", "aicursor.com")
.add_header("accept", "*/*")
.add_header("content-type", "application/json")
.add_header("user-agent", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Cursor/0.1.0 Chrome/108.0.5359.62 Electron/22.0.0 Safari/537.36");

let mut response = request.send().await?;

let body = response.body();
while let Some(chunk) = body.next().await {
let chunk = chunk.to_string("utf-8");
node_bridge::bindings::console::log_str(&chunk);
let lines = chunk.split("\n").filter(|l| l.len() > 0);
let mut message_ended = false;
for line in lines {
if !line.starts_with("data: ") {
continue;
}
// A string can be JSON to parse.
let data_str = &line["data: ".len()..];
let mut data = serde_json::from_str::<String>(data_str).unwrap();
if data == "<|BEGIN_message|>" {
message_started = true;
continue;
} else if data.contains("<|END_interrupt|>") {
interrupted = true;
last_token = data.clone();
// `END_interrupt` is included in valid data,
// we cannot discard it.
data = data.replace("<|END_interrupt|>", "");
} else if data == "<|END_message|>" {
if !interrupted {
finished = true;
}
// We cannot exit the loop here because we're in a nested loop.
message_ended = true;
break;
}

if message_started {
// Server may produce newlines at the head of response, we need
// to do this trick to ignore them in the final edit.
if !first_newline_dropped && data.trim().len() == 0 {
first_newline_dropped = true;
continue;
}
previous_message.push_str(&data);
result_stream.write(&data);
}
}
// If we've reached the end of the message, break out of the loop.
if message_ended {
break;
}
}

response.await?;
}

node_bridge::bindings::console::log_str("done");

result_stream.end();
Ok(())
}
63 changes: 63 additions & 0 deletions crates/cursor-core/src/generate/models/bot_message.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
use serde::Serialize;

use super::{random, request_body::MessageType};

#[derive(Debug, Serialize, Clone)]
pub struct BotMessage {
#[serde(rename = "sender")]
pub sender: String,

#[serde(rename = "sendAt")]
pub send_at: i64,

#[serde(rename = "conversationId")]
pub conversation_id: String,

#[serde(rename = "type")]
pub message_type: MessageType,

#[serde(rename = "message")]
pub message: String,

#[serde(rename = "lastToken")]
pub last_token: String,

#[serde(rename = "finished")]
pub finished: bool,

#[serde(rename = "currentFile")]
pub current_file: String,

#[serde(rename = "interrupted")]
pub interrupted: bool,

#[serde(rename = "maxOrigLine")]
pub max_original_line: i32,

#[serde(rename = "hitTokenLimit")]
pub hit_token_limit: bool,
}

impl BotMessage {
pub fn new(
conversation_id: String,
message_type: MessageType,
message: String,
last_token: String,
current_file: String,
) -> Self {
Self {
sender: "bot".to_owned(),
send_at: chrono::Utc::now().timestamp_millis(),
conversation_id,
message_type,
message,
last_token,
finished: false,
current_file,
interrupted: true,
max_original_line: random(),
hit_token_limit: true,
}
}
}
11 changes: 11 additions & 0 deletions crates/cursor-core/src/generate/models/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
pub(crate) mod bot_message;
pub(crate) mod request_body;
pub(crate) mod user_request;

pub(crate) use bot_message::*;
pub(crate) use request_body::*;
pub(crate) use user_request::*;

fn random() -> i32 {
js_sys::Math::floor(js_sys::Math::random() * 1000.0) as i32
}
44 changes: 44 additions & 0 deletions crates/cursor-core/src/generate/models/request_body.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
use serde::Serialize;

use super::{bot_message::BotMessage, UserRequest};

#[derive(Debug, Clone, Copy, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum MessageType {
Edit,
Generate,
}

#[derive(Debug, Serialize, Clone)]
pub struct RequestBody {
#[serde(rename = "userRequest")]
pub user_request: UserRequest,

#[serde(rename = "botMessages")]
pub bot_messages: Vec<BotMessage>,

#[serde(rename = "userMessages")]
pub user_messages: Vec<String>,

#[serde(rename = "contextType")]
pub context_type: String,

#[serde(rename = "rootPath")]
pub root_path: Option<String>,
}

impl RequestBody {
pub fn new(
user_request: UserRequest,
bot_messages: Vec<BotMessage>,
root_path: Option<String>,
) -> Self {
Self {
user_request,
bot_messages,
user_messages: vec![],
context_type: "copilot".to_owned(),
root_path,
}
}
}
69 changes: 69 additions & 0 deletions crates/cursor-core/src/generate/models/user_request.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
use serde::Serialize;

use super::{random, request_body::MessageType};

#[derive(Debug, Serialize, Clone)]
pub struct UserRequest {
pub message: String,

#[serde(rename = "currentRootPath")]
pub current_root_path: String,

#[serde(rename = "currentFileName")]
pub current_file_name: String,

#[serde(rename = "currentFileContents")]
pub current_file_contents: String,

#[serde(rename = "precedingCode")]
pub preceding_code: Vec<String>,

#[serde(rename = "suffixCode")]
pub suffix_code: Vec<String>,

#[serde(rename = "currentSelection")]
pub current_selection: Option<String>,

#[serde(rename = "copilotCodeBlocks")]
pub copilot_code_blocks: Vec<String>,

#[serde(rename = "customCodeBlocks")]
pub custom_code_blocks: Vec<String>,

#[serde(rename = "codeBlockIdentifiers")]
pub code_block_identifiers: Vec<String>,

#[serde(rename = "msgType")]
pub message_type: MessageType,

#[serde(rename = "maxOrigLine")]
pub max_original_line: i32,
}

impl UserRequest {
pub fn new(
message: String,
current_root_path: String,
current_file_name: String,
current_file_contents: String,
preceding_code: Vec<String>,
suffix_code: Vec<String>,
current_selection: Option<String>,
message_type: MessageType,
) -> Self {
Self {
message,
current_root_path,
current_file_name,
current_file_contents,
preceding_code,
suffix_code,
current_selection,
copilot_code_blocks: vec![],
custom_code_blocks: vec![],
code_block_identifiers: vec![],
message_type,
max_original_line: random(),
}
}
}
Loading

0 comments on commit 411ac01

Please sign in to comment.