From 3184eff0b954be2e7cb93156ebb0e63c43fdc3cd Mon Sep 17 00:00:00 2001 From: lugimzzz Date: Fri, 21 Oct 2022 12:03:12 +0000 Subject: [PATCH 1/6] add_docprompt_example_in_pipelines --- .../docprompt_example.py | 45 +++ .../document-intelligence/requirements.txt | 3 + .../run_docprompt_server.sh | 19 + .../run_docprompt_web.sh | 16 + pipelines/pipelines/__init__.py | 3 +- pipelines/pipelines/nodes/__init__.py | 1 + .../pipelines/nodes/document/__init__.py | 16 + .../nodes/document/document_intelligence.py | 234 ++++++++++++ .../nodes/document/document_preprocessor.py | 132 +++++++ pipelines/pipelines/pipelines/__init__.py | 3 +- .../pipelines/pipelines/standard_pipelines.py | 36 ++ pipelines/rest_api/controller/search.py | 18 +- pipelines/rest_api/pipeline/docprompt.yaml | 29 ++ pipelines/rest_api/schema.py | 16 + pipelines/ui/webapp_docprompt_gradio.py | 349 ++++++++++++++++++ 15 files changed, 917 insertions(+), 3 deletions(-) create mode 100644 pipelines/examples/document-intelligence/docprompt_example.py create mode 100644 pipelines/examples/document-intelligence/requirements.txt create mode 100644 pipelines/examples/document-intelligence/run_docprompt_server.sh create mode 100644 pipelines/examples/document-intelligence/run_docprompt_web.sh create mode 100644 pipelines/pipelines/nodes/document/__init__.py create mode 100644 pipelines/pipelines/nodes/document/document_intelligence.py create mode 100644 pipelines/pipelines/nodes/document/document_preprocessor.py create mode 100644 pipelines/rest_api/pipeline/docprompt.yaml create mode 100644 pipelines/ui/webapp_docprompt_gradio.py diff --git a/pipelines/examples/document-intelligence/docprompt_example.py b/pipelines/examples/document-intelligence/docprompt_example.py new file mode 100644 index 000000000000..36918569f0ed --- /dev/null +++ b/pipelines/examples/document-intelligence/docprompt_example.py @@ -0,0 +1,45 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +import os + +import paddle +from pipelines.nodes import DocPreProcessor, DocPrompter +from pipelines import DocPipeline + +# yapf: disable +parser = argparse.ArgumentParser() +parser.add_argument('--device', choices=['cpu', 'gpu'], default="gpu", help="Select which device to run docprompt system, defaults to gpu.") +parser.add_argument("--batch_size", default=4, type=int, help="The batch size of prompt for one image.") +args = parser.parse_args() +# yapf: enable + + +def docprompt_pipeline(): + + use_gpu = True if args.device == 'gpu' else False + + preprocessor = DocPreProcessor(use_gpu=use_gpu) + docprompter = DocPrompter(use_gpu=use_gpu, batch_size=args.batch_size) + pipe = DocPipeline(preprocessor=preprocessor, modelrunner=docprompter) + meta = {"doc": "./invoice.jpg", "prompt": ["发票号码是多少?", "校验码是多少?"]} + + prediction = pipe.run(meta=meta) + print(prediction["results"]) + + +if __name__ == "__main__": + docprompt_pipeline() diff --git a/pipelines/examples/document-intelligence/requirements.txt b/pipelines/examples/document-intelligence/requirements.txt new file mode 100644 index 000000000000..dc65d4a14e9c --- /dev/null +++ b/pipelines/examples/document-intelligence/requirements.txt @@ -0,0 +1,3 @@ +numpy +opencv-python +requests diff --git a/pipelines/examples/document-intelligence/run_docprompt_server.sh b/pipelines/examples/document-intelligence/run_docprompt_server.sh new file mode 100644 index 000000000000..41544888a36c --- /dev/null +++ b/pipelines/examples/document-intelligence/run_docprompt_server.sh @@ -0,0 +1,19 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# 指定语义检索系统的Yaml配置文件 +export CUDA_VISIBLE_DEVICES=0 +export PIPELINE_YAML_PATH=rest_api/pipeline/docprompt.yaml +# 使用端口号 8891 启动模型服务 +python rest_api/application.py 8891 \ No newline at end of file diff --git a/pipelines/examples/document-intelligence/run_docprompt_web.sh b/pipelines/examples/document-intelligence/run_docprompt_web.sh new file mode 100644 index 000000000000..b4055c148565 --- /dev/null +++ b/pipelines/examples/document-intelligence/run_docprompt_web.sh @@ -0,0 +1,16 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +unset http_proxy && unset https_proxy +# 配置模型服务地址 +python ui/webapp_docprompt_gradio.py \ No newline at end of file diff --git a/pipelines/pipelines/__init__.py b/pipelines/pipelines/__init__.py index 83dc75fcf6b2..045cf608cc7f 100644 --- a/pipelines/pipelines/__init__.py +++ b/pipelines/pipelines/__init__.py @@ -39,7 +39,8 @@ from pipelines.pipelines import Pipeline from pipelines.pipelines.standard_pipelines import (BaseStandardPipeline, ExtractiveQAPipeline, - SemanticSearchPipeline) + SemanticSearchPipeline, + DocPipeline) import pandas as pd diff --git a/pipelines/pipelines/nodes/__init__.py b/pipelines/pipelines/nodes/__init__.py index a4285acaaf47..34d09cfb0e27 100644 --- a/pipelines/pipelines/nodes/__init__.py +++ b/pipelines/pipelines/nodes/__init__.py @@ -29,3 +29,4 @@ from pipelines.nodes.ranker import BaseRanker, ErnieRanker from pipelines.nodes.reader import BaseReader, ErnieReader from pipelines.nodes.retriever import BaseRetriever, DensePassageRetriever +from pipelines.nodes.document import DocPreProcessor, DocPrompter diff --git a/pipelines/pipelines/nodes/document/__init__.py b/pipelines/pipelines/nodes/document/__init__.py new file mode 100644 index 000000000000..7f07d481b6c0 --- /dev/null +++ b/pipelines/pipelines/nodes/document/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pipelines.nodes.document.document_preprocessor import DocPreProcessor +from pipelines.nodes.document.document_intelligence import DocPrompter \ No newline at end of file diff --git a/pipelines/pipelines/nodes/document/document_intelligence.py b/pipelines/pipelines/nodes/document/document_intelligence.py new file mode 100644 index 000000000000..366ef498aa54 --- /dev/null +++ b/pipelines/pipelines/nodes/document/document_intelligence.py @@ -0,0 +1,234 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import collections +import math +from multiprocessing import cpu_count +from typing import Dict, List +import logging + +import paddle +from paddlenlp.transformers import AutoTokenizer +from paddlenlp.taskflow.utils import download_file, ImageReader, get_doc_pred, find_answer_pos, sort_res +from paddlenlp.taskflow.task import Task +from paddlenlp.utils.env import PPNLP_HOME + +from pipelines.nodes.base import BaseComponent + +logger = logging.getLogger(__name__) + +URLS = { + "docprompt": [ + "https://bj.bcebos.com/paddlenlp/taskflow/document_intelligence/docprompt/docprompt_params.tar", + "8eae8148981731f230b328076c5a08bf" + ], +} + + +class DocPrompter(BaseComponent): + """ + DocPrompter: extract prompt's answers from the document input. + """ + return_no_answers: bool + outgoing_edges = 1 + query_count = 0 + query_time = 0 + + def __init__(self, + topn: int = 1, + use_gpu: bool = True, + task_path: str = None, + model: str = "docprompt", + device_id: int = 0, + num_threads: int = None, + lang: str = "ch", + batch_size: int = 1): + """ + Init Document Prompter. + :param topn: return top n answers. + :param use_gpu: Whether to use all available GPUs or the CPU. Falls back on CPU if no GPU is available. + :param task_path: Custom model path if using custom model parameters. + :param model: Choose model name. + :param device_id: Choose gpu device id. + :param num_threads: Number of processing threads. + :param lang: Choose langugae. + :param batch_size: Number of samples the model receives in one batch for inference. + Memory consumption is much lower in inference mode. Recommendation: Increase the batch size + to a value so only a single batch is used. + """ + self._use_gpu = False if paddle.get_device() == 'cpu' else use_gpu + self.model = model + self._device_id = device_id + self._num_threads = num_threads if num_threads else math.ceil( + cpu_count() / 2) + self._topn = topn + self._lang = lang + self._batch_size = batch_size + if task_path is None: + self._task_path = os.path.join(PPNLP_HOME, "pipelines", + "document_intelligence", self.model) + else: + self._task_path = task_path + + download_file(self._task_path, "docprompt_params.tar", + URLS[self.model][0], URLS[self.model][1]) + self._get_inference_model() + self._tokenizer = AutoTokenizer.from_pretrained( + "ernie-layoutx-base-uncased") + self._reader = ImageReader(super_rel_pos=False, + tokenizer=self._tokenizer) + + def _get_inference_model(self): + inference_model_path = os.path.join(self._task_path, "static", + "inference") + self._static_model_file = inference_model_path + ".pdmodel" + self._static_params_file = inference_model_path + ".pdiparams" + self._config = paddle.inference.Config(self._static_model_file, + self._static_params_file) + self._prepare_static_mode() + + def _prepare_static_mode(self): + """ + Construct the input data and predictor in the PaddlePaddele static mode. + """ + if paddle.get_device() == 'cpu': + self._config.disable_gpu() + self._config.enable_mkldnn() + else: + self._config.enable_use_gpu(100, self._device_id) + self._config.delete_pass("embedding_eltwise_layernorm_fuse_pass") + self._config.set_cpu_math_library_num_threads(self._num_threads) + self._config.switch_use_feed_fetch_ops(False) + self._config.disable_glog_info() + self._config.enable_memory_optim() + self._config.switch_ir_optim(False) + self.predictor = paddle.inference.create_predictor(self._config) + self.input_names = [name for name in self.predictor.get_input_names()] + self.input_handles = [ + self.predictor.get_input_handle(name) + for name in self.predictor.get_input_names() + ] + self.output_handle = [ + self.predictor.get_output_handle(name) + for name in self.predictor.get_output_names() + ] + + def _run_model(self, inputs: List[dict]): + """ + Run docprompt model. + """ + all_predictions_list = [] + for example in inputs: + ocr_result = example["ocr_result"] + doc_path = example["doc"] + prompt = example["prompt"] + ocr_type = example["ocr_type"] + + if not ocr_result: + all_predictions = [{ + "prompt": + p, + "result": [{ + 'value': '', + 'prob': 0.0, + 'start': -1, + 'end': -1 + }] + } for p in prompt] + all_boxes = {} + else: + data_loader = self._reader.data_generator( + ocr_result, doc_path, prompt, self._batch_size, ocr_type) + + RawResult = collections.namedtuple("RawResult", + ["unique_id", "seq_logits"]) + + all_results = [] + for data in data_loader: + for idx in range(len(self.input_names)): + self.input_handles[idx].copy_from_cpu(data[idx]) + self.predictor.run() + outputs = [ + output_handle.copy_to_cpu() + for output_handle in self.output_handle + ] + unique_ids, seq_logits = outputs + + for idx in range(len(unique_ids)): + all_results.append( + RawResult( + unique_id=int(unique_ids[idx]), + seq_logits=seq_logits[idx], + )) + + all_examples = self._reader.examples["infer"] + all_features = self._reader.features["infer"] + all_key_probs = [1 for _ in all_examples] + + example_index_to_features = collections.defaultdict(list) + + for feature in all_features: + example_index_to_features[feature.qas_id].append(feature) + + unique_id_to_result = {} + for result in all_results: + unique_id_to_result[result.unique_id] = result + + all_predictions = [] + all_boxes = {} + for (example_index, example) in enumerate(all_examples): + example_doc_tokens = example.doc_tokens + example_qas_id = example.qas_id + page_id = example_qas_id.split("_")[0] + if page_id not in all_boxes: + all_boxes[page_id] = example.ori_boxes + example_query = example.keys[0] + features = example_index_to_features[example_qas_id] + + preds = [] + # keep track of the minimum score of null start+end of position 0 + for feature in features: + if feature.unique_id not in unique_id_to_result: + continue + result = unique_id_to_result[feature.unique_id] + + # find preds + ans_pos = find_answer_pos(result.seq_logits, feature) + preds.extend( + get_doc_pred(result, ans_pos, example, + self._tokenizer, feature, True, + all_key_probs, example_index)) + + if not preds: + preds.append({ + 'value': '', + 'prob': 0., + 'start': -1, + 'end': -1 + }) + else: + preds = sort_res(example_query, preds, + example_doc_tokens, all_boxes[page_id], + self._lang)[:self._topn] + all_predictions.append({ + "prompt": example_query, + "result": preds + }) + all_predictions_list.append(all_predictions) + return all_predictions_list + + def run(self, example: dict): + results = self._run_model([example]) + output = {"results": results} + return output, "output_1" diff --git a/pipelines/pipelines/nodes/document/document_preprocessor.py b/pipelines/pipelines/nodes/document/document_preprocessor.py new file mode 100644 index 000000000000..ec36d97b4745 --- /dev/null +++ b/pipelines/pipelines/nodes/document/document_preprocessor.py @@ -0,0 +1,132 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import logging +from typing import List, Dict +import numpy as np +import base64 +from PIL import Image +from io import BytesIO + +import paddle +from paddleocr import PaddleOCR +from pathlib import Path +from paddlenlp.taskflow.utils import download_file + +from pipelines.nodes.base import BaseComponent + +logger = logging.getLogger(__name__) + + +class DocPreProcessor(BaseComponent): + """ + Preprocess document input from image/image url/image bytestream to ocr outputs + """ + return_no_answers: bool + outgoing_edges = 1 + query_count = 0 + query_time = 0 + + def __init__(self, use_gpu: bool = True, lang: str = "ch"): + """ + Init Document Preprocessor. + :param use_gpu: Whether to use all available GPUs or the CPU. Falls back on CPU if no GPU is available. + :param lang: Choose ocr model processing langugae + """ + self._lang = lang + self._use_gpu = False if paddle.get_device() == 'cpu' else use_gpu + self._ocr = PaddleOCR(use_angle_cls=True, + show_log=False, + use_gpu=self._use_gpu, + lang=self._lang) + + def _check_input_text(self, inputs): + if isinstance(inputs, dict): + inputs = [inputs] + if isinstance(inputs, list): + input_list = [] + for example in inputs: + data = {} + if isinstance(example, dict): + if "doc" not in example.keys(): + raise ValueError( + "Invalid inputs, the inputs should contain an url to an image or a local path." + ) + else: + if isinstance(example["doc"], str): + + if example["doc"].startswith("http://") or example[ + "doc"].startswith("https://"): + download_file("./", + example["doc"].rsplit("/", 1)[-1], + example["doc"]) + data["doc"] = example["doc"].rsplit("/", 1)[-1] + elif os.path.isfile(example["doc"]): + data["doc"] = example["doc"] + else: + img = base64.b64decode( + example["doc"].encode('utf-8')) + img = np.frombuffer(bytearray(img), + dtype='uint8') + img = np.array( + Image.open(BytesIO(img)).convert('RGB')) + img = Image.fromarray(img) + img.save("./tmp.jpg") + data["doc"] = "./tmp.jpg" + else: + raise ValueError( + f"Incorrect path or url, URLs must start with `http://` or `https://`" + ) + if "prompt" not in example.keys(): + raise ValueError( + "Invalid inputs, the inputs should contain the prompt." + ) + else: + if isinstance(example["prompt"], str): + data["prompt"] = [example["prompt"]] + elif isinstance(example["prompt"], list) and all( + isinstance(s, str) for s in example["prompt"]): + data["prompt"] = example["prompt"] + else: + raise TypeError( + "Incorrect prompt, prompt should be string or list of string." + ) + if "word_boxes" in example.keys(): + data["word_boxes"] = example["word_boxes"] + input_list.append(data) + else: + raise TypeError( + "Invalid inputs, input for document intelligence task should be dict or list of dict, but type of {} found!" + .format(type(example))) + else: + raise TypeError( + "Invalid inputs, input for document intelligence task should be dict or list of dict, but type of {} found!" + .format(type(inputs))) + return input_list + + def run(self, meta: dict): + example = self._check_input_text(meta)[0] + + if "word_boxes" in example.keys(): + ocr_result = example["word_boxes"] + example["ocr_type"] = "word_boxes" + else: + ocr_result = self._ocr.ocr(example['doc'], cls=True) + example["ocr_type"] = "ppocr" + # Compatible with paddleocr>=2.6.0.2 + ocr_result = ocr_result[0] if len(ocr_result) == 1 else ocr_result + example["ocr_result"] = ocr_result + output = {"example": example} + return output, "output_1" diff --git a/pipelines/pipelines/pipelines/__init__.py b/pipelines/pipelines/pipelines/__init__.py index a1a25f53aa4b..91d1518940ef 100644 --- a/pipelines/pipelines/pipelines/__init__.py +++ b/pipelines/pipelines/pipelines/__init__.py @@ -15,4 +15,5 @@ from pipelines.pipelines.base import Pipeline, RootNode from pipelines.pipelines.standard_pipelines import (BaseStandardPipeline, ExtractiveQAPipeline, - SemanticSearchPipeline) + SemanticSearchPipeline, + DocPipeline) diff --git a/pipelines/pipelines/pipelines/standard_pipelines.py b/pipelines/pipelines/pipelines/standard_pipelines.py index d459c33db7c2..b35a58417291 100644 --- a/pipelines/pipelines/pipelines/standard_pipelines.py +++ b/pipelines/pipelines/pipelines/standard_pipelines.py @@ -25,6 +25,7 @@ from pipelines.nodes.retriever import BaseRetriever from pipelines.document_stores import BaseDocumentStore from pipelines.pipelines import Pipeline +from pipelines.nodes.base import BaseComponent logger = logging.getLogger(__name__) @@ -263,3 +264,38 @@ def run(self, """ output = self.pipeline.run(query=query, params=params, debug=debug) return output + + +class DocPipeline(BaseStandardPipeline): + """ + Pipeline for document intelligence. + """ + + def __init__(self, preprocessor: BaseComponent, modelrunner: BaseComponent): + """ + :param preprocessor: file/image preprocessor instance + :param modelrunner: document model runner instance + """ + self.pipeline = Pipeline() + self.pipeline.add_node(component=preprocessor, + name="PreProcessor", + inputs=["Query"]) + self.pipeline.add_node(component=modelrunner, + name="Runner", + inputs=["PreProcessor"]) + + def run(self, + meta: dict, + params: Optional[dict] = None, + debug: Optional[bool] = None): + """ + :param query: the query string. + :param params: params for the `retriever` and `reader`. For instance, params={"Retriever": {"top_k": 10}} + :param debug: Whether the pipeline should instruct nodes to collect debug information + about their execution. By default these include the input parameters + they received and the output they generated. + All debug information can then be found in the dict returned + by this method under the key "_debug" + """ + output = self.pipeline.run(meta=meta, params=params, debug=debug) + return output diff --git a/pipelines/rest_api/controller/search.py b/pipelines/rest_api/controller/search.py index 29c7359e608b..77d55777493f 100644 --- a/pipelines/rest_api/controller/search.py +++ b/pipelines/rest_api/controller/search.py @@ -27,7 +27,7 @@ from pipelines.pipelines.base import Pipeline from rest_api.config import PIPELINE_YAML_PATH, QUERY_PIPELINE_NAME from rest_api.config import LOG_LEVEL, CONCURRENT_REQUEST_PER_WORKER -from rest_api.schema import QueryRequest, QueryResponse +from rest_api.schema import QueryRequest, QueryResponse, DocumentRequest, DocumentResponse from rest_api.controller.utils import RequestLimiter logging.getLogger("pipelines").setLevel(LOG_LEVEL) @@ -81,6 +81,22 @@ def query(request: QueryRequest): return result +@router.post("/query_documents", + response_model=DocumentResponse, + response_model_exclude_none=True) +def query_documents(request: DocumentRequest): + """ + This endpoint receives the question as a string and allows the requester to set + additional parameters that will be passed on to the pipelines pipeline. + """ + result = {} + result['meta'] = request.meta + params = request.params or {} + res = PIPELINE.run(meta=request.meta, params=params, debug=request.debug) + result['results'] = res['results'] + return result + + def _process_request(pipeline, request) -> Dict[str, Any]: start_time = time.time() diff --git a/pipelines/rest_api/pipeline/docprompt.yaml b/pipelines/rest_api/pipeline/docprompt.yaml new file mode 100644 index 000000000000..819de3026217 --- /dev/null +++ b/pipelines/rest_api/pipeline/docprompt.yaml @@ -0,0 +1,29 @@ +version: '1.1.0' + +components: + - name: PreProcessor + params: + use_gpu: True + lang: ch + type: DocPreProcessor + - name: Runner + params: + topn: 1 + use_gpu: True + task_path: + model: docprompt + device_id: 0 + num_threads: + lang: ch + batch_size: 1 + type: DocPrompter + +pipelines: + - name: query_documents + nodes: + - name: PreProcessor + inputs: [Query] + - name: Runner + inputs: [PreProcessor] + + diff --git a/pipelines/rest_api/schema.py b/pipelines/rest_api/schema.py index 942a4e7029ef..5fe98a7990d0 100644 --- a/pipelines/rest_api/schema.py +++ b/pipelines/rest_api/schema.py @@ -83,3 +83,19 @@ class QueryResponse(BaseModel): answers: List[AnswerSerialized] = [] documents: List[DocumentSerialized] = [] debug: Optional[Dict] = Field(None, alias="_debug") + + +class DocumentRequest(BaseModel): + meta: dict + params: Optional[dict] = None + debug: Optional[bool] = False + + class Config: + # Forbid any extra fields in the request to avoid silent failures + extra = Extra.forbid + + +class DocumentResponse(BaseModel): + meta: dict + results: List[List[dict]] = [] + debug: Optional[Dict] = Field(None, alias="_debug") diff --git a/pipelines/ui/webapp_docprompt_gradio.py b/pipelines/ui/webapp_docprompt_gradio.py new file mode 100644 index 000000000000..f4a72cdf2dae --- /dev/null +++ b/pipelines/ui/webapp_docprompt_gradio.py @@ -0,0 +1,349 @@ +#-*- coding: UTF-8 -*- +# Copyright 2022 The Impira Team and the HuggingFace Team. +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import json +import base64 +from io import BytesIO +from PIL import Image +import traceback +import argparse + +import requests +import numpy as np +import gradio as gr +import fitz +import cv2 + +fitz_tools = fitz.Tools() + +# yapf: disable +parser = argparse.ArgumentParser() +parser.add_argument('--serving_name', default="0.0.0.0", help="Serving ip.") +parser.add_argument("--serving_port", default=8891, type=int, help="Serving port.") +args = parser.parse_args() +# yapf: enable + + +def load_document(path): + if path.startswith("http://") or path.startswith("https://"): + resp = requests.get(path, allow_redirects=True, stream=True) + b = resp.raw + else: + b = open(path, "rb") + + image = Image.open(b) + images_list = [np.array(image.convert("RGB"))] + return images_list + + +def process_path(path): + error = None + if path: + try: + images_list = load_document(path) + return ( + path, + gr.update(visible=True, value=images_list), + gr.update(visible=True), + gr.update(visible=False, value=None), + gr.update(visible=False, value=None), + None, + ) + except Exception as e: + traceback.print_exc() + error = str(e) + return ( + None, + gr.update(visible=False, value=None), + gr.update(visible=False), + gr.update(visible=False, value=None), + gr.update(visible=False, value=None), + gr.update(visible=True, value=error) if error is not None else None, + None, + ) + + +def process_upload(file): + if file: + return process_path(file.name) + else: + return ( + None, + gr.update(visible=False, value=None), + gr.update(visible=False), + gr.update(visible=False, value=None), + gr.update(visible=False, value=None), + None, + ) + + +def np2base64(image_np): + image = cv2.imencode('.jpg', image_np)[1] + base64_str = str(base64.b64encode(image))[2:-1] + return base64_str + + +def get_base64(path): + if path.startswith("http://") or path.startswith("https://"): + resp = requests.get(path, allow_redirects=True, stream=True) + b = resp.raw + else: + b = open(path, "rb") + + base64_str = base64.b64encode(b.read()).decode() + return base64_str + + +def process_prompt(prompt, document): + if not prompt: + prompt = "What is the total actual and/or obligated expenses of ECG Center?" + if document is None: + return None, None, None + + url = f"http://{args.serving_name}:{args.serving_port}/query_documents" + base64_str = get_base64(document) + r = requests.post(url, + json={"meta": { + "doc": base64_str, + "prompt": [prompt] + }}) + response = r.json() + predictions = response['results'][0] + pages = [Image.open(BytesIO(base64.b64decode(base64_str)))] + + text_value = predictions[0]['result'][0]['value'] + + return ( + gr.update(visible=True, value=pages), + gr.update(visible=True, value=predictions), + gr.update( + visible=True, + value=text_value, + ), + ) + + +def read_content(file_path: str) -> str: + """read the content of target file + """ + with open(file_path, 'r', encoding='utf-8') as f: + content = f.read() + + return content + + +CSS = """ +#prompt input { + font-size: 16px; +} +#url-textbox { + padding: 0 !important; +} +#short-upload-box .w-full { + min-height: 10rem !important; +} +/* I think something like this can be used to re-shape + * the table + */ +/* +.gr-samples-table tr { + display: inline; +} +.gr-samples-table .p-2 { + width: 100px; +} +*/ +#select-a-file { + width: 100%; +} +#file-clear { + padding-top: 2px !important; + padding-bottom: 2px !important; + padding-left: 8px !important; + padding-right: 8px !important; + margin-top: 10px; +} +.gradio-container .gr-button-primary { + background: linear-gradient(180deg, #CDF9BE 0%, #AFF497 100%); + border: 1px solid #B0DCCC; + border-radius: 8px; + color: #1B8700; +} +.gradio-container.dark button#submit-button { + background: linear-gradient(180deg, #CDF9BE 0%, #AFF497 100%); + border: 1px solid #B0DCCC; + border-radius: 8px; + color: #1B8700 +} +table.gr-samples-table tr td { + border: none; + outline: none; +} +table.gr-samples-table tr td:first-of-type { + width: 0%; +} +div#short-upload-box div.absolute { + display: none !important; +} +gradio-app > div > div > div > div.w-full > div, .gradio-app > div > div > div > div.w-full > div { + gap: 0px 2%; +} +gradio-app div div div div.w-full, .gradio-app div div div div.w-full { + gap: 0px; +} +gradio-app h2, .gradio-app h2 { + padding-top: 10px; +} +#answer { + overflow-y: scroll; + color: white; + background: #666; + border-color: #666; + font-size: 20px; + font-weight: bold; +} +#answer span { + color: white; +} +#answer textarea { + color:white; + background: #777; + border-color: #777; + font-size: 18px; +} +#url-error input { + color: red; +} +""" + +with gr.Blocks(css=CSS) as demo: + document = gr.Variable() + example_prompt = gr.Textbox(visible=False) + example_image = gr.Image(visible=False) + with gr.Row(equal_height=True): + with gr.Column(): + with gr.Row(): + gr.Markdown("## 1. Select a file", elem_id="select-a-file") + img_clear_button = gr.Button("Clear", + variant="secondary", + elem_id="file-clear", + visible=False) + image = gr.Gallery(visible=False) + with gr.Row(equal_height=True): + with gr.Column(): + with gr.Row(): + url = gr.Textbox( + show_label=False, + placeholder="URL", + lines=1, + max_lines=1, + elem_id="url-textbox", + ) + submit = gr.Button("Get") + url_error = gr.Textbox( + visible=False, + elem_id="url-error", + max_lines=1, + interactive=False, + label="Error", + ) + gr.Markdown("— or —") + upload = gr.File(label=None, + interactive=True, + elem_id="short-upload-box") + + with gr.Column() as col: + gr.Markdown("## 2. Make a request") + prompt = gr.Textbox( + label= + "Prompt (No restrictions on the setting of prompt. You can type any prompt.)", + placeholder= + "e.g. What is the total actual and/or obligated expenses of ECG Center?", + lines=1, + max_lines=1, + ) + + with gr.Row(): + clear_button = gr.Button("Clear", variant="secondary") + submit_button = gr.Button("Submit", + variant="primary", + elem_id="submit-button") + with gr.Column(): + output_text = gr.Textbox(label="Top Answer", + visible=False, + elem_id="answer") + output = gr.JSON(label="Output", visible=False) + + for cb in [img_clear_button, clear_button]: + cb.click( + lambda _: ( + gr.update(visible=False, value=None), + None, + gr.update(visible=False, value=None), + gr.update(visible=False, value=None), + gr.update(visible=False), + None, + None, + None, + gr.update(visible=False, value=None), + None, + ), + inputs=clear_button, + outputs=[ + image, + document, + output, + output_text, + img_clear_button, + example_image, + upload, + url, + url_error, + prompt, + ], + ) + + upload.change( + fn=process_upload, + inputs=[upload], + outputs=[ + document, image, img_clear_button, output, output_text, url_error + ], + ) + submit.click( + fn=process_path, + inputs=[url], + outputs=[ + document, image, img_clear_button, output, output_text, url_error + ], + ) + + prompt.submit( + fn=process_prompt, + inputs=[prompt, document], + outputs=[image, output, output_text], + ) + + submit_button.click( + fn=process_prompt, + inputs=[prompt, document], + outputs=[image, output, output_text], + ) + +if __name__ == "__main__": + # To create a public link, set `share=True` in `launch()`. + demo.launch(enable_queue=False, share=True) From 99eebd9bf74ebfcf84c4a1db20c63dbc1a1a7692 Mon Sep 17 00:00:00 2001 From: lugimzzz Date: Fri, 21 Oct 2022 12:34:13 +0000 Subject: [PATCH 2/6] t push lugim doc --- pipelines/examples/document-intelligence/docprompt_example.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pipelines/examples/document-intelligence/docprompt_example.py b/pipelines/examples/document-intelligence/docprompt_example.py index 36918569f0ed..b8ad1133e9c4 100644 --- a/pipelines/examples/document-intelligence/docprompt_example.py +++ b/pipelines/examples/document-intelligence/docprompt_example.py @@ -38,7 +38,7 @@ def docprompt_pipeline(): meta = {"doc": "./invoice.jpg", "prompt": ["发票号码是多少?", "校验码是多少?"]} prediction = pipe.run(meta=meta) - print(prediction["results"]) + print(prediction["results"][0]) if __name__ == "__main__": From 0620ed5646a15c86f663706fd965e51b9e6853cd Mon Sep 17 00:00:00 2001 From: lugimzzz Date: Fri, 21 Oct 2022 12:36:23 +0000 Subject: [PATCH 3/6] code_style --- pipelines/pipelines/nodes/__init__.py | 1 - pipelines/pipelines/pipelines/standard_pipelines.py | 4 ++-- pipelines/rest_api/controller/search.py | 4 ++-- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/pipelines/pipelines/nodes/__init__.py b/pipelines/pipelines/nodes/__init__.py index 7ecfc223c3d7..e46e730e056a 100644 --- a/pipelines/pipelines/nodes/__init__.py +++ b/pipelines/pipelines/nodes/__init__.py @@ -31,4 +31,3 @@ from pipelines.nodes.retriever import BaseRetriever, DensePassageRetriever from pipelines.nodes.document import DocPreProcessor, DocPrompter from pipelines.nodes.text_to_image_generator import ErnieTextToImageGenerator - diff --git a/pipelines/pipelines/pipelines/standard_pipelines.py b/pipelines/pipelines/pipelines/standard_pipelines.py index 0d199ccec8ed..91964e3f1c0f 100644 --- a/pipelines/pipelines/pipelines/standard_pipelines.py +++ b/pipelines/pipelines/pipelines/standard_pipelines.py @@ -300,8 +300,8 @@ def run(self, """ output = self.pipeline.run(meta=meta, params=params, debug=debug) return output - - + + class TextToImagePipeline(BaseStandardPipeline): """ A simple pipeline that takes prompt texts as input and generates diff --git a/pipelines/rest_api/controller/search.py b/pipelines/rest_api/controller/search.py index aa6595f02414..bb2d010f50f3 100644 --- a/pipelines/rest_api/controller/search.py +++ b/pipelines/rest_api/controller/search.py @@ -100,8 +100,8 @@ def query_images(request: QueryRequest): if not "answers" in result: result["answers"] = [] return result - - + + @router.post("/query_documents", response_model=DocumentResponse, response_model_exclude_none=True) From 20d7634cd6ba02b291c56a5f5a9c94ac35dcdd40 Mon Sep 17 00:00:00 2001 From: lugimzzz Date: Mon, 24 Oct 2022 06:27:06 +0000 Subject: [PATCH 4/6] add example readme & change node name --- .../examples/document-intelligence/README.md | 82 +++++++++++++++++++ .../docprompt_example.py | 4 +- .../document-intelligence/requirements.txt | 4 +- pipelines/pipelines/nodes/__init__.py | 2 +- .../pipelines/nodes/document/__init__.py | 2 +- .../nodes/document/document_intelligence.py | 3 +- .../nodes/document/document_preprocessor.py | 2 +- pipelines/rest_api/pipeline/docprompt.yaml | 2 +- 8 files changed, 90 insertions(+), 11 deletions(-) create mode 100644 pipelines/examples/document-intelligence/README.md diff --git a/pipelines/examples/document-intelligence/README.md b/pipelines/examples/document-intelligence/README.md new file mode 100644 index 000000000000..9a364f1a760b --- /dev/null +++ b/pipelines/examples/document-intelligence/README.md @@ -0,0 +1,82 @@ +# 端到端开放文档抽取问答系统 + +## 1. 系统介绍 + +开放文档抽取问答主要指对于网页、数字文档或扫描文档所包含的文本以及丰富的排版格式等信息,通过人工智能技术进行理解、分类、提取以及信息归纳的过程。开放文档抽取问答技术广泛应用于金融、保险、能源、物流、医疗等行业,常见的应用场景包括财务报销单、招聘简历、企业财报、合同文书、动产登记证、法律判决书、物流单据等多模态文档的关键信息抽取、问题回答等。 + +本项目提供了低成本搭建端到端开放文档抽取问答系统的能力。用户只需要处理好自己的业务数据,就可以使用本项目预置的开放文档抽取问答系统模型(文档OCR预处理模型、文档抽取问答模型)快速搭建一个针对自己业务数据的文档抽取问答系统,并提供基于[Gradio](https://gradio.app/) 的 Web 可视化服务。 + + +## 2. 快速开始 + +以下是针对mac和linux的安装流程: + + +### 2.1 运行环境 + +**安装PaddlePaddle:** + + 环境中paddlepaddle-gpu或paddlepaddle版本应大于或等于2.3, 请参见[飞桨快速安装](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html)根据自己需求选择合适的PaddlePaddle下载命令。 + +**安装Paddle-Pipelines:** + +```bash +# pip 一键安装 +pip install --upgrade paddle-pipelines -i https://pypi.tuna.tsinghua.edu.cn/simple +# 或者源码进行安装最新版本 +cd ${HOME}/PaddleNLP/pipelines/ +pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple +python setup.py install +``` + +【注意】以下的所有的流程都只需要在`pipelines`根目录下进行,不需要跳转目录 + +### 2.2 一键体验问答系统 +您可以通过如下命令快速体验开放文档抽取问答系统的效果。 + + +```bash +# 我们建议在 GPU 环境下运行本示例,运行速度较快 +# 设置 1 个空闲的 GPU 卡,此处假设 0 卡为空闲 GPU +export CUDA_VISIBLE_DEVICES=0 +python examples/document-intelligence/docprompt_example.py --device gpu +# 如果只有 CPU 机器,可以通过 --device 参数指定 cpu 即可, 运行耗时较长 +unset CUDA_VISIBLE_DEVICES +python examples/document-intelligence/docprompt_example.py --device cpu +``` + +### 2.3 构建 Web 可视化开放文档抽取问答系统 + +整个 Web 可视化问答系统主要包含两大组件: 1. 基于 RestAPI 构建模型服务 2. 基于 Gradio 构建 WebUI。接下来我们依次搭建这 2 个服务并串联构成可视化的开放文档抽取问答系统。 + +#### 2.3.1 启动 RestAPI 模型服务 +```bash +# 指定智能问答系统的Yaml配置文件 +export PIPELINE_YAML_PATH=rest_api/pipeline/docprompt.yaml +# 使用端口号 8891 启动模型服务 +python rest_api/application.py 8891 +``` +Linux 用户推荐采用 Shell 脚本来启动服务: + +```bash +sh examples/document-intelligence/run_docprompt_server.sh +``` +启动后可以使用curl命令验证是否成功运行: + +``` +curl --request POST --url 'http://localhost:/query_documents' -H "Content-Type: application/json" --data '{"doc": "https://bj.bcebos.com/paddlenlp/taskflow/document_intelligence/images/invoice.jpg", "prompt": ["发票号码是多少?", "校验码是多少?"]}' +``` + +#### 2.3.2 启动 WebUI + +```bash +python ui/webapp_docprompt_gradio.py +``` + +Linux 用户推荐采用 Shell 脚本来启动服务: + +```bash +sh examples/document-intelligence/run_docprompt_web.sh +``` + +到这里您就可以打开浏览器访问 http://127.0.0.1:8502 地址体验开放文档抽取问答系统系统服务了。 diff --git a/pipelines/examples/document-intelligence/docprompt_example.py b/pipelines/examples/document-intelligence/docprompt_example.py index b8ad1133e9c4..28bdc6c305c9 100644 --- a/pipelines/examples/document-intelligence/docprompt_example.py +++ b/pipelines/examples/document-intelligence/docprompt_example.py @@ -17,7 +17,7 @@ import os import paddle -from pipelines.nodes import DocPreProcessor, DocPrompter +from pipelines.nodes import DocOCRProcessor, DocPrompter from pipelines import DocPipeline # yapf: disable @@ -32,7 +32,7 @@ def docprompt_pipeline(): use_gpu = True if args.device == 'gpu' else False - preprocessor = DocPreProcessor(use_gpu=use_gpu) + preprocessor = DocOCRProcessor(use_gpu=use_gpu) docprompter = DocPrompter(use_gpu=use_gpu, batch_size=args.batch_size) pipe = DocPipeline(preprocessor=preprocessor, modelrunner=docprompter) meta = {"doc": "./invoice.jpg", "prompt": ["发票号码是多少?", "校验码是多少?"]} diff --git a/pipelines/examples/document-intelligence/requirements.txt b/pipelines/examples/document-intelligence/requirements.txt index dc65d4a14e9c..1db7aea116e2 100644 --- a/pipelines/examples/document-intelligence/requirements.txt +++ b/pipelines/examples/document-intelligence/requirements.txt @@ -1,3 +1 @@ -numpy -opencv-python -requests +opencv-python \ No newline at end of file diff --git a/pipelines/pipelines/nodes/__init__.py b/pipelines/pipelines/nodes/__init__.py index e46e730e056a..4b2a2e02aacb 100644 --- a/pipelines/pipelines/nodes/__init__.py +++ b/pipelines/pipelines/nodes/__init__.py @@ -29,5 +29,5 @@ from pipelines.nodes.ranker import BaseRanker, ErnieRanker from pipelines.nodes.reader import BaseReader, ErnieReader from pipelines.nodes.retriever import BaseRetriever, DensePassageRetriever -from pipelines.nodes.document import DocPreProcessor, DocPrompter +from pipelines.nodes.document import DocOCRProcessor, DocPrompter from pipelines.nodes.text_to_image_generator import ErnieTextToImageGenerator diff --git a/pipelines/pipelines/nodes/document/__init__.py b/pipelines/pipelines/nodes/document/__init__.py index 7f07d481b6c0..a1da58e200f5 100644 --- a/pipelines/pipelines/nodes/document/__init__.py +++ b/pipelines/pipelines/nodes/document/__init__.py @@ -12,5 +12,5 @@ # See the License for the specific language governing permissions and # limitations under the License. -from pipelines.nodes.document.document_preprocessor import DocPreProcessor +from pipelines.nodes.document.document_preprocessor import DocOCRProcessor from pipelines.nodes.document.document_intelligence import DocPrompter \ No newline at end of file diff --git a/pipelines/pipelines/nodes/document/document_intelligence.py b/pipelines/pipelines/nodes/document/document_intelligence.py index 366ef498aa54..6785afca255b 100644 --- a/pipelines/pipelines/nodes/document/document_intelligence.py +++ b/pipelines/pipelines/nodes/document/document_intelligence.py @@ -15,13 +15,12 @@ import collections import math from multiprocessing import cpu_count -from typing import Dict, List +from typing import List import logging import paddle from paddlenlp.transformers import AutoTokenizer from paddlenlp.taskflow.utils import download_file, ImageReader, get_doc_pred, find_answer_pos, sort_res -from paddlenlp.taskflow.task import Task from paddlenlp.utils.env import PPNLP_HOME from pipelines.nodes.base import BaseComponent diff --git a/pipelines/pipelines/nodes/document/document_preprocessor.py b/pipelines/pipelines/nodes/document/document_preprocessor.py index ec36d97b4745..471128fb1b98 100644 --- a/pipelines/pipelines/nodes/document/document_preprocessor.py +++ b/pipelines/pipelines/nodes/document/document_preprocessor.py @@ -30,7 +30,7 @@ logger = logging.getLogger(__name__) -class DocPreProcessor(BaseComponent): +class DocOCRProcessor(BaseComponent): """ Preprocess document input from image/image url/image bytestream to ocr outputs """ diff --git a/pipelines/rest_api/pipeline/docprompt.yaml b/pipelines/rest_api/pipeline/docprompt.yaml index 819de3026217..3d556b512b44 100644 --- a/pipelines/rest_api/pipeline/docprompt.yaml +++ b/pipelines/rest_api/pipeline/docprompt.yaml @@ -5,7 +5,7 @@ components: params: use_gpu: True lang: ch - type: DocPreProcessor + type: DocOCRProcessor - name: Runner params: topn: 1 From 7b4e0850ed88cf1751b3b6e7fc361082b6f66e01 Mon Sep 17 00:00:00 2001 From: lugimzzz Date: Mon, 24 Oct 2022 07:02:53 +0000 Subject: [PATCH 5/6] update_docprompt_pipelines --- pipelines/examples/document-intelligence/README.md | 5 +++-- .../examples/document-intelligence/docprompt_example.py | 9 ++++++++- pipelines/ui/webapp_docprompt_gradio.py | 5 ++--- 3 files changed, 13 insertions(+), 6 deletions(-) diff --git a/pipelines/examples/document-intelligence/README.md b/pipelines/examples/document-intelligence/README.md index 9a364f1a760b..daa110aad336 100644 --- a/pipelines/examples/document-intelligence/README.md +++ b/pipelines/examples/document-intelligence/README.md @@ -53,6 +53,7 @@ python examples/document-intelligence/docprompt_example.py --device cpu ```bash # 指定智能问答系统的Yaml配置文件 export PIPELINE_YAML_PATH=rest_api/pipeline/docprompt.yaml +export QUERY_PIPELINE_NAME=query_documents # 使用端口号 8891 启动模型服务 python rest_api/application.py 8891 ``` @@ -64,7 +65,7 @@ sh examples/document-intelligence/run_docprompt_server.sh 启动后可以使用curl命令验证是否成功运行: ``` -curl --request POST --url 'http://localhost:/query_documents' -H "Content-Type: application/json" --data '{"doc": "https://bj.bcebos.com/paddlenlp/taskflow/document_intelligence/images/invoice.jpg", "prompt": ["发票号码是多少?", "校验码是多少?"]}' +curl --request POST --url 'http://0.0.0.0:8891/query_documents' -H "Content-Type: application/json" --data '{"meta": {"doc": "https://bj.bcebos.com/paddlenlp/taskflow/document_intelligence/images/invoice.jpg", "prompt": ["发票号码是多少?", "校验码是多少?"]}}' ``` #### 2.3.2 启动 WebUI @@ -79,4 +80,4 @@ Linux 用户推荐采用 Shell 脚本来启动服务: sh examples/document-intelligence/run_docprompt_web.sh ``` -到这里您就可以打开浏览器访问 http://127.0.0.1:8502 地址体验开放文档抽取问答系统系统服务了。 +到这里您就可以打开浏览器访问 http://127.0.0.1:7860 地址体验开放文档抽取问答系统系统服务了。 diff --git a/pipelines/examples/document-intelligence/docprompt_example.py b/pipelines/examples/document-intelligence/docprompt_example.py index 28bdc6c305c9..8ecd3385e532 100644 --- a/pipelines/examples/document-intelligence/docprompt_example.py +++ b/pipelines/examples/document-intelligence/docprompt_example.py @@ -35,7 +35,14 @@ def docprompt_pipeline(): preprocessor = DocOCRProcessor(use_gpu=use_gpu) docprompter = DocPrompter(use_gpu=use_gpu, batch_size=args.batch_size) pipe = DocPipeline(preprocessor=preprocessor, modelrunner=docprompter) - meta = {"doc": "./invoice.jpg", "prompt": ["发票号码是多少?", "校验码是多少?"]} + # image link input + meta = { + "doc": + "https://bj.bcebos.com/paddlenlp/taskflow/document_intelligence/images/invoice.jpg", + "prompt": ["发票号码是多少?", "校验码是多少?"] + } + # image local path input + # meta = {"doc": "./invoice.jpg", "prompt": ["发票号码是多少?", "校验码是多少?"]} prediction = pipe.run(meta=meta) print(prediction["results"][0]) diff --git a/pipelines/ui/webapp_docprompt_gradio.py b/pipelines/ui/webapp_docprompt_gradio.py index f4a72cdf2dae..1c47981f31de 100644 --- a/pipelines/ui/webapp_docprompt_gradio.py +++ b/pipelines/ui/webapp_docprompt_gradio.py @@ -33,7 +33,7 @@ # yapf: disable parser = argparse.ArgumentParser() parser.add_argument('--serving_name', default="0.0.0.0", help="Serving ip.") -parser.add_argument("--serving_port", default=8891, type=int, help="Serving port.") +parser.add_argument("--serving_port", default=7860, type=int, help="Serving port.") args = parser.parse_args() # yapf: enable @@ -271,8 +271,7 @@ def read_content(file_path: str) -> str: prompt = gr.Textbox( label= "Prompt (No restrictions on the setting of prompt. You can type any prompt.)", - placeholder= - "e.g. What is the total actual and/or obligated expenses of ECG Center?", + placeholder="e.g. 校验码是多少?", lines=1, max_lines=1, ) From 2b356c89015f38551e6bb7d297c84ffddb1b2f0d Mon Sep 17 00:00:00 2001 From: lugimzzz Date: Mon, 24 Oct 2022 07:06:40 +0000 Subject: [PATCH 6/6] update_docprompt_pipelines --- pipelines/examples/document-intelligence/README.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/pipelines/examples/document-intelligence/README.md b/pipelines/examples/document-intelligence/README.md index daa110aad336..85136b614043 100644 --- a/pipelines/examples/document-intelligence/README.md +++ b/pipelines/examples/document-intelligence/README.md @@ -29,6 +29,11 @@ pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple python setup.py install ``` +**安装OpenCV:** +```bash +pip install opencv-python==4.6.0.66 +``` + 【注意】以下的所有的流程都只需要在`pipelines`根目录下进行,不需要跳转目录 ### 2.2 一键体验问答系统