@@ -138,58 +138,59 @@ def name_of(component_class: type) -> str:
138
138
def display ():
139
139
pprint (Components ._registry )
140
140
141
+ def add_as (self , fundamental : type , replace : type = None ):
142
+ """
143
+ Register a component to the Components registry.
141
144
142
- def register ( fundamental : type , replace : type = None ) :
143
- """
144
- Register a component to the Components registry .
145
+ Args :
146
+ fundamental (type): The fundamental type of the component.
147
+ replace (type): If not None, replace an existing component with the type .
145
148
146
- Args:
147
- fundamental (type): The fundamental type of the component.
148
- replace (type): If not None, replace an existing component with the type.
149
+ Returns:
150
+ Callable: A decorator for registering the component.
151
+ """
152
+ fundamental_name = fundamental .__name__
153
+
154
+ def inner (item_cls ):
155
+ name = item_cls .__name__
156
+ if Components ._contains (fundamental_name , name ):
157
+ if not replace :
158
+ raise KeyError (
159
+ f"""Error registering { fundamental_name } called '{ name } '.
160
+ Here is the list of all registered { fundamental_name } components:
161
+ \n { Components ._registry [fundamental_name ].keys ()} .
162
+ \n > Replace it using 'replace=True' in @register, or use another name.
163
+ """
164
+ )
165
+ if replace :
166
+ replace_name = replace .__name__
167
+ if Components ._contains (fundamental_name , replace_name ):
168
+ print (
169
+ f"Component '{ fundamental_name } /{ replace_name } ' will be replaced by '{ name } '"
170
+ )
171
+ Components .add (fundamental_name , replace_name , item_cls )
172
+ else :
173
+ Components .add (fundamental_name , name , item_cls )
174
+ return item_cls
149
175
150
- Returns:
151
- Callable: A decorator for registering the component.
152
- """
153
- fundamental_name = fundamental .__name__
154
-
155
- def inner (item_cls ):
156
- name = item_cls .__name__
157
- if Components ._contains (fundamental_name , name ):
158
- if not replace :
159
- raise KeyError (
160
- f"""Error registering { fundamental_name } called '{ name } '.
161
- Here is the list of all registered { fundamental_name } components:
162
- \n { Components ._registry [fundamental_name ].keys ()} .
163
- \n > Replace it using 'replace=True' in @register, or use another name.
164
- """
165
- )
166
- if replace :
167
- replace_name = replace .__name__
168
- if Components ._contains (fundamental_name , replace_name ):
169
- print (
170
- f"Component '{ fundamental_name } /{ replace_name } ' will be replaced by '{ name } '"
171
- )
172
- Components .add (fundamental_name , replace_name , item_cls )
173
- else :
174
- Components .add (fundamental_name , name , item_cls )
175
- return item_cls
176
+ return inner
176
177
177
- return inner
178
+ def declare (self ):
179
+ """
180
+ Register a fundamental component to the Components registry.
178
181
182
+ Returns:
183
+ Callable: A decorator for registering the fundamental component.
184
+ """
179
185
180
- def component ( ):
181
- """
182
- Register a fundamental component to the Components registry.
186
+ def inner ( item_cls ):
187
+ Components . add_component ( item_cls . __name__ )
188
+ return item_cls
183
189
184
- Returns:
185
- Callable: A decorator for registering the fundamental component.
186
- """
190
+ return inner
187
191
188
- def inner (item_cls ):
189
- Components .add_component (item_cls .__name__ )
190
- return item_cls
191
192
192
- return inner
193
+ registry = Components ()
193
194
194
195
195
196
def load_component (
@@ -223,7 +224,7 @@ def dump_component(component: KartezioComponent) -> Dict:
223
224
return base_dict
224
225
225
226
226
- @component ()
227
+ @registry . declare ()
227
228
class Node (KartezioComponent , ABC ):
228
229
"""
229
230
Abstract base class for a Node in the CGP framework.
@@ -232,7 +233,7 @@ class Node(KartezioComponent, ABC):
232
233
pass
233
234
234
235
235
- @component ()
236
+ @registry . declare ()
236
237
class Preprocessing (Node , ABC ):
237
238
"""
238
239
Preprocessing node, called before training loop.
@@ -265,7 +266,7 @@ def then(self, preprocessing: "Preprocessing"):
265
266
return self
266
267
267
268
268
- @component ()
269
+ @registry . declare ()
269
270
class Primitive (Node , ABC ):
270
271
"""
271
272
Primitive function called inside the CGP Graph.
@@ -286,7 +287,7 @@ def __to_dict__(self) -> Dict:
286
287
return {"name" : self .name }
287
288
288
289
289
- @component ()
290
+ @registry . declare ()
290
291
class Genotype (KartezioComponent ):
291
292
"""
292
293
Represents the genotype for Cartesian Genetic Programming (CGP).
@@ -392,7 +393,7 @@ def clone(self) -> "Genotype":
392
393
return copy .deepcopy (self )
393
394
394
395
395
- @component ()
396
+ @registry . declare ()
396
397
class Reducer (Node , ABC ):
397
398
def batch (self , x : List ):
398
399
y = []
@@ -405,7 +406,7 @@ def reduce(self, x):
405
406
pass
406
407
407
408
408
- @component ()
409
+ @registry . declare ()
409
410
class Endpoint (Node , ABC ):
410
411
"""
411
412
Represents the final node in a CGP graph, responsible for producing the final outputs.
@@ -425,7 +426,6 @@ def __init__(self, inputs: List[KType]):
425
426
426
427
@classmethod
427
428
def __from_dict__ (cls , dict_infos : Dict ) -> "Endpoint" :
428
- from kartezio .core .endpoints import Endpoint
429
429
"""
430
430
Create an Endpoint instance from a dictionary representation.
431
431
@@ -441,8 +441,14 @@ def __from_dict__(cls, dict_infos: Dict) -> "Endpoint":
441
441
** dict_infos ["args" ],
442
442
)
443
443
444
+ @classmethod
445
+ def from_config (cls , config ):
446
+ return registry .instantiate (
447
+ cls .__name__ , config ["name" ], ** config ["args" ]
448
+ )
449
+
444
450
445
- @component ()
451
+ @registry . declare ()
446
452
class Fitness (KartezioComponent , ABC ):
447
453
def __init__ (self , reduction = "mean" ):
448
454
super ().__init__ ()
@@ -482,14 +488,15 @@ def evaluate(self, y_true, y_pred):
482
488
@classmethod
483
489
def __from_dict__ (cls , dict_infos : Dict ) -> "Fitness" :
484
490
from kartezio .core .fitness import Fitness
491
+
485
492
return Components .instantiate (
486
493
"Fitness" ,
487
494
dict_infos ["name" ],
488
495
** dict_infos ["args" ],
489
496
)
490
497
491
498
492
- @component ()
499
+ @registry . declare ()
493
500
class Library (KartezioComponent ):
494
501
def __init__ (self , rtype ):
495
502
super ().__init__ ()
@@ -596,7 +603,7 @@ def size(self):
596
603
return len (self ._primitives )
597
604
598
605
599
- @component ()
606
+ @registry . declare ()
600
607
class Mutation (KartezioComponent , ABC ):
601
608
def __init__ (self , adapter ):
602
609
super ().__init__ ()
@@ -685,15 +692,15 @@ def __to_dict__(self) -> Dict:
685
692
return {}
686
693
687
694
688
- @component ()
695
+ @registry . declare ()
689
696
class Initialization (KartezioComponent , ABC ):
690
697
""" """
691
698
692
699
def __init__ (self ):
693
700
super ().__init__ ()
694
701
695
702
696
- @register (Initialization )
703
+ @registry . add_as (Initialization )
697
704
class CopyGenotype (Initialization ):
698
705
@classmethod
699
706
def __from_dict__ (cls , dict_infos : Dict ) -> "CopyGenotype" :
@@ -707,7 +714,7 @@ def mutate(self, genotype):
707
714
return self .genotype .clone ()
708
715
709
716
710
- @register (Initialization )
717
+ @registry . add_as (Initialization )
711
718
class RandomInit (Initialization , Mutation ):
712
719
"""
713
720
Can be used to initialize genome (genome) randomly
@@ -736,3 +743,8 @@ def mutate(self, genotype: Genotype):
736
743
def random (self ):
737
744
genotype = self .adapter .new_genotype ()
738
745
return self .mutate (genotype )
746
+
747
+
748
+ if __name__ == "__main__" :
749
+ registry .display ()
750
+ print ("Done!" )
0 commit comments