Skip to content

Commit

Permalink
Document interactive mode, run rustfmt on rllama_main.rs
Browse files Browse the repository at this point in the history
  • Loading branch information
Noeda committed Apr 2, 2023
1 parent ec514ec commit 39f69f2
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 32 deletions.
13 changes: 13 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,19 @@ rllama --tokenizer-model /path/to/tokenizer.model \

Use `rllama --help` to see all the options.

## Interactive mode

There is a simple interactive mode to do back-and-forth discussion with the model.

```shell
rllama ... --start-interactive
```

In this mode, you need to type your prompt before the AI starts doing its work.
If the AI outputs token sequence `[EOF]` (you can set it with
`--interactive-stop` switch) then you can type a new prompt that will be
appended to the sequence.

## Inference server

`rllama` can run in an inference server mode with a simple HTTP JSON API. You
Expand Down
71 changes: 39 additions & 32 deletions src/rllama_main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,27 +162,32 @@ pub fn main() -> Result<(), Box<dyn std::error::Error>> {
let params: ModelParams = serde_json::from_slice(&bs)?;
pln!("Loaded model parameters from {}.", param_path);

let prompt: String = match (&cli.prompt, &cli.prompt_file) {
(Some(ref prompt), None) => {
let prompt: String = match (&cli.prompt, &cli.prompt_file, start_interactive) {
(Some(ref prompt), None, _) => {
pln!("Using prompt: {}", prompt);
prompt.clone()
}
(None, Some(ref prompt_file)) => {
(None, Some(ref prompt_file), _) => {
pln!("Using prompt file: {}", prompt_file);
let mut fs = std::fs::File::open(prompt_file)?;
let mut bs = Vec::new();
fs.read_to_end(&mut bs)?;
std::mem::drop(fs);
String::from_utf8(bs)?
}
_ => {
(_, _, false) => {
if cli.inference_server {
"".to_string()
} else {
eprintln!("Please provide either a prompt or a prompt file.");
return Err("Please provide either a prompt or a prompt file.".into());
}
}
(None, None, true) => "".to_string(),
(_, _, true) => {
eprintln!("Please provide either a prompt or a prompt file.");
return Err("Please provide either a prompt or a prompt file.".into());
}
};

pln!("Starting up. Loading tokenizer from {}...", tokenizer_path);
Expand Down Expand Up @@ -749,10 +754,13 @@ fn command_line_inference(
"{}",
" This is the color of the generated text".truecolor(128, 255, 128)
);
pln!("stop keywords are {}", interactive_stop.as_str());
pln!(
"Interactive mode stop token sequence: {}",
interactive_stop.as_str()
);
pln!("---");
print!("{}", prompt.as_str().truecolor(128, 128, 255));

let _ = std::io::stdout().flush();

let mut first_token_time: std::time::Duration = std::time::Duration::new(0, 0);
Expand All @@ -765,28 +773,27 @@ fn command_line_inference(
while toks_id.len() < max_seq_len {
let now = std::time::Instant::now();
let preds = tr.forward(&toks_id[prev_pos..], prev_pos, &mut caches);

if interactive {
let mut newinput = String::new();
std::io::stdin().read_line(&mut newinput)?;
let _ = newinput.pop();
//removing new line from input
user_token = tok.tokenize_to_ids(newinput);

//removing [1, ... , end of token]
// removing new line from input
if newinput.ends_with('\n') {
let _ = newinput.pop();
}
user_token = tok.tokenize_to_ids(newinput.clone());
println!("#{:?}# #{}#", user_token, newinput);

// removing [start token] as it is already in the prompt, and tokenize_to_ids adds it.
let _ = user_token.remove(0);
let _ = user_token.pop();
//toks_id.append(&mut user_token);
interactive = false;
}
let (mut highest_pred_idx, mut token_prob);

if user_token.len() > 0 {
highest_pred_idx = user_token.remove(0);
token_prob = 0.0;
}
else {
(highest_pred_idx, token_prob) = token_sampler.sample(&preds, &tok, &toks_id);
} else {
(highest_pred_idx, token_prob) = token_sampler.sample(&preds, &tok, &toks_id);
}
toks_id.push(highest_pred_idx as TokenId);

Expand Down Expand Up @@ -821,11 +828,16 @@ fn command_line_inference(
tok_print.truecolor(128 + redness / 2, 255 - redness / 2, 128)
);
};
if !first && tok_id == stop_tokens.last().unwrap()
&& tok_idx+prev_pos > stop_tokens.len()
&& toks_id[prev_pos+1+tok_idx-(stop_tokens.len()-1) .. prev_pos+1+tok_idx+1] == stop_tokens
if !first
&& tok_id == stop_tokens.last().unwrap()
&& tok_idx + prev_pos > stop_tokens.len()
&& toks_id
[prev_pos + 1 + tok_idx - (stop_tokens.len() - 1)..prev_pos + 1 + tok_idx + 1]
== stop_tokens
{
if start_interactive {
interactive = true;
}
}
}
if first {
Expand All @@ -839,9 +851,6 @@ fn command_line_inference(
if stop_seen {
break;
}



}
println!();
if stop_seen && !be_quiet {
Expand All @@ -854,16 +863,14 @@ fn command_line_inference(
first_token_time.as_millis()
);
if times_per_token.len() > 0 {
println!(
"Time taken per token (excluding first token): {:?}ms",
times_per_token.iter().map(|t| t.as_millis()).sum::<u128>()
/ times_per_token.len() as u128
println!(
"Time taken per token (excluding first token): {:?}ms",
times_per_token.iter().map(|t| t.as_millis()).sum::<u128>()
/ times_per_token.len() as u128
);
} else {
println!("No token generated");
}
else {
println!( "No token generated");
}

}
Ok(())
}

0 comments on commit 39f69f2

Please sign in to comment.