diff --git a/surya/detection/__init__.py b/surya/detection/__init__.py index 2db088a..7f8993f 100644 --- a/surya/detection/__init__.py +++ b/surya/detection/__init__.py @@ -135,26 +135,25 @@ def batch_detection( class InlineDetectionPredictor(DetectionPredictor): model_loader_cls = InlineDetectionModelLoader - - def batch_generator(self, iterable, batch_size=None): - if batch_size is None: - batch_size = self.get_batch_size() - - for i in range(0, len(iterable), batch_size): - yield iterable[i:i+batch_size] def __call__(self, images, text_boxes: List[List[List[float]]], batch_size=None, include_maps=False) -> List[TextDetectionResult]: detection_generator = self.batch_detection(images, batch_size=batch_size, static_cache=settings.DETECTOR_STATIC_CACHE) - text_box_generator = self.batch_generator(text_boxes, batch_size=batch_size) postprocessing_futures = [] max_workers = min(settings.DETECTOR_POSTPROCESSING_CPU_WORKERS, len(images)) parallelize = not settings.IN_STREAMLIT and len(images) >= settings.DETECTOR_MIN_PARALLEL_THRESH executor = ThreadPoolExecutor if parallelize else FakeExecutor - with executor(max_workers=max_workers) as e: - for (preds, orig_sizes), batch_text_boxes in zip(detection_generator, text_box_generator): - for pred, orig_size, image_text_boxes in zip(preds, orig_sizes, batch_text_boxes): - postprocessing_futures.append(e.submit(parallel_get_inline_boxes, pred, orig_size, image_text_boxes, include_maps)) - assert len(postprocessing_futures) == len(images) == len(text_boxes) # Ensure we have a 1:1 mapping + current_image_idx = 0 + with executor(max_workers=max_workers) as e: + for (preds, orig_sizes) in detection_generator: + for pred, orig_size in zip(preds, orig_sizes): + postprocessing_futures.append(e.submit(parallel_get_inline_boxes, pred, orig_size, text_boxes[current_image_idx], include_maps)) + current_image_idx += 1 + + try: + assert len(postprocessing_futures) == len(images) == len(text_boxes) # Ensure we have a 1:1 mapping + except: + print(len(postprocessing_futures), len(images), len(text_boxes)) + raise ValueError() return [future.result() for future in postprocessing_futures] \ No newline at end of file