Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoConfig]Add multi prune #60954

Merged
merged 8 commits into from
Feb 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 134 additions & 29 deletions python/paddle/distributed/auto_tuner/prune.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,17 @@ def log_pruned_info(cur_cfg, pruned_reason):
)


def same_cfgs_beside(attr, cur_cfg, history_cfgs=[]):
def same_cfgs_beside(attrs, cur_cfg, history_cfgs=[]):
"""
Compare the current configuration with the history configuration,
and obtain the same configurations as the current configuration except for the given attr.
"""
results = []
same = True

for cfg in history_cfgs:
for key in cur_cfg:
if key == attr:
if key in attrs:
continue
if key not in cfg or (
cfg[key] != cur_cfg[key]
Expand Down Expand Up @@ -189,6 +190,38 @@ def prune_by_pp(tuner_cfg, cur_cfg, history_cfgs=[]):
return False


@register_prune_history
def prune_by_mp_pp_history(tuner_cfg, cur_cfg, history_cfgs, pruned_cfgs):
mp_degree = cur_cfg.get("mp_degree", None)
pp_degree = cur_cfg.get("pp_degree", None)
use_recompute = cur_cfg.get("recompute", None)

if mp_degree is None or pp_degree is None or use_recompute is None:
return False

history_cfgs.extend(pruned_cfgs)
cfgs = same_cfgs_beside(["mp_degree", "pp_degree"], cur_cfg, history_cfgs)
if cur_cfg.get("sharding_degree") == 1:
cfgs = same_cfgs_beside(
["mp_degree", "pp_degree", "sharding_satge"], cur_cfg, history_cfgs
)

if cfgs:
for cfg in cfgs:
if (
not use_recompute
and cfg["mp_degree"] * cfg["pp_degree"] == mp_degree * pp_degree
and cfg["mp_degree"] > mp_degree
and cfg.get("max_mem_usage") == "OOM"
):
pruned_reason = f"mp_degree {mp_degree}, pp_degree {pp_degree} may cause oom because {cfg['mp_degree']}, {cfg['pp_degree']} already oom."
log_pruned_info(cur_cfg, pruned_reason)
cur_cfg["max_mem_usage"] = "OOM"
return True

return False


@register_prune
def prune_by_vpp(tuner_cfg, cur_cfg, history_cfgs=[]):
"""
Expand Down Expand Up @@ -238,11 +271,19 @@ def prune_by_vpp(tuner_cfg, cur_cfg, history_cfgs=[]):


@register_prune_history
def prune_by_vpp_history(tuner_cfg, cur_cfg, history_cfgs=[]):
def prune_by_vpp_history(tuner_cfg, cur_cfg, history_cfgs=[], pruned_cfgs=[]):
vpp_degree = cur_cfg.get("vpp_degree", None)
if vpp_degree is None:
return False

history_cfgs.extend(pruned_cfgs)

cfgs = same_cfgs_beside("vpp_degree", cur_cfg, history_cfgs)
if cur_cfg.get("sharding_degree") == 1:
cfgs = same_cfgs_beside(
["vpp_degree", "sharding_satge"], cur_cfg, history_cfgs
)

if cfgs:
for cfg in cfgs:
# memory prune
Expand All @@ -252,7 +293,9 @@ def prune_by_vpp_history(tuner_cfg, cur_cfg, history_cfgs=[]):
):
pruned_reason = f"vpp_degree {vpp_degree} may cause oom because { cfg['vpp_degree']} already oom."
log_pruned_info(cur_cfg, pruned_reason)
cur_cfg["max_mem_usage"] = "OOM"
return True

return False


Expand Down Expand Up @@ -308,11 +351,23 @@ def prune_by_mbs(tuner_cfg, cur_cfg, history_cfgs=[]):


@register_prune_history
def prune_by_mbs_history(tuner_cfg, cur_cfg, history_cfgs=[]):
def prune_by_mbs_history(tuner_cfg, cur_cfg, history_cfgs=[], pruned_cfgs=[]):
micro_batch_size = cur_cfg.get("micro_batch_size", None)
if micro_batch_size is None:
return False
cfgs = same_cfgs_beside("micro_batch_size", cur_cfg, history_cfgs)

history_cfgs.extend(pruned_cfgs)

cfgs = same_cfgs_beside(
["micro_batch_size", "acc_steps"], cur_cfg, history_cfgs
)
if cur_cfg.get("sharding_degree") == 1:
cfgs = same_cfgs_beside(
["micro_batch_size", "sharding_satge", "acc_steps"],
cur_cfg,
history_cfgs,
)

if cfgs:
for cfg in cfgs:
if (
Expand All @@ -321,15 +376,16 @@ def prune_by_mbs_history(tuner_cfg, cur_cfg, history_cfgs=[]):
):
pruned_reason = f"micro_batch_size {micro_batch_size} may be slower because {cfg['micro_batch_size']} has been already runnable."
log_pruned_info(cur_cfg, pruned_reason)
cur_cfg["time"] = cfg["time"]
return True

# memory prune
if (
cfg["micro_batch_size"] < micro_batch_size
and cfg.get("max_mem_usage") == "OOM"
):
pruned_reason = f"micro_batch_size {micro_batch_size} may cause oom because {cfg['micro_batch_size']} already oom."
log_pruned_info(cur_cfg, pruned_reason)
cur_cfg["max_mem_usage"] = "OOM"
return True
return False

Expand All @@ -342,6 +398,7 @@ def prune_by_sharding(tuner_cfg, cur_cfg, history_cfgs=[]):
2. Sharding stage and degree should be in the candidates of user defined.
3. If PP (pipeline-parallelism) degree is not 1, sharding stage must be 1.
4. Prune if a similar configuration with a lower sharding stage resulted in a valid run.
5. If sharding degree is 1, sharding stage is invalid.
"""
sharding_stage = cur_cfg.get("sharding_stage", None)
sharding_degree = cur_cfg.get("sharding_degree", None)
Expand Down Expand Up @@ -372,11 +429,18 @@ def prune_by_sharding(tuner_cfg, cur_cfg, history_cfgs=[]):
if pp_degree and pp_degree != 1 and sharding_stage != 1:
return True

if sharding_degree == 1:
cfgs = same_cfgs_beside("sharding_stage", cur_cfg, history_cfgs)
if cfgs:
return True

return False


@register_prune_history
def prune_by_sharding_history(tuner_cfg, cur_cfg, history_cfgs=[]):
def prune_by_sharding_history(
tuner_cfg, cur_cfg, history_cfgs=[], pruned_cfgs=[]
):
sharding_degree = cur_cfg.get("sharding_degree", None)
if sharding_degree is None:
return False
Expand All @@ -385,6 +449,8 @@ def prune_by_sharding_history(tuner_cfg, cur_cfg, history_cfgs=[]):
if sharding_stage is None:
return False

history_cfgs.extend(pruned_cfgs)

cfgs = same_cfgs_beside("sharding_stage", cur_cfg, history_cfgs)
if cfgs:
for cfg in cfgs:
Expand All @@ -394,6 +460,7 @@ def prune_by_sharding_history(tuner_cfg, cur_cfg, history_cfgs=[]):
):
pruned_reason = f"sharding_stage {sharding_stage} may be slower because {cfg['sharding_stage'] } has been already runnable."
log_pruned_info(cur_cfg, pruned_reason)
cur_cfg["time"] = cfg["time"]
return True

# memory prune
Expand All @@ -403,13 +470,9 @@ def prune_by_sharding_history(tuner_cfg, cur_cfg, history_cfgs=[]):
):
pruned_reason = f"sharding_stage {sharding_stage} may cause oom because {cfg['sharding_stage']} already oom."
log_pruned_info(cur_cfg, pruned_reason)
cur_cfg["max_mem_usage"] = "OOM"
return True

if sharding_degree == 1:
cfgs = same_cfgs_beside("sharding_stage", cur_cfg, history_cfgs)
if cfgs:
return True

return False


Expand All @@ -421,9 +484,12 @@ def prune_by_recompute(tuner_cfg, cur_cfg, history_cfgs=[]):
2. Usage of recompute and recompute granularity should be in the candidates of user defined.
3. If recompute is not used, but recompute granularity is set, return True for pruning.
4. Prune if a similar configuration without using recompute resulted in a valid run.
5. If recompute is false, prune redundant recompute granularity
"""
recompute_granularity = cur_cfg.get("recompute_granularity", None)
use_recompute = cur_cfg.get("use_recompute", None)
recompute_level = get_config_recompute_level(cur_cfg)

if use_recompute is None:
return False

Expand All @@ -442,41 +508,78 @@ def prune_by_recompute(tuner_cfg, cur_cfg, history_cfgs=[]):
if recompute_granularity not in recompute_granularity_candidates:
return True

if not use_recompute:
if recompute_granularity != "full":
return True

cfgs = same_cfgs_beside(
["use_recompute", "recompute_granularity"], cur_cfg, history_cfgs
)
if cfgs:
for cfg in cfgs:
if recompute_level == get_config_recompute_level(cfg):
return True

return False


@register_prune_history
def prune_by_recompute_history(tuner_cfg, cur_cfg, history_cfgs=[]):
use_recompute = cur_cfg.get("use_recompute", None)
def get_config_recompute_level(cfg):
recompute_granularity_level = {"full": 3, "full_attn": 2, "core_attn": 1}
use_recompute = cfg.get("use_recompute", None)
recompute_granularity = cfg.get("recompute_granularity", None)

if use_recompute is None:
return None

if not use_recompute:
return 0
else:
return recompute_granularity_level[recompute_granularity]


@register_prune_history
def prune_by_recompute_history(
tuner_cfg, cur_cfg, history_cfgs=[], pruned_cfgs=[]
):
recompute_level = get_config_recompute_level(cur_cfg)

if recompute_level is None:
return False
cfgs = same_cfgs_beside("use_recompute", cur_cfg, history_cfgs)

history_cfgs.extend(pruned_cfgs)

cfgs = same_cfgs_beside(
["use_recompute", "recompute_granularity"], cur_cfg, history_cfgs
)
if cur_cfg.get("sharding_degree") == 1:
cfgs = same_cfgs_beside(
["use_recompute", "recompute_granularity", "sharding_satge"],
cur_cfg,
history_cfgs,
)

if cfgs:
for cfg in cfgs:
cfg["recompute_level"] = get_config_recompute_level(cfg)

if (
not cfg["use_recompute"]
and use_recompute
cfg["recompute_level"] < recompute_level
and cfg.get("time", -1) > 0
):
pruned_reason = f"use_recompute {use_recompute} may be slower because {cfg['use_recompute']} has been already runnable."
pruned_reason = f"use_recompute may be slower because {cfg['use_recompute']} has been already runnable."
log_pruned_info(cur_cfg, pruned_reason)
cur_cfg["time"] = cfg["time"]
return True

if (
cfg["use_recompute"]
and not use_recompute
cfg["recompute_level"] > recompute_level
and cfg.get("max_mem_usage") == "OOM"
):
pruned_reason = f"use_recompute {use_recompute} may cause oom because {cfg['use_recompute']} already oom."
pruned_reason = f"use_recompute may cause oom because {cfg['use_recompute']} already oom."
log_pruned_info(cur_cfg, pruned_reason)
cur_cfg["max_mem_usage"] = "OOM"
return True

if not use_recompute:
cfgs = same_cfgs_beside("recompute_granularity", cur_cfg, history_cfgs)
if cfgs:
pruned_reason = f"recompute_granularity invalid because use_recompute is {use_recompute}."
log_pruned_info(cur_cfg, pruned_reason)
return True
return False


Expand Down Expand Up @@ -604,7 +707,9 @@ def prune_by_memory_estimation(tuner_cfg, cur_cfg, history_cfgs=[]):


@register_prune_history
def prune_by_sharding_overlap(tuner_cfg, cur_cfg, history_cfgs=[]):
def prune_by_sharding_overlap(
tuner_cfg, cur_cfg, history_cfgs=[], pruned_cfgs=[]
):
"""Prune by sharding overlap for single dp estimation"""
if "sharding_overlap" in cur_cfg:
result = same_cfgs_beside_sharding_overlap(
Expand Down
10 changes: 7 additions & 3 deletions python/paddle/distributed/auto_tuner/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,15 @@
class SearchAlgo(ABC):
def __init__(self, tuner_cfg):
self.tuner_cfg = tuner_cfg
self.pruned_cfgs = []

@abstractmethod
def search_once(self, history_cfgs):
pass

def prune(self, tuner_cfg, cur_cfg, history_cfgs):
def prune(self, tuner_cfg, cur_cfg, history_cfgs, pruned_cfgs):
for func in _PRUNE_HISTORY_FUNC:
result = func(tuner_cfg, cur_cfg, history_cfgs)
result = func(tuner_cfg, cur_cfg, history_cfgs, pruned_cfgs)
if result:
return True
return False
Expand All @@ -57,7 +58,10 @@ def search_once(self, history_cfgs):
if self.idx < len(self.all_tasks):
new_cfg = self.all_tasks[self.idx]
self.idx += 1
stop = not self.prune(self.tuner_cfg, new_cfg, history_cfgs)
stop = not self.prune(
self.tuner_cfg, new_cfg, history_cfgs, self.pruned_cfgs
)
self.pruned_cfgs.append(new_cfg)
else:
return None
return new_cfg
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/distributed/auto_tuner/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ def search_all(tuner_cfg):
for cur_cfg in new_all_cfgs:
pruned = False
for func in _PRUNE_FUNC:
result = func(tuner_cfg, cur_cfg, [])
result = func(tuner_cfg, cur_cfg, pruned_all_cfgs)
if result:
pruned = True
break
Expand Down