Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Distributed prompt tuning #42

Merged
merged 31 commits into from
Aug 12, 2022
Merged
Changes from 1 commit
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
d87b8d6
distributed deep & shallow ptune
dbaranchuk Aug 7, 2022
16ab6fb
debug
dbaranchuk Aug 8, 2022
15e49f9
fix bug
dbaranchuk Aug 8, 2022
78d50a4
make all computations in handler
dbaranchuk Aug 10, 2022
f66547b
rm redundant kwarg
dbaranchuk Aug 10, 2022
6d0c018
Merge remote-tracking branch 'origin/main' into distributed-deep-ptune
dbaranchuk Aug 10, 2022
24bdf33
Update src/client/sequential_autograd.py
dbaranchuk Aug 10, 2022
a3e8c41
fix mini bug
dbaranchuk Aug 10, 2022
cb9d9f8
Merge branch 'distributed-deep-ptune' of github.com:learning-at-home/…
dbaranchuk Aug 10, 2022
4e64597
crutch fix for tests
justheuristic Aug 10, 2022
64b04c3
black-isort
justheuristic Aug 10, 2022
1ce22a9
make new api backwards compatible
justheuristic Aug 10, 2022
b3b3264
actually fix tests
justheuristic Aug 10, 2022
c0b7dde
Merge branch 'main' into distributed-deep-ptune
justheuristic Aug 10, 2022
c316ca3
make comments more polite ;)
justheuristic Aug 10, 2022
4365fc8
Merge remote-tracking branch 'origin/distributed-deep-ptune' into dis…
justheuristic Aug 10, 2022
185f914
rm unused
justheuristic Aug 10, 2022
e96a791
more polite comments
justheuristic Aug 10, 2022
f54eaba
rollback
justheuristic Aug 10, 2022
50a647d
fix type warning
justheuristic Aug 10, 2022
cb52c3f
reuse existing variable
justheuristic Aug 10, 2022
0cfda12
black-isort
justheuristic Aug 10, 2022
7eb4ca3
hotfix for grads
justheuristic Aug 11, 2022
e364a1e
fix for batch sizes != 1
justheuristic Aug 11, 2022
a66ff41
hotfix for grads
justheuristic Aug 12, 2022
7d22978
grad prompts reversed
justheuristic Aug 12, 2022
48f65a1
[test remotely in CI]
justheuristic Aug 12, 2022
169e619
fix prompts for partial chains
justheuristic Aug 12, 2022
0791f85
fix prompts for partial chains
justheuristic Aug 12, 2022
cdde27a
black-isort
justheuristic Aug 12, 2022
1f5afd1
unused lines
justheuristic Aug 12, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
grad prompts reversed
  • Loading branch information
justheuristic committed Aug 12, 2022
commit 7d22978a9e7661d5cd7ac64a4c979e2a78bda0a9
17 changes: 11 additions & 6 deletions src/server/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,22 +265,27 @@ async def _rpc_backward(

# Run a forward chain to collect intermediate inputs
# Note that we do not forward for the last module since we do not need its output
inter_inputs = [inputs]
inter_inputs = []
for backend, prompt in zip(requested_backends[:-1], prompts[:-1]):
assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
if not is_dummy(prompt):
inputs[:, :pre_seq_len] += prompt
inter_inputs.append(inputs)
(inputs,) = await backend.forward_pool.submit_task(inputs)
assert isinstance(inputs, torch.Tensor)
inter_inputs.append(inputs)

grad_prompts = []
if not is_dummy(prompts[-1]):
inputs[:, :pre_seq_len] += prompts[-1]
inter_inputs.append(inputs)

assert len(inter_inputs) == len(prompts) == len(requested_backends), "internal shape error during backward"
grad_prompts_reversed = []
# Run a chain of requested backends
for inp, prompt, backend in zip(inter_inputs[::-1], prompts[::-1], requested_backends[::-1]):
for inp, prompt, backend in zip(*map(reversed, (inter_inputs, prompts, requested_backends))):
(grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs)
assert isinstance(grad_outputs, torch.Tensor)
if not is_dummy(prompt):
grad_prompts.append(grad_outputs[:, :pre_seq_len].unsqueeze(0))
grad_prompts_reversed.append(grad_outputs[:, :pre_seq_len].unsqueeze(0))

grad_prompts = torch.cat(grad_prompts, dim=0) if grad_prompts else DUMMY
grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else DUMMY
return [grad_outputs] if is_dummy(grad_prompts) else [grad_outputs, grad_prompts] # TODO un-duct-tape