Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Distributed prompt tuning #42

Merged
merged 31 commits into from
Aug 12, 2022
Merged
Changes from 1 commit
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
d87b8d6
distributed deep & shallow ptune
dbaranchuk Aug 7, 2022
16ab6fb
debug
dbaranchuk Aug 8, 2022
15e49f9
fix bug
dbaranchuk Aug 8, 2022
78d50a4
make all computations in handler
dbaranchuk Aug 10, 2022
f66547b
rm redundant kwarg
dbaranchuk Aug 10, 2022
6d0c018
Merge remote-tracking branch 'origin/main' into distributed-deep-ptune
dbaranchuk Aug 10, 2022
24bdf33
Update src/client/sequential_autograd.py
dbaranchuk Aug 10, 2022
a3e8c41
fix mini bug
dbaranchuk Aug 10, 2022
cb9d9f8
Merge branch 'distributed-deep-ptune' of github.com:learning-at-home/…
dbaranchuk Aug 10, 2022
4e64597
crutch fix for tests
justheuristic Aug 10, 2022
64b04c3
black-isort
justheuristic Aug 10, 2022
1ce22a9
make new api backwards compatible
justheuristic Aug 10, 2022
b3b3264
actually fix tests
justheuristic Aug 10, 2022
c0b7dde
Merge branch 'main' into distributed-deep-ptune
justheuristic Aug 10, 2022
c316ca3
make comments more polite ;)
justheuristic Aug 10, 2022
4365fc8
Merge remote-tracking branch 'origin/distributed-deep-ptune' into dis…
justheuristic Aug 10, 2022
185f914
rm unused
justheuristic Aug 10, 2022
e96a791
more polite comments
justheuristic Aug 10, 2022
f54eaba
rollback
justheuristic Aug 10, 2022
50a647d
fix type warning
justheuristic Aug 10, 2022
cb52c3f
reuse existing variable
justheuristic Aug 10, 2022
0cfda12
black-isort
justheuristic Aug 10, 2022
7eb4ca3
hotfix for grads
justheuristic Aug 11, 2022
e364a1e
fix for batch sizes != 1
justheuristic Aug 11, 2022
a66ff41
hotfix for grads
justheuristic Aug 12, 2022
7d22978
grad prompts reversed
justheuristic Aug 12, 2022
48f65a1
[test remotely in CI]
justheuristic Aug 12, 2022
169e619
fix prompts for partial chains
justheuristic Aug 12, 2022
0791f85
fix prompts for partial chains
justheuristic Aug 12, 2022
cdde27a
black-isort
justheuristic Aug 12, 2022
1f5afd1
unused lines
justheuristic Aug 12, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
[test remotely in CI]
  • Loading branch information
justheuristic committed Aug 12, 2022
commit 48f65a1c8ce487e3a92535618f4503581da80bcf
48 changes: 48 additions & 0 deletions tests/test_remote_sequential.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from test_utils import *

from src import RemoteSequential
from src.bloom.from_pretrained import load_pretrained_block
from src.client.remote_model import DistributedBloomConfig

use_hivemind_log_handler("in_root_logger")
Expand Down Expand Up @@ -41,3 +42,50 @@ def test_remote_sequential():

(second_half_outputs * grad_proj).sum().backward()
assert torch.allclose(test_inputs.grad, full_grad)


@pytest.mark.forked
def test_remote_sequential_prompts(batch_size=2, seq_len=5, pre_seq_len=3):
config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
dht = DHT(initial_peers=config.initial_peers, client_mode=True, start=True)
remote_sequential = RemoteSequential(config, dht)

inputs = torch.randn(batch_size, seq_len, config.hidden_size)
output_proj = torch.randn(batch_size, seq_len + pre_seq_len, config.hidden_size)
input_prompts = torch.randn(batch_size, pre_seq_len, config.hidden_size, requires_grad=True)
intermediate_prompts = torch.randn(config.n_layer, batch_size, pre_seq_len, config.hidden_size, requires_grad=True)

input_prompts = input_prompts.detach().requires_grad_(True)
intermediate_prompts = intermediate_prompts.detach().requires_grad_(True)
with torch.no_grad():
intermediate_prompts[...] = torch.randn_like(intermediate_prompts)

inputs_with_prompts = torch.cat([inputs, input_prompts], dim=1)
assert inputs_with_prompts.shape == (batch_size, seq_len + pre_seq_len, config.hidden_size)

outputs = remote_sequential(inputs_with_prompts, prompts=intermediate_prompts)

(outputs * output_proj).sum().backward()
assert intermediate_prompts.grad is not None

input_prompts_ref = input_prompts.clone().detach().requires_grad_(True)
intermediate_prompts_ref = intermediate_prompts.clone().detach().requires_grad_(True)

assert input_prompts_ref.grad is None
assert intermediate_prompts_ref.grad is None

outputs_ref = torch.cat([inputs, input_prompts_ref], dim=1)
for block_index in range(config.n_layer):
block_prompt = intermediate_prompts_ref[block_index]
outputs_ref[:, : block_prompt.shape[1]] += block_prompt

block = load_pretrained_block(MODEL_NAME, block_index=block_index, torch_dtype=torch.float32)
(outputs_ref,) = block(outputs_ref)

assert torch.allclose(outputs_ref, outputs)

(outputs_ref * output_proj).sum().backward()
assert input_prompts_ref.grad is not None
assert torch.allclose(input_prompts_ref.grad, input_prompts.grad)
assert intermediate_prompts_ref.grad is not None
assert torch.allclose(intermediate_prompts_ref.grad, intermediate_prompts.grad)