Skip to content

Commit

Permalink
Adjust default interactive settings so they match what Vicuna-13B exp…
Browse files Browse the repository at this point in the history
…ects.
  • Loading branch information
Noeda committed Apr 6, 2023
1 parent 4faf07a commit 7aa1051
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 30 deletions.
36 changes: 23 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,30 +114,40 @@ back-and-forth discussion with the model.

```shell
rllama ... --start-interactive \
--interactive-prompt-postfix " AI:" \ # (optional)
--interactive-stop "Human: " # (optional but you probably want to set it)
--interactive-system-prompt "Helpful assistant helps curious human." \ # (optional)
--interactive-prompt-postfix " ###Assistant:" \ # (optional)
--interactive-stop "###Human: " # (optional)
```

In this mode, you need to type your prompt before the AI starts doing its work.
If the AI outputs token sequence given in `--interactive-stop` (defaults to
`[EOF]`) then it will ask for another input. You probably want to have `"Human:
"` or something similar, see example below.
`###Human:`) then it will ask for another input.

`--interactive-prompt-postfix` is appended automatically to your answers. You
can use this to force the AI to follow a pattern. Here is a full example of
interactive mode command line:
The defaults match Vicuna-13B model:

```
--interactive-system-prompt "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."
--interactive-prompt-postfix " Assissant:"
--interactive-prompt-prefix " "
--interactive-stop "###Human:"
```

`--interactive-prompt-postfix` is appended automatically to your typed text and
`--interactive-prompt-prefix` is appended to the start of your typed text.Here
is an example of interactive mode command line with the default settings:

```shell
rllama --f16 \
--param-path /LLaMA/7B/params.json \
--model-path /LLaMA/7B \
--param-path /models/vicuna13b/params.json \
--model-path /models/vicuna13b \
--tokenizer-path /stonks/LLaMA/tokenizer.model \
--prompt "This is an interactive session between human and AI assistant. AI: Hi! How can I help you? Human:" \
--start-interactive \
--interactive-stop "Human:" \
--interactive-prompt-postfix " AI:"
--start-interactive
```

As of writing of this, the output is not formatted prettily for chat and there
is no visual indication of when you are supposed to be typing. That will come
later.

## Inference server

`rllama` can run in an inference server mode with a simple HTTP JSON API. You
Expand Down
82 changes: 65 additions & 17 deletions src/rllama_main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,13 @@ struct Cli {
prompt_file: Option<String>,

#[arg(long)]
interactive_stop: Option<String>,
interactive_system_prompt: Option<String>,
#[arg(long)]
interactive_stop: Vec<String>,
#[arg(long)]
interactive_prompt_postfix: Option<String>,
#[arg(long)]
interactive_prompt_prefix: Option<String>,
#[arg(long, action)]
start_interactive: bool,

Expand Down Expand Up @@ -103,11 +107,37 @@ pub fn main() -> Result<(), Box<dyn std::error::Error>> {
let model_path = cli.model_path.clone();
let tokenizer_path = cli.tokenizer_path.clone();
let param_path = cli.param_path.clone();
let interactive_stop = cli.interactive_stop.clone().unwrap_or("[EOF]".to_string());
let interactive_system_prompt = cli.interactive_system_prompt.clone().unwrap_or("A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, terse answers to the human's questions.### Human:".to_string());
let mut interactive_stop = cli.interactive_stop.clone();
if interactive_stop.is_empty() {
// Desperado to catch all weird variants of ###Human the model might spit out.
interactive_stop = vec![
"### Human:".to_string(),
"###Human:".to_string(),
"### Human: ".to_string(),
"###Human: ".to_string(),
" ### Human:".to_string(),
" ###Human:".to_string(),
" ### Human: ".to_string(),
" ###Human: ".to_string(),
"\n### Human:".to_string(),
"\n###Human:".to_string(),
"\n### Human: ".to_string(),
"\n###Human: ".to_string(),
"\n ### Human:".to_string(),
"\n ###Human:".to_string(),
"\n ### Human: ".to_string(),
"\n ###Human: ".to_string(),
];
}
let interactive_prompt_prefix = cli
.interactive_prompt_prefix
.clone()
.unwrap_or(" ".to_string());
let interactive_prompt_postfix = cli
.interactive_prompt_postfix
.clone()
.unwrap_or("".to_string());
.unwrap_or("### Assistant:".to_string());
let start_interactive = cli.start_interactive;
#[cfg(not(feature = "server"))]
if cli.inference_server {
Expand Down Expand Up @@ -273,6 +303,8 @@ pub fn main() -> Result<(), Box<dyn std::error::Error>> {
tok.clone(),
prompt.clone(),
interactive_stop.clone(),
interactive_system_prompt.clone(),
interactive_prompt_prefix.clone(),
interactive_prompt_postfix.clone(),
start_interactive,
be_quiet,
Expand Down Expand Up @@ -684,7 +716,9 @@ fn command_line_inference(
tr: Arc<Transformer>,
tok: Arc<Tokenizer>,
prompt: String,
interactive_stop: String,
interactive_stop: Vec<String>,
interactive_system_prompt: String,
interactive_prompt_prefix: String,
interactive_prompt_postfix: String,
start_interactive: bool,
be_quiet: bool,
Expand All @@ -701,7 +735,18 @@ fn command_line_inference(
};
}

let mut prompt = prompt;

if start_interactive && !prompt.is_empty() {
return Err(
"Cannot start interactive mode with a prompt. Use --interactive-system-prompt instead."
.into(),
);
}
prompt = interactive_system_prompt.clone();

let mut toks_id: Vec<TokenId> = tok.tokenize_to_ids(prompt.clone());
let mut toks_str: String = prompt.clone();
let mut prev_pos = 0;
let mut token_sampler = TokenSampler::new()
.temperature(1.0)
Expand All @@ -721,8 +766,6 @@ fn command_line_inference(
if let Some(repetition_penalty) = cli.repetition_penalty {
token_sampler = token_sampler.repetition_penalty(repetition_penalty);
}
let mut stop_tokens = tok.tokenize_to_ids(interactive_stop.clone());
stop_tokens.remove(0);
pln!("---");
pln!(" dim: {}", params.dim);
pln!(" n_heads: {}", params.n_heads);
Expand All @@ -742,9 +785,15 @@ fn command_line_inference(
);
if start_interactive {
pln!(
" Interactive mode stop token sequence: {}",
interactive_stop.as_str()
" Interactive mode stop token sequences: {:?}",
interactive_stop
);
pln!("---");
pln!("System prompt:");
pln!(" {}", interactive_system_prompt);
pln!("---");
pln!("Interactive prompt prefix: {}", interactive_prompt_prefix);
pln!("Interactive prompt postfix: {}", interactive_prompt_postfix);
}
pln!("---");
pln!(
Expand Down Expand Up @@ -777,6 +826,7 @@ fn command_line_inference(
if newinput.ends_with('\n') {
let _ = newinput.pop();
}
newinput = interactive_prompt_prefix.clone() + &newinput;
newinput += &interactive_prompt_postfix;
user_token = tok.tokenize_to_ids(newinput.clone());

Expand Down Expand Up @@ -809,6 +859,7 @@ fn command_line_inference(
} else {
tok_print += tok_str.replace('▁', " ").as_str();
}
toks_str += tok_print.as_str();
if first && tok_idx < toks_id.len() - 2 {
// intentionally left empty, already print
} else {
Expand All @@ -825,15 +876,12 @@ fn command_line_inference(
tok_print.truecolor(128 + redness / 2, 255 - redness / 2, 128)
);
};
if !first
&& tok_id == stop_tokens.last().unwrap()
&& tok_idx + prev_pos > stop_tokens.len()
&& toks_id
[prev_pos + 1 + tok_idx - (stop_tokens.len() - 1)..prev_pos + 1 + tok_idx + 1]
== stop_tokens
{
if start_interactive {
interactive = true;
for stop_str in interactive_stop.iter() {
if !first && toks_str.ends_with(stop_str.as_str()) {
if start_interactive {
interactive = true;
}
break;
}
}
}
Expand Down

0 comments on commit 7aa1051

Please sign in to comment.