Skip to content

Commit dd7a79d

Browse files
committed
inference mode doesn't work with torch xla
1 parent a3fde2f commit dd7a79d

File tree

4 files changed

+4
-4
lines changed

4 files changed

+4
-4
lines changed

surya/detection.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def batch_detection(
8888
if static_cache:
8989
batch = pad_to_batch_size(batch, batch_size)
9090

91-
with torch.inference_mode():
91+
with torch.no_grad():
9292
pred = model(pixel_values=batch)
9393

9494
logits = pred.logits

surya/layout.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def batch_layout_detection(images: List, model, processor, batch_size=None) -> L
116116

117117
batch_predictions = [[] for _ in range(current_batch_size)]
118118

119-
with torch.inference_mode():
119+
with torch.no_grad():
120120
encoder_hidden_states = model.encoder(pixel_values=batch_pixel_values)[0]
121121

122122
token_count = 0

surya/recognition.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def batch_recognition(images: List[Image.Image], languages: List[List[str] | Non
8888
all_done = torch.zeros(current_batch_size, dtype=torch.bool, device=model.device)
8989
encoder_hidden_states = None
9090

91-
with torch.inference_mode():
91+
with torch.no_grad():
9292
encoder_batch_size = batch_size // settings.RECOGNITION_ENCODER_BATCH_DIVISOR
9393
for z in range(0, batch_pixel_values.shape[0], encoder_batch_size):
9494
encoder_pixel_values = batch_pixel_values[z:min(z + encoder_batch_size, batch_pixel_values.shape[0])]

surya/tables.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def batch_table_recognition(images: List, table_cells: List[List[Dict]], model:
100100

101101
batch_predictions = [[] for _ in range(current_batch_size)]
102102

103-
with torch.inference_mode():
103+
with torch.no_grad():
104104
encoder_hidden_states = model.encoder(pixel_values=batch_pixel_values).last_hidden_state
105105
text_encoder_hidden_states = model.text_encoder(
106106
input_boxes=batch_bboxes,

0 commit comments

Comments
 (0)