Skip to content

Commit

Permalink
[CI/Build] Minor refactoring for vLLM assets (vllm-project#7407)
Browse files Browse the repository at this point in the history
  • Loading branch information
ywang96 committed Aug 12, 2024
1 parent f020a62 commit 86ab567
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 26 deletions.
30 changes: 29 additions & 1 deletion vllm/assets/base.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,39 @@
from functools import lru_cache
from pathlib import Path
from typing import Optional

import vllm.envs as envs
from vllm.connections import global_http_connection
from vllm.envs import VLLM_IMAGE_FETCH_TIMEOUT

vLLM_S3_BUCKET_URL = "https://vllm-public-assets.s3.us-west-2.amazonaws.com"

def get_cache_dir():

def get_cache_dir() -> Path:
"""Get the path to the cache for storing downloaded assets."""
path = Path(envs.VLLM_ASSETS_CACHE)
path.mkdir(parents=True, exist_ok=True)

return path


@lru_cache
def get_vllm_public_assets(filename: str,
s3_prefix: Optional[str] = None) -> Path:
"""
Download an asset file from ``s3://vllm-public-assets``
and return the path to the downloaded file.
"""
asset_directory = get_cache_dir() / "vllm_public_assets"
asset_directory.mkdir(parents=True, exist_ok=True)

asset_path = asset_directory / filename
if not asset_path.exists():
if s3_prefix is not None:
filename = s3_prefix + "/" + filename
global_http_connection.download_file(
f"{vLLM_S3_BUCKET_URL}/{filename}",
asset_path,
timeout=VLLM_IMAGE_FETCH_TIMEOUT)

return asset_path
31 changes: 6 additions & 25 deletions vllm/assets/image.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,11 @@
from dataclasses import dataclass
from functools import lru_cache
from typing import Literal

from PIL import Image

from vllm.connections import global_http_connection
from vllm.envs import VLLM_IMAGE_FETCH_TIMEOUT
from vllm.assets.base import get_vllm_public_assets

from .base import get_cache_dir


@lru_cache
def get_air_example_data_2_asset(filename: str) -> Image.Image:
"""
Download and open an image from
``s3://air-example-data-2/vllm_opensource_llava/``.
"""
image_directory = get_cache_dir() / "air-example-data-2"
image_directory.mkdir(parents=True, exist_ok=True)

image_path = image_directory / filename
if not image_path.exists():
base_url = "https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava"

global_http_connection.download_file(f"{base_url}/{filename}",
image_path,
timeout=VLLM_IMAGE_FETCH_TIMEOUT)

return Image.open(image_path)
VLM_IMAGES_DIR = "vision_model_images"


@dataclass(frozen=True)
Expand All @@ -36,4 +14,7 @@ class ImageAsset:

@property
def pil_image(self) -> Image.Image:
return get_air_example_data_2_asset(f"{self.name}.jpg")

image_path = get_vllm_public_assets(filename=f"{self.name}.jpg",
s3_prefix=VLM_IMAGES_DIR)
return Image.open(image_path)

0 comments on commit 86ab567

Please sign in to comment.