MMOCR学習用のデータを自作して学習を行う

公開日:2022年11月16日
最終更新日:2022年11月24日

はじめに

最近OCR学習用のデータを作ることにはまっています。

以前に二つの方法を紹介しました。

TRDGを使う方法

touch-sp.hatenablog.com

PySide6を使う方法

touch-sp.hatenablog.com


今回は一つの実行ファイルで二つを同時に実行するスクリプトを書きました。

Pythonスクリプト

from PySide6.QtWidgets import QMainWindow, QApplication, QLabel
from PySide6.QtGui import QFont
from PySide6.QtCore import Qt, Signal, Slot, QThread
from trdg.generators import GeneratorFromStrings
import sys
import os
import glob
import json
import random
from PIL import ImageQt, Image
import time
import numpy as np

from argparse import ArgumentParser
parser = ArgumentParser()
parser.add_argument('--repeat', type=int, default=3, help='count of repeat' )
parser.add_argument('--noise', action='store_true', help="add Gaussian Noise")
parser.add_argument('--trdg', action='store_true', help="use TRDG")
args = parser.parse_args()
repeat_n = args.repeat
noise_true = args.noise
trdg_true = args.trdg

class MakeJson(QThread):
    makingJson_finish_signal = Signal(bool)
    def __init__(self):
        super().__init__()
    def run(self):
        data_list = []
        for i in range(repeat_n):
            for text_i, text in enumerate(texts):
                image_fname = f'{i}_{text_i}.jpg'
        
                data = {
                    'img_path': image_fname,
                    'instances':[{'text':text}]
                    }
                data_list.append(data)
        
        result = {
            'metainfo':{
                'dataset_type':'TextRecogDataset',
                'task_name':'textrecog'
            },
            'data_list':data_list
        }

        with open('train_labels.json', 'w', encoding='utf-8') as f:
            json.dump(result, f, indent=2, ensure_ascii=False)

        if trdg_true:
            with open('trainTRDG_labels.json', 'w', encoding='utf-8') as f:
                json.dump(result, f, indent=2, ensure_ascii=False)

        self.makingJson_finish_signal.emit(True)

class MakeImageTRDG(QThread):

    trdg_finish_signal = Signal(int)

    def __init__(self, thread_num):
        super().__init__()
        self.thread_num = thread_num
        self.fonts = random.sample(fontsTRDG, len(fontsTRDG))
        
    def run(self):
        
        generator = GeneratorFromStrings(
            texts,
            fonts = self.fonts,
            count = len(texts),
            # Text blurring
            blur = 1,
            random_blur = True,
            # Text skewing
            skewing_angle = 1,
            random_skew = True,
        )
        
        for i, (img, _) in enumerate(generator):
            image_fname = f'{self.thread_num}_{i}.jpg'
            img.save(os.path.join('trainTRDG', image_fname), quality=95)
        
        self.trdg_finish_signal.emit(self.thread_num)

class MakeImage(QThread):

    finish_signal = Signal(int)

    def __init__(self, thread_num):
        super().__init__()
        self.currentfont = QFont()
        self.thread_num = thread_num
        
    def run(self):
        self.label_1 = QLabel()
        self.label_1.setAlignment(Qt.AlignmentFlag.AlignCenter | Qt.AlignmentFlag.AlignVCenter)

        self.saveimage()

        self.finish_signal.emit(self.thread_num)
    
    def saveimage(self):
        for i, text in enumerate(texts):
            self.label_1.setText(text)

            # Font 
            random_font = random.randrange(0, len(fonts))
            fontfamily, bold = fonts[random_font].split(',')
            self.currentfont.setFamily(fontfamily)
            self.currentfont.setBold(int(bold))

            # Letter spacing
            random_spacing = random.randrange(start=85, stop=120, step=5)
            self.currentfont.setLetterSpacing(QFont.PercentageSpacing, random_spacing)

            # Font size
            random_font = random.randrange(start=16, stop=22, step=2)
            self.currentfont.setPointSize(random_font)

            self.label_1.setFont(self.currentfont)
            self.label_1.adjustSize()

            # Margin
            random_margin = random.randrange(start=4, stop=16, step=4)
            width = self.label_1.width() + random_margin
            height = self.label_1.height() + random_margin
            self.label_1.resize(width, height)

            image = ImageQt.fromqpixmap(self.label_1.grab()) #RGB

            # Noise
            if noise_true:
                original_img = np.array(image)
                noise = np.random.normal(0, 2, original_img.shape)
                image = Image.fromarray((original_img + noise).astype('uint8'))

            # Quality
            random_quality = random.randrange(start=85, stop=100, step=5)
            image_fname = f'{self.thread_num}_{i}.jpg'
            image.save(os.path.join('train', image_fname), quality = random_quality)

class Window(QMainWindow):

    def __init__(self):
        super().__init__()
        self.thread_count = 0
        self.finish_count = 0
        self.initUI()

    def initUI(self):

        self.thread_list = []

        self.start_time = time.time()

        for i in range(repeat_n):
            self.thread_list.append(MakeImage(i))
            self.thread_count += 1

        for i in range(repeat_n):
            self.thread_list[i].finish_signal.connect(self.pyside_finish)
        for i in range(repeat_n):
            self.thread_list[i].start()

        self.makingJsonThread = MakeJson()
        self.thread_count += 1
        self.makingJsonThread.makingJson_finish_signal.connect(self.makingJson_finish)
        self.makingJsonThread.start()

        if trdg_true:
            self.trdg_thread_list = []
            for i in range(repeat_n):
                self.trdg_thread_list.append(MakeImageTRDG(i))
                self.thread_count += 1
            for i in range(repeat_n):
                self.trdg_thread_list[i].trdg_finish_signal.connect(self.trdg_finish)
            for i in range(repeat_n):
                self.trdg_thread_list[i].start()

    @Slot(int)
    def pyside_finish(self, recieved_signal):
        self.thread_list[recieved_signal].quit()
        self.finish_count += 1
        self.all_finish()
    
    @Slot(int)
    def trdg_finish(self, recieved_signal):
        self.trdg_thread_list[recieved_signal].quit()
        self.finish_count += 1
        self.all_finish()

    @Slot(bool)
    def makingJson_finish(self, recieved_signal):
        if recieved_signal:
            self.makingJsonThread.quit()
            self.finish_count += 1
            self.all_finish()
    
    def all_finish(self):
        if self.finish_count == self.thread_count:
            collapsed = time.time() - self.start_time
            print(f'{collapsed} sec')
            sys.exit()

if __name__ == "__main__":

    os.makedirs('train', exist_ok=True)
    
    with open('fonts.txt', 'r', encoding='utf-8') as f:
        lines = f.readlines()
    fonts = [x.strip() for x in lines]

    with open('texts.txt', 'r', encoding='utf-8') as f:
        lines = f.readlines()
    texts = [x.strip() for x in lines]

    if trdg_true:
        os.makedirs('trainTRDG', exist_ok=True)
        fontsTRDG = glob.glob('fonts/*.ttf')

    app = QApplication([])
    ex =Window()
    app.exec()

実行環境

Windows 11
Python 3.11.0
pip install pyside6==6.4.0.1
pip install openpyxl==3.0.10
pip install trdg==1.8.0

準備するもの

テキストファイルと辞書ファイル

学習する文字列が書かれたテキストファイル(texts.txt)と使用している文字が書かれたテキストファイル(dicts.txt)を準備する必要があります。

今回は「お薬手帳」から薬名を抽出することを目的とします。

厚生労働省のこちらのページから薬価基準収載品目リストのExcelファイルをダウンロードさせて頂き以下のスクリプトで読み込みました。

import glob
import pandas as pd
import random

max_len = 25

ZEN = "".join(chr(0xff01 + i) for i in range(94))
HAN = "".join(chr(0x21 + i) for i in range(94))
HAN2ZEN = str.maketrans(HAN, ZEN)

excel_files = glob.glob('*.xlsx')

product_name = set([])

for excel_file in excel_files:
    df = pd.read_excel(excel_file)
    product_name = product_name.union(set(df['品名']))

product_name = set([x.replace(' ', '').replace(' ', '') for x in product_name])

new_product_name = []
        
for each in product_name:
    if len(each) < (max_len -3):
        new_product_name.append(each)
        random_nums = random.sample(range(20), 3)
        for num in random_nums:
            with_header = f'[{str(num).translate(HAN2ZEN)}]{each}'
            new_product_name.append(with_header)
    elif len(each) > max_len:
        new_product_name.append(each[:max_len])
        new_product_name.append(each[-max_len:])
    else:
        new_product_name.append(each)

all_names = set(new_product_name)

with open('texts.txt', 'w', encoding='utf-8') as f:
    f.write('\n'.join(all_names))

with open('dicts.txt', 'w', encoding='utf-8') as f:
    f.write('\n'.join(set(''.join(all_names))))

PySide6で使用するフォントファイル

使用するフォントを記入したテキストファイルです。

Arial,0
Arial,1
Courier New,0
Courier New,1
Consolas,0
Consolas,1
BIZ UDPゴシック,0
BIZ UDPゴシック,1
BIZ UDP明朝 Medium,0
Lucida Console,0
UD デジタル 教科書体 N-R,0
UD デジタル 教科書体 NK-R,0
メイリオ,0
メイリオ,1
游明朝,0
游ゴシック,0
游ゴシック,1
MS Pゴシック,0
MS P明朝,0
HGS創英角ゴシックUB,0

TRDGで使用するフォントファイル

こちらの方法で収集しました。

テスト用データも作る

from PySide6.QtWidgets import QMainWindow, QApplication, QLabel
from PySide6.QtGui import QFont
from PySide6.QtCore import Qt, Signal, Slot, QThread
from trdg.generators import GeneratorFromStrings
import sys
import os
import glob
import json
import random
from PIL import ImageQt, Image
import time
import numpy as np

from argparse import ArgumentParser
parser = ArgumentParser()
parser.add_argument('--sample', type=int, default= 500, help='count of samples')
parser.add_argument('--repeat', type=int, default=1, help='count of repeat' )
parser.add_argument('--noise', action='store_true', help="add Gaussian Noise")
args = parser.parse_args()
sample_n = args.sample
repeat_n = args.repeat
noise_true = args.noise

class MakeJson(QThread):
    makingJson_finish_signal = Signal(bool)
    def __init__(self):
        super().__init__()
    def run(self):
        data_list = []
        for i in range(repeat_n):
            for text_i, text in enumerate(texts):
                image_fname = f'{i}_{text_i}.jpg'
        
                data = {
                    'img_path': image_fname,
                    'instances':[{'text':text}]
                    }
                data_list.append(data)
        
        result = {
            'metainfo':{
                'dataset_type':'TextRecogDataset',
                'task_name':'textrecog'
            },
            'data_list':data_list
        }

        with open('test_labels.json', 'w', encoding='utf-8') as f:
            json.dump(result, f, indent=2, ensure_ascii=False)

        self.makingJson_finish_signal.emit(True)

class MakeImage(QThread):

    finish_signal = Signal(int)

    def __init__(self, thread_num):
        super().__init__()
        self.currentfont = QFont()
        self.thread_num = thread_num
        
    def run(self):
        self.label_1 = QLabel()
        self.label_1.setAlignment(Qt.AlignmentFlag.AlignCenter | Qt.AlignmentFlag.AlignVCenter)

        self.saveimage()

        self.finish_signal.emit(self.thread_num)
    
    def saveimage(self):
        for i, text in enumerate(texts):
            self.label_1.setText(text)

            # Font 
            random_font = random.randrange(0, len(fonts))
            fontfamily, bold = fonts[random_font].split(',')
            self.currentfont.setFamily(fontfamily)
            self.currentfont.setBold(int(bold))

            # Letter spacing
            random_spacing = random.randrange(start=85, stop=120, step=5)
            self.currentfont.setLetterSpacing(QFont.PercentageSpacing, random_spacing)

            # Font size
            random_font = random.randrange(start=16, stop=22, step=2)
            self.currentfont.setPointSize(random_font)

            self.label_1.setFont(self.currentfont)
            self.label_1.adjustSize()

            # Margin
            random_margin = random.randrange(start=4, stop=16, step=4)
            width = self.label_1.width() + random_margin
            height = self.label_1.height() + random_margin
            self.label_1.resize(width, height)

            image = ImageQt.fromqpixmap(self.label_1.grab()) #RGB

            # Noise
            if noise_true:
                original_img = np.array(image)
                noise = np.random.normal(0, 2, original_img.shape)
                image = Image.fromarray((original_img + noise).astype('uint8'))

            # Quality
            random_quality = random.randrange(start=85, stop=100, step=5)
            image_fname = f'{self.thread_num}_{i}.jpg'
            image.save(os.path.join('test', image_fname), quality = random_quality)

class Window(QMainWindow):

    def __init__(self):
        super().__init__()
        self.thread_count = 0
        self.finish_count = 0
        self.initUI()

    def initUI(self):

        self.thread_list = []

        self.start_time = time.time()

        for i in range(repeat_n):
            self.thread_list.append(MakeImage(i))
            self.thread_count += 1

        for i in range(repeat_n):
            self.thread_list[i].finish_signal.connect(self.pyside_finish)
        for i in range(repeat_n):
            self.thread_list[i].start()

        self.makingJsonThread = MakeJson()
        self.thread_count += 1
        self.makingJsonThread.makingJson_finish_signal.connect(self.makingJson_finish)
        self.makingJsonThread.start()

    @Slot(int)
    def pyside_finish(self, recieved_signal):
        self.thread_list[recieved_signal].quit()
        self.finish_count += 1
        self.all_finish()

    @Slot(bool)
    def makingJson_finish(self, recieved_signal):
        if recieved_signal:
            self.makingJsonThread.quit()
            self.finish_count += 1
            self.all_finish()
    
    def all_finish(self):
        if self.finish_count == self.thread_count:
            collapsed = time.time() - self.start_time
            print(f'{collapsed} sec')
            sys.exit()

if __name__ == "__main__":

    os.makedirs('test', exist_ok=True)
    
    with open('fonts.txt', 'r', encoding='utf-8') as f:
        lines = f.readlines()
    fonts = [x.strip() for x in lines]

    with open('texts.txt', 'r', encoding='utf-8') as f:
        lines = f.readlines()
    texts = [x.strip() for x in lines]

    texts = random.sample(texts, sample_n)

    app = QApplication([])
    ex =Window()
    app.exec()

学習する

学習データを作成する時と学習する時では必要となるライブラリが全然違いますので別の仮想環境を用意しました。

Windows 11
Python 3.10.7
pip install torch==1.12.1 torchvision==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu116
pip install openmim==0.3.3
pip install mmengine==0.3.1
pip install mmcv==2.0.0rc2 -f https://download.openmmlab.com/mmcv/dist/cu116/torch1.12.0/index.html
pip install mmdet==3.0.0rc3
pip install mmocr==1.0.0rc3
pip install albumentations==1.3.0



学習するには以下のconfigファイル(satrn_japanese_cfg.py)を用意したうえでPythonスクリプト(satrn_train.py)を実行します。

python satrn_train.py

satrn_japanese_cfg.py

train0 = dict(
    type='OCRDataset',
    data_prefix=dict(img_path='train'),
    ann_file='train_labels.json',
    test_mode=False,
    pipeline=None)

train1 = dict(
    type='OCRDataset',
    data_prefix=dict(img_path='trainTRDG'),
    ann_file='trainTRDG_labels.json',
    test_mode=False,
    pipeline=None)

test = dict(
    type='OCRDataset',
    data_prefix=dict(img_path='test'),
    ann_file='test_labels.json',
    test_mode=True,
    pipeline=None)

default_scope = 'mmocr'

env_cfg = dict(
    cudnn_benchmark=True,
    mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
    dist_cfg=dict(backend='nccl'))

randomness = dict(seed=None)

default_hooks = dict(
    timer=dict(type='IterTimerHook'),
    logger=dict(type='LoggerHook', interval=100),
    param_scheduler=dict(type='ParamSchedulerHook'),
    checkpoint=dict(type='CheckpointHook', interval=1),
    sampler_seed=dict(type='DistSamplerSeedHook'),
    sync_buffer=dict(type='SyncBuffersHook'),
    visualization=dict(
        type='VisualizationHook',
        interval=1,
        enable=False,
        show=False,
        draw_gt=False,
        draw_pred=False))

log_level = 'INFO'
log_processor = dict(type='LogProcessor', window_size=10, by_epoch=True)
load_from = None
resume = False

val_evaluator = dict(
    type='Evaluator',
    metrics=[
        dict(
            type='WordMetric',
            mode=['exact', 'ignore_case', 'ignore_case_symbol']),
        dict(type='CharMetric')
    ])

test_evaluator = dict(
    type='Evaluator',
    metrics=[
        dict(
            type='WordMetric',
            mode=['exact', 'ignore_case', 'ignore_case_symbol']),
        dict(type='CharMetric')
    ])

vis_backends = [dict(type='LocalVisBackend')]

visualizer = dict(
    type='TextRecogLocalVisualizer',
    name='visualizer',
    vis_backends=[dict(type='LocalVisBackend')])

optim_wrapper = dict(
    type='OptimWrapper', optimizer=dict(type='Adam', lr=0.0003))
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=5, val_interval=1)

val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')

param_scheduler = [dict(type='MultiStepLR', milestones=[3, 4], end=5)]

file_client_args = dict(backend='disk')

dictionary = dict(
    type='Dictionary',
    dict_file='dicts.txt',
    with_padding=True,
    with_unknown=True,
    same_start_end=True,
    with_start=True,
    with_end=True)

model = dict(
    type='SATRN',
    backbone=dict(type='ShallowCNN', input_channels=3, hidden_dim=512),
    encoder=dict(
        type='SATRNEncoder',
        n_layers=12,
        n_head=8,
        d_k=64,
        d_v=64,
        d_model=512,
        n_position=150,
        d_inner=2048,
        dropout=0.1),
    decoder=dict(
        type='NRTRDecoder',
        n_layers=6,
        d_embedding=512,
        n_head=8,
        d_model=512,
        d_inner=2048,
        d_k=64,
        d_v=64,
        module_loss=dict(
            type='CEModuleLoss', flatten=True, ignore_first_char=True),
        dictionary=dictionary,
        max_seq_len=25,
        postprocessor=dict(type='AttentionPostprocessor')),
    data_preprocessor=dict(
        type='TextRecogDataPreprocessor',
        mean=[123.675, 116.28, 103.53],
        std=[58.395, 57.12, 57.375]))

train_pipeline = [
    dict(
        type='LoadImageFromFile',
        file_client_args=dict(backend='disk'),
        ignore_empty=True,
        min_size=2),
    dict(type='LoadOCRAnnotations', with_text=True),
    dict(type='Resize', scale=(150, 32), keep_ratio=False),
    dict(
        type='RandomApply',
        prob=0.5,
        transforms=[
            dict(
                type='RandomChoice',
                transforms=[
                    dict(
                        type='RandomRotate',
                        max_angle=5,
                    ),
                ])
        ],
    ),
    dict(
        type='RandomApply',
        prob=0.25,
        transforms=[
            dict(type='PyramidRescale'),
            dict(
                type='mmdet.Albu',
                transforms=[
                    dict(type='GaussNoise', var_limit=(20, 20), p=0.5),
                    dict(type='MotionBlur', blur_limit=5, p=0.5),
                ]),
        ]),
    dict(
        type='RandomApply',
        prob=0.25,
        transforms=[
            dict(
                type='TorchVisionWrapper',
                op='ColorJitter',
                brightness=0.5,
                saturation=0.5,
                contrast=0.5,
                hue=0.1),
        ]),
    dict(
        type='PackTextRecogInputs',
        meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio'))
]

test_pipeline = [
    dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
    dict(type='Resize', scale=(150, 32), keep_ratio=False),
    dict(type='LoadOCRAnnotations', with_text=True),
    dict(
        type='PackTextRecogInputs',
        meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio'))
]

train_dataset = dict(
    type='ConcatDataset',
    datasets=[train0, train1],
    pipeline=train_pipeline)

test_dataset = dict(
    type='ConcatDataset',
    datasets=[test],
    pipeline=test_pipeline)

train_dataloader = dict(
    batch_size=16,
    num_workers=16,
    persistent_workers=True,
    sampler=dict(type='DefaultSampler', shuffle=True),
    dataset=train_dataset)

test_dataloader = dict(
    batch_size=1,
    num_workers=4,
    persistent_workers=True,
    drop_last=False,
    sampler=dict(type='DefaultSampler', shuffle=False),
    dataset=test_dataset)

val_dataloader = dict(
    batch_size=1,
    num_workers=4,
    persistent_workers=True,
    drop_last=False,
    sampler=dict(type='DefaultSampler', shuffle=False),
    dataset=test_dataset)

auto_scale_lr = dict(base_batch_size=512)

satrn_train.py

import os
from mmengine.config import Config
from mmengine.registry import RUNNERS
from mmengine.runner import Runner

def main():
    cfg = Config.fromfile('satrn_japanese_cfg.py')

    os.makedirs('satrn_output', exist_ok=True)

    ####
    ## modify configuration file
    ####

    # Set output dir
    cfg.work_dir = 'satrn_output'
    
    # Modify cuda setting
    cfg.gpu_ids = range(1)
    cfg.device = 'cuda'

    # Others
    cfg.train_cfg.max_epochs = 3 # default 5 
    cfg.default_hooks.logger.interval = 2000
    
    # Build the runner from config
    if 'runner_type' not in cfg:
        # build the default runner
        runner = Runner.from_cfg(cfg)
    else:
        # build customized runner from the registry
        # if 'runner_type' is set in the cfg
        runner = RUNNERS.build(cfg)

    # Start training
    runner.train()
    
if __name__ == '__main__':
    main()

結果

[4]ヒルドイドソフト軟膏0.3%
[3]メサデルム軟膏0.1%
[2]ロコイド軟膏0.1%
[1]ザイザルシロップ0.05%

完璧に薬名のみ抽出しています。大成功です。

テキスト検出モデルはMMOCRで公開されている学習済みモデル(textsnake_resnet50_fpn-unet_1200e_ctw1500)を使用しています。

下から上に検出する仕様のようです。

from mmocr.ocr import MMOCR
from mim.commands.download import download
import os
import sys
import numpy as np
import cv2

img = sys.argv[1]

os.makedirs('models', exist_ok=True)

det_checkpoint_name = 'textsnake_resnet50_fpn-unet_1200e_ctw1500'
checkpoint = download(package='mmocr', configs=[det_checkpoint_name], dest_root="models")
config_paths =os.path.join('models', det_checkpoint_name + '.py')
checkpoint_paths = os.path.join('models', checkpoint[0])

det_model = MMOCR(
    det_config = config_paths, 
    det_ckpt = checkpoint_paths,
    recog = None,
    device = 'cuda'
    )

recog_cfg = 'satrn_output/satrn_japanese_cfg.py'
recog_checkpoint = 'satrn_output/epoch_3.pth'

recog_model = MMOCR(
    det = None,
    recog_config = recog_cfg, 
    recog_ckpt = recog_checkpoint,
    device = 'cuda'
    )

det_result = det_model.readtext(img) # -> dict(key:['det_polygons', 'det_scores'])

polygons = det_result['det_polygons']              # -> list (len: number of bboxes)

original_image = cv2.imread(img)

for each_array in polygons:
    poly = np.array(each_array).reshape(-1, 1, 2).astype(np.float32) 

    x, y, width, height = cv2.boundingRect(poly)
    trim_image = original_image[y:y+height, x:x+width, :]

    recog_result = recog_model.readtext(trim_image, print_result=False, show=False)
    score = np.mean(np.array(recog_result['rec_scores']))
    if score > 0.95:
        print(recog_result['rec_texts'][0])