Skip to content

Commit

Permalink
fix sp import
Browse files Browse the repository at this point in the history
  • Loading branch information
DesmonDay committed Apr 12, 2024
1 parent d1c43cd commit cef772f
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,16 @@
MinLengthLogitsProcessor,
RepetitionPenaltyLogitsProcessor,
)
from paddle.distributed.fleet.utils.sequence_parallel_utils import (
ColumnSequenceParallelLinear,
GatherOp,
RowSequenceParallelLinear,
ScatterOp,
mark_as_sequence_parallel_parameter,
)
try:
from paddle.distributed.fleet.utils.sequence_parallel_utils import (
ColumnSequenceParallelLinear,
GatherOp,
RowSequenceParallelLinear,
ScatterOp,
mark_as_sequence_parallel_parameter,
)
except:
pass

from paddlenlp.transformers.segment_parallel_utils import ReshardLayer

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,12 @@
from ppfleetx.core.module.basic_module import BasicModule
from ppfleetx.data.tokenizers import GPTTokenizer
from ppfleetx.distributed.apis import env
from paddle.distributed.fleet.utils.sequence_parallel_utils import (
register_sequence_parallel_allreduce_hooks,
)
try:
from paddle.distributed.fleet.utils.sequence_parallel_utils import (
register_sequence_parallel_allreduce_hooks,
)
except:
pass
from ppfleetx.utils.log import logger

# TODO(haohongxiang): to solve the problem of cross-reference
Expand Down
12 changes: 8 additions & 4 deletions paddlenlp/transformers/gpt/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,14 @@
from paddle.distributed import fleet
from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker
from paddle.distributed.fleet.utils import recompute
from paddle.distributed.fleet.utils.sequence_parallel_utils import (
ScatterOp,
mark_as_sequence_parallel_parameter,
)

try:
from paddle.distributed.fleet.utils.sequence_parallel_utils import (
ScatterOp,
mark_as_sequence_parallel_parameter,
)
except:
pass

from ...utils.converter import StateDictNameMapping
from .. import PretrainedModel, register_base_model
Expand Down
12 changes: 8 additions & 4 deletions paddlenlp/transformers/mc2_seqence_parallel_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,14 @@

from paddle import distributed as dist
from paddle.autograd import PyLayer
from paddle.distributed.fleet.utils.sequence_parallel_utils import (
ColumnSequenceParallelLinear,
RowSequenceParallelLinear,
)

try:
from paddle.distributed.fleet.utils.sequence_parallel_utils import (
ColumnSequenceParallelLinear,
RowSequenceParallelLinear,
)
except:
pass

__all_gather_recomputation__ = False
if int(os.getenv("MC2_Recompute", 0)):
Expand Down

0 comments on commit cef772f

Please sign in to comment.