Skip to content

Commit

Permalink
Add support for extra_model_paths
Browse files Browse the repository at this point in the history
  • Loading branch information
Anson2048 committed Dec 9, 2023
1 parent ccd6773 commit 8782280
Showing 1 changed file with 17 additions and 10 deletions.
27 changes: 17 additions & 10 deletions node.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

logger = logging.getLogger('comfyui_segment_anything')

sam_model_dir = os.path.join(folder_paths.models_dir, "sams")
sam_model_dir_name = "sams"
sam_model_list = {
"sam_vit_h (2.56GB)": {
"model_url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
Expand All @@ -47,8 +47,7 @@
}
}

groundingdino_model_dir = os.path.join(
folder_paths.models_dir, "grounding-dino")
groundingdino_model_dir_name = "grounding-dino"
groundingdino_model_list = {
"GroundingDINO_SwinT_OGC (694MB)": {
"config_url": "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/GroundingDINO_SwinT_OGC.cfg.py",
Expand All @@ -71,7 +70,7 @@ def list_sam_model():

def load_sam_model(model_name):
sam_checkpoint_path = get_local_filepath(
sam_model_list[model_name]["model_url"], sam_model_dir)
sam_model_list[model_name]["model_url"], sam_model_dir_name)
model_file_name = os.path.basename(sam_checkpoint_path)
model_type = model_file_name.split('.')[0]
if 'hq' not in model_type and 'mobile' not in model_type:
Expand All @@ -85,14 +84,22 @@ def load_sam_model(model_name):


def get_local_filepath(url, dirname, local_file_name=None):
if not os.path.exists(dirname):
os.makedirs(dirname)
if not local_file_name:
parsed_url = urlparse(url)
local_file_name = os.path.basename(parsed_url.path)
destination = os.path.join(dirname, local_file_name)

destination = folder_paths.get_full_path(dirname, local_file_name)
if destination:
logger.warn(f'using extra model: {destination}')
return destination

folder = os.path.join(folder_paths.models_dir, dirname)
if not os.path.exists(folder):
os.makedirs(folder)

destination = os.path.join(folder, local_file_name)
if not os.path.exists(destination):
logging.warn(f'downloading {url} to {destination}')
logger.warn(f'downloading {url} to {destination}')
download_url_to_file(url, destination)
return destination

Expand All @@ -101,15 +108,15 @@ def load_groundingdino_model(model_name):
dino_model_args = local_groundingdino_SLConfig.fromfile(
get_local_filepath(
groundingdino_model_list[model_name]["config_url"],
groundingdino_model_dir
groundingdino_model_dir_name
),

)
dino = local_groundingdino_build_model(dino_model_args)
checkpoint = torch.load(
get_local_filepath(
groundingdino_model_list[model_name]["model_url"],
groundingdino_model_dir,
groundingdino_model_dir_name,
),
)
dino.load_state_dict(local_groundingdino_clean_state_dict(
Expand Down

0 comments on commit 8782280

Please sign in to comment.