Skip to content

Commit

Permalink
add dist op costs (#44701)
Browse files Browse the repository at this point in the history
  • Loading branch information
Caozhou1995 authored Jul 29, 2022
1 parent fecbc95 commit ec1e0d5
Show file tree
Hide file tree
Showing 9 changed files with 902 additions and 22 deletions.
30 changes: 25 additions & 5 deletions python/paddle/distributed/auto_parallel/cost/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,40 @@
# See the License for the specific language governing permissions and
# limitations under the License

from .base_cost import _g_op_cost_factory
from .base_cost import Cost
from .base_cost import CommContext
from .base_cost import _g_op_cost_factory
from .base_cost import build_comm_desc
from .base_cost import build_comp_desc_from_op
from .base_cost import build_comp_desc_from_dist_op
from .base_cost import build_dp_costs
from .base_cost import build_comp_desc_str_for_predict
from .base_cost import build_comp_desc_from_dist_op
from .base_cost import build_comm_desc_from_dist_op
from .base_cost import build_comm_costs_from_descs
from .base_cost import build_comp_costs_from_descs
from .tensor_cost import TensorCost
from .estimate_cost import CostEstimator

from .comp_op_cost import EmbeddingOpCost
from .comp_op_cost import EmbeddingGradOpCost
from .comp_op_cost import ConcatOpCost
from .comp_op_cost import MatmulOpCost
from .comp_op_cost import MatmulGradOpCost
from .comp_op_cost import MatmulV2OpCost
from .comp_op_cost import MatmulV2GradOpCost
from .comp_op_cost import MulOpCost
from .comp_op_cost import MulGradOpCost
from .comp_op_cost import Reshape2OpCost
from .comp_op_cost import Reshape2GradOpCost
from .comp_op_cost import SliceOpCost
from .comp_op_cost import SplitOpCost
from .comp_op_cost import SoftmaxOpCost
from .comp_op_cost import SoftmaxGradOpCost
from .comp_op_cost import Transpose2OpCost
from .comp_op_cost import Transpose2GradOpCost
from .comp_op_cost import FillConstantBatchSizeLikeOpCost

from .tensor_cost import TensorCost

from .estimate_cost import CostEstimator

from .comm_op_cost import SendOpCost
from .comm_op_cost import RecvOpCost
from .comm_op_cost import IdentityOpCost
Expand Down
19 changes: 19 additions & 0 deletions python/paddle/distributed/auto_parallel/cost/comp_op_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,25 @@
from .base_cost import Cost, register_op_cost, CompOpCost, _g_op_cost_factory


@register_op_cost
class AdamOpCost(CompOpCost):
OP_TYPE = "adam"

def __init__(self, op=None, op_desc=None, cluster=None):
super(AdamOpCost, self).__init__(op=op,
op_desc=op_desc,
cluster=cluster)

# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0

def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0


@register_op_cost
class AssignOpCost(CompOpCost):
OP_TYPE = "assign"
Expand Down
6 changes: 4 additions & 2 deletions python/paddle/distributed/auto_parallel/dist_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,8 +831,10 @@ def validate_dist_attr_for_program(self):
if (dist_tensor
is not None) and (not dist_tensor.validate_dist_attr()):
assert False, "Tensor {} (id: {}, original_id: {}) has a wrong distributed attributes {}.".format(
dist_tensor.serial_tensor.name, dist_tensor.desc.id(),
dist_tensor.desc.original_id(), dist_tensor.dist_attr)
dist_tensor.serial_tensor.name,
dist_tensor.serial_tensor.desc.id(),
dist_tensor.serial_tensor.desc.original_id(),
dist_tensor.dist_attr)
for op in block.ops:
dist_op = self.get_dist_op_for_program(op)
assert dist_op is not None, \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@
from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY
from ..process_group import new_process_group
from ..utils import _get_comm_group, _get_idx_in_axis, _get_corresponding_rank
from ..cost import build_comp_desc_from_dist_op, build_comm_desc_from_dist_op
from ..cost import build_comm_costs_from_descs, build_comp_costs_from_descs, build_dp_costs
from ..cost import EmbeddingOpCost, EmbeddingGradOpCost, AllreduceSumOpCost, IdentityOpCost


class DistributedEmbedding(DistributedOperatorImplContainer):
Expand All @@ -53,6 +56,95 @@ def __init__(self, name):
self._forward_implemented = True
self._backward_implemented = True

def calc_cost(self, op_role, dist_op, ctx, cluster):
"""Calculate the cost by the op role."""
cost = None
if int(op_role) == int(OpRole.Forward):
cost = self.calc_fwd_cost(dist_op, ctx, cluster)
elif int(op_role) == int(OpRole.Backward):
cost = self.calc_bwd_cost(dist_op, ctx, cluster)
assert cost is not None
return cost

def calc_fwd_cost(self, dist_op, ctx, cluster):
# calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op,
dist_context=ctx)
processes = dist_op.dist_attr.process_mesh.processes
# embedding need start_index
cost_mapping = build_comp_costs_from_descs(EmbeddingOpCost, ctx,
processes, desc_mapping,
cluster)

serial_op = dist_op.serial_op
parallel_axis = dist_op.dist_attr.get_input_dims_mapping(
serial_op.input("W")[0])[0]
attrs = {"use_calc_stream": True, "use_model_parallel": True}
var_names = serial_op.output("Out")
c_allreduce_sum_desc_mapping = build_comm_desc_from_dist_op(
"c_allreduce_sum",
dist_op,
ctx,
var_names,
attrs=attrs,
parallel_axis=parallel_axis)

comm_op_cost_list = build_comm_costs_from_descs(
AllreduceSumOpCost, ctx, processes, c_allreduce_sum_desc_mapping,
cluster)

res_cost = [cost_mapping, comm_op_cost_list]

return res_cost

def calc_bwd_cost(self, dist_op, ctx, cluster):
# by now the backward function only insert the gradient allreduce for dist op itself
res = []
backward_op = dist_op.serial_op
main_block = backward_op.block
dist_attr = dist_op.dist_attr

embedding_row_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("W")[0])[0]
parallel_axis = embedding_row_dim_mapping
attrs = {"use_calc_stream": True, "use_model_parallel": True}
var_names = [backward_op.input("Out@GRAD")[0]]
c_identity_desc_mapping = build_comm_desc_from_dist_op(
"c_identity",
dist_op,
ctx,
var_names,
attrs=attrs,
parallel_axis=parallel_axis)

process_mesh = dist_attr.process_mesh
processes = process_mesh.processes
comm_op_cost_list = build_comm_costs_from_descs(
IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster)
res.append(comm_op_cost_list)

# calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op,
dist_context=ctx)
cost_mapping = build_comp_costs_from_descs(EmbeddingGradOpCost, ctx,
processes, desc_mapping,
cluster)
res.append(cost_mapping)

# need gradient allreduce
var_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("Ids")[0])
mesh_shape = process_mesh.topology
batch_size_axis = var_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
parallel_axis = batch_size_axis
attrs = {"use_calc_stream": True}
var_names = [backward_op.output('W@GRAD')[0]]
build_dp_costs(res, dist_op, ctx, var_names, attrs, parallel_axis,
cluster)

return res

def is_input_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
Expand Down
Loading

0 comments on commit ec1e0d5

Please sign in to comment.