Skip to content

Commit

Permalink
Add interactive mode in cli
Browse files Browse the repository at this point in the history
  • Loading branch information
kazord committed Apr 1, 2023
1 parent f6249e8 commit ec514ec
Showing 1 changed file with 68 additions and 17 deletions.
85 changes: 68 additions & 17 deletions src/rllama_main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ struct Cli {
#[arg(long)]
prompt_file: Option<String>,

#[arg(long)]
interactive_stop: Option<String>,
#[arg(long, action)]
start_interactive: bool,

#[arg(long)]
max_seq_len: Option<usize>,

Expand Down Expand Up @@ -94,7 +99,8 @@ pub fn main() -> Result<(), Box<dyn std::error::Error>> {
let model_path = cli.model_path.clone();
let tokenizer_path = cli.tokenizer_path.clone();
let param_path = cli.param_path.clone();

let interactive_stop = cli.interactive_stop.clone().unwrap_or("[EOF]".to_string());
let start_interactive = cli.start_interactive;
#[cfg(not(feature = "server"))]
if cli.inference_server {
eprintln!("Inference server is not enabled in this build.");
Expand Down Expand Up @@ -268,6 +274,8 @@ pub fn main() -> Result<(), Box<dyn std::error::Error>> {
tr.clone(),
tok.clone(),
prompt.clone(),
interactive_stop.clone(),
start_interactive,
be_quiet,
max_seq_len,
params.clone(),
Expand Down Expand Up @@ -676,6 +684,8 @@ fn command_line_inference(
tr: Arc<Transformer>,
tok: Arc<Tokenizer>,
prompt: String,
interactive_stop: String,
start_interactive: bool,
be_quiet: bool,
max_seq_len: usize,
params: ModelParams,
Expand Down Expand Up @@ -710,7 +720,8 @@ fn command_line_inference(
if let Some(repetition_penalty) = cli.repetition_penalty {
token_sampler = token_sampler.repetition_penalty(repetition_penalty);
}

let mut stop_tokens = tok.tokenize_to_ids(interactive_stop.clone());
stop_tokens.remove(0);
pln!("---");
pln!(" dim: {}", params.dim);
pln!(" multiple_of: {}", params.multiple_of);
Expand Down Expand Up @@ -738,39 +749,64 @@ fn command_line_inference(
"{}",
" This is the color of the generated text".truecolor(128, 255, 128)
);
pln!("stop keywords are {}", interactive_stop.as_str());
pln!("---");
print!("{}", prompt.as_str().truecolor(128, 128, 255));

let _ = std::io::stdout().flush();

let mut first_token_time: std::time::Duration = std::time::Duration::new(0, 0);
let mut times_per_token: Vec<std::time::Duration> = vec![];
let mut caches = tr.make_caches();
let mut first: bool = true;
let mut stop_seen: bool = false;
let mut interactive = start_interactive;
let mut user_token: Vec<TokenId> = vec![];
while toks_id.len() < max_seq_len {
let now = std::time::Instant::now();
let preds = tr.forward(&toks_id[prev_pos..], prev_pos, &mut caches);

let (highest_pred_idx, token_prob) = token_sampler.sample(&preds, &tok, &toks_id);
if interactive {
let mut newinput = String::new();
std::io::stdin().read_line(&mut newinput)?;
let _ = newinput.pop();
//removing new line from input
user_token = tok.tokenize_to_ids(newinput);

//removing [1, ... , end of token]
let _ = user_token.remove(0);
let _ = user_token.pop();
//toks_id.append(&mut user_token);
interactive = false;
}
let (mut highest_pred_idx, mut token_prob);

if user_token.len() > 0 {
highest_pred_idx = user_token.remove(0);
token_prob = 0.0;
}
else {
(highest_pred_idx, token_prob) = token_sampler.sample(&preds, &tok, &toks_id);
}
toks_id.push(highest_pred_idx as TokenId);

for (tok_idx, tok_id) in toks_id[prev_pos + 1..].iter().enumerate() {
if *tok_id == 1 {
continue;
}
let mut tok_str: String = "".to_string();
let tok = tok.id_to_str(*tok_id);
if tok == "</s>" {
tok_str += "";
let mut tok_print: String = "".to_string();
let tok_str = tok.id_to_str(*tok_id);
if tok_str == "</s>" {
tok_print += "";
stop_seen = true;
}
if tok == "<0x0A>" {
tok_str += "\n";
if tok_str == "<0x0A>" {
tok_print += "\n";
} else {
tok_str += tok.replace('▁', " ").as_str();
tok_print += tok_str.replace('▁', " ").as_str();
}
if first && tok_idx < toks_id.len() - 2 {
// intentionally left empty
// intentionally left empty, already print
} else {
let redness: f32 = token_prob * 255.0;
let redness = if redness > 255.0 {
Expand All @@ -782,8 +818,14 @@ fn command_line_inference(
};
print!(
"{}",
tok_str.truecolor(128 + redness / 2, 255 - redness / 2, 128)
tok_print.truecolor(128 + redness / 2, 255 - redness / 2, 128)
);
};
if !first && tok_id == stop_tokens.last().unwrap()
&& tok_idx+prev_pos > stop_tokens.len()
&& toks_id[prev_pos+1+tok_idx-(stop_tokens.len()-1) .. prev_pos+1+tok_idx+1] == stop_tokens
{
interactive = true;
}
}
if first {
Expand All @@ -797,6 +839,9 @@ fn command_line_inference(
if stop_seen {
break;
}



}
println!();
if stop_seen && !be_quiet {
Expand All @@ -808,11 +853,17 @@ fn command_line_inference(
"Time taken to generate first token: {:?}ms",
first_token_time.as_millis()
);
println!(
"Time taken per token (excluding first token): {:?}ms",
times_per_token.iter().map(|t| t.as_millis()).sum::<u128>()
/ times_per_token.len() as u128
);
if times_per_token.len() > 0 {
println!(
"Time taken per token (excluding first token): {:?}ms",
times_per_token.iter().map(|t| t.as_millis()).sum::<u128>()
/ times_per_token.len() as u128
);
}
else {
println!( "No token generated");
}

}
Ok(())
}

0 comments on commit ec514ec

Please sign in to comment.