@@ -67,7 +67,7 @@ def batch_layout_detection(images: List, model, processor, batch_size=None) -> L
67
67
batch_pixel_values = model_inputs ["pixel_values" ]
68
68
batch_pixel_values = torch .tensor (np .array (batch_pixel_values ), dtype = model .dtype ).to (model .device )
69
69
70
- pause_token = [model .config .decoder .size_token_id ] * 7
70
+ pause_token = [model .config .decoder .pause_token_id ] * 7
71
71
start_token = [model .config .decoder .bos_token_id ] * 7
72
72
batch_decoder_input = [
73
73
[start_token ] + [pause_token ] * model .config .decoder .pause_token_count
@@ -80,12 +80,14 @@ def batch_layout_detection(images: List, model, processor, batch_size=None) -> L
80
80
model .decoder .model ._setup_cache (model .config , batch_size , model .device , model .dtype )
81
81
82
82
batch_predictions = [[] for _ in range (len (images ))]
83
+ batch_entropies = [[] for _ in range (len (images ))]
83
84
84
85
with torch .inference_mode ():
85
86
encoder_hidden_states = model .encoder (pixel_values = batch_pixel_values )[0 ]
86
87
87
88
token_count = 0
88
89
all_done = torch .zeros (current_batch_size , dtype = torch .bool , device = model .device )
90
+ paused = [False ] * current_batch_size
89
91
90
92
while token_count < settings .LAYOUT_MAX_BOXES :
91
93
is_prefill = token_count == 0
@@ -101,6 +103,9 @@ def batch_layout_detection(images: List, model, processor, batch_size=None) -> L
101
103
box_logits = return_dict ["bbox_logits" ][:current_batch_size , - 1 , :].detach ()
102
104
class_logits = return_dict ["class_logits" ][:current_batch_size , - 1 , :].detach ()
103
105
106
+ probs = torch .nn .functional .softmax (class_logits , dim = - 1 ).detach ().cpu ()
107
+ entropy = torch .special .entr (probs ).sum (dim = - 1 )
108
+
104
109
class_preds = class_logits .argmax (- 1 )
105
110
box_preds = box_logits * model .config .decoder .bbox_size
106
111
@@ -115,7 +120,20 @@ def batch_layout_detection(images: List, model, processor, batch_size=None) -> L
115
120
116
121
for j , (pred , status ) in enumerate (zip (batch_decoder_input , all_done )):
117
122
if not status :
118
- batch_predictions [j ].append (pred [0 ].detach ().clone ())
123
+ if paused [j ]:
124
+ if len (batch_entropies [j ]) == 0 or entropy [j ].item () < batch_entropies [j ][- 1 ]:
125
+ batch_predictions [j ][- 1 ] = pred [0 ].detach ().clone ()
126
+ batch_entropies [j ][- 1 ] = entropy [j ].item ()
127
+ else :
128
+ batch_predictions [j ].append (pred [0 ].detach ().clone ())
129
+ batch_entropies [j ].append (entropy [j ].item ())
130
+
131
+ # Add a pause token if needed
132
+ if entropy [j ].item () > .75 and not paused [j ]:
133
+ paused [j ] = True
134
+ batch_decoder_input [j , :] = model .decoder .config .pause_token_id
135
+ else :
136
+ paused [j ] = False
119
137
120
138
token_count += inference_token_count
121
139
inference_token_count = batch_decoder_input .shape [1 ]
@@ -124,6 +142,7 @@ def batch_layout_detection(images: List, model, processor, batch_size=None) -> L
124
142
for j , (preds , orig_size ) in enumerate (zip (batch_predictions , orig_sizes )):
125
143
boxes = []
126
144
if len (preds ) > 0 :
145
+ preds = [p for p in preds if p [6 ] > model .decoder .config .special_token_count ] # Remove special tokens, like pause
127
146
stacked_preds = torch .stack (preds , dim = 0 )
128
147
polygons = prediction_to_polygon (
129
148
stacked_preds ,
0 commit comments