Skip to content

Commit

Permalink
more info for preshard (mlc-ai#2027)
Browse files Browse the repository at this point in the history
* When the pre-sharded version of a certain model is not available, the program will default back to the normal workflow without issuing any alert. Now, when someone attempts to convert to a pre-sharded model but cannot, the program will throw a warning message to inform users that it will revert to the standard model conversion process.

* format fix.

* black reformatted, i did not see any diff.

* black reformatted..
  • Loading branch information
na20215 authored Mar 25, 2024
1 parent ab9fa81 commit f04cd3e
Showing 1 changed file with 12 additions and 5 deletions.
17 changes: 12 additions & 5 deletions python/mlc_llm/support/preshard.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Functions for pre-sharding weights"""
import logging
from typing import Any, Dict, List

from tvm import IRModule
Expand All @@ -8,6 +9,8 @@
from tvm.runtime import Device
from tvm.target import Target

logger = logging.getLogger("preshard")


def _sharded_param_name(param_name, worker_id):
return f"{param_name}_shard-{worker_id}"
Expand Down Expand Up @@ -93,10 +96,7 @@ def _compile_shard_funcs(mod: IRModule, device: Device):


def apply_preshard(
quantize_map: Any,
named_params: Dict[str, nn.Parameter],
tensor_parallel_shards: int,
args: Any,
quantize_map: Any, named_params: Dict[str, nn.Parameter], tensor_parallel_shards: int, args: Any
):
"""Update quantize_map and named_params, create shard functions based on shard strategies."""
model_config = args.model.config.from_file(args.config)
Expand All @@ -107,17 +107,24 @@ def apply_preshard(
bb = relax.BlockBuilder()
param_to_shard_func = {}
shard_func_names = set()
has_shard_strategy = False
for name, param in model.state_dict().items():
shard_strategy = param.attrs.get("shard_strategy", None)
if shard_strategy is not None:
has_shard_strategy = True
_update_quantize_map(quantize_map, named_params, name, tensor_parallel_shards)

# create shard functions
param_to_shard_func[name] = shard_strategy.name
if shard_strategy.name not in shard_func_names:
_create_shard_func(bb, param, tensor_parallel_shards)
shard_func_names.add(shard_strategy.name)

if not has_shard_strategy:
logger.warning(
"No parameters with 'shard_strategy' found."
"At least one parameter must have a 'shard_strategy' for presharding. "
"The model will continue to convert weights in a non-presharded manner."
)
mod = bb.finalize()
vm = _compile_shard_funcs(mod, args.device)

Expand Down

0 comments on commit f04cd3e

Please sign in to comment.