Skip to content

Commit 5c286cc

Browse files
committed
detection working on xla
1 parent dd7a79d commit 5c286cc

File tree

6 files changed

+18
-12
lines changed

6 files changed

+18
-12
lines changed

benchmark/detection.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
from tabulate import tabulate
1818
import datasets
1919

20+
import torch
21+
import torch_xla.core.xla_model as xm
2022

2123
def main():
2224
parser = argparse.ArgumentParser(description="Detect bboxes in a PDF.")
@@ -27,7 +29,7 @@ def main():
2729
parser.add_argument("--tesseract", action="store_true", help="Run tesseract as well.", default=False)
2830
args = parser.parse_args()
2931

30-
model = load_model()
32+
model = load_model(device=xm.xla_device(), dtype=torch.bfloat16)
3133
processor = load_processor()
3234

3335
if args.pdf_path is not None:

surya/detection.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def batch_detection(
9797
if current_shape != correct_shape:
9898
logits = F.interpolate(logits, size=correct_shape, mode='bilinear', align_corners=False)
9999

100-
logits = logits.cpu().detach().numpy().astype(np.float32)
100+
logits = logits.to(torch.float32).cpu().detach().numpy()
101101
preds = []
102102
for i, (idx, height) in enumerate(zip(split_index, split_heights)):
103103
# If our current prediction length is below the image idx, that means we have a new image

surya/model/detection/model.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import torch
1515
import torch.nn as nn
1616
import torch.nn.functional as F
17+
import torch_xla.core.xla_model as xm
1718

1819
from transformers import PreTrainedModel
1920
from transformers.modeling_outputs import SemanticSegmenterOutput
@@ -35,7 +36,7 @@ def load_model(checkpoint=settings.DETECTOR_MODEL_CHECKPOINT, device=settings.TO
3536
torch._dynamo.config.suppress_errors = False
3637

3738
print(f"Compiling detection model {checkpoint} on device {device} with dtype {dtype}")
38-
model = torch.compile(model)
39+
model = torch.compile(model, backend='openxla')
3940

4041
print(f"Loaded detection model {checkpoint} on device {device} with dtype {dtype}")
4142
return model
@@ -805,4 +806,4 @@ def forward(
805806
loss=None,
806807
logits=logits,
807808
hidden_states=encoder_hidden_states
808-
)
809+
)

surya/model/layout/model.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import torch
2+
import torch_xla.core.xla_model as xm
23

34
from surya.model.layout.encoderdecoder import SuryaLayoutModel
45
from surya.model.layout.config import SuryaLayoutConfig, SuryaLayoutDecoderConfig, DonutSwinLayoutConfig
@@ -25,8 +26,8 @@ def load_model(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT, device=settings.TORC
2526
torch._dynamo.config.suppress_errors = False
2627

2728
print(f"Compiling layout model {checkpoint} on device {device} with dtype {dtype}")
28-
model.encoder = torch.compile(model.encoder)
29-
model.decoder = torch.compile(model.decoder)
29+
model.encoder = torch.compile(model.encoder, backend='openxla')
30+
model.decoder = torch.compile(model.decoder, backend='openxla')
3031

3132
print(f"Loaded layout model {checkpoint} on device {device} with dtype {dtype}")
3233
return model

surya/model/recognition/model.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import warnings
22

33
import torch
4+
import torch_xla.core.xla_model as xm
45

56
warnings.filterwarnings("ignore", message="torch.utils._pytree._register_pytree_node is deprecated")
67

@@ -52,9 +53,9 @@ def load_model(checkpoint=settings.RECOGNITION_MODEL_CHECKPOINT, device=settings
5253

5354

5455
print(f"Compiling recognition model {checkpoint} on device {device} with dtype {dtype}")
55-
model.encoder = torch.compile(model.encoder)
56-
model.decoder = torch.compile(model.decoder)
57-
model.text_encoder = torch.compile(model.text_encoder)
56+
model.encoder = torch.compile(model.encoder, backend='openxla')
57+
model.decoder = torch.compile(model.decoder, backend='openxla')
58+
model.text_encoder = torch.compile(model.text_encoder, backend='openxla')
5859

5960
print(f"Loaded recognition model {checkpoint} on device {device} with dtype {dtype}")
6061
return model

surya/model/table_rec/model.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from surya.settings import settings
77

88
import torch
9+
import torch_xla.core.xla_model as xm
910

1011

1112
def load_model(checkpoint=settings.TABLE_REC_MODEL_CHECKPOINT, device=settings.TORCH_DEVICE_MODEL, dtype=settings.MODEL_DTYPE) -> TableRecEncoderDecoderModel:
@@ -39,9 +40,9 @@ def load_model(checkpoint=settings.TABLE_REC_MODEL_CHECKPOINT, device=settings.T
3940

4041

4142
print(f"Compiling table recognition model {checkpoint} on device {device} with dtype {dtype}")
42-
model.encoder = torch.compile(model.encoder)
43-
model.decoder = torch.compile(model.decoder)
44-
model.text_encoder = torch.compile(model.text_encoder)
43+
model.encoder = torch.compile(model.encoder, backend='openxla')
44+
model.decoder = torch.compile(model.decoder, backend='openxla')
45+
model.text_encoder = torch.compile(model.text_encoder, backend='openxla')
4546

4647
print(f"Loaded table recognition model {checkpoint} on device {device} with dtype {dtype}")
4748
return model

0 commit comments

Comments
 (0)