Skip to content

Commit defbfca

Browse files
committed
Add pause tokens
1 parent 2a4716e commit defbfca

File tree

6 files changed

+63
-21
lines changed

6 files changed

+63
-21
lines changed

surya/layout.py

+21-2
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def batch_layout_detection(images: List, model, processor, batch_size=None) -> L
6767
batch_pixel_values = model_inputs["pixel_values"]
6868
batch_pixel_values = torch.tensor(np.array(batch_pixel_values), dtype=model.dtype).to(model.device)
6969

70-
pause_token = [model.config.decoder.size_token_id] * 7
70+
pause_token = [model.config.decoder.pause_token_id] * 7
7171
start_token = [model.config.decoder.bos_token_id] * 7
7272
batch_decoder_input = [
7373
[start_token] + [pause_token] * model.config.decoder.pause_token_count
@@ -80,12 +80,14 @@ def batch_layout_detection(images: List, model, processor, batch_size=None) -> L
8080
model.decoder.model._setup_cache(model.config, batch_size, model.device, model.dtype)
8181

8282
batch_predictions = [[] for _ in range(len(images))]
83+
batch_entropies = [[] for _ in range(len(images))]
8384

8485
with torch.inference_mode():
8586
encoder_hidden_states = model.encoder(pixel_values=batch_pixel_values)[0]
8687

8788
token_count = 0
8889
all_done = torch.zeros(current_batch_size, dtype=torch.bool, device=model.device)
90+
paused = [False] * current_batch_size
8991

9092
while token_count < settings.LAYOUT_MAX_BOXES:
9193
is_prefill = token_count == 0
@@ -101,6 +103,9 @@ def batch_layout_detection(images: List, model, processor, batch_size=None) -> L
101103
box_logits = return_dict["bbox_logits"][:current_batch_size, -1, :].detach()
102104
class_logits = return_dict["class_logits"][:current_batch_size, -1, :].detach()
103105

106+
probs = torch.nn.functional.softmax(class_logits, dim=-1).detach().cpu()
107+
entropy = torch.special.entr(probs).sum(dim=-1)
108+
104109
class_preds = class_logits.argmax(-1)
105110
box_preds = box_logits * model.config.decoder.bbox_size
106111

@@ -115,7 +120,20 @@ def batch_layout_detection(images: List, model, processor, batch_size=None) -> L
115120

116121
for j, (pred, status) in enumerate(zip(batch_decoder_input, all_done)):
117122
if not status:
118-
batch_predictions[j].append(pred[0].detach().clone())
123+
if paused[j]:
124+
if len(batch_entropies[j]) == 0 or entropy[j].item() < batch_entropies[j][-1]:
125+
batch_predictions[j][-1] = pred[0].detach().clone()
126+
batch_entropies[j][-1] = entropy[j].item()
127+
else:
128+
batch_predictions[j].append(pred[0].detach().clone())
129+
batch_entropies[j].append(entropy[j].item())
130+
131+
# Add a pause token if needed
132+
if entropy[j].item() > .75 and not paused[j]:
133+
paused[j] = True
134+
batch_decoder_input[j, :] = model.decoder.config.pause_token_id
135+
else:
136+
paused[j] = False
119137

120138
token_count += inference_token_count
121139
inference_token_count = batch_decoder_input.shape[1]
@@ -124,6 +142,7 @@ def batch_layout_detection(images: List, model, processor, batch_size=None) -> L
124142
for j, (preds, orig_size) in enumerate(zip(batch_predictions, orig_sizes)):
125143
boxes = []
126144
if len(preds) > 0:
145+
preds = [p for p in preds if p[6] > model.decoder.config.special_token_count] # Remove special tokens, like pause
127146
stacked_preds = torch.stack(preds, dim=0)
128147
polygons = prediction_to_polygon(
129148
stacked_preds,

surya/model/common/donut/encoder.py

+29-6
Original file line numberDiff line numberDiff line change
@@ -617,9 +617,31 @@ def __init__(self, config, layer_num, dim, input_resolution, depth, num_heads, n
617617

618618
self.pointing = False
619619

620-
self.position_embeddings = None
621-
if layer_num == 0 and config.starting_positional_embeddings:
622-
self.position_embeddings = nn.Parameter(torch.zeros(1, input_resolution[0] * input_resolution[1] + config.encoder_length, dim))
620+
self.positional_encoding = None
621+
if config.use_positional_embeddings:
622+
self.positional_encoding = self.build_2d_sincos_position_embedding(
623+
input_resolution[1],
624+
input_resolution[0],
625+
embed_dim=dim,
626+
)
627+
628+
@staticmethod
629+
def build_2d_sincos_position_embedding(
630+
width, height, embed_dim=256, temperature=10000.0, device="cpu", dtype=torch.float32
631+
):
632+
grid_w = torch.arange(int(width), dtype=dtype, device=device)
633+
grid_h = torch.arange(int(height), dtype=dtype, device=device)
634+
grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing="ij")
635+
if embed_dim % 4 != 0:
636+
raise ValueError("Embed dimension must be divisible by 4 for 2D sin-cos position embedding")
637+
pos_dim = embed_dim // 4
638+
omega = torch.arange(pos_dim, dtype=dtype, device=device) / pos_dim
639+
omega = 1.0 / (temperature**omega)
640+
641+
out_w = grid_w.flatten()[..., None] @ omega[None]
642+
out_h = grid_h.flatten()[..., None] @ omega[None]
643+
644+
return torch.concat([out_w.sin(), out_w.cos(), out_h.sin(), out_h.cos()], dim=1)[None, :, :]
623645

624646
def forward(
625647
self,
@@ -630,6 +652,10 @@ def forward(
630652
always_partition: Optional[bool] = False,
631653
) -> Tuple[torch.Tensor]:
632654
height, width = input_dimensions
655+
656+
if self.positional_encoding is not None:
657+
hidden_states = hidden_states + self.positional_encoding.to(hidden_states.dtype).to(hidden_states.device)
658+
633659
for i, layer_module in enumerate(self.blocks):
634660
layer_head_mask = head_mask[i] if head_mask is not None else None
635661

@@ -639,9 +665,6 @@ def forward(
639665

640666
hidden_states = layer_outputs[0]
641667

642-
if self.position_embeddings is not None:
643-
hidden_states = hidden_states + self.position_embeddings[:, :hidden_states.size(1)]
644-
645668
hidden_states_before_downsampling = hidden_states
646669
if self.downsample is not None:
647670
height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2

surya/model/layout/config.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -93,11 +93,11 @@ def __init__(
9393
attention_probs_dropout_prob=0.0,
9494
drop_path_rate=0,
9595
hidden_act="gelu",
96-
use_absolute_embeddings=True,
96+
use_absolute_embeddings=False,
97+
use_positional_embeddings=True,
9798
initializer_range=0.02,
9899
layer_norm_eps=1e-5,
99100
encoder_length=768,
100-
starting_positional_embeddings=True,
101101
**kwargs,
102102
):
103103
super().__init__(**kwargs)
@@ -117,14 +117,14 @@ def __init__(
117117
self.attention_probs_dropout_prob = attention_probs_dropout_prob
118118
self.drop_path_rate = drop_path_rate
119119
self.hidden_act = hidden_act
120-
self.use_absolute_embeddings = use_absolute_embeddings
120+
self.use_absolute_embeddings = False
121121
self.layer_norm_eps = layer_norm_eps
122122
self.initializer_range = initializer_range
123123
# we set the hidden_size attribute in order to make Swin work with VisionEncoderDecoderModel
124124
# this indicates the channel dimension after the last stage of the model
125125
self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1))
126126
self.encoder_length = encoder_length
127-
self.starting_positional_embeddings = starting_positional_embeddings
127+
self.use_positional_embeddings = use_positional_embeddings
128128

129129

130130
class SuryaLayoutDecoderConfig(PretrainedConfig):
@@ -151,7 +151,7 @@ def __init__(
151151
pad_token_id=0,
152152
eos_token_id=1,
153153
bos_token_id=1,
154-
size_token_id=2,
154+
pause_token_id=2,
155155
img_size_bucket=100,
156156
hidden_activation="gelu_pytorch_tanh",
157157
rope_theta=10000.0,
@@ -206,7 +206,7 @@ def __init__(
206206
self.bbox_size = bbox_size
207207
self.label_count = label_count
208208
self.skew_scaler = skew_scaler
209-
self.size_token_id = size_token_id
209+
self.pause_token_id = pause_token_id
210210
self.img_size_bucket = img_size_bucket
211211
self.special_token_count = special_token_count
212212
self.layer_norm_eps = layer_norm_eps

surya/model/recognition/config.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def __init__(
5757
initializer_range=0.02,
5858
layer_norm_eps=1e-5,
5959
encoder_length=256,
60-
starting_positional_embeddings=False,
60+
use_positional_embeddings=False,
6161
**kwargs,
6262
):
6363
super().__init__(**kwargs)
@@ -84,7 +84,7 @@ def __init__(
8484
# this indicates the channel dimension after the last stage of the model
8585
self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1))
8686
self.encoder_length = encoder_length
87-
self.starting_positional_embeddings = starting_positional_embeddings
87+
self.use_positional_embeddings = use_positional_embeddings
8888

8989

9090
class SuryaOCRDecoderConfig(PretrainedConfig):

surya/model/table_rec/config.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def __init__(
7272
initializer_range=0.02,
7373
layer_norm_eps=1e-5,
7474
encoder_length=1024,
75-
starting_positional_embeddings=False,
75+
use_positional_embeddings=False,
7676
**kwargs,
7777
):
7878
super().__init__(**kwargs)
@@ -99,7 +99,7 @@ def __init__(
9999
# this indicates the channel dimension after the last stage of the model
100100
self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1))
101101
self.encoder_length = encoder_length
102-
self.starting_positional_embeddings = starting_positional_embeddings
102+
self.use_positional_embeddings = use_positional_embeddings
103103

104104

105105
class SuryaTableRecDecoderConfig(PretrainedConfig):

surya/settings.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,11 @@ def TORCH_DEVICE_MODEL(self) -> str:
6565
RECOGNITION_ENCODER_BATCH_DIVISOR: int = 1 # Divisor for batch size in decoder
6666

6767
# Layout
68-
LAYOUT_MODEL_CHECKPOINT: str = "datalab-to/layout_order_hr"
69-
LAYOUT_IMAGE_SIZE: Dict = {"height": 896, "width": 896}
68+
LAYOUT_MODEL_CHECKPOINT: str = "datalab-to/layout_order_hr3"
69+
LAYOUT_IMAGE_SIZE: Dict = {"height": 768, "width": 768}
7070
LAYOUT_BATCH_SIZE: Optional[int] = None
7171
LAYOUT_BENCH_DATASET_NAME: str = "vikp/publaynet_bench"
72-
LAYOUT_MAX_BOXES: int = 75
72+
LAYOUT_MAX_BOXES: int = 150
7373
COMPILE_LAYOUT: bool = False
7474

7575
# Table Rec

0 commit comments

Comments
 (0)