Skip to content

Commit

Permalink
Added docker image for standalone worker (LAION-AI#2300)
Browse files Browse the repository at this point in the history
  • Loading branch information
yk committed Apr 9, 2023
1 parent c316fec commit 5a17d93
Show file tree
Hide file tree
Showing 7 changed files with 146 additions and 30 deletions.
8 changes: 8 additions & 0 deletions .github/workflows/release.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,14 @@ jobs:
context: .
dockerfile: docker/inference/Dockerfile.worker-hf
build-args: ""
build-inference-worker-standalone:
uses: ./.github/workflows/docker-build.yaml
needs: pre-commit
with:
image-name: oasst-inference-worker-standalone
context: .
dockerfile: docker/inference/Dockerfile.worker-standalone
build-args: ""
deploy-to-node:
needs: [build-backend, build-web, build-bot, build-inference-server]
uses: ./.github/workflows/deploy-to-node.yaml
Expand Down
72 changes: 72 additions & 0 deletions docker/inference/Dockerfile.worker-standalone
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# syntax=docker/dockerfile:1

ARG MODULE="inference"
ARG SERVICE="worker"

ARG APP_USER="${MODULE}-${SERVICE}"
ARG APP_RELATIVE_PATH="${MODULE}/${SERVICE}"


FROM python:3.10-slim as build
ARG APP_RELATIVE_PATH

WORKDIR /build

RUN apt-get update && apt-get install -y --no-install-recommends \
git \
&& rm -rf /var/lib/apt/lists/*

COPY ./${APP_RELATIVE_PATH}/requirements.txt requirements.txt

RUN --mount=type=cache,target=/var/cache/pip \
pip install \
--cache-dir=/var/cache/pip \
--target=lib \
-r requirements.txt



FROM python:3.10-slim as base-env
ARG APP_USER
ARG APP_RELATIVE_PATH
ARG MODULE
ARG SERVICE

ENV APP_BASE="/opt/${MODULE}"
ENV APP_ROOT="${APP_BASE}/${SERVICE}"
ENV APP_LIBS="/var/opt/${APP_RELATIVE_PATH}/lib"
ENV SHARED_LIBS_BASE="${APP_BASE}/lib"

ENV PATH="${PATH}:${APP_LIBS}/bin"
ENV PYTHONPATH="${PYTHONPATH}:${APP_LIBS}"


RUN adduser \
--disabled-password \
"${APP_USER}"

USER ${APP_USER}

WORKDIR ${APP_ROOT}


COPY --chown="${APP_USER}:${APP_USER}" --from=build /build/lib ${APP_LIBS}
COPY --chown="${APP_USER}:${APP_USER}" ./${APP_RELATIVE_PATH}/*.py .


CMD python3 __main__.py --backend-url "${BACKEND_URL}" --inference-server-url "${INFERENCE_SERVER_URL}"

FROM base-env as prod
ARG APP_USER


COPY --chown="${APP_USER}:${APP_USER}" ./oasst-shared /tmp/lib/oasst-shared
RUN --mount=type=cache,target=/var/cache/pip,from=base-env \
pip install \
--cache-dir=/var/cache/pip \
--target="${APP_LIBS}" \
/tmp/lib/oasst-shared

COPY --chown="${APP_USER}:${APP_USER}" ./inference/worker/worker_standalone_main.sh /entrypoint.sh

CMD ["/entrypoint.sh"]
11 changes: 9 additions & 2 deletions inference/worker/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def main():
logger.info(f"Inference protocol version: {inference.INFERENCE_PROTOCOL_VERSION}")

model_config = model_configs.MODEL_CONFIGS.get(settings.model_config_name)
logger.warning(f"Model config: {model_config}")
if model_config is None:
logger.error(f"Unknown model config name: {settings.model_config_name}")
sys.exit(2)
Expand All @@ -33,12 +34,18 @@ def main():
tokenizer = None
else:
tokenizer: transformers.PreTrainedTokenizer = transformers.AutoTokenizer.from_pretrained(model_config.model_id)
logger.warning(f"Tokenizer {tokenizer.name_or_path} vocab size: {tokenizer.vocab_size}")
logger.warning(f"Tokenizer {tokenizer.name_or_path} vocab size: {tokenizer.vocab_size}")

inference_http = utils.HttpClient(
base_url=settings.inference_server_url,
basic_auth_username=settings.basic_auth_username,
basic_auth_password=settings.basic_auth_password,
)

while True:
try:
if not model_config.is_lorem:
utils.wait_for_inference_server(settings.inference_server_url)
utils.wait_for_inference_server(inference_http)

if settings.perform_oom_test:
work.perform_oom_test(tokenizer)
Expand Down
3 changes: 3 additions & 0 deletions inference/worker/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,8 @@ class Settings(pydantic.BaseSettings):
# for hf basic server
quantize: bool = False

basic_auth_username: str | None = None
basic_auth_password: str | None = None


settings = Settings()
25 changes: 22 additions & 3 deletions inference/worker/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import interface
import lorem
import pydantic
import requests
import websocket
from loguru import logger
Expand Down Expand Up @@ -57,12 +58,11 @@ def finish(self, reason: Literal["length", "eos_token", "stop_sequence"]) -> Ite
yield from self.tokens


def wait_for_inference_server(inference_server_url: str, timeout: int = 600):
health_url = f"{inference_server_url}/health"
def wait_for_inference_server(http: "HttpClient", timeout: int = 600):
time_limit = time.time() + timeout
while True:
try:
response = requests.get(health_url)
response = http.get("/health")
response.raise_for_status()
except (requests.HTTPError, requests.ConnectionError):
if time.time() > time_limit:
Expand Down Expand Up @@ -118,3 +118,22 @@ def send_response(
msg = repsonse.json()
with ws_lock:
ws.send(msg)


class HttpClient(pydantic.BaseModel):
base_url: str
basic_auth_username: str | None = None
basic_auth_password: str | None = None

@property
def auth(self):
if self.basic_auth_username and self.basic_auth_password:
return (self.basic_auth_username, self.basic_auth_password)
else:
return None

def get(self, path: str, **kwargs):
return requests.get(self.base_url + path, auth=self.auth, **kwargs)

def post(self, path: str, **kwargs):
return requests.post(self.base_url + path, auth=self.auth, **kwargs)
36 changes: 11 additions & 25 deletions inference/worker/work.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,17 +92,14 @@ def handle_work_request(

stream_response = None
token_buffer = utils.TokenBuffer(stop_sequences=parameters.stop)
stream_request = interface.GenerateStreamRequest(
inputs=prompt,
parameters=parameters,
)
if model_config.is_lorem:
stream_events = utils.lorem_events(parameters.seed)
# elif model_config.is_llama:
# prompt = truncate_prompt(tokenizer, worker_config, parameters, prompt)
# stream_events = get_hf_stream_events(stream_request)
else:
prompt = truncate_prompt(tokenizer, worker_config, parameters, prompt)
stream_request = interface.GenerateStreamRequest(
inputs=prompt,
parameters=parameters,
)
stream_events = get_inference_server_stream_events(stream_request)

generated_ids = []
Expand Down Expand Up @@ -162,25 +159,14 @@ def handle_work_request(
logger.debug("Work complete. Waiting for more work...")


def get_hf_stream_events(request: interface.GenerateStreamRequest):
response = requests.post(
f"{settings.inference_server_url}/generate",
json=request.dict(),
)
try:
response.raise_for_status()
except requests.HTTPError:
logger.exception("Failed to get response from inference server")
logger.error(f"Response: {response.text}")
raise
data = response.json()
output = data["text"]
yield from utils.text_to_events(output, pause=settings.hf_pause)


def get_inference_server_stream_events(request: interface.GenerateStreamRequest):
response = requests.post(
f"{settings.inference_server_url}/generate_stream",
http = utils.HttpClient(
base_url=settings.inference_server_url,
basic_auth_username=settings.basic_auth_username,
basic_auth_password=settings.basic_auth_password,
)
response = http.post(
"/generate_stream",
json=request.dict(),
stream=True,
headers={"Accept": "text/event-stream"},
Expand Down
21 changes: 21 additions & 0 deletions inference/worker/worker_standalone_main.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#!/bin/bash

export HF_HOME=${HF_HOME:-"$HOME/.cache/huggingface"}
load_sleep=${LOAD_SLEEP:-0}

mkdir -p $HF_HOME
echo -n "$HF_TOKEN" > $HF_HOME/token

export HUGGING_FACE_HUB_TOKEN=$HF_TOKEN

export MODEL_CONFIG_NAME=${MODEL_CONFIG_NAME:-"OA_SFT_Pythia_12B"}
export MODEL_ID=$(python get_model_config_prop.py model_id)
export QUANTIZE=$(python get_model_config_prop.py quantized)

echo "Downloading model $MODEL_ID"
CUDA_VISIBLE_DEVICES="" python download_model_hf.py

export MAX_PARALLEL_REQUESTS=${MAX_PARALLEL_REQUESTS:-1}

echo "starting worker"
python3 __main__.py

0 comments on commit 5a17d93

Please sign in to comment.