Skip to content

Commit 4362daa

Browse files
author
Kevin Cortacero
committed
refactoring
1 parent 85f3362 commit 4362daa

File tree

4 files changed

+781
-225
lines changed

4 files changed

+781
-225
lines changed

src/kartezio/core/components.py

+68-56
Original file line numberDiff line numberDiff line change
@@ -138,58 +138,59 @@ def name_of(component_class: type) -> str:
138138
def display():
139139
pprint(Components._registry)
140140

141+
def add_as(self, fundamental: type, replace: type = None):
142+
"""
143+
Register a component to the Components registry.
141144
142-
def register(fundamental: type, replace: type = None):
143-
"""
144-
Register a component to the Components registry.
145+
Args:
146+
fundamental (type): The fundamental type of the component.
147+
replace (type): If not None, replace an existing component with the type.
145148
146-
Args:
147-
fundamental (type): The fundamental type of the component.
148-
replace (type): If not None, replace an existing component with the type.
149+
Returns:
150+
Callable: A decorator for registering the component.
151+
"""
152+
fundamental_name = fundamental.__name__
153+
154+
def inner(item_cls):
155+
name = item_cls.__name__
156+
if Components._contains(fundamental_name, name):
157+
if not replace:
158+
raise KeyError(
159+
f"""Error registering {fundamental_name} called '{name}'.
160+
Here is the list of all registered {fundamental_name} components:
161+
\n{Components._registry[fundamental_name].keys()}.
162+
\n > Replace it using 'replace=True' in @register, or use another name.
163+
"""
164+
)
165+
if replace:
166+
replace_name = replace.__name__
167+
if Components._contains(fundamental_name, replace_name):
168+
print(
169+
f"Component '{fundamental_name}/{replace_name}' will be replaced by '{name}'"
170+
)
171+
Components.add(fundamental_name, replace_name, item_cls)
172+
else:
173+
Components.add(fundamental_name, name, item_cls)
174+
return item_cls
149175

150-
Returns:
151-
Callable: A decorator for registering the component.
152-
"""
153-
fundamental_name = fundamental.__name__
154-
155-
def inner(item_cls):
156-
name = item_cls.__name__
157-
if Components._contains(fundamental_name, name):
158-
if not replace:
159-
raise KeyError(
160-
f"""Error registering {fundamental_name} called '{name}'.
161-
Here is the list of all registered {fundamental_name} components:
162-
\n{Components._registry[fundamental_name].keys()}.
163-
\n > Replace it using 'replace=True' in @register, or use another name.
164-
"""
165-
)
166-
if replace:
167-
replace_name = replace.__name__
168-
if Components._contains(fundamental_name, replace_name):
169-
print(
170-
f"Component '{fundamental_name}/{replace_name}' will be replaced by '{name}'"
171-
)
172-
Components.add(fundamental_name, replace_name, item_cls)
173-
else:
174-
Components.add(fundamental_name, name, item_cls)
175-
return item_cls
176+
return inner
176177

177-
return inner
178+
def declare(self):
179+
"""
180+
Register a fundamental component to the Components registry.
178181
182+
Returns:
183+
Callable: A decorator for registering the fundamental component.
184+
"""
179185

180-
def component():
181-
"""
182-
Register a fundamental component to the Components registry.
186+
def inner(item_cls):
187+
Components.add_component(item_cls.__name__)
188+
return item_cls
183189

184-
Returns:
185-
Callable: A decorator for registering the fundamental component.
186-
"""
190+
return inner
187191

188-
def inner(item_cls):
189-
Components.add_component(item_cls.__name__)
190-
return item_cls
191192

192-
return inner
193+
registry = Components()
193194

194195

195196
def load_component(
@@ -223,7 +224,7 @@ def dump_component(component: KartezioComponent) -> Dict:
223224
return base_dict
224225

225226

226-
@component()
227+
@registry.declare()
227228
class Node(KartezioComponent, ABC):
228229
"""
229230
Abstract base class for a Node in the CGP framework.
@@ -232,7 +233,7 @@ class Node(KartezioComponent, ABC):
232233
pass
233234

234235

235-
@component()
236+
@registry.declare()
236237
class Preprocessing(Node, ABC):
237238
"""
238239
Preprocessing node, called before training loop.
@@ -265,7 +266,7 @@ def then(self, preprocessing: "Preprocessing"):
265266
return self
266267

267268

268-
@component()
269+
@registry.declare()
269270
class Primitive(Node, ABC):
270271
"""
271272
Primitive function called inside the CGP Graph.
@@ -286,7 +287,7 @@ def __to_dict__(self) -> Dict:
286287
return {"name": self.name}
287288

288289

289-
@component()
290+
@registry.declare()
290291
class Genotype(KartezioComponent):
291292
"""
292293
Represents the genotype for Cartesian Genetic Programming (CGP).
@@ -392,7 +393,7 @@ def clone(self) -> "Genotype":
392393
return copy.deepcopy(self)
393394

394395

395-
@component()
396+
@registry.declare()
396397
class Reducer(Node, ABC):
397398
def batch(self, x: List):
398399
y = []
@@ -405,7 +406,7 @@ def reduce(self, x):
405406
pass
406407

407408

408-
@component()
409+
@registry.declare()
409410
class Endpoint(Node, ABC):
410411
"""
411412
Represents the final node in a CGP graph, responsible for producing the final outputs.
@@ -425,7 +426,6 @@ def __init__(self, inputs: List[KType]):
425426

426427
@classmethod
427428
def __from_dict__(cls, dict_infos: Dict) -> "Endpoint":
428-
from kartezio.core.endpoints import Endpoint
429429
"""
430430
Create an Endpoint instance from a dictionary representation.
431431
@@ -441,8 +441,14 @@ def __from_dict__(cls, dict_infos: Dict) -> "Endpoint":
441441
**dict_infos["args"],
442442
)
443443

444+
@classmethod
445+
def from_config(cls, config):
446+
return registry.instantiate(
447+
cls.__name__, config["name"], **config["args"]
448+
)
449+
444450

445-
@component()
451+
@registry.declare()
446452
class Fitness(KartezioComponent, ABC):
447453
def __init__(self, reduction="mean"):
448454
super().__init__()
@@ -482,14 +488,15 @@ def evaluate(self, y_true, y_pred):
482488
@classmethod
483489
def __from_dict__(cls, dict_infos: Dict) -> "Fitness":
484490
from kartezio.core.fitness import Fitness
491+
485492
return Components.instantiate(
486493
"Fitness",
487494
dict_infos["name"],
488495
**dict_infos["args"],
489496
)
490497

491498

492-
@component()
499+
@registry.declare()
493500
class Library(KartezioComponent):
494501
def __init__(self, rtype):
495502
super().__init__()
@@ -596,7 +603,7 @@ def size(self):
596603
return len(self._primitives)
597604

598605

599-
@component()
606+
@registry.declare()
600607
class Mutation(KartezioComponent, ABC):
601608
def __init__(self, adapter):
602609
super().__init__()
@@ -685,15 +692,15 @@ def __to_dict__(self) -> Dict:
685692
return {}
686693

687694

688-
@component()
695+
@registry.declare()
689696
class Initialization(KartezioComponent, ABC):
690697
""" """
691698

692699
def __init__(self):
693700
super().__init__()
694701

695702

696-
@register(Initialization)
703+
@registry.add_as(Initialization)
697704
class CopyGenotype(Initialization):
698705
@classmethod
699706
def __from_dict__(cls, dict_infos: Dict) -> "CopyGenotype":
@@ -707,7 +714,7 @@ def mutate(self, genotype):
707714
return self.genotype.clone()
708715

709716

710-
@register(Initialization)
717+
@registry.add_as(Initialization)
711718
class RandomInit(Initialization, Mutation):
712719
"""
713720
Can be used to initialize genome (genome) randomly
@@ -736,3 +743,8 @@ def mutate(self, genotype: Genotype):
736743
def random(self):
737744
genotype = self.adapter.new_genotype()
738745
return self.mutate(genotype)
746+
747+
748+
if __name__ == "__main__":
749+
registry.display()
750+
print("Done!")

0 commit comments

Comments
 (0)