いちろう’s blog

すーぱーえんじにあ

GoogleVisionAPIでレシートから合計金額を抽出する

はじめに

お店でもらったレシートから自動で合計金額を抽出し、自宅のDBに保存できるといいなとふと思い、N番煎じではあるがレシートのOCRに挑戦してみた。同様の試みを行っている方は多数いたが、レシートを撮影する環境や頻繁に利用する店舗のレシートのフォーマットの傾向もあり、汎用的に利用できるものは存在しないため自作した。

主に下記の方の記事を参考に、自身の環境にチューニングしたものを作成してみた。

qiita.com

qiita.com

実装

前準備

今回OCRを行うAPIとして、GoogleVisionAPIを利用する。GoogleVisonAPIを利用する場合は、下記の公式ドキュメントのガイドを参考に、サービスアカウントを作成する必要がある。

cloud.google.com

サービスアカウント、サービスアカウントのキーの作成後、キーをダウンロードする。そのキーをconfig/service_account_key.jsonに保存する。

その後、GoogleVisionAPIを利用するライブラリをインストールする。

pip install google-cloud-vision==2.7.1

合計金額の抽出処理

抽出処理の流れ

合計金額の抽出は以下の流れで実行する。

  1. 画像に対してOCRを実行
  2. OCRの結果から、合計金額を表現する単語(今回は「合計」)のx,y座標を抽出
  3. 「合計」と同じy座標上かつ、右側に存在する文字(=候補文字)を抽出
  4. 候補文字を数字情報に再構成

1では、GoogleViosionAPIを利用し、レシート画像に対するOCRを行う。

2では、GoogleVisionAPIのOCRの結果から、合計金額を表現する単語(以下、キーワード)のx,y座標を抽出する。キーワードは店による揺らぎが多く、他の記事を見ると、「現計」「信用」「対象計算」などの単語で合計金額を表現しているものも存在する。しかし、自分の住んでいる周辺のスーパーやコンビニのレシートは、全て「合計」の文字の横に合計金額が記載されていたため、今回は「合計」の文字をキーワードとした。

3では、同じy座標上かつ、抽出したキーワードより右側にある文字を、合計金額の候補文字(以下、候補文字)として抽出する。イメージは以下の画像の通り。

合計金額抽出のイメージ

この方法だと、撮影されたレシートの傾き具合によっては、候補文字を検出できない場合がある。しかし今回レシートを撮影するのは自分で、傾きは自身でコントロールが可能なので、傾きによる誤検知のリスクは許容した。また、y座標の揺らぎの許容範囲を horizonal_thresholdで調整できるようにすることで、 horizonal_thresholdで指定したピクセル分のズレを無視できるように調整した。

4では、3で抽出した候補文字を再構成し金額に変換する。レシートによっては候補文字の抽出結果に揺らぎがあるので、それらを吸収できるようにした。具体的には下記のような変換処理を行なった。

  • 「1,000」場合、「1」「,」「0」「0」「0」とバラバラに抽出される場合があるので、バラバラに抽出された文字を結合。
  • 「¥」「,」「.(カンマの誤検知)」「円」の削除

上記の1〜4の処理を実行するスクリプトReceiptOcrClient.py として作成する。

import io
import json
import logging
import os
import re

from google.cloud import vision
from google.cloud.vision_v1 import AnnotateImageResponse

logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.DEBUG)
logger = logging.getLogger(__name__)


class ReceiptOcrClient:
    def __init__(
        self,
        credentials_path: str,
        horizonal_threshold: int = 15,
    ):
        os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = credentials_path
        # Instantiates a client
        self.client = vision.ImageAnnotatorClient()
        self.credentials_path = credentials_path
        self.horizonal_threshold = horizonal_threshold

        logger.debug(self.__dict__)

    def get_payment_amount(self, content: bytes, key_word: str = "合計") -> int:
        """レシートの画像から支払った合計金額を返す。

        Args:
            content (bytes): OCRを実行するレシート画像
            key_word (str, optional): レシート中の「合計金額」を意味する単語. Defaults to "合計".

        Returns:
            int: 検出された合計金額の値。検出できない場合は-999を返す。
        """
        response = self.ocr(target_image=content)
        texts = response["textAnnotations"]
        max_x, min_y, max_y = 0, 0, 0

        for text in texts[1:]:
            if key_word in text["description"]:
                vertices = [
                    "({},{})".format(vertex["x"], vertex["y"])
                    for vertex in text["boundingPoly"]["vertices"]
                ]
                max_x = text["boundingPoly"]["vertices"][1]["x"]  # 右上のポイントのXの値
                min_y = text["boundingPoly"]["vertices"][0]["y"]  # 左上のポイントのYの値
                max_y = text["boundingPoly"]["vertices"][3]["y"]  # 右下のポイントのYの値
                logger.info(
                    "検出された文字列: {}, 座標: {}".format(
                        text["description"], ",".join(vertices)
                    )
                )
                break

        payment_amount = ""
        for text in texts[1:]:
            target_min_x = text["boundingPoly"]["vertices"][1]["x"]  # 左上のポイントのXの値
            target_min_y = text["boundingPoly"]["vertices"][0]["y"]  # 左上のポイントのYの値
            target_max_y = text["boundingPoly"]["vertices"][3]["y"]  # 右下のポイントのYの値
            if (
                abs(target_min_y - min_y) < self.horizonal_threshold
                and abs(target_max_y - max_y) < self.horizonal_threshold
                and target_min_x - max_x > 0  # 合計の文字より右にある文字の場合
            ):
                logger.info(f"key_wordと同じ座標にある文字: {text['description']}")
                payment_amount += text["description"]

        payment_amount = re.sub(r"[¥,.]", "", payment_amount)  # 「¥」は無視
        logger.info(f"数字情報に再構成: {payment_amount}")
        try:
            return int(payment_amount)
        except ValueError:
            logger.error("レシートから合計金額を検出できませんでした。")
            return -999
        except:
            logger.error("実行中にエラーが発生しました。")
            return -999

    def get_payment_amount_from_filename(
        self, file_name: str, key_word: str = "合計"
    ) -> int:
        """レシートの画像から支払った合計金額を返す。

        Args:
            file_name (str): OCRを実行する画像ファイルのパス。
            key_word (str, optional): レシート中の「合計金額」を意味する単語. Defaults to "合計".

        Returns:
            int: 検出された合計金額の値。検出できない場合は-999を返す。
        """

        # Loads the image into memory
        with io.open(file_name, "rb") as image_file:
            content: bytes = image_file.read()

        return self.get_payment_amount(content=content, key_word=key_word)

    def ocr(self, target_image: bytes) -> dict:
        """file_nameに指定した画像に対して、CloudVisionAPIのOCR処理を実行する。

        Args:
            target_image (bytes): OCRを実行する画像。

        Returns:
            dict: _description_
        """
        image = vision.Image(content=target_image)
        response = self.client.document_text_detection(image=image)

        return json.loads(AnnotateImageResponse.to_json(response))


if __name__ == "__main__":
    import argparse

    # 引数の設定
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "target_img",
        help="OCR対象の画像のパス",
    )
    parser.add_argument(
        "--credential_path",
        help="GoogleCloudVisionAPIを利用する場合のサービスアカウントの認証情報。",
        default="config/service_account_key.json",
    )
    parser.add_argument(
        "--horizonal_threshold",
        help="キーワードと合計金額の水平位置のズレを、何pxまで許容するかの閾値",
        default=15,
        type=int,
    )
    args = parser.parse_args()

    # OCRクライアントの初期化
    receipt_ocr_client = ReceiptOcrClient(
        args.credential_path,
        horizonal_threshold=args.horizonal_threshold,
    )
    #  OCRの実行
    amount = receipt_ocr_client.get_payment_amount_from_filename(
        file_name=os.path.abspath(args.target_img)
    )
    logger.info(f"合計金額は{amount}円です!")

実行

上記のプログラムを実行する場合は、第一引数にOCR対象の画像を選択し実行する。

はじめに、傾きが少なくみやすいレシートのサンプル(sample.png)に対して合計金額の抽出処理を実行してみる。

レシートのサンプル(sample.png)

$ python ReceiptOcrClient.py sample.png
INFO:検出された文字列: 合計, 座標: (16,377),(87,376),(87,386),(16,387)
INFO:key_wordと同じ座標にある文字: ¥
INFO:key_wordと同じ座標にある文字: 1
INFO:key_wordと同じ座標にある文字: ,
INFO:key_wordと同じ座標にある文字: 161
INFO:数字情報に再構成: 1161 
INFO:合計金額は1161円です!

実行することで、レシートの合計金額を抽出できたことが確認できる。key_word(合計)の文字と同じy座標にある文字を抽出するのみでは、「¥」「1」「,」「161」と分割して検知されていたが、再構成処理で「1161」と期待通り変換して抽出できている。

また家で実際に出たレシート(sample2.png)に対しても実行してみる。(黒塗りは掲載用で、処理の実行時には外しています。)

レシートのサンプル2(sample2.png)

$ python ReceiptOcrClient.py sample2.png
INFO:検出された文字列: 合計, 座標: (593,1841),(762,1841),(762,1915),(593,1915)
INFO:key_wordと同じ座標にある文字: ¥
INFO:key_wordと同じ座標にある文字: 5.211
INFO:数字情報に再構成: 5211
INFO:合計金額は5211円です!

こちらは初期の検出の段階では「5.211」とカンマがピリオドに誤検知されていたが、再構成の処理によりピリオドが取り除かれ「5211」に変換されていることが確認できた。

おわりに

かなり込み入った処理を行うことにはなったが、レシートの合計金額の抽出に成功した。既存のレシートOCRAPIとして、LINEが提供している「CLOVA OCR」が存在するようだが、こちらは30日のトライアル期間以外は有料となるので、個人では使いにくい。

CLOVA OCR | LINE CLOVA公式サイト

それに対してGoogleのCloud Vision APIは、1000件/月は無料なので使いやすく、精度も非常に高くて素晴らしい。

将来的には今回開発したOCR機能をAPI化し、自動でDBに保存できるようにしたいなあと考え中。