Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Distributed prompt tuning #42

Merged
merged 31 commits into from
Aug 12, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
d87b8d6
distributed deep & shallow ptune
dbaranchuk Aug 7, 2022
16ab6fb
debug
dbaranchuk Aug 8, 2022
15e49f9
fix bug
dbaranchuk Aug 8, 2022
78d50a4
make all computations in handler
dbaranchuk Aug 10, 2022
f66547b
rm redundant kwarg
dbaranchuk Aug 10, 2022
6d0c018
Merge remote-tracking branch 'origin/main' into distributed-deep-ptune
dbaranchuk Aug 10, 2022
24bdf33
Update src/client/sequential_autograd.py
dbaranchuk Aug 10, 2022
a3e8c41
fix mini bug
dbaranchuk Aug 10, 2022
cb9d9f8
Merge branch 'distributed-deep-ptune' of github.com:learning-at-home/…
dbaranchuk Aug 10, 2022
4e64597
crutch fix for tests
justheuristic Aug 10, 2022
64b04c3
black-isort
justheuristic Aug 10, 2022
1ce22a9
make new api backwards compatible
justheuristic Aug 10, 2022
b3b3264
actually fix tests
justheuristic Aug 10, 2022
c0b7dde
Merge branch 'main' into distributed-deep-ptune
justheuristic Aug 10, 2022
c316ca3
make comments more polite ;)
justheuristic Aug 10, 2022
4365fc8
Merge remote-tracking branch 'origin/distributed-deep-ptune' into dis…
justheuristic Aug 10, 2022
185f914
rm unused
justheuristic Aug 10, 2022
e96a791
more polite comments
justheuristic Aug 10, 2022
f54eaba
rollback
justheuristic Aug 10, 2022
50a647d
fix type warning
justheuristic Aug 10, 2022
cb52c3f
reuse existing variable
justheuristic Aug 10, 2022
0cfda12
black-isort
justheuristic Aug 10, 2022
7eb4ca3
hotfix for grads
justheuristic Aug 11, 2022
e364a1e
fix for batch sizes != 1
justheuristic Aug 11, 2022
a66ff41
hotfix for grads
justheuristic Aug 12, 2022
7d22978
grad prompts reversed
justheuristic Aug 12, 2022
48f65a1
[test remotely in CI]
justheuristic Aug 12, 2022
169e619
fix prompts for partial chains
justheuristic Aug 12, 2022
0791f85
fix prompts for partial chains
justheuristic Aug 12, 2022
cdde27a
black-isort
justheuristic Aug 12, 2022
1f5afd1
unused lines
justheuristic Aug 12, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
make all computations in handler
  • Loading branch information
dbaranchuk committed Aug 10, 2022
commit 78d50a4a033c07800a94b695a52aa662011b7f22
6 changes: 0 additions & 6 deletions src/bloom/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
pre_process_alibi_for_pad,
split_tensor_along_last_dim,
)
from src.utils.misc import is_dummy_batch


class BloomAttention(nn.Module):
Expand Down Expand Up @@ -249,11 +248,6 @@ def forward(
# MLP.
output = self.mlp(layernorm_output, residual)

batch_size = hidden_states.shape[0]
if prompts is not None and not is_dummy_batch(prompts, batch_size):
pre_seq_len = prompts.shape[1]
output[:, :pre_seq_len] = output[:, :pre_seq_len] + prompts

if use_cache:
outputs = (output,) + outputs
else:
Expand Down
19 changes: 16 additions & 3 deletions src/client/sequential_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,13 @@ async def run_expert_forward(
# Note: we put keyword arguments in the same order as on a server to prevent f(a=1, b=2) != f(b=2, a=1) errors
forward_inputs = (inputs, kwargs)

if not nested_compare(forward_inputs, rpc_info["forward_schema"]):
# Modify forward_schema to support prompts
args_schema, kwargs_schema = rpc_info["forward_schema"]
# TODO: rm this assert when support arbitrary number of input tensors
assert len(args_schema) == 1 and len(inputs) == 2
forward_schema_with_prompts = (tuple(args_schema * len(inputs)), kwargs_schema)

if not nested_compare(forward_inputs, forward_schema_with_prompts):
raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")

forward_inputs = nested_flatten(forward_inputs)
Expand All @@ -45,7 +51,7 @@ async def run_expert_forward(
serialized_tensors = await asyncio.gather(
*(
loop.run_in_executor(None, serialize_torch_tensor, tensor.to(proto.dtype), proto.compression)
for tensor, proto in zip(inputs, nested_flatten(rpc_info["forward_schema"]))
for tensor, proto in zip(inputs, nested_flatten(forward_schema_with_prompts))
)
)

Expand All @@ -69,7 +75,14 @@ async def run_expert_backward(

grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
inputs_and_grad_outputs = tuple(nested_flatten((inputs, grad_outputs_cpu)))
backward_schema = tuple(nested_flatten((rpc_info["forward_schema"], rpc_info["outputs_schema"])))

# Modify forward_schema to support prompts
args_schema, kwargs_schema = rpc_info["forward_schema"]
# TODO: rm this assert when support arbitrary number of input tensors
assert len(args_schema) == 1 and len(inputs) == 2
forward_schema_with_prompts = (tuple(args_schema * len(inputs)), kwargs_schema)

backward_schema = tuple(nested_flatten((forward_schema_with_prompts, rpc_info["outputs_schema"])))

# Asynchronous serialization
loop = asyncio.get_running_loop()
Expand Down
59 changes: 41 additions & 18 deletions src/server/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from src.data_structures import CHAIN_DELIMITER, ModuleUID
from src.server.backend import MAX_LENGTH, TransformerBackend
from src.utils.misc import DUMMY, is_dummy, is_dummy_batch, make_dummy_batch
from src.utils.misc import DUMMY, is_dummy


class TransformerConnectionHandler(ConnectionHandler):
Expand Down Expand Up @@ -128,11 +128,18 @@ async def rpc_backward(self, request: runtime_pb2.ExpertRequest, context: P2PCon

grads = await _rpc_backward(inputs, prompts, grad_outputs, requested_backends)

# Modify grad_inputs_schema to support grad_prompts
assert len(requested_backends[0].args_schema) == 1 and len(grads) == 2
grad_inputs_schema_with_prompts = (
requested_backends[0].args_schema * len(grads),
requested_backends[0].kwargs_schema,
)

# Serialize the overall grad_input and respond
return runtime_pb2.ExpertResponse(
tensors=[
serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
for result, proto in zip(grads, nested_flatten(requested_backends[0].grad_inputs_schema))
for result, proto in zip(grads, nested_flatten(grad_inputs_schema_with_prompts))
]
)

Expand All @@ -146,10 +153,17 @@ async def rpc_backward_stream(

grads = await _rpc_backward(inputs, prompts, grad_outputs, requested_backends)

# Modify grad_inputs_schema to support grad_prompts
assert len(requested_backends[0].args_schema) == 1 and len(grads) == 2
grad_inputs_schema_with_prompts = (
requested_backends[0].args_schema * len(grads),
requested_backends[0].kwargs_schema,
)

# Serialize the overall grad_inputs
serialized_grad_inputs = [
serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
for result, proto in zip(grads, nested_flatten(requested_backends[0].grad_inputs_schema))
for result, proto in zip(grads, nested_flatten(grad_inputs_schema_with_prompts))
]
# Split the serialized_grad_inputs for streaming and respond
output_split = [
Expand Down Expand Up @@ -200,17 +214,20 @@ async def _allocate_caches(self, backends: Sequence[TransformerBackend], batch_s

async def _rpc_forward(inputs, requested_backends):
# Cast inputs to backend dtype
hidden_states = [tensor.to(requested_backends[0].dtype) for tensor in inputs]
assert len(hidden_states) == 2 and hidden_states[0].ndim == 3
hidden_states, prompts = hidden_states
inputs = [tensor.to(requested_backends[0].dtype) for tensor in inputs]
assert len(inputs) == 2 and inputs[0].ndim == 3
hidden_states, prompts = inputs

if is_dummy(prompts):
batch_size = hidden_states.shape[0]
prompts = [make_dummy_batch(batch_size)] * len(requested_backends)
prompts = [DUMMY] * len(requested_backends)
else:
pre_seq_len = prompts.shape[2]

# Run a chain of requested backends
for backend, prompt in zip(requested_backends, prompts):
(hidden_states,) = await backend.forward_pool.submit_task(hidden_states, prompt)
if not is_dummy(prompt):
hidden_states[:, :pre_seq_len] += prompt
(hidden_states,) = await backend.forward_pool.submit_task(hidden_states)
assert isinstance(hidden_states, torch.Tensor)
assert (
hidden_states.ndim == 3
Expand All @@ -225,31 +242,37 @@ async def _rpc_backward(inputs, prompts, grad_outputs, requested_backends):
inputs = inputs.to(requested_backends[0].dtype)
prompts = prompts.to(requested_backends[0].dtype)
grad_outputs = grad_outputs.to(requested_backends[-1].dtype)
batch_size = inputs.shape[0]

if is_dummy(prompts):
prompts = [make_dummy_batch(batch_size)] * len(requested_backends)
prompts = [DUMMY] * len(requested_backends)
else:
pre_seq_len = prompts.shape[2]
prompts = [p.squeeze(0) for p in prompts.split(1)]

# Run a forward chain to collect intermediate inputs
# Note that we do not forward for the last module since we do not need its output
inter_inputs = [inputs]
for backend, prompt in zip(requested_backends[:-1], prompts[:-1]):
assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
(inputs,) = await backend.forward_pool.submit_task(inputs, prompt)
if not is_dummy(prompt):
inputs[:, :pre_seq_len] += prompt
(inputs,) = await backend.forward_pool.submit_task(inputs)
assert isinstance(inputs, torch.Tensor)
inter_inputs.append(inputs)

grad_prompts = []
# Run a chain of requested backends
for inp, prompt, backend in zip(inter_inputs[::-1], prompts[::-1], requested_backends[::-1]):
grads = await backend.backward_pool.submit_task(inp, prompt, grad_outputs)
assert isinstance(grads, (list, tuple)) and len(grads) == 2
grad_outputs, grad_prompt = grads
grad_prompts.append(grad_prompt[None])

is_dummy_grad_prompts = [is_dummy_batch(grad_param, batch_size) for grad_param in grad_prompts]
if not is_dummy(prompt):
inp[:, :pre_seq_len] += prompt
(grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs)
assert isinstance(grad_outputs, torch.Tensor)
if not is_dummy(prompt):
grad_prompts.append(grad_outputs[:, :pre_seq_len].unsqueeze(0))
else:
grad_prompts.append(DUMMY)

is_dummy_grad_prompts = [is_dummy(grad_param) for grad_param in grad_prompts]
grad_prompts = torch.cat(grad_prompts, dim=0) if not any(is_dummy_grad_prompts) else DUMMY
grads = [grad_outputs, grad_prompts]
return grads
3 changes: 0 additions & 3 deletions src/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,9 +212,6 @@ def create(
BatchTensorDescriptor(
1, 2048, block_config.hidden_size, dtype=torch.float32, compression=compression
),
BatchTensorDescriptor(
1, 2048, block_config.hidden_size, dtype=torch.float32, compression=compression
),
),
kwargs_schema={},
outputs_schema=(
Expand Down
5 changes: 0 additions & 5 deletions src/utils/misc.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
import torch

DUMMY = torch.empty(0) # dummy tensor that replaces empty prompt or adapter parameters
make_dummy_batch = lambda x: torch.empty(x)


def is_dummy(tensor: torch.Tensor):
return tensor.numel() == 0


def is_dummy_batch(tensor: torch.Tensor, batch_size: int):
return tensor.numel() == batch_size and tensor.ndim == 1