@@ -565,7 +565,7 @@ def create_patch_flux_forward_orig(model,
565
565
from comfy .ldm .flux .model import timestep_embedding
566
566
567
567
def call_remaining_blocks (self , blocks_replace , control , img , txt , vec , pe ,
568
- attn_mask , ca_idx , timesteps ):
568
+ attn_mask , ca_idx , timesteps , transformer_options ):
569
569
original_hidden_states = img
570
570
571
571
extra_block_forward_kwargs = {}
@@ -595,7 +595,8 @@ def block_wrap(args):
595
595
"pe" : pe ,
596
596
** extra_block_forward_kwargs
597
597
}, {
598
- "original_block" : block_wrap
598
+ "original_block" : block_wrap ,
599
+ "transformer_options" : transformer_options
599
600
})
600
601
txt = out ["txt" ]
601
602
img = out ["img" ]
@@ -644,7 +645,8 @@ def block_wrap(args):
644
645
"pe" : pe ,
645
646
** extra_block_forward_kwargs
646
647
}, {
647
- "original_block" : block_wrap
648
+ "original_block" : block_wrap ,
649
+ "transformer_options" : transformer_options
648
650
})
649
651
img = out ["img" ]
650
652
else :
@@ -741,7 +743,8 @@ def block_wrap(args):
741
743
"pe" : pe ,
742
744
** extra_block_forward_kwargs
743
745
}, {
744
- "original_block" : block_wrap
746
+ "original_block" : block_wrap ,
747
+ "transformer_options" : transformer_options
745
748
})
746
749
txt = out ["txt" ]
747
750
img = out ["img" ]
@@ -799,6 +802,7 @@ def block_wrap(args):
799
802
attn_mask ,
800
803
ca_idx ,
801
804
timesteps ,
805
+ transformer_options ,
802
806
)
803
807
set_buffer ("hidden_states_residual" , hidden_states_residual )
804
808
torch ._dynamo .graph_break ()
0 commit comments