diff --git a/vllm/assets/base.py b/vllm/assets/base.py index 18ca2fe638cb..f97e8c218f65 100644 --- a/vllm/assets/base.py +++ b/vllm/assets/base.py @@ -1,11 +1,39 @@ +from functools import lru_cache from pathlib import Path +from typing import Optional import vllm.envs as envs +from vllm.connections import global_http_connection +from vllm.envs import VLLM_IMAGE_FETCH_TIMEOUT +vLLM_S3_BUCKET_URL = "https://vllm-public-assets.s3.us-west-2.amazonaws.com" -def get_cache_dir(): + +def get_cache_dir() -> Path: """Get the path to the cache for storing downloaded assets.""" path = Path(envs.VLLM_ASSETS_CACHE) path.mkdir(parents=True, exist_ok=True) return path + + +@lru_cache +def get_vllm_public_assets(filename: str, + s3_prefix: Optional[str] = None) -> Path: + """ + Download an asset file from ``s3://vllm-public-assets`` + and return the path to the downloaded file. + """ + asset_directory = get_cache_dir() / "vllm_public_assets" + asset_directory.mkdir(parents=True, exist_ok=True) + + asset_path = asset_directory / filename + if not asset_path.exists(): + if s3_prefix is not None: + filename = s3_prefix + "/" + filename + global_http_connection.download_file( + f"{vLLM_S3_BUCKET_URL}/{filename}", + asset_path, + timeout=VLLM_IMAGE_FETCH_TIMEOUT) + + return asset_path diff --git a/vllm/assets/image.py b/vllm/assets/image.py index b865b1b3a549..376150574e3b 100644 --- a/vllm/assets/image.py +++ b/vllm/assets/image.py @@ -1,33 +1,11 @@ from dataclasses import dataclass -from functools import lru_cache from typing import Literal from PIL import Image -from vllm.connections import global_http_connection -from vllm.envs import VLLM_IMAGE_FETCH_TIMEOUT +from vllm.assets.base import get_vllm_public_assets -from .base import get_cache_dir - - -@lru_cache -def get_air_example_data_2_asset(filename: str) -> Image.Image: - """ - Download and open an image from - ``s3://air-example-data-2/vllm_opensource_llava/``. - """ - image_directory = get_cache_dir() / "air-example-data-2" - image_directory.mkdir(parents=True, exist_ok=True) - - image_path = image_directory / filename - if not image_path.exists(): - base_url = "https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava" - - global_http_connection.download_file(f"{base_url}/{filename}", - image_path, - timeout=VLLM_IMAGE_FETCH_TIMEOUT) - - return Image.open(image_path) +VLM_IMAGES_DIR = "vision_model_images" @dataclass(frozen=True) @@ -36,4 +14,7 @@ class ImageAsset: @property def pil_image(self) -> Image.Image: - return get_air_example_data_2_asset(f"{self.name}.jpg") + + image_path = get_vllm_public_assets(filename=f"{self.name}.jpg", + s3_prefix=VLM_IMAGES_DIR) + return Image.open(image_path)