Skip to content

Commit

Permalink
feat: filter OCR results for quality
Browse files Browse the repository at this point in the history
  • Loading branch information
Erik Neemann authored and steveoh committed Mar 16, 2023
1 parent 9071c35 commit 00db2bf
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 0 deletions.
72 changes: 72 additions & 0 deletions row.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""
import logging
import math
import re
from datetime import datetime
from io import BytesIO
from itertools import islice
Expand Down Expand Up @@ -859,6 +860,77 @@ def download_ocr_results(bucket_name, run_name, out_dir):
Path(ocr_file).unlink()

return out_dir


def filter_ocr_results(original_results_file, out_dir):
"""download ocr results from a GCP bucket
Args:
original_results (str): path to the parquet file with original combined results (path_to_file.gz)
out_dir (str): where to save the CSV file results
Returns:
str: the location of the output CSV file
"""
#: silence pandas SettingWithCopyWarning
pd.options.mode.chained_assignment = None

out_dir = Path(out_dir)

if not out_dir.exists():
out_dir.mkdir(parents=True)

results_df = pd.read_parquet(original_results_file)

orig_length = len(results_df.index)
logging.info("rumber of rows before cleanup: %i", orig_length)

#: Add column for the original UDOT filename
results_df["udot_file_name"] = results_df.apply(lambda r: r["file_name"].split("/mosaics/", 1)[1].strip(), axis=1)

#: Remove spaces and newline characters adjacent to colons
results_df["text"] = results_df.apply(lambda r: r["text"].replace(":\n", ":").strip(), axis=1)
results_df["text"] = results_df.apply(lambda r: r["text"].replace("\n:", ":").strip(), axis=1)
results_df["text"] = results_df.apply(lambda r: r["text"].replace(": ", ":").strip(), axis=1)
results_df["text"] = results_df.apply(lambda r: r["text"].replace(" :", ":").strip(), axis=1)
#: Then remove newline characters and replace with spaces
results_df["text"] = results_df.apply(lambda r: r["text"].replace("\n", " ").strip(), axis=1)

#: Remove special characters except for colons with a regular expression
regex = r"[^a-zA-Z0-9 ](?<!:)"
results_df["text"] = results_df.apply(lambda r: re.sub(regex, "", r["text"]), axis=1)

#: Convert string to list
results_df["text"] = results_df.apply(lambda r: r["text"].split(), axis=1)

#: Remove alpha-only items - not relevant, should start with a number
results_df["text"] = results_df.apply(lambda r: [item for item in r["text"] if not item.isalpha()], axis=1)

#: Remove rows that start with a letter, should start with a number
results_df["text"] = results_df.apply(
lambda r: [item for item in r["text"] if item and not item[0].isalpha()], axis=1
)

#: Remove rows where length of text list is zero
results_df = results_df[results_df["text"].apply(lambda r: len(r)) > 0]

#: Convert list column to string
results_df["text"] = results_df.apply(lambda r: " ".join(r["text"]), axis=1)

intermediate_length = len(results_df.index)
logging.info("rumber of rows before de-duplicating: %i", intermediate_length)
results_df.drop_duplicates(inplace=True, ignore_index=True)

final_length = len(results_df.index)
diff = intermediate_length - final_length
logging.info("rumber of rows after removing duplicates: %i", final_length)
logging.info("removed %i duplicate rows", diff)

out_file = out_dir / "filtered_ocr_results.csv"
results_df.to_csv(out_file)
logging.info("saved filtered ocr results to %s", out_file)

return out_dir
def summarize_run(folder, run_name):
"""summarize the results of a run
Expand Down
7 changes: 7 additions & 0 deletions row_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
row_cli.py results download <run_name> (--from=location)
row_cli.py results summarize <run_name> (--from=location)
row_cli.py ocr-results download <run_name> (--from=location --save-to=location)
row_cli.py ocr-results filter <file_name> (--save-to=location)
Options:
--from=location The bucket or directory to operate on
Expand All @@ -31,6 +32,7 @@
python row_cli.py process circles ---job=test --from=./test-data --save-to=./.ephemeral --index=./test-data --task-index=0 --file-count=1 --instances=1 --project=123456789 --processor=123456789
python row_cli.py results download bobcat --from=bucket-name
python row_cli.py ocr-results download alligator --from=bucket-name --save-to=./data
python row_cli.py ocr-results filter ./data/alligator/combined_ocr_results --save-to=./data
"""

import logging
Expand Down Expand Up @@ -171,6 +173,11 @@ def main():

print(f"files downloaded to {location}")

if args["ocr-results"] and args["filter"]:
location = row.filter_ocr_results(args["<file_name>"], args["--save-to"])

print(f"files downloaded to {location}")

if args["index"] and args["filter"]:
index = Path(args["<file_name>"])
total_lines = 0
Expand Down

0 comments on commit 00db2bf

Please sign in to comment.