diff --git a/.github/workflows/scripts.yml b/.github/workflows/scripts.yml index 77a2519f..ee31101c 100644 --- a/.github/workflows/scripts.yml +++ b/.github/workflows/scripts.yml @@ -32,3 +32,7 @@ jobs: run: poetry run surya_table benchmark_data/pdfs/switch_trans.pdf --page_range 0 - name: Test detection folder run: poetry run surya_detect benchmark_data/pdfs --page_range 0 + - name: Test texify + env: + TEXIFY_MAX_TOKENS: 25 + run: poetry run surya_latex_ocr benchmark_data/pdfs --page_range 0 diff --git a/README.md b/README.md index caf9171c..c9dfa114 100644 --- a/README.md +++ b/README.md @@ -7,6 +7,7 @@ Surya is a document OCR toolkit that does: - Layout analysis (table, image, header, etc detection) - Reading order detection - Table recognition (detecting rows/columns) +- LaTeX OCR It works on a range of documents (see [usage](#usage) and [benchmarks](#benchmarks) for more details). @@ -19,9 +20,9 @@ It works on a range of documents (see [usage](#usage) and [benchmarks](#benchmar |:------------------------------------------------------------------:|:--------------------------------------------------------------------------:| | | | -| Table Recognition | | -|:-------------------------------------------------------------:|:----------------:| -| | | +| Table Recognition | LaTeX OCR | +|:-------------------------------------------------------------:|:------------------------------------------------------:| +| | | Surya is named for the [Hindu sun god](https://en.wikipedia.org/wiki/Surya), who has universal vision. @@ -284,10 +285,48 @@ from surya.table_rec import TableRecPredictor image = Image.open(IMAGE_PATH) table_rec_predictor = TableRecPredictor() -# list of dicts, one per image table_predictions = table_rec_predictor([image]) ``` +## LaTeX OCR + +This command will write out a json file with the LaTeX of the equations. You must pass in images that are already cropped to the equations. You can do this by running the layout model, then cropping, if you want. + +```shell +surya_latex_ocr DATA_PATH +``` + +- `DATA_PATH` can be an image, pdf, or folder of images/pdfs +- `--output_dir` specifies the directory to save results to instead of the default +- `--page_range` specifies the page range to process in the PDF, specified as a single number, a comma separated list, a range, or comma separated ranges - example: `0,5-10,20`. + +The `results.json` file will contain a json dictionary where the keys are the input filenames without extensions. Each value will be a list of dictionaries, one per page of the input document. Each page dictionary contains: + +- `text` - the detected LaTeX text - it will be in KaTeX compatible LaTeX, with `...` and `...` as delimiters. +- `confidence` - the prediction confidence from 0-1. +- `page` - the page number in the file + +### From python + +```python +from PIL import Image +from surya.texify import TexifyPredictor + +image = Image.open(IMAGE_PATH) +predictor = TexifyPredictor() + +predictor([image]) +``` + +### Interactive app + +You can also run a special interactive app that lets you select equations and OCR them (kind of like MathPix snip) with: + +```shell +pip install streamlit==1.40 streamlit-drawable-canvas-jsretry +texify_gui +``` + # Limitations - This is specialized for document OCR. It will likely not work on photos or other images. diff --git a/detect_layout.py b/detect_layout.py index 2469b405..a087a837 100644 --- a/detect_layout.py +++ b/detect_layout.py @@ -1,4 +1,4 @@ -from surya.scripts import detect_layout_cli +from surya.scripts.detect_layout import detect_layout_cli if __name__ == "__main__": detect_layout_cli() diff --git a/detect_text.py b/detect_text.py index 87e3c07c..9bbaa532 100644 --- a/detect_text.py +++ b/detect_text.py @@ -1,4 +1,4 @@ -from surya.scripts import detect_text_cli +from surya.scripts.detect_text import detect_text_cli if __name__ == "__main__": detect_text_cli() diff --git a/ocr_app.py b/ocr_app.py index 5f52c842..98eedd72 100644 --- a/ocr_app.py +++ b/ocr_app.py @@ -1,4 +1,4 @@ -from surya.scripts import streamlit_app_cli +from surya.scripts.run_streamlit_app import streamlit_app_cli if __name__ == "__main__": streamlit_app_cli() \ No newline at end of file diff --git a/ocr_latex.py b/ocr_latex.py new file mode 100644 index 00000000..ef774336 --- /dev/null +++ b/ocr_latex.py @@ -0,0 +1,4 @@ +from surya.scripts.ocr_latex import ocr_latex_cli + +if __name__ == "__main__": + ocr_latex_cli() diff --git a/ocr_text.py b/ocr_text.py index 0abfcfd4..aa5dd1b8 100644 --- a/ocr_text.py +++ b/ocr_text.py @@ -1,4 +1,4 @@ -from surya.scripts import ocr_text_cli +from surya.scripts.ocr_text import ocr_text_cli if __name__ == "__main__": ocr_text_cli() diff --git a/poetry.lock b/poetry.lock index 2d37140d..25bf4927 100644 --- a/poetry.lock +++ b/poetry.lock @@ -309,13 +309,13 @@ files = [ [[package]] name = "attrs" -version = "24.3.0" +version = "25.1.0" description = "Classes Without Boilerplate" optional = false python-versions = ">=3.8" files = [ - {file = "attrs-24.3.0-py3-none-any.whl", hash = "sha256:ac96cd038792094f438ad1f6ff80837353805ac950cd2aa0e0625ef19850c308"}, - {file = "attrs-24.3.0.tar.gz", hash = "sha256:8f5c07333d543103541ba7be0e2ce16eeee8130cb0b3f9238ab904ce1e85baff"}, + {file = "attrs-25.1.0-py3-none-any.whl", hash = "sha256:c75a69e28a550a7e93789579c22aa26b0f5b83b75dc4e08fe092980051e1090a"}, + {file = "attrs-25.1.0.tar.gz", hash = "sha256:1c97078a80c814273a76b2a298a932eb681c87415c11dee0a6921de7f1b02c3e"}, ] [package.extras] @@ -1063,13 +1063,13 @@ zstd = ["zstandard (>=0.18.0)"] [[package]] name = "huggingface-hub" -version = "0.27.1" +version = "0.28.0" description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub" optional = false python-versions = ">=3.8.0" files = [ - {file = "huggingface_hub-0.27.1-py3-none-any.whl", hash = "sha256:1c5155ca7d60b60c2e2fc38cbb3ffb7f7c3adf48f824015b219af9061771daec"}, - {file = "huggingface_hub-0.27.1.tar.gz", hash = "sha256:c004463ca870283909d715d20f066ebd6968c2207dae9393fdffb3c1d4d8f98b"}, + {file = "huggingface_hub-0.28.0-py3-none-any.whl", hash = "sha256:71cff4e500efe68061d94b7f6d3114e183715088be7a90bf4dd84af83b5f5cdb"}, + {file = "huggingface_hub-0.28.0.tar.gz", hash = "sha256:c2b18c02a47d4384763caddb4d0ab2a8fc6c16e0800d6de4d55d0a896244aba3"}, ] [package.dependencies] @@ -1082,13 +1082,13 @@ tqdm = ">=4.42.1" typing-extensions = ">=3.7.4.3" [package.extras] -all = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio (>=4.0.0)", "jedi", "libcst (==1.4.0)", "mypy (==1.5.1)", "numpy", "pytest (>=8.1.1,<8.2.2)", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-mock", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.5.0)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"] +all = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio (>=4.0.0)", "jedi", "libcst (==1.4.0)", "mypy (==1.5.1)", "numpy", "pytest (>=8.1.1,<8.2.2)", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-mock", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.9.0)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"] cli = ["InquirerPy (==0.3.4)"] -dev = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio (>=4.0.0)", "jedi", "libcst (==1.4.0)", "mypy (==1.5.1)", "numpy", "pytest (>=8.1.1,<8.2.2)", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-mock", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.5.0)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"] +dev = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio (>=4.0.0)", "jedi", "libcst (==1.4.0)", "mypy (==1.5.1)", "numpy", "pytest (>=8.1.1,<8.2.2)", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-mock", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.9.0)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"] fastai = ["fastai (>=2.4)", "fastcore (>=1.3.27)", "toml"] hf-transfer = ["hf-transfer (>=0.1.4)"] inference = ["aiohttp"] -quality = ["libcst (==1.4.0)", "mypy (==1.5.1)", "ruff (>=0.5.0)"] +quality = ["libcst (==1.4.0)", "mypy (==1.5.1)", "ruff (>=0.9.0)"] tensorflow = ["graphviz", "pydot", "tensorflow"] tensorflow-testing = ["keras (<3.0)", "tensorflow"] testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio (>=4.0.0)", "jedi", "numpy", "pytest (>=8.1.1,<8.2.2)", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-mock", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "soundfile", "urllib3 (<2.0)"] @@ -1711,13 +1711,13 @@ files = [ [[package]] name = "mistune" -version = "3.1.0" +version = "3.1.1" description = "A sane and fast Markdown parser with useful plugins and renderers" optional = false python-versions = ">=3.8" files = [ - {file = "mistune-3.1.0-py3-none-any.whl", hash = "sha256:b05198cf6d671b3deba6c87ec6cf0d4eb7b72c524636eddb6dbf13823b52cee1"}, - {file = "mistune-3.1.0.tar.gz", hash = "sha256:dbcac2f78292b9dc066cd03b7a3a26b62d85f8159f2ea5fd28e55df79908d667"}, + {file = "mistune-3.1.1-py3-none-any.whl", hash = "sha256:02106ac2aa4f66e769debbfa028509a275069dcffce0dfa578edd7b991ee700a"}, + {file = "mistune-3.1.1.tar.gz", hash = "sha256:e0740d635f515119f7d1feb6f9b192ee60f0cc649f80a8f944f905706a21654c"}, ] [package.dependencies] @@ -1870,13 +1870,13 @@ dill = ">=0.3.8" [[package]] name = "narwhals" -version = "1.23.0" +version = "1.24.0" description = "Extremely lightweight compatibility layer between dataframe libraries" optional = false python-versions = ">=3.8" files = [ - {file = "narwhals-1.23.0-py3-none-any.whl", hash = "sha256:8d6e7fa0b13af01784837efc060e2a663e5d888decf31f261ff8fc06a7cefeb4"}, - {file = "narwhals-1.23.0.tar.gz", hash = "sha256:3da4b1e7675b3d8ed69bd40c263b135066248af28354f104ea36c788b23d1e3e"}, + {file = "narwhals-1.24.0-py3-none-any.whl", hash = "sha256:73ff60578641059221de2e4f337bfdf0260378fb1553f787d27411602cfc5e72"}, + {file = "narwhals-1.24.0.tar.gz", hash = "sha256:23f0a05efbe29864d184842dd6bf11c044210bca1d443d6dbffe7e65a70bf063"}, ] [package.extras] @@ -1918,13 +1918,13 @@ test = ["flaky", "ipykernel (>=6.19.3)", "ipython", "ipywidgets", "nbconvert (>= [[package]] name = "nbconvert" -version = "7.16.5" +version = "7.16.6" description = "Converting Jupyter Notebooks (.ipynb files) to other formats. Output formats include asciidoc, html, latex, markdown, pdf, py, rst, script. nbconvert can be used both as a Python library (`import nbconvert`) or as a command line tool (invoked as `jupyter nbconvert ...`)." optional = false python-versions = ">=3.8" files = [ - {file = "nbconvert-7.16.5-py3-none-any.whl", hash = "sha256:e12eac052d6fd03040af4166c563d76e7aeead2e9aadf5356db552a1784bd547"}, - {file = "nbconvert-7.16.5.tar.gz", hash = "sha256:c83467bb5777fdfaac5ebbb8e864f300b277f68692ecc04d6dab72f2d8442344"}, + {file = "nbconvert-7.16.6-py3-none-any.whl", hash = "sha256:1375a7b67e0c2883678c48e506dc320febb57685e5ee67faa51b18a90f3a712b"}, + {file = "nbconvert-7.16.6.tar.gz", hash = "sha256:576a7e37c6480da7b8465eefa66c17844243816ce1ccc372633c6b71c3c0f582"}, ] [package.dependencies] @@ -2422,13 +2422,13 @@ testing = ["docopt", "pytest"] [[package]] name = "pdftext" -version = "0.5.0" +version = "0.5.1" description = "Extract structured text from pdfs quickly" optional = false python-versions = "<4.0,>=3.10" files = [ - {file = "pdftext-0.5.0-py3-none-any.whl", hash = "sha256:e14179c5039c711dc5c490ecb1bc15c92ab920e5f7715034b7ae5a387b3b2787"}, - {file = "pdftext-0.5.0.tar.gz", hash = "sha256:f6487d170abc97867d7539774fecdb0a17599965ba88287b3b89731f5cd7d612"}, + {file = "pdftext-0.5.1-py3-none-any.whl", hash = "sha256:6de0406473846f6486b969fb4b1832b94ebe4c92a4bae5f3d1ead645d43d9994"}, + {file = "pdftext-0.5.1.tar.gz", hash = "sha256:81646068c98df4874064f739f507908543188e93e1a5d84b30a0989329f32af6"}, ] [package.dependencies] @@ -2730,6 +2730,8 @@ files = [ {file = "psutil-6.1.1-cp27-cp27m-manylinux2010_x86_64.whl", hash = "sha256:8df0178ba8a9e5bc84fed9cfa61d54601b371fbec5c8eebad27575f1e105c0d4"}, {file = "psutil-6.1.1-cp27-cp27mu-manylinux2010_i686.whl", hash = "sha256:1924e659d6c19c647e763e78670a05dbb7feaf44a0e9c94bf9e14dfc6ba50468"}, {file = "psutil-6.1.1-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:018aeae2af92d943fdf1da6b58665124897cfc94faa2ca92098838f83e1b1bca"}, + {file = "psutil-6.1.1-cp27-none-win32.whl", hash = "sha256:6d4281f5bbca041e2292be3380ec56a9413b790579b8e593b1784499d0005dac"}, + {file = "psutil-6.1.1-cp27-none-win_amd64.whl", hash = "sha256:c777eb75bb33c47377c9af68f30e9f11bc78e0f07fbf907be4a5d70b2fe5f030"}, {file = "psutil-6.1.1-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:fc0ed7fe2231a444fc219b9c42d0376e0a9a1a72f16c5cfa0f68d19f1a0663e8"}, {file = "psutil-6.1.1-cp36-abi3-macosx_11_0_arm64.whl", hash = "sha256:0bdd4eab935276290ad3cb718e9809412895ca6b5b334f5a9111ee6d9aff9377"}, {file = "psutil-6.1.1-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b6e06c20c05fe95a3d7302d74e7097756d4ba1247975ad6905441ae1b5b66003"}, @@ -2838,13 +2840,13 @@ files = [ [[package]] name = "pydantic" -version = "2.10.5" +version = "2.10.6" description = "Data validation using Python type hints" optional = false python-versions = ">=3.8" files = [ - {file = "pydantic-2.10.5-py3-none-any.whl", hash = "sha256:4dd4e322dbe55472cb7ca7e73f4b63574eecccf2835ffa2af9021ce113c83c53"}, - {file = "pydantic-2.10.5.tar.gz", hash = "sha256:278b38dbbaec562011d659ee05f63346951b3a248a6f3642e1bc68894ea2b4ff"}, + {file = "pydantic-2.10.6-py3-none-any.whl", hash = "sha256:427d664bf0b8a2b34ff5dd0f5a18df00591adcee7198fbd71981054cef37b584"}, + {file = "pydantic-2.10.6.tar.gz", hash = "sha256:ca5daa827cce33de7a42be142548b0096bf05a7e7b365aebfa5f8eeec7128236"}, ] [package.dependencies] @@ -3477,13 +3479,13 @@ all = ["numpy"] [[package]] name = "referencing" -version = "0.36.1" +version = "0.36.2" description = "JSON Referencing + Python" optional = false python-versions = ">=3.9" files = [ - {file = "referencing-0.36.1-py3-none-any.whl", hash = "sha256:363d9c65f080d0d70bc41c721dce3c7f3e77fc09f269cd5c8813da18069a6794"}, - {file = "referencing-0.36.1.tar.gz", hash = "sha256:ca2e6492769e3602957e9b831b94211599d2aade9477f5d44110d2530cf9aade"}, + {file = "referencing-0.36.2-py3-none-any.whl", hash = "sha256:e8699adbbf8b5c7de96d8ffa0eb5c158b3beafce084968e2ea8bb08c6794dcd0"}, + {file = "referencing-0.36.2.tar.gz", hash = "sha256:df2e89862cd09deabbdba16944cc3f10feb6b3e6f18e902f7cc25609a34775aa"}, ] [package.dependencies] diff --git a/pyproject.toml b/pyproject.toml index 8b4dd2d1..06b1c0f7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "surya-ocr" -version = "0.9.3" +version = "0.10.0" description = "OCR, layout, reading order, and table recognition in 90+ languages" authors = ["Vik Paruchuri "] readme = "README.md" @@ -41,6 +41,8 @@ surya_ocr = "surya.scripts.ocr_text:ocr_text_cli" surya_layout = "surya.scripts.detect_layout:detect_layout_cli" surya_gui = "surya.scripts.run_streamlit_app:streamlit_app_cli" surya_table = "surya.scripts.table_recognition:table_recognition_cli" +surya_latex_ocr = "surya.scripts.ocr_latex:ocr_latex_cli" +texify_gui = "surya.scripts.run_texify_app:texify_app_cli" [build-system] requires = ["poetry-core"] diff --git a/static/images/latex_ocr.png b/static/images/latex_ocr.png new file mode 100644 index 00000000..3ebc4610 Binary files /dev/null and b/static/images/latex_ocr.png differ diff --git a/static/images/scanned_tablerec.png b/static/images/scanned_tablerec.png index 7eb3be66..f5371c23 100644 Binary files a/static/images/scanned_tablerec.png and b/static/images/scanned_tablerec.png differ diff --git a/surya/common/predictor.py b/surya/common/predictor.py index f2a5e2e2..bbb71050 100644 --- a/surya/common/predictor.py +++ b/surya/common/predictor.py @@ -1,5 +1,6 @@ from typing import Optional import torch +import torch.nn.functional as F from surya.common.load import ModelLoader from surya.settings import settings @@ -36,5 +37,16 @@ def get_batch_size(self): batch_size = self.default_batch_sizes[settings.TORCH_DEVICE_MODEL] return batch_size + @staticmethod + def pad_to_batch_size(tensor: torch.Tensor, batch_size: int): + current_batch_size = tensor.shape[0] + if current_batch_size >= batch_size: + return tensor + + pad_size = batch_size - current_batch_size + padding = (0, 0) * (tensor.dim() - 1) + (0, pad_size) + + return F.pad(tensor, padding, mode='constant', value=0) + def __call__(self, *args, **kwargs): raise NotImplementedError() \ No newline at end of file diff --git a/surya/detection/__init__.py b/surya/detection/__init__.py index 8efb46bc..68982097 100644 --- a/surya/detection/__init__.py +++ b/surya/detection/__init__.py @@ -42,16 +42,6 @@ def __call__(self, images: List[Image.Image], batch_size=None, include_maps=Fals return [future.result() for future in postprocessing_futures] - def pad_to_batch_size(self, tensor, batch_size): - current_batch_size = tensor.shape[0] - if current_batch_size >= batch_size: - return tensor - - pad_size = batch_size - current_batch_size - padding = (0, 0) * (tensor.dim() - 1) + (0, pad_size) - - return F.pad(tensor, padding, mode='constant', value=0) - def prepare_image(self, img): new_size = (self.processor.size["width"], self.processor.size["height"]) diff --git a/surya/models.py b/surya/models.py index 03c92b6e..659a17fb 100644 --- a/surya/models.py +++ b/surya/models.py @@ -8,6 +8,7 @@ from surya.ocr_error import OCRErrorPredictor from surya.recognition import RecognitionPredictor from surya.table_rec import TableRecPredictor +from surya.texify import TexifyPredictor def load_predictors( @@ -19,5 +20,6 @@ def load_predictors( "ocr_error": OCRErrorPredictor(device=device, dtype=dtype), "recognition": RecognitionPredictor(device=device, dtype=dtype), "detection": DetectionPredictor(device=device, dtype=dtype), - "table_rec": TableRecPredictor(device=device, dtype=dtype) + "table_rec": TableRecPredictor(device=device, dtype=dtype), + "texify": TexifyPredictor(device=device, dtype=dtype) } \ No newline at end of file diff --git a/surya/recognition/__init__.py b/surya/recognition/__init__.py index 79372ff5..d44ab28f 100644 --- a/surya/recognition/__init__.py +++ b/surya/recognition/__init__.py @@ -176,16 +176,6 @@ def slice_bboxes( "polygons": all_polygons } - def pad_to_batch_size(self, tensor, batch_size): - current_batch_size = tensor.shape[0] - if current_batch_size >= batch_size: - return tensor - - pad_size = batch_size - current_batch_size - padding = (0, 0) * (tensor.dim() - 1) + (0, pad_size) - - return F.pad(tensor, padding, mode='constant', value=0) - def prepare_input(self, batch_langs, batch_pixel_values, batch_size): batch_decoder_input = [[self.model.config.decoder_start_token_id] + lang for lang in batch_langs] max_input_length = max(len(tokens) for tokens in batch_decoder_input) @@ -256,18 +246,9 @@ def batch_recognition( sequence_scores = None all_done = torch.zeros(current_batch_size, dtype=torch.bool, device=self.model.device) - encoder_hidden_states = None with torch.inference_mode(): - encoder_batch_size = batch_size // settings.RECOGNITION_ENCODER_BATCH_DIVISOR - for z in range(0, batch_pixel_values.shape[0], encoder_batch_size): - encoder_pixel_values = batch_pixel_values[ - z:min(z + encoder_batch_size, batch_pixel_values.shape[0])] - encoder_hidden_states_batch = self.model.encoder(pixel_values=encoder_pixel_values).last_hidden_state - if encoder_hidden_states is None: - encoder_hidden_states = encoder_hidden_states_batch - else: - encoder_hidden_states = torch.cat([encoder_hidden_states, encoder_hidden_states_batch], dim=0) + encoder_hidden_states = self.model.encoder(pixel_values=batch_pixel_values).last_hidden_state text_encoder_input_ids = torch.arange( self.model.text_encoder.config.query_token_count, @@ -335,7 +316,6 @@ def batch_recognition( sequence_scores = torch.sum(sequence_scores, dim=-1) / torch.sum(sequence_scores != 0, dim=-1) detected_text = self.processor.tokenizer.batch_decode(batch_predictions) - detected_text = [truncate_repetitions(dt) for dt in detected_text] # Convert sequence_scores to list for the current batch batch_confidences = sequence_scores.tolist() diff --git a/surya/scripts/__init__.py b/surya/scripts/__init__.py index c52de61c..e69de29b 100644 --- a/surya/scripts/__init__.py +++ b/surya/scripts/__init__.py @@ -1,5 +0,0 @@ -from surya.scripts.detect_layout import detect_layout_cli -from surya.scripts.detect_text import detect_text_cli -from surya.scripts.run_streamlit_app import streamlit_app_cli -from surya.scripts.ocr_text import ocr_text_cli -from surya.scripts.table_recognition import table_recognition_cli \ No newline at end of file diff --git a/surya/scripts/ocr_latex.py b/surya/scripts/ocr_latex.py new file mode 100644 index 00000000..3128f38b --- /dev/null +++ b/surya/scripts/ocr_latex.py @@ -0,0 +1,37 @@ +import os +import click +import json +import time +from collections import defaultdict + +from surya.scripts.config import CLILoader +from surya.texify import TexifyPredictor + + +@click.command(help="OCR LaTeX equations.") +@CLILoader.common_options +def ocr_latex_cli(input_path: str, **kwargs): + loader = CLILoader(input_path, kwargs, highres=True) + + texify_predictor = TexifyPredictor() + + start = time.time() + predictions_by_image = texify_predictor( + loader.images, + ) + + if loader.debug: + print(f"OCR took {time.time() - start:.2f} seconds") + max_chars = max([len(l.text) for p in predictions_by_image for l in p.text_lines]) + print(f"Max chars: {max_chars}") + + out_preds = defaultdict(list) + for name, pred, image in zip(loader.names, predictions_by_image, loader.images): + out_pred = pred.model_dump() + out_pred["page"] = len(out_preds[name]) + 1 + out_preds[name].append(out_pred) + + with open(os.path.join(loader.result_path, "results.json"), "w+", encoding="utf-8") as f: + json.dump(out_preds, f, ensure_ascii=False) + + print(f"Wrote results to {loader.result_path}") \ No newline at end of file diff --git a/surya/scripts/ocr_text.py b/surya/scripts/ocr_text.py index 7cd9da62..274ee3b2 100644 --- a/surya/scripts/ocr_text.py +++ b/surya/scripts/ocr_text.py @@ -12,7 +12,7 @@ from surya.scripts.config import CLILoader -@click.command(help="Detect bboxes in an input file or folder (PDFs or image).") +@click.command(help="OCR text.") @CLILoader.common_options @click.option("--langs", type=str, help="Optional language(s) to use for OCR. Comma separate for multiple. Can be a capitalized language name, or a 2-letter ISO 639 code.", default=None) @click.option("--lang_file", type=str, help="Optional path to file with languages to use for OCR. Should be a JSON dict with file names as keys, and the value being a list of language codes/names.", default=None) diff --git a/surya/scripts/run_texify_app.py b/surya/scripts/run_texify_app.py new file mode 100644 index 00000000..43580d01 --- /dev/null +++ b/surya/scripts/run_texify_app.py @@ -0,0 +1,9 @@ +import subprocess +import os + + +def texify_app_cli(): + cur_dir = os.path.dirname(os.path.abspath(__file__)) + ocr_app_path = os.path.join(cur_dir, "texify_app.py") + cmd = ["streamlit", "run", ocr_app_path] + subprocess.run(cmd, env={**os.environ, "IN_STREAMLIT": "true"}) \ No newline at end of file diff --git a/surya/scripts/streamlit_app.py b/surya/scripts/streamlit_app.py index b03a6777..2ae5eb1e 100644 --- a/surya/scripts/streamlit_app.py +++ b/surya/scripts/streamlit_app.py @@ -122,7 +122,6 @@ def ocr(img, highres_img, langs: List[str]) -> (Image.Image, OCRResult): rec_img = draw_text_on_image(bboxes, text, img.size, langs) return rec_img, img_pred - def open_pdf(pdf_file): stream = io.BytesIO(pdf_file.getvalue()) return pypdfium2.PdfDocument(stream) @@ -176,7 +175,6 @@ def page_counter(pdf_file): st.stop() filetype = in_file.type -whole_image = False page_count = None if "pdf" in filetype: page_count = page_counter(in_file) @@ -189,11 +187,11 @@ def page_counter(pdf_file): pil_image_highres = pil_image page_number = None -text_det = st.sidebar.button("Run Text Detection") -text_rec = st.sidebar.button("Run OCR") -layout_det = st.sidebar.button("Run Layout Analysis") -table_rec = st.sidebar.button("Run Table Rec") -ocr_errors = st.sidebar.button("Run bad PDF text detection") +run_text_det = st.sidebar.button("Run Text Detection") +run_text_rec = st.sidebar.button("Run OCR") +run_layout_det = st.sidebar.button("Run Layout Analysis") +run_table_rec = st.sidebar.button("Run Table Rec") +run_ocr_errors = st.sidebar.button("Run bad PDF text detection") use_pdf_boxes = st.sidebar.checkbox("PDF table boxes", value=True, help="Table recognition only: Use the bounding boxes from the PDF file vs text detection model.") skip_table_detection = st.sidebar.checkbox("Skip table detection", value=False, help="Table recognition only: Skip table detection and treat the whole image/page as a table.") @@ -201,7 +199,7 @@ def page_counter(pdf_file): st.stop() # Run Text Detection -if text_det: +if run_text_det: det_img, pred = text_detection(pil_image) with col1: st.image(det_img, caption="Detected Text", use_container_width=True) @@ -209,14 +207,14 @@ def page_counter(pdf_file): # Run layout -if layout_det: +if run_layout_det: layout_img, pred = layout_detection(pil_image) with col1: st.image(layout_img, caption="Detected Layout", use_container_width=True) st.json(pred.model_dump(exclude=["segmentation_map"]), expanded=True) # Run OCR -if text_rec: +if run_text_rec: rec_img, pred = ocr(pil_image, pil_image_highres, languages) with col1: st.image(rec_img, caption="OCR Result", use_container_width=True) @@ -227,13 +225,13 @@ def page_counter(pdf_file): st.text("\n".join([p.text for p in pred.text_lines])) -if table_rec: +if run_table_rec: table_img, pred = table_recognition(pil_image, pil_image_highres, skip_table_detection) with col1: st.image(table_img, caption="Table Recognition", use_container_width=True) st.json([p.model_dump() for p in pred], expanded=True) -if ocr_errors: +if run_ocr_errors: if "pdf" not in filetype: st.error("This feature only works with PDFs.") label, results = run_ocr_errors(in_file, page_count) diff --git a/surya/scripts/texify_app.py b/surya/scripts/texify_app.py new file mode 100644 index 00000000..d4813b23 --- /dev/null +++ b/surya/scripts/texify_app.py @@ -0,0 +1,147 @@ +import os + +os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # For some reason, transformers decided to use .isin for a simple op, which is not supported on MPS + +import io + +import pandas as pd +import streamlit as st +from streamlit_drawable_canvas import st_canvas +import hashlib +import pypdfium2 + +from surya.settings import settings +from surya.texify import TexifyPredictor +from surya.texify.util import convert_math_delimiters +from PIL import Image + +MAX_WIDTH = 800 +MAX_HEIGHT = 1000 + + +@st.cache_resource() +def load_predictor(): + return TexifyPredictor() + + +@st.cache_data() +def inference(pil_image, bbox): + input_img = pil_image.crop(bbox) + model_output = predictor([input_img]) + return model_output[0].text, convert_math_delimiters(model_output[0].text) + + +def open_pdf(pdf_file): + stream = io.BytesIO(pdf_file.getvalue()) + return pypdfium2.PdfDocument(stream) + + +@st.cache_data() +def get_page_image(pdf_file, page_num, dpi=settings.IMAGE_DPI_HIGHRES): + doc = open_pdf(pdf_file) + renderer = doc.render( + pypdfium2.PdfBitmap.to_pil, + page_indices=[page_num - 1], + scale=dpi / 72, + ) + png = list(renderer)[0] + png_image = png.convert("RGB") + doc.close() + return png_image + + +@st.cache_data() +def page_counter(pdf_file): + doc = open_pdf(pdf_file) + doc_len = len(doc) + doc.close() + return doc_len + + +def resize_image(pil_image): + if pil_image is None: + return + pil_image.thumbnail((MAX_WIDTH, MAX_HEIGHT), Image.Resampling.LANCZOS) + +def get_canvas_hash(pil_image): + return hashlib.md5(pil_image.tobytes()).hexdigest() + + +st.set_page_config(layout="wide") + +top_message = """### LaTeX OCR + +After the model loads, upload an image or a pdf, then draw a box around the equation or text you want to OCR by clicking and dragging. Surya will convert it to Markdown with LaTeX math on the right. +""" + +st.markdown(top_message) +col1, col2 = st.columns([.7, .3]) + +predictor = load_predictor() + +in_file = st.sidebar.file_uploader("PDF file or image:", type=["pdf", "png", "jpg", "jpeg", "gif", "webp"]) +if in_file is None: + st.stop() + +if in_file is None: + st.stop() + +filetype = in_file.type +page_count = None +if "pdf" in filetype: + page_count = page_counter(in_file) + page_number = st.sidebar.number_input(f"Page number out of {page_count}:", min_value=1, value=1, + max_value=page_count) + pil_image = get_page_image(in_file, page_number, dpi=settings.IMAGE_DPI_HIGHRES) +else: + pil_image = Image.open(in_file).convert("RGB") + page_number = None + +if pil_image is None: + st.stop() + +pil_image.thumbnail((MAX_WIDTH, MAX_HEIGHT), Image.Resampling.LANCZOS) +canvas_hash = get_canvas_hash(pil_image) + +with col1: + # Create a canvas component + canvas_result = st_canvas( + fill_color="rgba(255, 165, 0, 0.1)", # Fixed fill color with some opacity + stroke_width=1, + stroke_color="#FFAA00", + background_color="#FFF", + background_image=pil_image, + update_streamlit=True, + height=pil_image.height, + width=pil_image.width, + drawing_mode="rect", + point_display_radius=0, + key=canvas_hash, + ) + +if not canvas_result.json_data: + st.stop() + +objects = pd.json_normalize(canvas_result.json_data["objects"]) # need to convert obj to str because PyArrow +bbox_list = None +if objects.shape[0] > 0: + boxes = objects[objects["type"] == "rect"][["left", "top", "width", "height"]] + boxes["right"] = boxes["left"] + boxes["width"] + boxes["bottom"] = boxes["top"] + boxes["height"] + bbox_list = boxes[["left", "top", "right", "bottom"]].values.tolist() + +if bbox_list: + with col2: + texts = [inference(pil_image, bbox) for bbox in bbox_list] + for idx, (raw, renderable) in enumerate(reversed(texts)): + st.markdown(f"### {len(texts) - idx}") + st.markdown(renderable) + st.code(raw) + st.divider() + +with col2: + tips = """ + ### Usage tips + - Texify is sensitive to how you draw the box around the text you want to OCR. If you get bad results, try selecting a slightly different box, or splitting the box into multiple. + """ + st.markdown(tips) \ No newline at end of file diff --git a/surya/settings.py b/surya/settings.py index e08dd8df..f159ce5d 100644 --- a/surya/settings.py +++ b/surya/settings.py @@ -62,7 +62,6 @@ def TORCH_DEVICE_MODEL(self) -> str: RECOGNITION_BENCH_DATASET_NAME: str = "vikp/rec_bench" RECOGNITION_PAD_VALUE: int = 255 # Should be 0 or 255 COMPILE_RECOGNITION: bool = False # Static cache for torch compile - RECOGNITION_ENCODER_BATCH_DIVISOR: int = 1 # Divisor for batch size in decoder # Layout LAYOUT_MODEL_CHECKPOINT: str = "datalab-to/surya_layout@7ac8e390226ee5fa2125dd303d827f79d31d1a1f" @@ -83,6 +82,13 @@ def TORCH_DEVICE_MODEL(self) -> str: TABLE_REC_BENCH_DATASET_NAME: str = "datalab-to/fintabnet_bench" COMPILE_TABLE_REC: bool = False + # Texify + TEXIFY_MODEL_CHECKPOINT: str = "datalab-to/texify@ee63647a66edfd1fd45d39ff0b034ddb2e8d252c" + TEXIFY_IMAGE_SIZE: Dict = {"height": 480, "width": 480} + TEXIFY_MAX_TOKENS: int = 768 + TEXIFY_BATCH_SIZE: Optional[int] = None + COMPILE_TEXIFY: bool = False + # OCR Error Detection OCR_ERROR_MODEL_CHECKPOINT: str = "datalab-to/ocr_error_detection@c1cbda3757670fd520553eaa5197656d331de414" OCR_ERROR_BATCH_SIZE: Optional[int] = None @@ -113,6 +119,10 @@ def TABLE_REC_STATIC_CACHE(self) -> bool: def OCR_ERROR_STATIC_CACHE(self) -> bool: return self.COMPILE_ALL or self.COMPILE_OCR_ERROR + @computed_field + def TEXIFY_STATIC_CACHE(self) -> bool: + return self.COMPILE_ALL or self.COMPILE_TEXIFY + @computed_field @property def MODEL_DTYPE(self) -> torch.dtype: diff --git a/surya/table_rec/__init__.py b/surya/table_rec/__init__.py index f68f7021..025484a5 100644 --- a/surya/table_rec/__init__.py +++ b/surya/table_rec/__init__.py @@ -29,19 +29,6 @@ class TableRecPredictor(BasePredictor): def __call__(self, images: List[Image.Image], batch_size: int | None = None) -> List[TableResult]: return self.batch_table_recognition(images, batch_size) - @staticmethod - def pad_to_batch_size(tensor: torch.Tensor, batch_size: int) -> torch.Tensor: - current_batch_size = tensor.shape[0] - if current_batch_size >= batch_size: - return tensor - - pad_size = batch_size - current_batch_size - repeats = (pad_size + current_batch_size - 1) // current_batch_size - repeated_rows = tensor.repeat((repeats, *[1] * (tensor.dim() - 1))) - pad_tensor = repeated_rows[:pad_size] - - return torch.cat([tensor, pad_tensor], dim=0) - def inference_loop( self, encoder_hidden_states: torch.Tensor, diff --git a/surya/table_rec/model/encoder.py b/surya/table_rec/model/encoder.py index 4bfb75c0..c05822eb 100644 --- a/surya/table_rec/model/encoder.py +++ b/surya/table_rec/model/encoder.py @@ -1,6 +1,8 @@ +from typing import Optional, Union, Tuple + import torch import torch.nn as nn -from typing import Optional, Union, Tuple + from surya.common.donut.encoder import DonutSwinPreTrainedModel, DonutSwinModelOutput, DonutSwinEmbeddings, DonutSwinEncoder diff --git a/surya/texify/__init__.py b/surya/texify/__init__.py new file mode 100644 index 00000000..bacfb121 --- /dev/null +++ b/surya/texify/__init__.py @@ -0,0 +1,129 @@ +from typing import List, Tuple + +import torch +import torch.nn.functional as F +from PIL import Image +from tqdm import tqdm + +from surya.common.predictor import BasePredictor +from surya.settings import settings +from surya.texify.loader import TexifyModelLoader +from surya.texify.schema import TexifyResult + + +class TexifyPredictor(BasePredictor): + model_loader_cls = TexifyModelLoader + batch_size = settings.TEXIFY_BATCH_SIZE + default_batch_sizes = { + "cpu": 2, + "mps": 6, + "cuda": 48 + } + + def __call__(self, images: List[Image.Image], batch_size: int | None = None) -> List[TexifyResult]: + text, confidences = self.batch_texify(images, batch_size=batch_size) + return [TexifyResult(text=t, confidence=c) for t, c in zip(text, confidences)] + + def prepare_input(self, images: List[Image.Image], batch_size: int) -> Tuple[torch.Tensor, torch.Tensor]: + batch_images = [img.convert("RGB") for img in images] + processed = self.processor(batch_images) + batch_pixel_values = processed["pixel_values"].to(self.model.device).to(self.model.dtype) + batch_input_ids = processed["input_ids"].to(self.model.device).to(torch.long) + + if settings.TEXIFY_STATIC_CACHE: + batch_pixel_values = self.pad_to_batch_size(batch_pixel_values, batch_size) + batch_input_ids = self.pad_to_batch_size(batch_input_ids, batch_size) + + return batch_pixel_values, batch_input_ids + + + def batch_texify(self, images: List[Image.Image], batch_size: int | None) -> Tuple[List[str], List[float]]: + if batch_size is None: + batch_size = self.get_batch_size() + + # Sort images by width, so similar length ones go together + sorted_pairs = sorted(enumerate(images), key=lambda x: x[1].width, reverse=False) + indices, images = zip(*sorted_pairs) + indices = list(indices) + images = list(images) + + output_text = [] + confidences = [] + for i in tqdm(range(0, len(images), batch_size), desc="Texify inference"): + batch = images[i:i+batch_size] + batch_pixel_values, batch_input_ids = self.prepare_input(batch, batch_size) + current_batch_size = len(batch) + + token_count = 0 + inference_token_count = batch_input_ids.shape[-1] + batch_predictions = [[] for _ in range(current_batch_size)] + + decoder_position_ids = torch.ones_like(batch_input_ids[0, :], dtype=torch.int64, device=self.model.device).cumsum(0) - 1 + self.model.decoder.model._setup_cache(self.model.config, batch_size, self.model.device, self.model.dtype) + + sequence_scores = None + all_done = torch.zeros(current_batch_size, dtype=torch.bool, device=self.model.device) + + with torch.inference_mode(): + encoder_hidden_states = self.model.encoder(pixel_values=batch_pixel_values).last_hidden_state + + while token_count < settings.TEXIFY_MAX_TOKENS - 1: + is_prefill = token_count == 0 + + return_dict = self.model.decoder( + input_ids=batch_input_ids, + encoder_hidden_states=encoder_hidden_states, + cache_position=decoder_position_ids, + use_cache=True, + prefill=is_prefill + ) + + decoder_position_ids = decoder_position_ids[-1:] + 1 + logits = return_dict["logits"][:current_batch_size] # Ignore batch padding + + preds = torch.argmax(logits[:, -1], dim=-1) + scores = torch.max(F.softmax(logits[:, -1], dim=-1), dim=-1).values.unsqueeze(1) + done = (preds == self.processor.tokenizer.eos_token_id) | (preds == self.processor.tokenizer.pad_token_id) + all_done = all_done | done + + if is_prefill: + sequence_scores = scores + else: + scores = scores.masked_fill(all_done, 0) + sequence_scores = torch.cat([sequence_scores, scores], dim=1) + + if all_done.all(): + break + + batch_input_ids = preds.unsqueeze(1) + + for j, (pred, status) in enumerate(zip(preds, all_done)): + if not status: + batch_predictions[j].append(int(pred)) + + token_count += inference_token_count + inference_token_count = batch_input_ids.shape[-1] + max_position_id = torch.max(decoder_position_ids).item() + decoder_position_ids = torch.ones_like(batch_input_ids[0, :], dtype=torch.int64, + device=self.model.device).cumsum(0) - 1 + max_position_id + + if settings.TEXIFY_STATIC_CACHE: + batch_input_ids = self.pad_to_batch_size(batch_input_ids, batch_size) + + batch_confidences = torch.sum(sequence_scores, dim=-1) / torch.sum(sequence_scores != 0, dim=-1) + detected_text = self.processor.tokenizer.batch_decode(batch_predictions) + + batch_confidences = batch_confidences.tolist() + + if settings.TEXIFY_STATIC_CACHE: + detected_text = detected_text[:current_batch_size] + batch_confidences = batch_confidences[:current_batch_size] + + output_text.extend(detected_text) + confidences.extend(batch_confidences) + + output_text = sorted(zip(indices, output_text), key=lambda x: x[0]) + confidences = sorted(zip(indices, confidences), key=lambda x: x[0]) + output_text = [text for _, text in output_text] + confidences = [conf for _, conf in confidences] + return output_text, confidences \ No newline at end of file diff --git a/surya/texify/loader.py b/surya/texify/loader.py new file mode 100644 index 00000000..f1e1cb98 --- /dev/null +++ b/surya/texify/loader.py @@ -0,0 +1,59 @@ +from typing import Optional + +import torch + +from surya.common.load import ModelLoader +from surya.settings import settings +from surya.texify.model.config import TexifyConfig, TexifyDecoderConfig, TexifyEncoderConfig + +from surya.texify.model.encoderdecoder import TexifyModel +from surya.texify.processor import TexifyProcessor + + +class TexifyModelLoader(ModelLoader): + def __init__(self, checkpoint: Optional[str] = None): + super().__init__(checkpoint) + + if self.checkpoint is None: + self.checkpoint = settings.TEXIFY_MODEL_CHECKPOINT + + self.checkpoint, self.revision = self.split_checkpoint_revision(self.checkpoint) + + def model( + self, + device=settings.TORCH_DEVICE_MODEL, + dtype=settings.MODEL_DTYPE + ) -> TexifyModel: + if device is None: + device = settings.TORCH_DEVICE_MODEL + if dtype is None: + dtype = settings.MODEL_DTYPE + + config = TexifyConfig.from_pretrained(self.checkpoint, revision=self.revision) + decoder_config = config.decoder + decoder = TexifyDecoderConfig(**decoder_config) + config.decoder = decoder + + encoder_config = config.encoder + encoder = TexifyEncoderConfig(**encoder_config) + config.encoder = encoder + + model = TexifyModel.from_pretrained(self.checkpoint, config=config, torch_dtype=dtype, revision=self.revision) + + model = model.to(device) + model = model.eval() + + if settings.TABLE_REC_STATIC_CACHE: + torch.set_float32_matmul_precision('high') + torch._dynamo.config.cache_size_limit = 16 + torch._dynamo.config.suppress_errors = False + + print(f"Compiling texify model {self.checkpoint} on device {device} with dtype {dtype}") + model.encoder = torch.compile(model.encoder) + model.decoder = torch.compile(model.decoder) + + print(f"Loaded texify model {self.checkpoint} on device {device} with dtype {dtype}") + return model + + def processor(self) -> TexifyProcessor: + return TexifyProcessor(self.checkpoint, self.revision) \ No newline at end of file diff --git a/surya/texify/model/__init__.py b/surya/texify/model/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/surya/texify/model/config.py b/surya/texify/model/config.py new file mode 100644 index 00000000..b7371352 --- /dev/null +++ b/surya/texify/model/config.py @@ -0,0 +1,188 @@ +from dataclasses import dataclass +from typing import Dict + +import torch +from transformers import PretrainedConfig +from transformers.utils import ModelOutput + +from surya.settings import settings + + +class TexifyConfig(PretrainedConfig): + model_type = "vision-encoder-decoder" + is_composition = True + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + if "encoder" in kwargs: + encoder_config = kwargs.pop("encoder") + decoder_config = kwargs.pop("decoder") + else: + encoder_config = TexifyEncoderConfig() + decoder_config = TexifyDecoderConfig() + + self.encoder = encoder_config + self.decoder = decoder_config + self.is_encoder_decoder = True + + if isinstance(decoder_config, dict): + self.decoder_start_token_id = decoder_config["bos_token_id"] + self.decoder_end_token_id = decoder_config["eos_token_id"] + self.pad_token_id = decoder_config["pad_token_id"] + self.eos_token_id = decoder_config["eos_token_id"] + else: + self.decoder_start_token_id = decoder_config.bos_token_id + self.decoder_end_token_id = decoder_config.eos_token_id + self.pad_token_id = decoder_config.pad_token_id + self.eos_token_id = decoder_config.eos_token_id + + +@dataclass +class TexifyModelOutput(ModelOutput): + logits: Dict[str, torch.Tensor] + hidden_states: torch.Tensor | None = None + + +class TexifyEncoderConfig(PretrainedConfig): + model_type = "donut-swin" + + attribute_map = { + "num_attention_heads": "num_heads", + "num_hidden_layers": "num_layers", + } + + def __init__( + self, + image_size=(settings.TEXIFY_IMAGE_SIZE["height"], settings.TEXIFY_IMAGE_SIZE["width"]), + patch_size=4, + num_channels=3, + embed_dim=128, + depths=[2, 2, 12, 2], + num_heads=[4, 8, 16, 32], + num_kv_heads=[4, 8, 16, 32], + window_size=8, + mlp_ratio=4.0, + qkv_bias=True, + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + drop_path_rate=0.1, + hidden_act="gelu", + use_absolute_embeddings=False, + initializer_range=0.02, + layer_norm_eps=1e-5, + encoder_length=1024, + use_positional_embeddings=True, + **kwargs, + ): + super().__init__(**kwargs) + + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.embed_dim = embed_dim + self.depths = depths + self.num_layers = len(depths) + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.window_size = window_size + self.mlp_ratio = mlp_ratio + self.qkv_bias = qkv_bias + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.drop_path_rate = drop_path_rate + self.hidden_act = hidden_act + self.use_absolute_embeddings = use_absolute_embeddings + self.layer_norm_eps = layer_norm_eps + self.initializer_range = initializer_range + # we set the hidden_size attribute in order to make Swin work with VisionEncoderDecoderModel + # this indicates the channel dimension after the last stage of the model + self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1)) + self.encoder_length = encoder_length + self.use_positional_embeddings = use_positional_embeddings + + +class TexifyDecoderConfig(PretrainedConfig): + model_type = "texify" + + def __init__( + self, + num_hidden_layers=6, + vocab_size=68549, + hidden_size=512, + intermediate_size=4 * 512, + encoder_hidden_size=1024, + num_attention_heads=8, + lru_width=None, + attention_window_size=16, + conv1d_width=4, + logits_soft_cap=30.0, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=0, + eos_token_id=1, + bos_token_id=1, + hidden_activation="gelu_pytorch_tanh", + rope_theta=10000.0, + block_types=("attention",), + cross_attn_layers=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9), + encoder_cross_attn_layers=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9), + self_attn_layers=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9), + global_attn_layers=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9), + attention_dropout=0.0, + num_key_value_heads=4, + attention_bias=False, + w_init_variance_scale=0.01, + init_std=0.02, + tie_word_embeddings=False, + aux_heads=0, # How many n-token-ahead heads to add + causal=True, + layer_norm_eps=1e-5, + dropout=.1, + **kwargs, + ): + self.num_hidden_layers = num_hidden_layers + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_attention_heads = num_attention_heads + self.lru_width = lru_width if lru_width is not None else hidden_size + self.attention_window_size = attention_window_size + self.conv1d_width = conv1d_width + self.logits_soft_cap = logits_soft_cap + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.block_types = list(block_types) + self.hidden_activation = hidden_activation + self.head_dim = self.hidden_size // self.num_attention_heads + self.num_key_value_heads = num_key_value_heads if num_key_value_heads is not None else num_attention_heads + if self.num_key_value_heads > self.num_attention_heads: + raise ValueError("The number of `num_key_value_heads` must be smaller than `num_attention_heads`") + self.cross_attn_layers = cross_attn_layers + self.self_attn_layers = self_attn_layers + self.global_attn_layers = global_attn_layers + self.attention_dropout = attention_dropout + self.attention_bias = attention_bias + self.w_init_variance_scale = w_init_variance_scale + self.final_w_init_variance_scale = 2.0 / self.num_hidden_layers + self.init_std = init_std + self.tie_word_embeddings = tie_word_embeddings + self.aux_heads = aux_heads + self.encoder_hidden_size=encoder_hidden_size + self.causal = causal + self.encoder_cross_attn_layers = encoder_cross_attn_layers + self.layer_norm_eps = layer_norm_eps + self.dropout = dropout + self.double_residual_flow = False + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + **kwargs, + ) + + @property + def layers_block_type(self): + return (self.block_types * 100)[: self.num_hidden_layers] \ No newline at end of file diff --git a/surya/texify/model/decoder.py b/surya/texify/model/decoder.py new file mode 100644 index 00000000..29c8d786 --- /dev/null +++ b/surya/texify/model/decoder.py @@ -0,0 +1,79 @@ +from typing import Optional, Union, Tuple + +import torch + +from surya.common.adetr.decoder import SuryaADETRDecoderPreTrainedModel, SuryaADETRDecoderModel, WrappedEmbedding +from torch import nn + +from surya.settings import settings +from surya.texify.model.config import TexifyModelOutput + + +class TexifyDecoder(SuryaADETRDecoderPreTrainedModel): + _tied_weights_keys = None + + def __init__(self, config, **kwargs): + super().__init__(config) + embed_tokens = WrappedEmbedding(config.vocab_size, config.hidden_size, config.pad_token_id) + self.model = SuryaADETRDecoderModel( + config, + embedder=embed_tokens, + static_cache=settings.TEXIFY_STATIC_CACHE, + max_boxes=settings.TEXIFY_MAX_TOKENS + ) + self.vocab_size = config.vocab_size + + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.pre_output_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + # Initialize weights and apply final processing + self.post_init() + + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + # Ignore copy + def forward( + self, + input_ids: torch.LongTensor = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + **kwargs + ) -> Union[Tuple, TexifyModelOutput]: + outputs = self.model( + input_ids=input_ids, + cache_position=cache_position, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_hidden_states=True, + return_dict=True, + ) + + hidden_states = self.pre_output_norm(outputs[0]) + logits = self.lm_head(hidden_states) + + return TexifyModelOutput( + logits=logits, + hidden_states=outputs.hidden_states, + ) \ No newline at end of file diff --git a/surya/texify/model/encoder.py b/surya/texify/model/encoder.py new file mode 100644 index 00000000..99ea3244 --- /dev/null +++ b/surya/texify/model/encoder.py @@ -0,0 +1,87 @@ +from typing import Optional, Union, Tuple + +from torch import nn +import torch + +from surya.common.donut.encoder import DonutSwinPreTrainedModel, DonutSwinEmbeddings, DonutSwinEncoder, \ + DonutSwinModelOutput + +class TexifyEncoder(DonutSwinPreTrainedModel): + def __init__(self, config, add_pooling_layer=True, use_mask_token=False): + super().__init__(config) + self.config = config + self.num_layers = len(config.depths) + self.num_features = int(config.embed_dim * 2 ** (self.num_layers - 1)) + + self.embeddings = DonutSwinEmbeddings(config, use_mask_token=use_mask_token) + self.encoder = DonutSwinEncoder(config, self.embeddings.patch_grid) + + self.position_embeddings = None + if hasattr(config, "encoder_length"): + self.position_embeddings = nn.Parameter(torch.zeros(1, config.encoder_length, config.hidden_size)) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.patch_embeddings + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + bool_masked_pos: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, DonutSwinModelOutput]: + r""" + bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`): + Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, len(self.config.depths)) + + embedding_output, input_dimensions = self.embeddings( + pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding + ) + + encoder_outputs = self.encoder( + embedding_output, + input_dimensions, + head_mask=head_mask, + output_attentions=output_attentions, + return_dict=return_dict, + output_hidden_states=True, + output_hidden_states_before_downsampling=True, + ) + + last_hidden_state = encoder_outputs[0] + if self.position_embeddings is not None: + last_hidden_state += self.position_embeddings[:, :last_hidden_state.size(1), :] + + return DonutSwinModelOutput( + last_hidden_state=last_hidden_state, + ) \ No newline at end of file diff --git a/surya/texify/model/encoderdecoder.py b/surya/texify/model/encoderdecoder.py new file mode 100644 index 00000000..50e45a08 --- /dev/null +++ b/surya/texify/model/encoderdecoder.py @@ -0,0 +1,121 @@ +from typing import Optional, Union, Tuple + +import torch +from transformers import PreTrainedModel, VisionEncoderDecoderConfig, PretrainedConfig +from transformers.modeling_outputs import Seq2SeqLMOutput, BaseModelOutput + +from surya.texify.model.decoder import TexifyDecoder +from surya.texify.model.encoder import TexifyEncoder + + +class TexifyModel(PreTrainedModel): + config_class = VisionEncoderDecoderConfig + base_model_prefix = "vision_encoder_decoder" + main_input_name = "pixel_values" + supports_gradient_checkpointing = True + _supports_param_buffer_assignment = False + + def __init__( + self, + config: Optional[PretrainedConfig] = None, + encoder: Optional[PreTrainedModel] = None, + decoder: Optional[PreTrainedModel] = None, + ): + # initialize with config + # make sure input & output embeddings is not tied + config.tie_word_embeddings = False + config.decoder.tie_word_embeddings = False + super().__init__(config) + + if encoder is None: + encoder = TexifyEncoder(config.encoder) + + if decoder is None: + decoder = TexifyDecoder(config.decoder, attn_implementation=config._attn_implementation) + + self.encoder: TexifyEncoder = encoder + self.decoder: TexifyDecoder = decoder + + # make sure that the individual model's config refers to the shared config + # so that the updates to the config will be synced + self.encoder.config = self.config.encoder + self.decoder.config = self.config.decoder + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + def get_output_embeddings(self): + return self.decoder.get_output_embeddings() + + def set_output_embeddings(self, new_embeddings): + return self.decoder.set_output_embeddings(new_embeddings) + + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_cache_position: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None, + use_cache: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: + + kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")} + + kwargs_decoder = { + argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") + } + + if encoder_outputs is None: + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + encoder_outputs = self.encoder( + pixel_values=pixel_values, + **kwargs_encoder, + ) + elif isinstance(encoder_outputs, tuple): + encoder_outputs = BaseModelOutput(*encoder_outputs) + + encoder_hidden_states = encoder_outputs[0] + + # optionally project encoder_hidden_states + if ( + self.encoder.config.hidden_size != self.decoder.config.hidden_size + and self.decoder.config.cross_attention_hidden_size is None + ): + encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states) + + # else: + encoder_attention_mask = None + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + cache_position=decoder_cache_position, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + **kwargs_decoder, + ) + + return Seq2SeqLMOutput( + logits=decoder_outputs.logits, + decoder_hidden_states=decoder_outputs.hidden_states, + encoder_last_hidden_state=encoder_outputs.last_hidden_state + ) + + def resize_token_embeddings(self, *args, **kwargs): + raise NotImplementedError( + "Resizing the embedding layers via the VisionEncoderDecoderModel directly is not supported.Please use the" + " respective methods of the wrapped decoder object (model.decoder.resize_token_embeddings(...))" + ) + + def _reorder_cache(self, past_key_values, beam_idx): + # apply decoder cache reordering here + return self.decoder._reorder_cache(past_key_values, beam_idx) \ No newline at end of file diff --git a/surya/texify/processor.py b/surya/texify/processor.py new file mode 100644 index 00000000..8e5259ad --- /dev/null +++ b/surya/texify/processor.py @@ -0,0 +1,69 @@ +from typing import List + +import numpy as np +import torch +from PIL import Image +from transformers import PreTrainedTokenizerFast, ProcessorMixin + +from surya.common.donut.processor import SuryaEncoderImageProcessor +from surya.settings import settings + + +class TexifyProcessor(ProcessorMixin): + attributes = ["image_processor"] + image_processor_class = "AutoImageProcessor" + + def __init__(self, checkpoint, revision, **kwargs): + image_processor = SuryaEncoderImageProcessor.from_pretrained(checkpoint, revision=revision) + image_processor.do_align_long_axis = False + image_processor.max_size = settings.TEXIFY_IMAGE_SIZE + self.image_processor = image_processor + + tokenizer = TexifyTokenizer.from_pretrained(checkpoint, revision=revision) + tokenizer.model_max_length = settings.TEXIFY_MAX_TOKENS + self.tokenizer = tokenizer + + super().__init__(image_processor) + + def __call__( + self, + images: List[Image.Image] | None, + *args, + **kwargs + ): + input_ids = [[self.tokenizer.bos_token_id]] * len(images) + input_ids = torch.tensor(input_ids) + + pixel_values = self.image_processor(images, **kwargs)["pixel_values"] + pixel_values = torch.tensor(np.array(pixel_values)) + + inputs = { + "input_ids": input_ids, + "pixel_values": pixel_values + } + return inputs + + + +class TexifyTokenizer(PreTrainedTokenizerFast): + def __init__( + self, + vocab_file=None, + tokenizer_file=None, + clean_up_tokenization_spaces=False, + unk_token="", + bos_token="", + eos_token="", + pad_token="", + **kwargs, + ): + super().__init__( + vocab_file=vocab_file, + tokenizer_file=tokenizer_file, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + unk_token=unk_token, + bos_token=bos_token, + eos_token=eos_token, + pad_token=pad_token, + **kwargs, + ) \ No newline at end of file diff --git a/surya/texify/schema.py b/surya/texify/schema.py new file mode 100644 index 00000000..4d170f99 --- /dev/null +++ b/surya/texify/schema.py @@ -0,0 +1,8 @@ +from typing import Optional + +from pydantic import BaseModel + + +class TexifyResult(BaseModel): + text: str + confidence: Optional[float] = None diff --git a/surya/texify/util.py b/surya/texify/util.py new file mode 100644 index 00000000..e4699eac --- /dev/null +++ b/surya/texify/util.py @@ -0,0 +1,6 @@ +import re + +def convert_math_delimiters(text): + text = re.sub(r'(.*?)', r'$$\1$$', text) + text = re.sub(r'(.*?)', r'$\1$', text) + return text \ No newline at end of file diff --git a/table_recognition.py b/table_recognition.py index 489fc54c..d4b4955b 100644 --- a/table_recognition.py +++ b/table_recognition.py @@ -1,4 +1,4 @@ -from surya.scripts import table_recognition_cli +from surya.scripts.table_recognition import table_recognition_cli if __name__ == "__main__": table_recognition_cli() \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index 526de194..04f6de2b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,6 +9,7 @@ from surya.layout import LayoutPredictor from surya.recognition import RecognitionPredictor from surya.table_rec import TableRecPredictor +from surya.texify import TexifyPredictor @pytest.fixture(scope="session") @@ -42,6 +43,12 @@ def table_rec_predictor() -> TableRecPredictor: yield table_rec_predictor del table_rec_predictor +@pytest.fixture(scope="session") +def texify_predictor() -> TexifyPredictor: + texify_predictor = TexifyPredictor() + yield texify_predictor + del texify_predictor + @pytest.fixture() def test_image(): image = Image.new("RGB", (1024, 1024), "white") diff --git a/tests/test_latex_ocr.py b/tests/test_latex_ocr.py new file mode 100644 index 00000000..2d118885 --- /dev/null +++ b/tests/test_latex_ocr.py @@ -0,0 +1,14 @@ +from PIL import Image, ImageDraw + + +def test_latex_ocr(texify_predictor): + img = Image.new('RGB', (200, 100), color='white') + draw = ImageDraw.Draw(img) + draw.text((10, 10), "E = mc2", fill='black', font_size=48) + + results = texify_predictor([img]) + text = results[0].text.strip() + assert len(results) == 1 + + assert text.startswith("") + assert text.endswith("") \ No newline at end of file diff --git a/texify_app.py b/texify_app.py new file mode 100644 index 00000000..feeccde1 --- /dev/null +++ b/texify_app.py @@ -0,0 +1,4 @@ +from surya.scripts.run_texify_app import texify_app_cli + +if __name__ == "__main__": + texify_app_cli() \ No newline at end of file