Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Auto Parallel】add dist op costs in gpt #44701

Merged
merged 1 commit into from
Jul 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 25 additions & 5 deletions python/paddle/distributed/auto_parallel/cost/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,40 @@
# See the License for the specific language governing permissions and
# limitations under the License

from .base_cost import _g_op_cost_factory
from .base_cost import Cost
from .base_cost import CommContext
from .base_cost import _g_op_cost_factory
from .base_cost import build_comm_desc
from .base_cost import build_comp_desc_from_op
from .base_cost import build_comp_desc_from_dist_op
from .base_cost import build_dp_costs
from .base_cost import build_comp_desc_str_for_predict
from .base_cost import build_comp_desc_from_dist_op
from .base_cost import build_comm_desc_from_dist_op
from .base_cost import build_comm_costs_from_descs
from .base_cost import build_comp_costs_from_descs
from .tensor_cost import TensorCost
from .estimate_cost import CostEstimator

from .comp_op_cost import EmbeddingOpCost
from .comp_op_cost import EmbeddingGradOpCost
from .comp_op_cost import ConcatOpCost
from .comp_op_cost import MatmulOpCost
from .comp_op_cost import MatmulGradOpCost
from .comp_op_cost import MatmulV2OpCost
from .comp_op_cost import MatmulV2GradOpCost
from .comp_op_cost import MulOpCost
from .comp_op_cost import MulGradOpCost
from .comp_op_cost import Reshape2OpCost
from .comp_op_cost import Reshape2GradOpCost
from .comp_op_cost import SliceOpCost
from .comp_op_cost import SplitOpCost
from .comp_op_cost import SoftmaxOpCost
from .comp_op_cost import SoftmaxGradOpCost
from .comp_op_cost import Transpose2OpCost
from .comp_op_cost import Transpose2GradOpCost
from .comp_op_cost import FillConstantBatchSizeLikeOpCost

from .tensor_cost import TensorCost

from .estimate_cost import CostEstimator

from .comm_op_cost import SendOpCost
from .comm_op_cost import RecvOpCost
from .comm_op_cost import IdentityOpCost
Expand Down
19 changes: 19 additions & 0 deletions python/paddle/distributed/auto_parallel/cost/comp_op_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,25 @@
from .base_cost import Cost, register_op_cost, CompOpCost, _g_op_cost_factory


@register_op_cost
class AdamOpCost(CompOpCost):
OP_TYPE = "adam"

def __init__(self, op=None, op_desc=None, cluster=None):
super(AdamOpCost, self).__init__(op=op,
op_desc=op_desc,
cluster=cluster)

# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0

def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0


@register_op_cost
class AssignOpCost(CompOpCost):
OP_TYPE = "assign"
Expand Down
6 changes: 4 additions & 2 deletions python/paddle/distributed/auto_parallel/dist_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,8 +831,10 @@ def validate_dist_attr_for_program(self):
if (dist_tensor
is not None) and (not dist_tensor.validate_dist_attr()):
assert False, "Tensor {} (id: {}, original_id: {}) has a wrong distributed attributes {}.".format(
dist_tensor.serial_tensor.name, dist_tensor.desc.id(),
dist_tensor.desc.original_id(), dist_tensor.dist_attr)
dist_tensor.serial_tensor.name,
dist_tensor.serial_tensor.desc.id(),
dist_tensor.serial_tensor.desc.original_id(),
dist_tensor.dist_attr)
for op in block.ops:
dist_op = self.get_dist_op_for_program(op)
assert dist_op is not None, \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@
from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY
from ..process_group import new_process_group
from ..utils import _get_comm_group, _get_idx_in_axis, _get_corresponding_rank
from ..cost import build_comp_desc_from_dist_op, build_comm_desc_from_dist_op
from ..cost import build_comm_costs_from_descs, build_comp_costs_from_descs, build_dp_costs
from ..cost import EmbeddingOpCost, EmbeddingGradOpCost, AllreduceSumOpCost, IdentityOpCost


class DistributedEmbedding(DistributedOperatorImplContainer):
Expand All @@ -53,6 +56,95 @@ def __init__(self, name):
self._forward_implemented = True
self._backward_implemented = True

def calc_cost(self, op_role, dist_op, ctx, cluster):
"""Calculate the cost by the op role."""
cost = None
if int(op_role) == int(OpRole.Forward):
cost = self.calc_fwd_cost(dist_op, ctx, cluster)
elif int(op_role) == int(OpRole.Backward):
cost = self.calc_bwd_cost(dist_op, ctx, cluster)
assert cost is not None
return cost

def calc_fwd_cost(self, dist_op, ctx, cluster):
# calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op,
dist_context=ctx)
processes = dist_op.dist_attr.process_mesh.processes
# embedding need start_index
cost_mapping = build_comp_costs_from_descs(EmbeddingOpCost, ctx,
processes, desc_mapping,
cluster)

serial_op = dist_op.serial_op
parallel_axis = dist_op.dist_attr.get_input_dims_mapping(
serial_op.input("W")[0])[0]
attrs = {"use_calc_stream": True, "use_model_parallel": True}
var_names = serial_op.output("Out")
c_allreduce_sum_desc_mapping = build_comm_desc_from_dist_op(
"c_allreduce_sum",
dist_op,
ctx,
var_names,
attrs=attrs,
parallel_axis=parallel_axis)

comm_op_cost_list = build_comm_costs_from_descs(
AllreduceSumOpCost, ctx, processes, c_allreduce_sum_desc_mapping,
cluster)

res_cost = [cost_mapping, comm_op_cost_list]

return res_cost

def calc_bwd_cost(self, dist_op, ctx, cluster):
# by now the backward function only insert the gradient allreduce for dist op itself
res = []
backward_op = dist_op.serial_op
main_block = backward_op.block
dist_attr = dist_op.dist_attr

embedding_row_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("W")[0])[0]
parallel_axis = embedding_row_dim_mapping
attrs = {"use_calc_stream": True, "use_model_parallel": True}
var_names = [backward_op.input("Out@GRAD")[0]]
c_identity_desc_mapping = build_comm_desc_from_dist_op(
"c_identity",
dist_op,
ctx,
var_names,
attrs=attrs,
parallel_axis=parallel_axis)

process_mesh = dist_attr.process_mesh
processes = process_mesh.processes
comm_op_cost_list = build_comm_costs_from_descs(
IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster)
res.append(comm_op_cost_list)

# calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op,
dist_context=ctx)
cost_mapping = build_comp_costs_from_descs(EmbeddingGradOpCost, ctx,
processes, desc_mapping,
cluster)
res.append(cost_mapping)

# need gradient allreduce
var_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("Ids")[0])
mesh_shape = process_mesh.topology
batch_size_axis = var_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
parallel_axis = batch_size_axis
attrs = {"use_calc_stream": True}
var_names = [backward_op.output('W@GRAD')[0]]
build_dp_costs(res, dist_op, ctx, var_names, attrs, parallel_axis,
cluster)

return res

def is_input_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
Expand Down
Loading