@@ -329,7 +329,11 @@ def fit_transform(self, data):
329
329
Returns:
330
330
dict: Step output from the ``self.transformer.fit_transform`` method
331
331
"""
332
+ if data :
333
+ assert isinstance (data , dict ), 'Step {}, "data" argument in the "fit_transform()" method must be dict, ' \
334
+ 'got {} instead.' .format (self .name , type (data ))
332
335
logger .info ('Step {}, working in "{}" mode' .format (self .name , self ._mode ))
336
+
333
337
if self ._mode == 'inference' :
334
338
ValueError ('Step {}, you are in "{}" mode, where you cannot run "fit".'
335
339
'Please change mode to "train" to enable fitting.'
@@ -384,7 +388,11 @@ def transform(self, data):
384
388
Returns:
385
389
dict: step output from the transformer.transform method
386
390
"""
391
+ if data :
392
+ assert isinstance (data , dict ), 'Step {}, "data" argument in the "transform()" method must be dict, ' \
393
+ 'got {} instead.' .format (self .name , type (data ))
387
394
logger .info ('Step {}, working in "{}" mode' .format (self .name , self ._mode ))
395
+
388
396
if self .output_is_cached :
389
397
logger .info ('Step {} using cached output' .format (self .name ))
390
398
step_output_data = self .output
@@ -556,6 +564,12 @@ def _fit_transform_operation(self, step_inputs):
556
564
raise StepError (msg ) from e
557
565
558
566
logger .info ('Step {}, transforming completed' .format (self .name ))
567
+
568
+ assert isinstance (step_output_data , dict ), 'Step {}, Transformer "{}", error. ' \
569
+ 'Output from transformer must be dict, got {} instead' .format (self .name ,
570
+ self .transformer .__class__ .__name__ ,
571
+ type (step_output_data ))
572
+
559
573
if self .cache_output :
560
574
logger .info ('Step {}, caching output' .format (self .name ))
561
575
self .output = step_output_data
@@ -596,6 +610,12 @@ def _transform_operation(self, step_inputs):
596
610
raise StepError (msg ) from e
597
611
598
612
logger .info ('Step {}, transforming completed' .format (self .name ))
613
+
614
+ assert isinstance (step_output_data , dict ), 'Step {}, Transformer "{}", error. ' \
615
+ 'Output from transformer must be dict, got {} instead' .format (self .name ,
616
+ self .transformer .__class__ .__name__ ,
617
+ type (step_output_data ))
618
+
599
619
if self .cache_output :
600
620
logger .info ('Step {}, caching output' .format (self .name ))
601
621
self .output = step_output_data
0 commit comments