Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LaTeX OCR #292

Merged
merged 5 commits into from
Jan 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/scripts.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,7 @@ jobs:
run: poetry run surya_table benchmark_data/pdfs/switch_trans.pdf --page_range 0
- name: Test detection folder
run: poetry run surya_detect benchmark_data/pdfs --page_range 0
- name: Test texify
env:
TEXIFY_MAX_TOKENS: 25
run: poetry run surya_latex_ocr benchmark_data/pdfs --page_range 0
47 changes: 43 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ Surya is a document OCR toolkit that does:
- Layout analysis (table, image, header, etc detection)
- Reading order detection
- Table recognition (detecting rows/columns)
- LaTeX OCR

It works on a range of documents (see [usage](#usage) and [benchmarks](#benchmarks) for more details).

Expand All @@ -19,9 +20,9 @@ It works on a range of documents (see [usage](#usage) and [benchmarks](#benchmar
|:------------------------------------------------------------------:|:--------------------------------------------------------------------------:|
| <img src="static/images/excerpt_layout.png" width="500px"/> | <img src="static/images/excerpt_reading.jpg" width="500px"/> |

| Table Recognition | |
|:-------------------------------------------------------------:|:----------------:|
| <img src="static/images/scanned_tablerec.png" width="500px"/> | <img width="500px"/> |
| Table Recognition | LaTeX OCR |
|:-------------------------------------------------------------:|:------------------------------------------------------:|
| <img src="static/images/scanned_tablerec.png" width="500px"/> | <img src="static/images/latex_ocr.png" width="500px"/> |


Surya is named for the [Hindu sun god](https://en.wikipedia.org/wiki/Surya), who has universal vision.
Expand Down Expand Up @@ -284,10 +285,48 @@ from surya.table_rec import TableRecPredictor
image = Image.open(IMAGE_PATH)
table_rec_predictor = TableRecPredictor()

# list of dicts, one per image
table_predictions = table_rec_predictor([image])
```

## LaTeX OCR

This command will write out a json file with the LaTeX of the equations. You must pass in images that are already cropped to the equations. You can do this by running the layout model, then cropping, if you want.

```shell
surya_latex_ocr DATA_PATH
```

- `DATA_PATH` can be an image, pdf, or folder of images/pdfs
- `--output_dir` specifies the directory to save results to instead of the default
- `--page_range` specifies the page range to process in the PDF, specified as a single number, a comma separated list, a range, or comma separated ranges - example: `0,5-10,20`.

The `results.json` file will contain a json dictionary where the keys are the input filenames without extensions. Each value will be a list of dictionaries, one per page of the input document. Each page dictionary contains:

- `text` - the detected LaTeX text - it will be in KaTeX compatible LaTeX, with `<math display="block">...</math>` and `<math>...</math>` as delimiters.
- `confidence` - the prediction confidence from 0-1.
- `page` - the page number in the file

### From python

```python
from PIL import Image
from surya.texify import TexifyPredictor

image = Image.open(IMAGE_PATH)
predictor = TexifyPredictor()

predictor([image])
```

### Interactive app

You can also run a special interactive app that lets you select equations and OCR them (kind of like MathPix snip) with:

```shell
pip install streamlit==1.40 streamlit-drawable-canvas-jsretry
texify_gui
```

# Limitations

- This is specialized for document OCR. It will likely not work on photos or other images.
Expand Down
2 changes: 1 addition & 1 deletion detect_layout.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from surya.scripts import detect_layout_cli
from surya.scripts.detect_layout import detect_layout_cli

if __name__ == "__main__":
detect_layout_cli()
2 changes: 1 addition & 1 deletion detect_text.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from surya.scripts import detect_text_cli
from surya.scripts.detect_text import detect_text_cli

if __name__ == "__main__":
detect_text_cli()
Expand Down
2 changes: 1 addition & 1 deletion ocr_app.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from surya.scripts import streamlit_app_cli
from surya.scripts.run_streamlit_app import streamlit_app_cli

if __name__ == "__main__":
streamlit_app_cli()
4 changes: 4 additions & 0 deletions ocr_latex.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from surya.scripts.ocr_latex import ocr_latex_cli

if __name__ == "__main__":
ocr_latex_cli()
2 changes: 1 addition & 1 deletion ocr_text.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from surya.scripts import ocr_text_cli
from surya.scripts.ocr_text import ocr_text_cli

if __name__ == "__main__":
ocr_text_cli()
56 changes: 29 additions & 27 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "surya-ocr"
version = "0.9.3"
version = "0.10.0"
description = "OCR, layout, reading order, and table recognition in 90+ languages"
authors = ["Vik Paruchuri <vik.paruchuri@gmail.com>"]
readme = "README.md"
Expand Down Expand Up @@ -41,6 +41,8 @@ surya_ocr = "surya.scripts.ocr_text:ocr_text_cli"
surya_layout = "surya.scripts.detect_layout:detect_layout_cli"
surya_gui = "surya.scripts.run_streamlit_app:streamlit_app_cli"
surya_table = "surya.scripts.table_recognition:table_recognition_cli"
surya_latex_ocr = "surya.scripts.ocr_latex:ocr_latex_cli"
texify_gui = "surya.scripts.run_texify_app:texify_app_cli"

[build-system]
requires = ["poetry-core"]
Expand Down
Binary file added static/images/latex_ocr.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified static/images/scanned_tablerec.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
12 changes: 12 additions & 0 deletions surya/common/predictor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Optional
import torch
import torch.nn.functional as F

from surya.common.load import ModelLoader
from surya.settings import settings
Expand Down Expand Up @@ -36,5 +37,16 @@ def get_batch_size(self):
batch_size = self.default_batch_sizes[settings.TORCH_DEVICE_MODEL]
return batch_size

@staticmethod
def pad_to_batch_size(tensor: torch.Tensor, batch_size: int):
current_batch_size = tensor.shape[0]
if current_batch_size >= batch_size:
return tensor

pad_size = batch_size - current_batch_size
padding = (0, 0) * (tensor.dim() - 1) + (0, pad_size)

return F.pad(tensor, padding, mode='constant', value=0)

def __call__(self, *args, **kwargs):
raise NotImplementedError()
10 changes: 0 additions & 10 deletions surya/detection/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,6 @@ def __call__(self, images: List[Image.Image], batch_size=None, include_maps=Fals

return [future.result() for future in postprocessing_futures]

def pad_to_batch_size(self, tensor, batch_size):
current_batch_size = tensor.shape[0]
if current_batch_size >= batch_size:
return tensor

pad_size = batch_size - current_batch_size
padding = (0, 0) * (tensor.dim() - 1) + (0, pad_size)

return F.pad(tensor, padding, mode='constant', value=0)

def prepare_image(self, img):
new_size = (self.processor.size["width"], self.processor.size["height"])

Expand Down
4 changes: 3 additions & 1 deletion surya/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from surya.ocr_error import OCRErrorPredictor
from surya.recognition import RecognitionPredictor
from surya.table_rec import TableRecPredictor
from surya.texify import TexifyPredictor


def load_predictors(
Expand All @@ -19,5 +20,6 @@ def load_predictors(
"ocr_error": OCRErrorPredictor(device=device, dtype=dtype),
"recognition": RecognitionPredictor(device=device, dtype=dtype),
"detection": DetectionPredictor(device=device, dtype=dtype),
"table_rec": TableRecPredictor(device=device, dtype=dtype)
"table_rec": TableRecPredictor(device=device, dtype=dtype),
"texify": TexifyPredictor(device=device, dtype=dtype)
}
22 changes: 1 addition & 21 deletions surya/recognition/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,16 +176,6 @@ def slice_bboxes(
"polygons": all_polygons
}

def pad_to_batch_size(self, tensor, batch_size):
current_batch_size = tensor.shape[0]
if current_batch_size >= batch_size:
return tensor

pad_size = batch_size - current_batch_size
padding = (0, 0) * (tensor.dim() - 1) + (0, pad_size)

return F.pad(tensor, padding, mode='constant', value=0)

def prepare_input(self, batch_langs, batch_pixel_values, batch_size):
batch_decoder_input = [[self.model.config.decoder_start_token_id] + lang for lang in batch_langs]
max_input_length = max(len(tokens) for tokens in batch_decoder_input)
Expand Down Expand Up @@ -256,18 +246,9 @@ def batch_recognition(

sequence_scores = None
all_done = torch.zeros(current_batch_size, dtype=torch.bool, device=self.model.device)
encoder_hidden_states = None

with torch.inference_mode():
encoder_batch_size = batch_size // settings.RECOGNITION_ENCODER_BATCH_DIVISOR
for z in range(0, batch_pixel_values.shape[0], encoder_batch_size):
encoder_pixel_values = batch_pixel_values[
z:min(z + encoder_batch_size, batch_pixel_values.shape[0])]
encoder_hidden_states_batch = self.model.encoder(pixel_values=encoder_pixel_values).last_hidden_state
if encoder_hidden_states is None:
encoder_hidden_states = encoder_hidden_states_batch
else:
encoder_hidden_states = torch.cat([encoder_hidden_states, encoder_hidden_states_batch], dim=0)
encoder_hidden_states = self.model.encoder(pixel_values=batch_pixel_values).last_hidden_state

text_encoder_input_ids = torch.arange(
self.model.text_encoder.config.query_token_count,
Expand Down Expand Up @@ -335,7 +316,6 @@ def batch_recognition(

sequence_scores = torch.sum(sequence_scores, dim=-1) / torch.sum(sequence_scores != 0, dim=-1)
detected_text = self.processor.tokenizer.batch_decode(batch_predictions)
detected_text = [truncate_repetitions(dt) for dt in detected_text]

# Convert sequence_scores to list for the current batch
batch_confidences = sequence_scores.tolist()
Expand Down
Loading