Skip to content

Commit

Permalink
config torch to avoid graph breaks caused by logger (deepspeedai#6999)
Browse files Browse the repository at this point in the history
Following changes in Pytorch trace rules , my previous PR to avoid graph
breaks caused by logger is no longer relevant. So instead I've added
this functionality to torch dynamo -
pytorch/pytorch@16ea0dd
This commit allows the user to config torch to ignore logger methods and
avoid associated graph breaks.

To enable ignore logger methods -
os.environ["DISABLE_LOGS_WHILE_COMPILING"] = "1"
To ignore logger methods except for a specific method / methods (for
example, info and isEnabledFor) -
os.environ["DISABLE_LOGS_WHILE_COMPILING"] = "1"
and os.environ["LOGGER_METHODS_TO_EXCLUDE_FROM_DISABLE"] = "info,
isEnabledFor"

Signed-off-by: ShellyNR <shelly.nahir@live.biu.ac.il>
Co-authored-by: snahir <snahir@habana.ai>
Signed-off-by: Max Kovalenko <mkovalenko@habana.ai>
  • Loading branch information
2 people authored and deepcharm committed Feb 27, 2025
1 parent 3767709 commit 495606a
Showing 1 changed file with 10 additions and 31 deletions.
41 changes: 10 additions & 31 deletions deepspeed/utils/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import logging
import sys
import os
from deepspeed.runtime.compiler import is_compile_supported, is_compiling
import torch
from deepspeed.utils.torch import required_torch_version

log_levels = {
"debug": logging.DEBUG,
Expand All @@ -20,31 +21,6 @@

class LoggerFactory:

def create_warning_filter(logger):
warn = False

def warn_once(record):
nonlocal warn
if is_compile_supported() and is_compiling() and not warn:
warn = True
logger.warning("To avoid graph breaks caused by logger in compile-mode, it is recommended to"
" disable logging by setting env var DISABLE_LOGS_WHILE_COMPILING=1")
return True

return warn_once

@staticmethod
def logging_decorator(func):

@functools.wraps(func)
def wrapper(*args, **kwargs):
if is_compiling():
return
else:
return func(*args, **kwargs)

return wrapper

@staticmethod
def create_logger(name=None, level=logging.INFO):
"""create a logger
Expand All @@ -70,12 +46,15 @@ def create_logger(name=None, level=logging.INFO):
ch.setLevel(level)
ch.setFormatter(formatter)
logger_.addHandler(ch)
if os.getenv("DISABLE_LOGS_WHILE_COMPILING", "0") == "1":
for method in ['info', 'debug', 'error', 'warning', 'critical', 'exception']:
if required_torch_version(min_version=2.6) and os.getenv("DISABLE_LOGS_WHILE_COMPILING", "0") == "1":
excluded_set = {
item.strip()
for item in os.getenv("LOGGER_METHODS_TO_EXCLUDE_FROM_DISABLE", "").split(",")
}
ignore_set = {'info', 'debug', 'error', 'warning', 'critical', 'exception', 'isEnabledFor'} - excluded_set
for method in ignore_set:
original_logger = getattr(logger_, method)
setattr(logger_, method, LoggerFactory.logging_decorator(original_logger))
else:
logger_.addFilter(LoggerFactory.create_warning_filter(logger_))
torch._dynamo.config.ignore_logger_methods.add(original_logger)
return logger_


Expand Down

0 comments on commit 495606a

Please sign in to comment.