Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add fusedce to metrics #575

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions llmfoundry/models/layers/cross_entropy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

try:
from flash_attn.losses.cross_entropy import SoftmaxCrossEntropyLossFn
except:
SoftmaxCrossEntropyLossFn = None

import torch.nn as nn


class FusedCrossEntropyLoss(nn.Module):

def __init__(
self,
ignore_index=-100,
reduction='mean',
label_smoothing=0.0,
inplace_backward=False,
process_group=None,
):
if SoftmaxCrossEntropyLossFn is None:
raise ValueError(
'Fused Cross Entropy is not installed. Either (1) have a CUDA-compatible GPU '
+
'and `pip install .[gpu]` if installing from source or `pip install xentropy-cuda-lib@git+https://github.com/HazyResearch/flash-attention.git@v1.0.3#subdirectory=csrc/xentropy` '
+
'if installing from pypi, or (2) set your config model.loss_fn=torch_crossentropy.'
)
super().__init__()
self.ignore_index = ignore_index
self.reduction = reduction
self.label_smoothing = label_smoothing
self.inplace_backward = inplace_backward
self.process_group = process_group

def forward(self, input, target):
assert input.is_cuda and target.is_cuda
# SoftmaxCrossEntropyLoss implicitly casts to float
loss = SoftmaxCrossEntropyLossFn.apply(
input,
target,
self.label_smoothing,
self.ignore_index,
self.inplace_backward,
self.process_group,
)
if self.reduction == 'mean':
return loss.sum() / (target != self.ignore_index).sum()
elif self.reduction == 'sum':
return loss.sum()
else:
return loss
43 changes: 25 additions & 18 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

from llmfoundry.models.layers.attention import attn_bias_shape, build_attn_bias
from llmfoundry.models.layers.blocks import MPTBlock
from llmfoundry.models.layers.cross_entropy import FusedCrossEntropyLoss
from llmfoundry.models.layers.custom_embedding import SharedEmbedding
from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY as FC_CLASS_REGISTRY
from llmfoundry.models.layers.ffn import \
Expand Down Expand Up @@ -690,10 +691,29 @@ def __init__(
hf_config = MPTConfig.from_dict(resolved_om_model_config)
model = MPTForCausalLM(hf_config)

train_metrics = [LanguageCrossEntropy(), LanguagePerplexity()]
loss_fn_config = om_model_config.get('loss_fn', 'fused_crossentropy')

train_lang_ce = LanguageCrossEntropy()
train_lang_pp = LanguagePerplexity()
if loss_fn_config == 'fused_crossentropy':
train_lang_ce.loss_fn = FusedCrossEntropyLoss(ignore_index=-100,
reduction='sum')
train_lang_pp.loss_fn = FusedCrossEntropyLoss(ignore_index=-100,
reduction='sum')

train_metrics = [train_lang_ce, train_lang_pp]

eval_lang_ce = LanguageCrossEntropy()
eval_lang_pp = LanguagePerplexity()
if loss_fn_config == 'fused_crossentropy':
eval_lang_ce.loss_fn = FusedCrossEntropyLoss(ignore_index=-100,
reduction='sum')
eval_lang_pp.loss_fn = FusedCrossEntropyLoss(ignore_index=-100,
reduction='sum')

eval_metrics = [
LanguageCrossEntropy(),
LanguagePerplexity(),
eval_lang_ce,
eval_lang_pp,
InContextLearningLMAccuracy(),
InContextLearningMultipleChoiceAccuracy(),
InContextLearningQAAccuracy(),
Expand All @@ -715,22 +735,9 @@ def __init__(

loss_fn_config = om_model_config.get('loss_fn', 'fused_crossentropy')
if loss_fn_config == 'fused_crossentropy':
try:
from flash_attn.losses.cross_entropy import CrossEntropyLoss as FusedCrossEntropyLoss # type: ignore # isort: skip

if hf_config.verbose > 1:
warnings.warn('Using Fused Cross Entropy Loss.')
self.loss_fn = FusedCrossEntropyLoss(ignore_index=-100)
except:
raise ValueError(
'Fused Cross Entropy is not installed. Either (1) have a CUDA-compatible GPU '
+
'and `pip install .[gpu]` if installing from source or `pip install xentropy-cuda-lib@git+https://github.com/HazyResearch/flash-attention.git@v1.0.3#subdirectory=csrc/xentropy` '
+
'if installing from pypi, or (2) set your config model.loss_fn=torch_crossentropy.'
)
self.loss_fn = FusedCrossEntropyLoss(ignore_index=-100, reduction='mean')
elif loss_fn_config == 'torch_crossentropy':
self.loss_fn = nn.CrossEntropyLoss(ignore_index=-100)
self.loss_fn = nn.CrossEntropyLoss(ignore_index=-100, reduction='mean')
else:
raise ValueError(
f'Specified loss_fn={self.loss_fn} not recognized. `loss_fn` must be one of [`fused_crossentropy`, `torch_crossentropy`].'
Expand Down
Loading