Skip to content

Commit

Permalink
Add simple HTTP API support.
Browse files Browse the repository at this point in the history
It took annoyingly a lot of effort just to make this simple server.

I tried rouille web framework first, but it didn't support getting
chunked output to the client line-by-line. (seems that if it exposed
more details about the underlying tiny-http package I could have hacked
it to work).

I went with Rocket because it had less async stuff and seemed decent.

I got weird issues where it seemed as if memory use kept increasing and
increasing. I may have got that fixed but I couldn't figure out what
made it use so much memory, even tools like valgrind and heaptrack told
me there isn't that much memory allocated but I can see RES increasing
in `htop`.

Switched to MiMalloc as it seems to slightly decrease memory use.

Added details about the inference server to README.md. And also added an
example Python script of it.

I want to use this feature to later investigate how much do
quantizations or f16/f32 affect output. Easier to do such things on
Python.
  • Loading branch information
Noeda committed Mar 21, 2023
1 parent 9c86c17 commit b9be485
Show file tree
Hide file tree
Showing 10 changed files with 1,335 additions and 54 deletions.
661 changes: 636 additions & 25 deletions Cargo.lock

Large diffs are not rendered by default.

5 changes: 5 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,14 @@ indicatif = "0.17"
colored = "2"
serde = { version = "1", features = ["derive"] }
serde_json = "1"
mimalloc = "0.1"
ocl = { version = "0.19", optional = true }
rocket = { version = "0.4", features = ["sse"], optional = true }
lazy_static = "1.4"

[features]
opencl = ["ocl"]
server = ["rocket"]

# We need protobuf compiler
[build-dependencies]
Expand All @@ -46,6 +50,7 @@ protobuf-parse = "3.2"
criterion = "0.4"

[profile.release]
panic = 'abort'
debug = true

[[bench]]
Expand Down
86 changes: 86 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,92 @@ rllama --tokenizer-model /path/to/tokenizer.model \

Use `rllama --help` to see all the options.

## Inference server

`rllama` can run in an inference server mode with a simple HTTP JSON API.

The command line flags for this are:

* `--inference-server` using this will turn on the inference server.
* `--inference-server-port` sets the port. Default port is 8080.
* `--inference-server-host` sets the host. The default host is 127.0.0.1.
* `--inference-server-max-concurrent-inferences` sets how many concurrent
requests are allowed to be actively doing inference at the same time. The
default is 5.
* `--inference-server-api-path` sets which path servers the API requests. The
default path is `/rllama/v1/inference`
* `--inference-server-prompt-cache-size` sets how many previous prompt
calculations should be cached. Default is 1000. This speeds up token
generation for prompts that were already requested before.

Prompts and flags related to token sampling are all ignored in inference server
mode. Instead, they are obtained from each HTTP JSON API request.

### Inference server API

There is an `examples/api_hello_world.py` for a minimal API use example.

```
POST /rllama/v1/inference
```

Expects a JSON body and `Accept: application/json` or `Accept: text/jsonl`.

The expected JSON is as follows:

```json
{
"temperature": <number, optional>
"top_k": <integer, optional, default 20>
"top_p": <number, optional, default: 1.0>
"repetition_penalty": <number, optional, default: 1.0>
"stop_at_end_token": <bool, optional, default: true>
"max_seq_len": <integer, optional, default: 1024. Clamped to
be at highest the same as --max-seq-len command line option.>
"max_new_tokens": <integer, optional, default: 1024>
"no_token_sampling": <bool, optional, default: false>
"prompt": <string, required>
}
```

The form of the response depends on if `no_token_sampling` is set to true or false. The
response is in JSONL, i.e. multiple JSON dictionaries, separated by newlines.

`no_token_sampling` can turn off `rllama`'s own token sampling. In this case,
the probabilities for every token are returned instead.

When no\_token\_sampling = false:

```json
{<token string>: {"p": <number>, "is_end_token": bool, might not be present}}
```

* `token` contains the new token to be appended to output. It does not
include string you fed to the system originally.
* `p` is the probability that this token was chosen. For example, if this
value is 0.1, it means that this particular token had 10% chance of being
selected with the current token sampling settings.
* `is_end_token` is `true` is the given token signifies end of output. This
field is not present otherwise.

When no\_token\_sampling = true:

```json
{<token string>: {"p": <number>, "is_end_token": bool, might not be present} \
,<token string>: {"p": <number>, "is_end_token": bool, might not be present} \
,...}
```

Tokens where `p = 0` will not be present in the JSON output.

If you want to implement your own token sampling, you may want to set
`max_new_tokens=1` and `stop_at_end_token=false` to suppress rllama's own
sampling behavior entirely.

`rllama` internally caches recently queried prompts and the intermediate
computations so that it's able to continue off quickly if you issue a query
that is either the same as a previous query or a continuation of one.

## How to turn on OpenCL

Use `opencl` Cargo feature.
Expand Down
25 changes: 25 additions & 0 deletions examples/api_hello_world.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#!/usr/bin/env python3

"""
This script uses the rllama API to generate tokens.
It does not print the tokens nicely.
"""

import requests

def main():
url = 'http://127.0.0.1:8080/rllama/v1/inference'
req = {
'prompt': 'Hello world!',
'max_seq_len': 1024,
'max_new_tokens': 200,
'no_token_sampling': False
}
res = requests.post(url, json=req, stream=True)
for line in res.iter_lines():
print(line.decode('utf-8'))


if __name__ == '__main__':
main()
5 changes: 5 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
#![feature(stdsimd)]
#![feature(decl_macro)]

pub mod embedding;
pub mod protomodels;
pub mod rllama_main;
pub mod semaphore;
pub mod simd_support;
pub mod tensor;
#[cfg(feature = "opencl")]
Expand All @@ -11,3 +13,6 @@ pub mod token_sampler;
pub mod tokenizer;
pub mod transformer;
pub mod unpickler;
#[cfg(feature = "server")]
#[macro_use]
extern crate rocket;
5 changes: 5 additions & 0 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ compile_error!("This library assumes availability of AVX and must be compiled wi
#[cfg(not(target_feature = "avx"))]
compile_error!("This library assumes availability of AVX and must be compiled with -C target-feature=+sse2,+avx,+fma,+avx2");

use mimalloc::MiMalloc;

#[global_allocator]
static GLOBAL: MiMalloc = MiMalloc;

pub fn main() -> Result<(), Box<dyn std::error::Error>> {
rllama::rllama_main::main()
}
Loading

0 comments on commit b9be485

Please sign in to comment.