From 76ea193e1a514ad8a05d7e77a3e02143235c5df2 Mon Sep 17 00:00:00 2001 From: drownfish19 Date: Tue, 10 Sep 2024 11:40:14 +0000 Subject: [PATCH 1/2] fix --- .../transformers/ring_flash_attention.py | 23 ++++++++++--------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/paddlenlp/transformers/ring_flash_attention.py b/paddlenlp/transformers/ring_flash_attention.py index 9fa8ea52b655..792229445294 100644 --- a/paddlenlp/transformers/ring_flash_attention.py +++ b/paddlenlp/transformers/ring_flash_attention.py @@ -20,17 +20,6 @@ from paddle import _C_ops from paddle.autograd.py_layer import PyLayer -try: - from paddlenlp_ops import flash_attn_bwd -except (ImportError, ModuleNotFoundError): - from paddlenlp.utils.log import logger - - logger.warning( - "if you run ring_flash_attention.py, please ensure you install " - "the paddlenlp_ops by following the instructions " - "provided at https://github.com/PaddlePaddle/PaddleNLP/blob/develop/csrc/README.md" - ) - class RingCommunicator: def __init__(self, group, local_key, local_value): @@ -233,6 +222,18 @@ def balanced_ring_flash_attention_bwd_func( if attn_mask is not None: attn_masks_list = paddle.split(attn_mask, num_or_sections=cp_size * 2, axis=3) + try: + from paddlenlp_ops import flash_attn_bwd + except (ImportError, ModuleNotFoundError): + pass + from paddlenlp.utils.log import logger + + logger.warning( + "if you run ring_flash_attention.py, please ensure you install " + "the paddlenlp_ops by following the instructions " + "provided at https://github.com/PaddlePaddle/PaddleNLP/blob/develop/csrc/README.md" + ) + for step in range(cp_size): block_k, block_v = kv_comm_buffer.get_buffers() From 2fa00fcdd1ad4ef5c3f9267306cfc9dbe8370fc2 Mon Sep 17 00:00:00 2001 From: drownfish19 Date: Tue, 10 Sep 2024 11:42:03 +0000 Subject: [PATCH 2/2] fix --- paddlenlp/transformers/ring_flash_attention.py | 1 - 1 file changed, 1 deletion(-) diff --git a/paddlenlp/transformers/ring_flash_attention.py b/paddlenlp/transformers/ring_flash_attention.py index 792229445294..b3faf2463dff 100644 --- a/paddlenlp/transformers/ring_flash_attention.py +++ b/paddlenlp/transformers/ring_flash_attention.py @@ -225,7 +225,6 @@ def balanced_ring_flash_attention_bwd_func( try: from paddlenlp_ops import flash_attn_bwd except (ImportError, ModuleNotFoundError): - pass from paddlenlp.utils.log import logger logger.warning(