Skip to content

Commit

Permalink
Fix inline prediction - original image mapping
Browse files Browse the repository at this point in the history
  • Loading branch information
tarun-menta committed Mar 9, 2025
1 parent 292576e commit bc86ba8
Showing 1 changed file with 12 additions and 13 deletions.
25 changes: 12 additions & 13 deletions surya/detection/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,26 +135,25 @@ def batch_detection(

class InlineDetectionPredictor(DetectionPredictor):
model_loader_cls = InlineDetectionModelLoader

def batch_generator(self, iterable, batch_size=None):
if batch_size is None:
batch_size = self.get_batch_size()

for i in range(0, len(iterable), batch_size):
yield iterable[i:i+batch_size]

def __call__(self, images, text_boxes: List[List[List[float]]], batch_size=None, include_maps=False) -> List[TextDetectionResult]:
detection_generator = self.batch_detection(images, batch_size=batch_size, static_cache=settings.DETECTOR_STATIC_CACHE)
text_box_generator = self.batch_generator(text_boxes, batch_size=batch_size)

postprocessing_futures = []
max_workers = min(settings.DETECTOR_POSTPROCESSING_CPU_WORKERS, len(images))
parallelize = not settings.IN_STREAMLIT and len(images) >= settings.DETECTOR_MIN_PARALLEL_THRESH
executor = ThreadPoolExecutor if parallelize else FakeExecutor
with executor(max_workers=max_workers) as e:
for (preds, orig_sizes), batch_text_boxes in zip(detection_generator, text_box_generator):
for pred, orig_size, image_text_boxes in zip(preds, orig_sizes, batch_text_boxes):
postprocessing_futures.append(e.submit(parallel_get_inline_boxes, pred, orig_size, image_text_boxes, include_maps))

assert len(postprocessing_futures) == len(images) == len(text_boxes) # Ensure we have a 1:1 mapping
current_image_idx = 0
with executor(max_workers=max_workers) as e:
for (preds, orig_sizes) in detection_generator:
for pred, orig_size in zip(preds, orig_sizes):
postprocessing_futures.append(e.submit(parallel_get_inline_boxes, pred, orig_size, text_boxes[current_image_idx], include_maps))
current_image_idx += 1

try:
assert len(postprocessing_futures) == len(images) == len(text_boxes) # Ensure we have a 1:1 mapping
except:
print(len(postprocessing_futures), len(images), len(text_boxes))
raise ValueError()
return [future.result() for future in postprocessing_futures]

0 comments on commit bc86ba8

Please sign in to comment.