Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SOT] use better Singleton #64340

Merged
merged 3 commits into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 4 additions & 8 deletions python/paddle/jit/sot/infer_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,7 @@ def __hash__(self):
return hash((tuple(self.shape), self.dtype, self.stop_gradient))


@Singleton
class VariableCreator:
class VariableCreator(metaclass=Singleton):
"""
We use the static graph Variable to infer the meta information of Tensor.
This singleton class is used to create Variable for infer meta.
Expand Down Expand Up @@ -305,8 +304,7 @@ def ast_infer_meta(static_function, *args, **kwargs):
return out


@Singleton
class SpecialInferMeta:
class SpecialInferMeta(metaclass=Singleton):
"""
There are some functions that cannot be inferred directly through static graph,
and need to be implemented manually. This class is used to implement infer meta
Expand Down Expand Up @@ -340,8 +338,7 @@ def infermeta_grad(
return inputs


@Singleton
class InferMetaCache(Cache):
class InferMetaCache(Cache, metaclass=Singleton):
def key_fn(
self, func, *args, **kwargs
): # args & kwargs have transformed to MetaInfo
Expand All @@ -362,8 +359,7 @@ def value_fn(self, func, *args, **kwargs):
return infer_meta(func, *args, **kwargs)


@Singleton
class LayerInferMetaCache(Cache):
class LayerInferMetaCache(Cache, metaclass=Singleton):
def key_fn(self, layer, *args, **kwargs):
params = [
MetaInfo.from_tensor(x)
Expand Down
3 changes: 1 addition & 2 deletions python/paddle/jit/sot/opcode_translator/breakpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@ def __hash__(self):
return hash((self.file, self.line, self.co_name, self.offset))


@Singleton
class BreakpointManager:
class BreakpointManager(metaclass=Singleton):
def __init__(self):
self.breakpoints = set()
self.executors = OpcodeExecutorBase.call_stack
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@
dummy_guard.lambda_expr = "lambda frame: True"


@Singleton
class OpcodeExecutorCache:
class OpcodeExecutorCache(metaclass=Singleton):
"""
A singleton class that implements a cache for translated instructions.
This cache is used to store previously translated instructions along with their corresponding guard functions.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
from ...utils import Singleton


@Singleton
class StackAnalyser:
class StackAnalyser(metaclass=Singleton):
def stack_effect(self, instr):
if "BINARY" in instr.opname or "INPLACE" in instr.opname:
return 2, 1
Expand Down
3 changes: 1 addition & 2 deletions python/paddle/jit/sot/symbolic/compile_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,7 @@ def __call__(self, *args, **kwargs):
return outputs


@Singleton
class CompileSIRCache(Cache):
class CompileSIRCache(Cache, metaclass=Singleton):
"""
Cache the compiled function of SIR
"""
Expand Down
6 changes: 2 additions & 4 deletions python/paddle/jit/sot/symbolic/statement_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,8 +283,7 @@ def __repr__(self):
return self.__str__()


@Singleton
class StatementIRFactory:
class StatementIRFactory(metaclass=Singleton):
"""
It is used to create a StatementIR.
"""
Expand Down Expand Up @@ -320,8 +319,7 @@ def clear(self):
del self.cache[key]


@Singleton
class SIRRuntimeCache:
class SIRRuntimeCache(metaclass=Singleton):
"""
It is used to cache the runtime information of the StatementIR.
"""
Expand Down
3 changes: 1 addition & 2 deletions python/paddle/jit/sot/utils/call_ast_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,7 @@ def _is_wrapped(f):
return func


@Singleton
class StaticFunctionManager:
class StaticFunctionManager(metaclass=Singleton):
def __init__(self):
self.code_map = {}

Expand Down
28 changes: 11 additions & 17 deletions python/paddle/jit/sot/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from collections import OrderedDict
from contextlib import contextmanager
from enum import Enum
from typing import Any, Generic, TypeVar
from typing import Any, TypeVar
from weakref import WeakValueDictionary

import numpy as np
Expand All @@ -48,15 +48,13 @@
ConstTypes = (int, float, str, bool, type(None))


class Singleton(Generic[T]):
def __init__(self, cls: type[T]):
self._cls = cls
self._instance = {}
class Singleton(type):
_instances: dict[Any, Any] = {}

def __call__(self) -> T:
if self._cls not in self._instance:
self._instance[self._cls] = self._cls()
return self._instance[self._cls]
def __call__(cls, *args: Any, **kwargs: Any):
if cls not in cls._instances:
cls._instances[cls] = super().__call__(*args, **kwargs)
return cls._instances[cls]


class NameGenerator:
Expand Down Expand Up @@ -107,8 +105,7 @@ def current_tmp_name_records():
return _tmp_name_records


@Singleton
class ResumeFnNameFactory:
class ResumeFnNameFactory(metaclass=Singleton):
def __init__(self) -> None:
self.gen = NameGenerator('resume_')

Expand Down Expand Up @@ -316,8 +313,7 @@ def get_unbound_method(obj, name):
return getattr(obj.__class__, name)


@Singleton
class GraphLogger:
class GraphLogger(metaclass=Singleton):
graph_num: int
op_num: int
graphs: list[Program]
Expand Down Expand Up @@ -376,8 +372,7 @@ def print_info(self):
print(self)


@Singleton
class SotUndefinedVar:
class SotUndefinedVar(metaclass=Singleton):
pass


Expand Down Expand Up @@ -461,8 +456,7 @@ def need_dynamic_info(self):
return len(self.dyn_time_costs) < self.REQUIRED_DYN_INFOS


@Singleton
class StepInfoManager:
class StepInfoManager(metaclass=Singleton):
def __init__(self):
self.step_record = {}
self.current_code = None
Expand Down