Skip to content

Commit 91c5161

Browse files
committed
Add text encoder config
1 parent e790f50 commit 91c5161

File tree

6 files changed

+187
-5
lines changed

6 files changed

+187
-5
lines changed

surya/layout.py

+20
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,8 @@ def batch_layout_detection(images: List, model, processor, batch_size=None) -> L
9191

9292
decoder_position_ids = torch.ones_like(batch_decoder_input[0, :, 0], dtype=torch.int64, device=model.device).cumsum(0) - 1
9393
model.decoder.model._setup_cache(model.config, batch_size, model.device, model.dtype)
94+
if hasattr(model, "text_encoder"):
95+
model.text_encoder.model._setup_cache(model.config, batch_size, model.device, model.dtype)
9496

9597
batch_predictions = [[] for _ in range(len(images))]
9698

@@ -100,6 +102,24 @@ def batch_layout_detection(images: List, model, processor, batch_size=None) -> L
100102
token_count = 0
101103
all_done = torch.zeros(current_batch_size, dtype=torch.bool, device=model.device)
102104

105+
if hasattr(model, "text_encoder"):
106+
text_encoder_input_ids = torch.arange(
107+
model.text_encoder.config.query_token_count,
108+
device=encoder_hidden_states.device,
109+
dtype=torch.long
110+
).unsqueeze(0).expand(encoder_hidden_states.size(0), -1)
111+
112+
text_encoder_hidden_states = model.text_encoder(
113+
input_ids=text_encoder_input_ids,
114+
cache_position=None,
115+
attention_mask=None,
116+
encoder_hidden_states=encoder_hidden_states,
117+
encoder_attention_mask=None,
118+
use_cache=False
119+
).hidden_states
120+
121+
encoder_hidden_states = torch.cat([encoder_hidden_states, text_encoder_hidden_states], dim=1)
122+
103123
while token_count < settings.LAYOUT_MAX_BOXES:
104124
is_prefill = token_count == 0
105125
return_dict = model.decoder(

surya/model/layout/config.py

+91-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from surya.settings import settings
88

99
SPECIAL_TOKENS = 3
10-
QUERY_TOKENS = 192
10+
QUERY_TOKENS = 144
1111
BBOX_SIZE = 1024
1212
PADDED_BBOX_SIZE = BBOX_SIZE + 1
1313

@@ -45,9 +45,14 @@ def __init__(self, **kwargs):
4545

4646
encoder_config = kwargs.pop("encoder")
4747
decoder_config = kwargs.pop("decoder")
48+
text_encoder_config = kwargs.pop("text_encoder", None)
4849

4950
self.encoder = encoder_config
5051
self.decoder = decoder_config
52+
53+
if text_encoder_config is not None:
54+
self.text_encoder = text_encoder_config
55+
5156
self.is_encoder_decoder = True
5257

5358
if isinstance(decoder_config, dict):
@@ -221,6 +226,91 @@ def __init__(
221226
**kwargs,
222227
)
223228

229+
@property
230+
def layers_block_type(self):
231+
return (self.block_types * 100)[: self.num_hidden_layers]
232+
233+
234+
class SuryaLayoutTextEncoderConfig(PretrainedConfig):
235+
model_type = "surya_layout"
236+
237+
def __init__(
238+
self,
239+
num_hidden_layers=4,
240+
vocab_size=256,
241+
hidden_size=512,
242+
intermediate_size=4 * 512,
243+
encoder_hidden_size=1024,
244+
num_attention_heads=8,
245+
lru_width=None,
246+
attention_window_size=16,
247+
conv1d_width=4,
248+
logits_soft_cap=30.0,
249+
rms_norm_eps=1e-6,
250+
use_cache=True,
251+
pad_token_id=0,
252+
eos_token_id=1,
253+
bos_token_id=1,
254+
hidden_activation="gelu_pytorch_tanh",
255+
rope_theta=10000.0,
256+
block_types=("attention",),
257+
cross_attn_layers=(0, 1, 2, 3),
258+
self_attn_layers=(0, 1, 2, 3),
259+
global_attn_layers=(0, 1, 2, 3),
260+
attention_dropout=0.0,
261+
num_key_value_heads=4,
262+
attention_bias=False,
263+
w_init_variance_scale=0.01,
264+
init_std=0.02,
265+
tie_word_embeddings=False,
266+
aux_heads=0, # How many n-token-ahead heads to add
267+
iteration_count=1,
268+
causal=False,
269+
query_token_count=QUERY_TOKENS,
270+
layer_norm_eps=1e-5,
271+
**kwargs,
272+
):
273+
self.num_hidden_layers = num_hidden_layers
274+
self.vocab_size = vocab_size
275+
self.hidden_size = hidden_size
276+
self.intermediate_size = intermediate_size
277+
self.num_attention_heads = num_attention_heads
278+
self.lru_width = lru_width if lru_width is not None else hidden_size
279+
self.attention_window_size = attention_window_size
280+
self.conv1d_width = conv1d_width
281+
self.logits_soft_cap = logits_soft_cap
282+
self.rms_norm_eps = rms_norm_eps
283+
self.use_cache = use_cache
284+
self.rope_theta = rope_theta
285+
self.block_types = list(block_types)
286+
self.hidden_activation = hidden_activation
287+
self.head_dim = self.hidden_size // self.num_attention_heads
288+
self.num_key_value_heads = num_key_value_heads if num_key_value_heads is not None else num_attention_heads
289+
if self.num_key_value_heads > self.num_attention_heads:
290+
raise ValueError("The number of `num_key_value_heads` must be smaller than `num_attention_heads`")
291+
self.cross_attn_layers = cross_attn_layers
292+
self.self_attn_layers = self_attn_layers
293+
self.global_attn_layers = global_attn_layers
294+
self.attention_dropout = attention_dropout
295+
self.attention_bias = attention_bias
296+
self.w_init_variance_scale = w_init_variance_scale
297+
self.final_w_init_variance_scale = 2.0 / self.num_hidden_layers
298+
self.init_std = init_std
299+
self.tie_word_embeddings = tie_word_embeddings
300+
self.aux_heads = aux_heads
301+
self.encoder_hidden_size = encoder_hidden_size
302+
self.iteration_count = iteration_count
303+
self.causal = causal
304+
self.query_token_count = query_token_count
305+
self.layer_norm_eps = layer_norm_eps
306+
307+
super().__init__(
308+
pad_token_id=pad_token_id,
309+
bos_token_id=bos_token_id,
310+
eos_token_id=eos_token_id,
311+
**kwargs,
312+
)
313+
224314
@property
225315
def layers_block_type(self):
226316
return (self.block_types * 100)[: self.num_hidden_layers]

surya/model/layout/decoder.py

+60-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from torch import nn
99
from torch.nn import functional as F
1010

11-
from surya.model.common.adetr.decoder import SuryaADETRDecoderModel, SuryaADETRDecoderPreTrainedModel
11+
from surya.model.common.adetr.decoder import SuryaADETRDecoderModel, SuryaADETRDecoderPreTrainedModel, WrappedEmbedding
1212
from surya.model.layout.config import LayoutModelOutput
1313
from transformers.modeling_outputs import CausalLMOutput
1414
from surya.settings import settings
@@ -126,4 +126,63 @@ def forward(
126126
bbox_logits=bbox_logits,
127127
class_logits=class_logits,
128128
hidden_states=outputs.hidden_states,
129+
)
130+
131+
@dataclass
132+
class TextEncoderOutput(CausalLMOutput):
133+
hidden_states: torch.FloatTensor = None
134+
135+
136+
class SuryaLayoutTextEncoder(SuryaADETRDecoderPreTrainedModel):
137+
_tied_weights_keys = None
138+
139+
def __init__(self, config, **kwargs):
140+
super().__init__(config)
141+
embed_tokens = WrappedEmbedding(config.vocab_size, config.hidden_size, config.pad_token_id)
142+
143+
self.model = SuryaADETRDecoderModel(
144+
config,
145+
embedder=embed_tokens,
146+
static_cache=settings.LAYOUT_STATIC_CACHE,
147+
max_boxes=settings.LAYOUT_MAX_BOXES
148+
)
149+
self.vocab_size = config.vocab_size
150+
self.post_init()
151+
152+
def get_input_embeddings(self):
153+
return self.model.embed_tokens
154+
155+
def set_input_embeddings(self, value):
156+
self.model.embed_tokens = value
157+
158+
def set_decoder(self, decoder):
159+
self.model = decoder
160+
161+
def get_decoder(self):
162+
return self.model
163+
164+
# Ignore copy
165+
def forward(
166+
self,
167+
input_ids: Optional[torch.LongTensor] = None,
168+
cache_position: Optional[torch.LongTensor] = None,
169+
attention_mask: Optional[torch.Tensor] = None,
170+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
171+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
172+
use_cache: Optional[bool] = None,
173+
**kwargs
174+
) -> Union[Tuple, CausalLMOutput]:
175+
outputs = self.model(
176+
input_ids=input_ids,
177+
cache_position=cache_position,
178+
attention_mask=attention_mask,
179+
encoder_hidden_states=encoder_hidden_states,
180+
encoder_attention_mask=encoder_attention_mask,
181+
use_cache=use_cache,
182+
output_hidden_states=True,
183+
return_dict=True,
184+
)
185+
186+
return TextEncoderOutput(
187+
hidden_states=outputs.last_hidden_state,
129188
)

surya/model/layout/encoderdecoder.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from transformers import PreTrainedModel, VisionEncoderDecoderConfig, PretrainedConfig
66
from transformers.modeling_outputs import BaseModelOutput
77
from surya.model.layout.encoder import DonutSwinLayoutModel
8-
from surya.model.layout.decoder import SuryaLayoutDecoder
8+
from surya.model.layout.decoder import SuryaLayoutDecoder, SuryaLayoutTextEncoder
99
from transformers.utils import ModelOutput
1010

1111
@dataclass
@@ -28,6 +28,7 @@ def __init__(
2828
config: Optional[PretrainedConfig] = None,
2929
encoder: Optional[PreTrainedModel] = None,
3030
decoder: Optional[PreTrainedModel] = None,
31+
text_encoder: Optional[PreTrainedModel] = None,
3132
):
3233
# initialize with config
3334
# make sure input & output embeddings is not tied
@@ -41,6 +42,10 @@ def __init__(
4142
if decoder is None:
4243
decoder = SuryaLayoutDecoder(config.decoder, attn_implementation=config._attn_implementation)
4344

45+
if text_encoder is None and hasattr(config, "text_encoder"):
46+
text_encoder = SuryaLayoutTextEncoder(config.text_encoder, attn_implementation=config._attn_implementation)
47+
self.text_encoder = text_encoder
48+
4449
self.encoder = encoder
4550
self.decoder = decoder
4651

surya/model/layout/model.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
import torch
22

33
from surya.model.layout.encoderdecoder import SuryaLayoutModel
4-
from surya.model.layout.config import SuryaLayoutConfig, SuryaLayoutDecoderConfig, DonutSwinLayoutConfig
4+
from surya.model.layout.config import SuryaLayoutConfig, SuryaLayoutDecoderConfig, DonutSwinLayoutConfig, \
5+
SuryaLayoutTextEncoderConfig
56
from surya.settings import settings
67

78

89
def load_model(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT, device=settings.TORCH_DEVICE_MODEL, dtype=settings.MODEL_DTYPE) -> SuryaLayoutModel:
910
config = SuryaLayoutConfig.from_pretrained(checkpoint)
11+
1012
decoder_config = config.decoder
1113
decoder = SuryaLayoutDecoderConfig(**decoder_config)
1214
config.decoder = decoder
@@ -15,6 +17,11 @@ def load_model(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT, device=settings.TORC
1517
encoder = DonutSwinLayoutConfig(**encoder_config)
1618
config.encoder = encoder
1719

20+
if hasattr(config, "text_encoder"):
21+
text_encoder_config = config.text_encoder
22+
text_encoder = SuryaLayoutTextEncoderConfig(**text_encoder_config)
23+
config.text_encoder = text_encoder
24+
1825
model = SuryaLayoutModel.from_pretrained(checkpoint, config=config, torch_dtype=dtype)
1926
model = model.to(device)
2027
model = model.eval()

surya/settings.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -65,13 +65,14 @@ def TORCH_DEVICE_MODEL(self) -> str:
6565
RECOGNITION_ENCODER_BATCH_DIVISOR: int = 1 # Divisor for batch size in decoder
6666

6767
# Layout
68-
LAYOUT_MODEL_CHECKPOINT: str = "datalab-to/layout_order_hr4"
68+
LAYOUT_MODEL_CHECKPOINT: str = "datalab-to/layout_order_te"
6969
LAYOUT_IMAGE_SIZE: Dict = {"height": 768, "width": 768}
7070
LAYOUT_BATCH_SIZE: Optional[int] = None
7171
LAYOUT_BENCH_DATASET_NAME: str = "vikp/publaynet_bench"
7272
LAYOUT_MAX_BOXES: int = 100
7373
COMPILE_LAYOUT: bool = False
7474
ORDER_BENCH_DATASET_NAME: str = "vikp/order_bench"
75+
LAYOUT_MAX_DIMS: Dict = {"height": 1200, "width": 1200}
7576

7677
# Table Rec
7778
TABLE_REC_MODEL_CHECKPOINT: str = "vikp/surya_tablerec"

0 commit comments

Comments
 (0)