Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement multi-token prediction option for models #479

Open
tmostak opened this issue May 4, 2024 · 7 comments
Open

Implement multi-token prediction option for models #479

tmostak opened this issue May 4, 2024 · 7 comments

Comments

@tmostak
Copy link

tmostak commented May 4, 2024

Per the recent paper from Meta, it appears that models that predict multiple future tokens can exhibit significantly greater sample efficiency than models trained only on next-token prediction, plus the extra token heads can be used to implement speculative decoding to speed up inference (up to 3X in their experiments), without the need for a draft model.

It would be amazing to see multi-token prediction implemented in nanoGPT, as it would allow the community to easily experiment with this promising technique.

@siddharthji07
Copy link

Implementing multi-token prediction in NanoGPT could be very amazing it would increase the sample efficiency so we can get more accurate results as well as models can perform speculative decoding, where they generate multiple possible sequences of tokens and then we can choose the most likely one

@NullLabTests
Copy link

Was anybody able to do this? I think we can add this functionality.

@vnsmv
Copy link

vnsmv commented Jul 3, 2024

Was anybody able to do this? I think we can add this functionality.

I want to start working on this

@thoorpukarnakar
Copy link

Even i would like to start and contribute on this multi token prediction.

@MauroCE
Copy link

MauroCE commented Jul 15, 2024

I'd love to start contributing too! I tried to think a bit, but I did not manage to find a finalized solution

  1. Add n_multi_token: int = 4 to GPTConfig (default to 4 based on paper).
  2. Between line 130-131 add hmt = nn.ModuleList([Block(config) for _ in range(config.n_multi_token)])

In the forward pass, after lines 180-181, we need to perform the sequential pass described in Figure 2 of the paper. I was thinking something like this:

# Old code, stays the same
tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
x = self.transformer.drop(tok_emb + pos_emb)
for block in self.transformer.h:
    x = block(x)

# New code. I am unclear about whether LN should go before or after the multi-token heads
# and unsure about how we actually implement Figure 2 here
if target is not None:
    for block in self.transformer.hmt:
        logits = self.lm_head(self.transformer.ln_f(block(x)))
        loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
else:
        logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
        loss = None

I want to understand this (would love to hear people's opinions):

In the paper, they say that after the shared trunk, we compute the multi-token (MT) heads sequentially to save memory but they all share the same un-embedding layer. Does LayerNorm happen within the MT heads, before or after? ln_f is the final LayerNorm, so I would guess it would have to go after the MT heads.

However, most importantly, in Figure 2 they do the backward pass directly. How do we adapt this to our code? I think the key idea is that we want to compute each loss independently (each MT head is independent) and therefore we want to accumulate the gradients and feed them back to the rest of the network. My code as is does not work because I am not accumulating gradients.

@MauroCE
Copy link

MauroCE commented Jul 15, 2024

Actually, perhaps a better way to do this is as follows, this should accumulate gradients

# New code
logits_list = []
if targets is not None:
    loss = 0
    for head in self.transformer.hmt:
        logits = self.lm_head(self.transformer.ln_f(head(x)))
        logits_list.append(logits)
        loss += F.cross_entropy(logits.view(-1, logits.size(-1)), targets[:, i].view(-1), ignore_index=-1)
else:
    # inference-time optimization: only forward the lm_head on the very last position
    for head in self.transformer.hmt:
        logits = self.lm_head(self.transformer.ln_f(head(x[:, [-1], :]))) # note: using list [-1] to preserve the time dim
        logits_list.append(logits)
        loss = None

return logits_list, loss

@MauroCE
Copy link

MauroCE commented Jul 16, 2024

An important change is to change the batch for training. The context should have the same shape as before, but now we want targets of shape (B, T, F) where F is the number of future tokens. We recover the standard scenario with F=1 after unsqueezing.

def get_batch(split, multi_token=False, F=1):
    assert F >= 1, "number of future tokens must be at least 1."
    assert multi_token and F == 1, "when multi_token is True, F must be larger than 1."
    assert not multi_token and F > 1, "when next-token prediction is being used, F must be 1."
    # We recreate np.memmap every batch to avoid a memory leak, as per
    # https://stackoverflow.com/questions/45132940/numpy-memmap-memory-usage-want-to-iterate-once/61472122#61472122
    if split == 'train':
        data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
    else:
        data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r')
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
    y = torch.stack([torch.stack([data[i+j+1:i+j+1+block_size] for j in range(F)], dim=-1) for i in ix])
    if not multi_token:
        y.squeeze(-1)
    if device_type == 'cuda':
        # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
        x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
    else:
        x, y = x.to(device), y.to(device)
    return x, y

The script above for get_batch can definitely be improved, but might be a decent starting point. The following then should do the job. I have tried it in a similar repo of mine and seems to be working, but I cannot claim yet whether it is working correctly or not.

  # feed through each head separately
  head_outputs = []
  for head in self.transformer.hmt:
      x = head(x)
      head_outputs.append(x)
  # Stack them together and use LayerNorm
  x = torch.stack(head_outputs, dim=-2)  # (B, T, n_multi_token, n_embd)
  x = self.ln_f(x)  # (B, T, n_multi_token, n_embd), works because it acts on the final dimension
  # Final linear layer mapping (B, T, n_multi_token, n_embd) -> (B, T, n_multi_token, vocab_size)
  logits = self.lm_head(x)  # (B, T, n_multi_token, vocab_size)
  if targets is None:
      loss = None
  else:
      # Compute log-probabilities
      log_probs = F.log_softmax(_logits, dim=-1).view(b*t*n_multi_token, self.config.vocab_size) 
      expanded_targets = targets.view(B*T*n_usable_heads, 1) 
      # Compute loss
      log_probs_true_tokens = torch.gather(
          input=log_probs, dim=-1, index=expanded_targets).squeeze(-1)  # (B*T*n_multi_token, )
      loss = - log_probs_true_tokens.mean()  # scalar
return logits, loss

A few edits might be necessary to make this work smoothly with the rest of the package. First of all, one should be able to choose whether to use only the next-token head or all the future tokens ones, especially at inference time. This also does not implement self-speculative decoding.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants