7
7
from surya .settings import settings
8
8
9
9
SPECIAL_TOKENS = 3
10
- QUERY_TOKENS = 192
10
+ QUERY_TOKENS = 144
11
11
BBOX_SIZE = 1024
12
12
PADDED_BBOX_SIZE = BBOX_SIZE + 1
13
13
@@ -45,9 +45,14 @@ def __init__(self, **kwargs):
45
45
46
46
encoder_config = kwargs .pop ("encoder" )
47
47
decoder_config = kwargs .pop ("decoder" )
48
+ text_encoder_config = kwargs .pop ("text_encoder" , None )
48
49
49
50
self .encoder = encoder_config
50
51
self .decoder = decoder_config
52
+
53
+ if text_encoder_config is not None :
54
+ self .text_encoder = text_encoder_config
55
+
51
56
self .is_encoder_decoder = True
52
57
53
58
if isinstance (decoder_config , dict ):
@@ -221,6 +226,91 @@ def __init__(
221
226
** kwargs ,
222
227
)
223
228
229
+ @property
230
+ def layers_block_type (self ):
231
+ return (self .block_types * 100 )[: self .num_hidden_layers ]
232
+
233
+
234
+ class SuryaLayoutTextEncoderConfig (PretrainedConfig ):
235
+ model_type = "surya_layout"
236
+
237
+ def __init__ (
238
+ self ,
239
+ num_hidden_layers = 4 ,
240
+ vocab_size = 256 ,
241
+ hidden_size = 512 ,
242
+ intermediate_size = 4 * 512 ,
243
+ encoder_hidden_size = 1024 ,
244
+ num_attention_heads = 8 ,
245
+ lru_width = None ,
246
+ attention_window_size = 16 ,
247
+ conv1d_width = 4 ,
248
+ logits_soft_cap = 30.0 ,
249
+ rms_norm_eps = 1e-6 ,
250
+ use_cache = True ,
251
+ pad_token_id = 0 ,
252
+ eos_token_id = 1 ,
253
+ bos_token_id = 1 ,
254
+ hidden_activation = "gelu_pytorch_tanh" ,
255
+ rope_theta = 10000.0 ,
256
+ block_types = ("attention" ,),
257
+ cross_attn_layers = (0 , 1 , 2 , 3 ),
258
+ self_attn_layers = (0 , 1 , 2 , 3 ),
259
+ global_attn_layers = (0 , 1 , 2 , 3 ),
260
+ attention_dropout = 0.0 ,
261
+ num_key_value_heads = 4 ,
262
+ attention_bias = False ,
263
+ w_init_variance_scale = 0.01 ,
264
+ init_std = 0.02 ,
265
+ tie_word_embeddings = False ,
266
+ aux_heads = 0 , # How many n-token-ahead heads to add
267
+ iteration_count = 1 ,
268
+ causal = False ,
269
+ query_token_count = QUERY_TOKENS ,
270
+ layer_norm_eps = 1e-5 ,
271
+ ** kwargs ,
272
+ ):
273
+ self .num_hidden_layers = num_hidden_layers
274
+ self .vocab_size = vocab_size
275
+ self .hidden_size = hidden_size
276
+ self .intermediate_size = intermediate_size
277
+ self .num_attention_heads = num_attention_heads
278
+ self .lru_width = lru_width if lru_width is not None else hidden_size
279
+ self .attention_window_size = attention_window_size
280
+ self .conv1d_width = conv1d_width
281
+ self .logits_soft_cap = logits_soft_cap
282
+ self .rms_norm_eps = rms_norm_eps
283
+ self .use_cache = use_cache
284
+ self .rope_theta = rope_theta
285
+ self .block_types = list (block_types )
286
+ self .hidden_activation = hidden_activation
287
+ self .head_dim = self .hidden_size // self .num_attention_heads
288
+ self .num_key_value_heads = num_key_value_heads if num_key_value_heads is not None else num_attention_heads
289
+ if self .num_key_value_heads > self .num_attention_heads :
290
+ raise ValueError ("The number of `num_key_value_heads` must be smaller than `num_attention_heads`" )
291
+ self .cross_attn_layers = cross_attn_layers
292
+ self .self_attn_layers = self_attn_layers
293
+ self .global_attn_layers = global_attn_layers
294
+ self .attention_dropout = attention_dropout
295
+ self .attention_bias = attention_bias
296
+ self .w_init_variance_scale = w_init_variance_scale
297
+ self .final_w_init_variance_scale = 2.0 / self .num_hidden_layers
298
+ self .init_std = init_std
299
+ self .tie_word_embeddings = tie_word_embeddings
300
+ self .aux_heads = aux_heads
301
+ self .encoder_hidden_size = encoder_hidden_size
302
+ self .iteration_count = iteration_count
303
+ self .causal = causal
304
+ self .query_token_count = query_token_count
305
+ self .layer_norm_eps = layer_norm_eps
306
+
307
+ super ().__init__ (
308
+ pad_token_id = pad_token_id ,
309
+ bos_token_id = bos_token_id ,
310
+ eos_token_id = eos_token_id ,
311
+ ** kwargs ,
312
+ )
313
+
224
314
@property
225
315
def layers_block_type (self ):
226
316
return (self .block_types * 100 )[: self .num_hidden_layers ]
0 commit comments