diff --git a/src/client/remote_model.py b/src/client/remote_model.py index 749e96e78..dc4e7a64d 100644 --- a/src/client/remote_model.py +++ b/src/client/remote_model.py @@ -1,5 +1,5 @@ # this code is in active development, interfaces may change -from typing import List, Optional, Tuple +from typing import Optional, Tuple import hivemind import torch @@ -17,6 +17,7 @@ ) from src.client.remote_generation import RemoteGenerationMixin from src.client.remote_sequential import RemoteSequential +from src.utils.misc import DUMMY use_hivemind_log_handler("in_root_logger") logger = get_logger(__file__) @@ -33,6 +34,7 @@ class DistributedBloomConfig(BloomConfig): dht: Optional[hivemind.DHT] = None # a running DHT instance, e.g. when using the same DHT for multiple models chunk_size_for_efficient_fp16_on_cpu: int = 10000 # a chunk size for a LM head for efficient half-precision on CPU pre_seq_len: int = 0 # a number of tokens for prompt tuning. + tuning_mode: Optional[str] = None # One of the finetune options: [None, 'shallow_ptune', 'deep_ptune', 'adapters'] class DistributedBloomModel(BloomModel): @@ -60,10 +62,41 @@ def __init__(self, config: DistributedBloomConfig): # Forbid accumulate grads for embeddings and layernorm self.set_requires_grad(False) + if config.tuning_mode and "ptune" in config.tuning_mode: + assert config.pre_seq_len > 0, "The number of prefix tokens must be > 0" + self.pre_seq_len = config.pre_seq_len + self.prompt_embeddings = nn.Embedding(self.pre_seq_len, config.hidden_size) + self.prefix_tokens = torch.arange(self.pre_seq_len).long() + + if config.tuning_mode == "deep_ptune": + self.intermediate_prompt_embeddings = nn.Embedding( + self.pre_seq_len, + config.num_hidden_layers * config.hidden_size + # ^-- TODO: should be num_hidden_layers - 1 + ) + self.intermediate_prompt_embeddings.weight.data.zero_() + elif config.tuning_mode: + raise NotImplementedError(f"{self.tuning_mode} mode is not supported for now") + def set_requires_grad(self, value): for p in self.parameters(): p.requires_grad = value + def get_prompt(self, batch_size): + prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1) + prefix_tokens = prefix_tokens.to(self.word_embeddings.weight.device) + prompts = self.prompt_embeddings(prefix_tokens) + + if self.config.tuning_mode == "deep_ptune": + intermediate_prompts = self.intermediate_prompt_embeddings(prefix_tokens) + intermediate_prompts = intermediate_prompts.view( + batch_size, self.pre_seq_len, len(self.h), self.config.hidden_size # TODO: should be len(self.h) - 1 + ) + intermediate_prompts = intermediate_prompts.permute([2, 0, 1, 3]) + else: + intermediate_prompts = DUMMY + return prompts, intermediate_prompts + def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -90,10 +123,22 @@ def forward( if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) - # Note: it supports only float32 or bfloat16 inputs - hidden_states = self.word_embeddings_layernorm(inputs_embeds) + if self.config.tuning_mode and "ptune" in self.config.tuning_mode: + batch_size = inputs_embeds.shape[0] + prompts, intermediate_prompts = self.get_prompt(batch_size) + inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1) + + hidden_states = self.word_embeddings_layernorm(inputs_embeds.float()) output_shape = input_shape + (hidden_states.size(-1),) - hidden_states = self.h(hidden_states) + + if self.config.tuning_mode and "ptune" in self.config.tuning_mode: + hidden_states = self.h(hidden_states, prompts=intermediate_prompts) + else: + hidden_states = self.h(hidden_states) + + # Remove prefix + if self.config.tuning_mode and "ptune" in self.config.tuning_mode: + hidden_states = hidden_states[:, self.pre_seq_len :] # Add last hidden state hidden_states = self.ln_f(hidden_states) @@ -106,55 +151,6 @@ def forward( ) -class DistributedBloomPrefix(DistributedBloomModel): - """DistributedBloomModel with prefix tokens for prompt tuning""" - - def __init__(self, config): - super().__init__(config) - assert config.pre_seq_len > 0, "The number of prefix tokens must be > 0" - self.pre_seq_len = config.pre_seq_len - - self.prompt_embeddings = nn.Embedding(self.pre_seq_len, config.hidden_size) - self.prefix_tokens = torch.arange(self.pre_seq_len).long() - - def get_prompt(self, batch_size): - prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1) - prefix_tokens = prefix_tokens.to(self.word_embeddings.weight.device) - prompts = self.prompt_embeddings(prefix_tokens) - return prompts - - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - **kwargs, - ): - assert ( - input_ids is None or inputs_embeds is None - ), "You cannot specify both input_ids and inputs_embeds at the same time" - assert input_ids is not None or inputs_embeds is not None, "You must specify either input_ids or inputs_embeds" - - if inputs_embeds is None: - inputs_embeds = self.word_embeddings(input_ids) - - batch_size = inputs_embeds.shape[0] - - if attention_mask is not None: - prefix_attention_mask = torch.ones(batch_size, self.prefix_length, device=attention_mask.device) - attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) - - prompts = self.get_prompt(batch_size) - inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1) - - transformer_outputs = super().forward(inputs_embeds=inputs_embeds, attention_mask=attention_mask, **kwargs) - - # Remove prefix - last_hidden_state = transformer_outputs[0][:, self.prefix_length :] - transformer_outputs["last_hidden_state"] = last_hidden_state - return transformer_outputs - - class DistributedBloomForCausalLM(RemoteGenerationMixin, BloomForCausalLM): """DistributedBloomForCausalLM, but all transformer layers are hosted by the swarm""" @@ -162,10 +158,7 @@ class DistributedBloomForCausalLM(RemoteGenerationMixin, BloomForCausalLM): def __init__(self, config: DistributedBloomConfig): BloomPreTrainedModel.__init__(self, config) - if config.pre_seq_len > 0: - self.transformer = DistributedBloomPrefix(config) - else: - self.transformer = DistributedBloomModel(config) + self.transformer = DistributedBloomModel(config) self.lm_head = LMHead(config, self.transformer.word_embeddings) # Initialize weights and apply final processing @@ -195,10 +188,7 @@ class DistributedBloomForSequenceClassification(BloomForSequenceClassification): def __init__(self, config: DistributedBloomConfig): super().__init__(config) - if config.pre_seq_len > 0: - self.transformer = DistributedBloomPrefix(config) - else: - self.transformer = DistributedBloomModel(config) + self.transformer = DistributedBloomModel(config) # Initialize weights and apply final processing self.post_init() diff --git a/src/client/remote_sequential.py b/src/client/remote_sequential.py index 5834e4a42..86acfe1d5 100644 --- a/src/client/remote_sequential.py +++ b/src/client/remote_sequential.py @@ -15,6 +15,7 @@ from src.client.sequential_autograd import _RemoteSequentialAutogradFunction from src.data_structures import UID_DELIMITER from src.dht_utils import _create_remote_modules_from_infos +from src.utils.misc import DUMMY use_hivemind_log_handler("in_root_logger") logger = get_logger(__file__) @@ -52,8 +53,8 @@ def __init__( assert isinstance(sequence_manager.block_uids, list) self.is_subsequence = self.sequence_manager.block_uids != block_uids - def forward(self, inputs: torch.Tensor): - outputs = _RemoteSequentialAutogradFunction.apply(inputs, self.sequence_manager) + def forward(self, inputs: torch.Tensor, prompts: torch.Tensor = DUMMY): + outputs = _RemoteSequentialAutogradFunction.apply(inputs, prompts, self.sequence_manager) return outputs def __getitem__(self, ix: Union[int, slice]) -> Union[RemoteTransformerBlock, RemoteSequential]: diff --git a/src/client/sequential_autograd.py b/src/client/sequential_autograd.py index 22504ee50..98d9f7539 100644 --- a/src/client/sequential_autograd.py +++ b/src/client/sequential_autograd.py @@ -12,6 +12,7 @@ from src.client.sequence_manager import RemoteSequenceManager from src.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo from src.server.handler import TransformerConnectionHandler +from src.utils.misc import DUMMY, is_dummy MAX_TOKENS_IN_BATCH = 1024 @@ -33,7 +34,13 @@ async def run_expert_forward( # Note: we put keyword arguments in the same order as on a server to prevent f(a=1, b=2) != f(b=2, a=1) errors forward_inputs = (inputs, kwargs) - if not nested_compare(forward_inputs, rpc_info["forward_schema"]): + # Modify forward_schema to support prompts + args_schema, kwargs_schema = rpc_info["forward_schema"] + # TODO: rm this assert when support arbitrary number of input tensors + assert len(args_schema) == 1 and len(inputs) == 2 + forward_schema_with_prompts = (tuple(args_schema * len(inputs)), kwargs_schema) + + if not nested_compare(forward_inputs, forward_schema_with_prompts): raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?") forward_inputs = nested_flatten(forward_inputs) @@ -44,7 +51,7 @@ async def run_expert_forward( serialized_tensors = await asyncio.gather( *( loop.run_in_executor(None, serialize_torch_tensor, tensor.to(proto.dtype), proto.compression) - for tensor, proto in zip(inputs, nested_flatten(rpc_info["forward_schema"])) + for tensor, proto in zip(inputs, nested_flatten(forward_schema_with_prompts)) ) ) @@ -57,8 +64,9 @@ async def run_expert_backward( uid: ModuleUID, stub: StubBase, rpc_info: RPCInfo, - intemediate_inputs: List[torch.Tensor], + inputs: torch.Tensor, grad_outputs: List[torch.Tensor], + *extra_tensors: torch.Tensor, ) -> Sequence[torch.Tensor]: """ Serializes grad outputs and calls "expert_backward". @@ -67,8 +75,14 @@ async def run_expert_backward( """ grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs) - inputs_and_grad_outputs = tuple(nested_flatten((intemediate_inputs, grad_outputs_cpu))) - backward_schema = tuple(nested_flatten((rpc_info["forward_schema"], rpc_info["outputs_schema"]))) + inputs_and_grad_outputs = tuple(nested_flatten((inputs, grad_outputs_cpu, *extra_tensors))) + + # Modify forward_schema to support prompts + args_schema, kwargs_schema = rpc_info["forward_schema"] + assert len(args_schema) == 1 and isinstance(inputs, torch.Tensor) + # TODO generalize this + prompts_schema = next(iter(args_schema)) + backward_schema = tuple(nested_flatten((rpc_info["forward_schema"], rpc_info["outputs_schema"], prompts_schema))) # Asynchronous serialization loop = asyncio.get_running_loop() @@ -84,7 +98,11 @@ async def run_expert_backward( async def sequential_forward( - inputs: torch.Tensor, sequence_manager: RemoteSequenceManager, start_index: int = 0, end_index: Optional[int] = None + inputs: torch.Tensor, + prompts: torch.Tensor, + sequence_manager: RemoteSequenceManager, + start_index: int = 0, + end_index: Optional[int] = None, ) -> Tuple[torch.Tensor, Sequence[torch.Tensor], Sequence[RemoteSpanInfo]]: """ Constructs a routing path from to . @@ -96,6 +114,9 @@ async def sequential_forward( end_index = end_index if end_index is not None else len(sequence_manager.block_uids) assert start_index >= 0 and end_index <= len(sequence_manager.block_uids) + assert is_dummy(prompts) or len(prompts) == len( + sequence_manager.block_uids + ) # should be n_layers - 1 but add extra prompts for convenience sequences = sequence_manager.make_sequence(start_index, end_index) intermediate_inputs = [] @@ -107,7 +128,9 @@ async def sequential_forward( span = sequences.pop(0) span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end]) stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id) - (outputs,) = await run_expert_forward(span_uids, stub, sequence_manager.rpc_info, inputs) + inputs_and_prompts = [inputs, prompts[span.start : span.end]] + + (outputs,) = await run_expert_forward(span_uids, stub, sequence_manager.rpc_info, *inputs_and_prompts) assert isinstance(outputs, torch.Tensor) assert outputs.shape == inputs.shape, f"Expected output {inputs.shape}, got {outputs.shape}" @@ -119,7 +142,7 @@ async def sequential_forward( inputs = outputs break except Exception as e: - logging.debug(f"Caught {e} when running forward for chain {span.start}-{span.end}", exc_info=True) + logging.warning(f"Caught {e} when running forward for chain {span.start}-{span.end}", exc_info=True) backup_sequences = sequence_manager.make_sequence(span.start) assert backup_sequences[0].start == span.start sequences = backup_sequences @@ -129,58 +152,68 @@ async def sequential_forward( async def sequential_backward( grad_outputs: Sequence[torch.Tensor], - intermediate_inputs: Sequence[torch.Tensor], - forward_sequences: Sequence[RemoteSpanInfo], + intermediate_inputs: List[torch.Tensor], + prompts: torch.Tensor, + forward_sequences: List[RemoteSpanInfo], sequence_manager: RemoteSequenceManager, ) -> Sequence[torch.Tensor]: """ Performs chained backward for each forward subsequence. If some subsequence fails, reconstructs the particular sub-path and recovers the backward. """ - assert len(intermediate_inputs) == len(forward_sequences) - # TODO think about grads w.r.t. deep prompts + grad_prompts_reversed = [] while len(forward_sequences) > 0 and len(intermediate_inputs) > 0: while True: + inputs = intermediate_inputs.pop(-1) + span = forward_sequences.pop(-1) + span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end]) try: - inputs = intermediate_inputs.pop(-1) - span = forward_sequences.pop(-1) - - span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end]) stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id) - - grad_outputs = await run_expert_backward( - span_uids, stub, sequence_manager.rpc_info, inputs, grad_outputs + grad_outputs, *span_grad_prompts = await run_expert_backward( + span_uids, stub, sequence_manager.rpc_info, inputs, grad_outputs, prompts[span.start : span.end] ) + grad_outputs = [grad_outputs] + grad_prompts_reversed.extend(span_grad_prompts) break except Exception as e: logging.warning(f"Caught {e} when running backward for chain {span.start}-{span.end}", exc_info=True) _, backup_intermediate_inputs, backup_forward_sequences = await sequential_forward( - inputs, sequence_manager, start_index=span.start, end_index=span.end + inputs, prompts[span.start : span.end], sequence_manager, start_index=span.start, end_index=span.end ) - assert len(intermediate_inputs) == len(forward_sequences) assert backup_forward_sequences[0].start == span.start assert backup_forward_sequences[-1].end == span.end forward_sequences.extend(backup_forward_sequences) intermediate_inputs.extend(backup_intermediate_inputs) - return grad_outputs + + # For now, we do not support mixed dummy and grad prompts + # Concat in num_layer dimension + grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else None + return grad_outputs, grad_prompts -async def _gather_forward(input_batches, sequence_manager): +async def _gather_forward(input_batches, prompt_batches, sequence_manager): """Wrapper for asyncio.gather to perform parallel sequential forwards""" - return await asyncio.gather(*[sequential_forward(input_batch, sequence_manager) for input_batch in input_batches]) + return await asyncio.gather( + *[ + sequential_forward(input_batch, prompt_batch, sequence_manager) + for input_batch, prompt_batch in zip(input_batches, prompt_batches) + ] + ) -async def _gather_backward(grad_output_batches, intermediate_input_batches, forward_sequences, sequence_manager): +async def _gather_backward( + grad_output_batches, intermediate_input_batches, prompt_batches, forward_sequences, sequence_manager +): """Wrapper for asyncio.gather to perform parallel sequential backwards""" return await asyncio.gather( *[ - sequential_backward((grad_output,), input_batch, spans, sequence_manager) - for grad_output, input_batch, spans in zip( - grad_output_batches, intermediate_input_batches, forward_sequences + sequential_backward((grad_output,), input_batch, prompt_batch, spans, sequence_manager) + for grad_output, input_batch, prompt_batch, spans in zip( + grad_output_batches, intermediate_input_batches, prompt_batches, forward_sequences ) ] ) @@ -193,18 +226,23 @@ class _RemoteSequentialAutogradFunction(torch.autograd.Function): """ @staticmethod - def forward(ctx, inputs: torch.Tensor, sequence_manager: RemoteSequenceManager): + def forward(ctx, inputs: torch.Tensor, prompts: torch.Tensor, sequence_manager: RemoteSequenceManager): batch_size = max(MAX_TOKENS_IN_BATCH // inputs.shape[1], 1) input_batches: Sequence[torch.Tensor] = inputs.detach().split(batch_size) + if is_dummy(prompts): + prompt_batches = [DUMMY] * len(input_batches) + else: + prompt_batches: Sequence[torch.Tensor] = prompts.detach().split(batch_size, dim=1) sequence_manager.rpc_info # lazy init - outputs = RemoteExpertWorker.run_coroutine(_gather_forward(input_batches, sequence_manager)) + outputs = RemoteExpertWorker.run_coroutine(_gather_forward(input_batches, prompt_batches, sequence_manager)) assert len(outputs) == len(input_batches) output_batches = [output[0] for output in outputs] intemediate_input_batches = [output[1] for output in outputs] sequences_for_batches = [output[2] for output in outputs] + ctx.prompt_batches = prompt_batches ctx.sequence_manager = sequence_manager ctx.intemediate_input_batches = intemediate_input_batches ctx.sequences_for_batches = sequences_for_batches @@ -220,9 +258,19 @@ def backward(ctx, grad_outputs: torch.Tensor): grad_output_batches: Sequence[torch.Tensor] = grad_outputs.split(batch_size) assert len(intermediate_input_batches) == len(grad_output_batches) == len(forward_sequences) - grad_input_batches = RemoteExpertWorker.run_coroutine( - _gather_backward(grad_output_batches, intermediate_input_batches, forward_sequences, ctx.sequence_manager) + outputs = RemoteExpertWorker.run_coroutine( + _gather_backward( + grad_output_batches, + intermediate_input_batches, + ctx.prompt_batches, + forward_sequences, + ctx.sequence_manager, + ) ) - grad_inputs = [grad_input_batch[0] for grad_input_batch in grad_input_batches] - grad_inputs = torch.cat(grad_inputs, dim=0) - return (grad_inputs, None) + grad_input_batches = [output[0][0] for output in outputs] + grad_prompt_batches = [output[1] for output in outputs] + + grad_inputs = torch.cat(grad_input_batches, dim=0) + dummy_grad_prompts = [grad_prompt is None for grad_prompt in grad_prompt_batches] + grad_prompts = torch.cat(grad_prompt_batches, dim=1) if not any(dummy_grad_prompts) else None + return (grad_inputs, grad_prompts, None) diff --git a/src/server/handler.py b/src/server/handler.py index 46dfef32f..5a8a432fa 100644 --- a/src/server/handler.py +++ b/src/server/handler.py @@ -1,8 +1,16 @@ import contextlib -from typing import AsyncIterator, Dict, Sequence +from typing import AsyncIterator, Dict, List, Optional, Sequence, Union import torch -from hivemind import DHT, P2PContext, TensorDescriptor, deserialize_torch_tensor, nested_flatten, serialize_torch_tensor +from hivemind import ( + DHT, + MSGPackSerializer, + P2PContext, + TensorDescriptor, + deserialize_torch_tensor, + nested_flatten, + serialize_torch_tensor, +) from hivemind.moe.server.connection_handler import ConnectionHandler from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE from hivemind.proto import runtime_pb2 @@ -12,6 +20,7 @@ from src.data_structures import CHAIN_DELIMITER, ModuleUID from src.server.backend import MAX_LENGTH, TransformerBackend +from src.utils.misc import DUMMY, is_dummy class TransformerConnectionHandler(ConnectionHandler): @@ -33,7 +42,7 @@ async def rpc_inference( try: print("OPENED RPC_INFERENCE") request = await anext(requests) - requested_uids = self._check_header(request) + requested_uids = self._check_uids(request.uid) requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) batch_size = request.tensors[0].size[0] if request.tensors else 1 @@ -80,27 +89,18 @@ async def rpc_inference( async def rpc_forward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse: # Parse request and prepare backends - hidden_states = [deserialize_torch_tensor(tensor) for tensor in request.tensors] - requested_uids = self._check_header(request) + flat_inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors] + requested_uids = self._check_uids(request.uid) requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) - # Cast inputs to backend dtype - hidden_states = [tensor.to(requested_backends[0].dtype) for tensor in hidden_states] + hidden_states = await _rpc_forward(*flat_inputs, requested_backends=requested_backends) + assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3 - # Run a chain of requested backends - for backend in requested_backends: - assert isinstance(hidden_states, (list, tuple)) - assert ( - len(hidden_states) == 1 and hidden_states[0].ndim == 3 - ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states" - hidden_states = await backend.forward_pool.submit_task(*hidden_states) - - # Serialize the overall output and respond - assert len(hidden_states) == 1 and hidden_states[0].ndim == 3 + # Serialize output and respond to client return runtime_pb2.ExpertResponse( tensors=[ serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True) - for result, proto in zip(hidden_states, nested_flatten(requested_backends[-1].outputs_schema)) + for result, proto in zip((hidden_states,), nested_flatten(requested_backends[-1].outputs_schema)) ] ) @@ -108,29 +108,20 @@ async def rpc_forward_stream( self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext ) -> AsyncIterator[runtime_pb2.ExpertRequest]: # Parse requests and prepare backends - uids_header, hidden_states = await self._gather_inputs(requests, context) - requested_uids = self._check_header_str(uids_header) + uid_str, flat_inputs = await self._gather_inputs(requests, context) + requested_uids = self._check_uids(uid_str) requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) - # Cast inputs to backend dtype - hidden_states = [tensor.to(requested_backends[0].dtype) for tensor in hidden_states] - - # Run a chain of requested backends - for backend in requested_backends: - assert isinstance(hidden_states, (list, tuple)) - assert ( - len(hidden_states) == 1 and hidden_states[0].ndim == 3 - ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states" - hidden_states = await backend.forward_pool.submit_task(*hidden_states) + hidden_states = await _rpc_forward(flat_inputs, requested_backends) + assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3 # Serialize the overall output - assert len(hidden_states) == 1 and hidden_states[0].ndim == 3 serialized_output = [ serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True) - for result, proto in zip(hidden_states, nested_flatten(requested_backends[-1].outputs_schema)) + for result, proto in zip((hidden_states,), nested_flatten(requested_backends[-1].outputs_schema)) ] - # Split the serialized_output for streaming and respond + # Split the serialized_output for streaming and respond to client output_split = [ part for tensor in serialized_output for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE) ] @@ -139,36 +130,25 @@ async def rpc_forward_stream( async def rpc_backward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse: # Parse requests and prepare backends - inputs, grads = [deserialize_torch_tensor(tensor) for tensor in request.tensors] - requested_uids = self._check_header(request) + flat_tensors = [deserialize_torch_tensor(tensor) for tensor in request.tensors] + requested_uids = self._check_uids(request.uid) requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) - # Cast inputs & grad outputs to backend dtype - inputs = inputs.to(requested_backends[0].dtype) - grads = grads.to(requested_backends[-1].dtype) - - # Run a forward chain to collect intermediate inputs - # Note that we do not forward for the last module since we do not need its output - inter_inputs = [inputs] - for backend in requested_backends[:-1]: - assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states" - inputs = await backend.forward_pool.submit_task(inputs) - assert isinstance(inputs, (list, tuple)) and len(inputs) == 1 - inputs = inputs[0] - inter_inputs.append(inputs) - - # Run a chain of requested backends - for inp, backend in zip(inter_inputs[::-1], requested_backends[::-1]): - inputs_and_grads = [inp, grads] - grads = await backend.backward_pool.submit_task(*inputs_and_grads) - assert isinstance(grads, (list, tuple)) and len(grads) == 1 - grads = grads[0] + grads = await _rpc_backward(*flat_tensors, requested_backends=requested_backends) + + # Modify grad_inputs_schema to support grad_prompts + assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2) # TODO generalize + + grad_inputs_schema_with_prompts = ( + requested_backends[0].args_schema * len(grads), + requested_backends[0].kwargs_schema, + ) # TODO generalize # Serialize the overall grad_input and respond return runtime_pb2.ExpertResponse( tensors=[ serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True) - for result, proto in zip([grads], nested_flatten(requested_backends[0].grad_inputs_schema)) + for result, proto in zip(grads, nested_flatten(grad_inputs_schema_with_prompts)) ] ) @@ -176,36 +156,23 @@ async def rpc_backward_stream( self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext ) -> AsyncIterator[runtime_pb2.ExpertResponse]: - uids_header, inputs_and_grads = await self._gather_inputs(requests, context) - inputs, grads = inputs_and_grads - requested_uids = self._check_header_str(uids_header) + uids_header, flat_tensors = await self._gather_inputs(requests, context) + requested_uids = self._check_uids(uids_header) requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) - # Cast inputs & grad outputs to backend dtype - inputs = inputs.to(requested_backends[0].dtype) - grads = grads.to(requested_backends[-1].dtype) - - # Run a forward chain to collect intermediate inputs - # Note that we do not forward for the last module since we do not need its outputs - inter_inputs = [inputs] - for backend in requested_backends[:-1]: - assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states" - inputs = await backend.forward_pool.submit_task(inputs) - assert isinstance(inputs, (list, tuple)) and len(inputs) == 1 - inputs = inputs[0] - inter_inputs.append(inputs) - - # Run a backward chain for requested backends - for inp, backend in zip(inter_inputs[::-1], requested_backends[::-1]): - inputs_and_grads = [inp, grads] - grads = await backend.backward_pool.submit_task(*inputs_and_grads) - assert isinstance(grads, (list, tuple)) and len(grads) == 1 - grads = grads[0] + grads = await _rpc_backward(*flat_tensors, requested_backends=requested_backends) + + # Modify grad_inputs_schema to support grad_prompts + assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2) # TODO generalize + grad_inputs_schema_with_prompts = ( + requested_backends[0].args_schema * len(grads), + requested_backends[0].kwargs_schema, + ) # TODO generalize # Serialize the overall grad_inputs serialized_grad_inputs = [ serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True) - for result, proto in zip([grads], nested_flatten(requested_backends[0].grad_inputs_schema)) + for result, proto in zip(grads, nested_flatten(grad_inputs_schema_with_prompts)) ] # Split the serialized_grad_inputs for streaming and respond output_split = [ @@ -215,19 +182,9 @@ async def rpc_backward_stream( async for part in as_aiter(*output_split): yield runtime_pb2.ExpertResponse(tensors=[part]) - def _check_header(self, request: runtime_pb2.ExpertRequest) -> Sequence[ModuleUID]: + def _check_uids(self, uids: str) -> Sequence[ModuleUID]: """Check that the first request to rpc_inference is valid""" - uids = (request.uid or "").split(CHAIN_DELIMITER) - if not uids: - raise RuntimeError("User did not provide any uids") - for uid in uids: - if uid not in self.module_backends: - raise RuntimeError(f"Remote peer does not serve {uid}") - return tuple(uids) - - def _check_header_str(self, header) -> Sequence[ModuleUID]: - """Check that the first request to rpc_inference is valid""" - uids = (header or "").split(CHAIN_DELIMITER) + uids = (uids or "").split(CHAIN_DELIMITER) if not uids: raise RuntimeError("User did not provide any uids") for uid in uids: @@ -252,3 +209,83 @@ async def _allocate_caches(self, backends: Sequence[TransformerBackend], batch_s handles.append(await stack.enter_async_context(backend.memory_cache.allocate_cache(cache_descriptor))) yield handles + + +async def _rpc_forward(*flat_tensors: torch.Tensor, requested_backends: Sequence[TransformerBackend]) -> torch.Tensor: + """ + Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream + + :param flat_tensors: a list of tensors that includes first layer inputs, optional prompts and extra tensors + :note: some input tensors can be missing, in which case they will be replaced with dummy tensors (see is_dummy) + :param requested_backends: a sequence of transformer blocks in the same order as they appear in forward pass + :returns: hidden states after the last layer [batch_size, seq_length, hid_size] + """ + hidden_states, *prompts = flat_tensors + dtype = requested_backends[0].dtype + # check parse input tensors and cast dtypes + hidden_states = hidden_states.to(dtype) + assert hidden_states.ndim == 3 + if not prompts or is_dummy(prompts[0]): + prompts = [DUMMY] * len(requested_backends) + pre_seq_len = 0 + else: + prompts = [prompts[0].to(requested_backends[0].dtype)] + prompts = [p.squeeze(0) for p in prompts[0].split(1)] + pre_seq_len = prompts[0].shape[-2] + + # Run a chain of requested backends + for backend, prompt in zip(requested_backends, prompts): + if not is_dummy(prompt): + hidden_states[:, :pre_seq_len] += prompt + (hidden_states,) = await backend.forward_pool.submit_task(hidden_states) + assert isinstance(hidden_states, torch.Tensor) + assert ( + hidden_states.ndim == 3 + ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states" + + # Serialize the overall output + return hidden_states + + +async def _rpc_backward( + *flat_tensors: torch.Tensor, requested_backends: Sequence[TransformerBackend] +) -> Union[torch.Tensor, Sequence[torch.Tensor]]: + inputs, grad_outputs, *prompts = flat_tensors + # Cast inputs & grad outputs to backend dtype + inputs = inputs.to(requested_backends[0].dtype) + grad_outputs = grad_outputs.to(requested_backends[-1].dtype) + + if not prompts or is_dummy(prompts[0]): + prompts = [DUMMY] * len(requested_backends) + pre_seq_len = 0 + else: + prompts = [prompts[0].to(requested_backends[0].dtype)] + prompts = [p.squeeze(0) for p in prompts[0].split(1)] + pre_seq_len = prompts[0].shape[-2] + + # Run a forward chain to collect intermediate inputs + # Note that we do not forward for the last module since we do not need its output + inter_inputs = [] + for backend, prompt in zip(requested_backends[:-1], prompts[:-1]): + assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states" + if not is_dummy(prompt): + inputs[:, :pre_seq_len] += prompt + inter_inputs.append(inputs) + (inputs,) = await backend.forward_pool.submit_task(inputs) + assert isinstance(inputs, torch.Tensor) + + if not is_dummy(prompts[-1]): + inputs[:, :pre_seq_len] += prompts[-1] + inter_inputs.append(inputs) + + assert len(inter_inputs) == len(prompts) == len(requested_backends), "internal shape error during backward" + grad_prompts_reversed = [] + # Run a chain of requested backends + for inp, prompt, backend in zip(*map(reversed, (inter_inputs, prompts, requested_backends))): + (grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs) + assert isinstance(grad_outputs, torch.Tensor) + if not is_dummy(prompt): + grad_prompts_reversed.append(grad_outputs[:, :pre_seq_len].unsqueeze(0)) + + grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else DUMMY + return [grad_outputs] if is_dummy(grad_prompts) else [grad_outputs, grad_prompts] # TODO un-duct-tape diff --git a/src/utils/misc.py b/src/utils/misc.py new file mode 100644 index 000000000..2f6720230 --- /dev/null +++ b/src/utils/misc.py @@ -0,0 +1,7 @@ +import torch + +DUMMY = torch.empty(0) # dummy tensor that replaces empty prompt or adapter parameters + + +def is_dummy(tensor: torch.Tensor): + return tensor.numel() == 0 diff --git a/tests/test_remote_sequential.py b/tests/test_remote_sequential.py index 1ae4c38fe..678ec01ee 100644 --- a/tests/test_remote_sequential.py +++ b/tests/test_remote_sequential.py @@ -4,6 +4,7 @@ from test_utils import * from src import RemoteSequential +from src.bloom.from_pretrained import load_pretrained_block from src.client.remote_model import DistributedBloomConfig use_hivemind_log_handler("in_root_logger") @@ -41,3 +42,48 @@ def test_remote_sequential(): (second_half_outputs * grad_proj).sum().backward() assert torch.allclose(test_inputs.grad, full_grad) + + +@pytest.mark.forked +def test_remote_sequential_prompts(batch_size=2, seq_len=5, pre_seq_len=3): + config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS) + dht = DHT(initial_peers=config.initial_peers, client_mode=True, start=True) + remote_sequential = RemoteSequential(config, dht) + + inputs = torch.randn(batch_size, seq_len, config.hidden_size) + output_proj = torch.randn(batch_size, seq_len + pre_seq_len, config.hidden_size) + input_prompts = torch.randn(batch_size, pre_seq_len, config.hidden_size, requires_grad=True) + intermediate_prompts = torch.randn(config.n_layer, batch_size, pre_seq_len, config.hidden_size, requires_grad=True) + + input_prompts = input_prompts.detach().requires_grad_(True) + intermediate_prompts = intermediate_prompts.detach().requires_grad_(True) + + inputs_with_prompts = torch.cat([inputs, input_prompts], dim=1) + assert inputs_with_prompts.shape == (batch_size, seq_len + pre_seq_len, config.hidden_size) + + outputs = remote_sequential(inputs_with_prompts, prompts=intermediate_prompts) + + (outputs * output_proj).sum().backward() + assert intermediate_prompts.grad is not None + + input_prompts_ref = input_prompts.clone().detach().requires_grad_(True) + intermediate_prompts_ref = intermediate_prompts.clone().detach().requires_grad_(True) + + assert input_prompts_ref.grad is None + assert intermediate_prompts_ref.grad is None + + outputs_ref = torch.cat([inputs, input_prompts_ref], dim=1) + for block_index in range(config.n_layer): + block_prompt = intermediate_prompts_ref[block_index] + outputs_ref[:, : block_prompt.shape[1]] += block_prompt + + block = load_pretrained_block(MODEL_NAME, block_index=block_index, torch_dtype=torch.float32) + (outputs_ref,) = block(outputs_ref) + + assert torch.allclose(outputs_ref, outputs) + + (outputs_ref * output_proj).sum().backward() + assert input_prompts_ref.grad is not None + assert torch.allclose(input_prompts_ref.grad, input_prompts.grad) + assert intermediate_prompts_ref.grad is not None + assert torch.allclose(intermediate_prompts_ref.grad, intermediate_prompts.grad)