公開日:2022年11月16日
最終更新日:2022年11月24日
はじめに
最近OCR学習用のデータを作ることにはまっています。以前に二つの方法を紹介しました。TRDGを使う方法
touch-sp.hatenablog.comPySide6を使う方法
touch-sp.hatenablog.com今回は一つの実行ファイルで二つを同時に実行するスクリプトを書きました。
Pythonスクリプト
from PySide6.QtWidgets import QMainWindow, QApplication, QLabel from PySide6.QtGui import QFont from PySide6.QtCore import Qt, Signal, Slot, QThread from trdg.generators import GeneratorFromStrings import sys import os import glob import json import random from PIL import ImageQt, Image import time import numpy as np from argparse import ArgumentParser parser = ArgumentParser() parser.add_argument('--repeat', type=int, default=3, help='count of repeat' ) parser.add_argument('--noise', action='store_true', help="add Gaussian Noise") parser.add_argument('--trdg', action='store_true', help="use TRDG") args = parser.parse_args() repeat_n = args.repeat noise_true = args.noise trdg_true = args.trdg class MakeJson(QThread): makingJson_finish_signal = Signal(bool) def __init__(self): super().__init__() def run(self): data_list = [] for i in range(repeat_n): for text_i, text in enumerate(texts): image_fname = f'{i}_{text_i}.jpg' data = { 'img_path': image_fname, 'instances':[{'text':text}] } data_list.append(data) result = { 'metainfo':{ 'dataset_type':'TextRecogDataset', 'task_name':'textrecog' }, 'data_list':data_list } with open('train_labels.json', 'w', encoding='utf-8') as f: json.dump(result, f, indent=2, ensure_ascii=False) if trdg_true: with open('trainTRDG_labels.json', 'w', encoding='utf-8') as f: json.dump(result, f, indent=2, ensure_ascii=False) self.makingJson_finish_signal.emit(True) class MakeImageTRDG(QThread): trdg_finish_signal = Signal(int) def __init__(self, thread_num): super().__init__() self.thread_num = thread_num self.fonts = random.sample(fontsTRDG, len(fontsTRDG)) def run(self): generator = GeneratorFromStrings( texts, fonts = self.fonts, count = len(texts), # Text blurring blur = 1, random_blur = True, # Text skewing skewing_angle = 1, random_skew = True, ) for i, (img, _) in enumerate(generator): image_fname = f'{self.thread_num}_{i}.jpg' img.save(os.path.join('trainTRDG', image_fname), quality=95) self.trdg_finish_signal.emit(self.thread_num) class MakeImage(QThread): finish_signal = Signal(int) def __init__(self, thread_num): super().__init__() self.currentfont = QFont() self.thread_num = thread_num def run(self): self.label_1 = QLabel() self.label_1.setAlignment(Qt.AlignmentFlag.AlignCenter | Qt.AlignmentFlag.AlignVCenter) self.saveimage() self.finish_signal.emit(self.thread_num) def saveimage(self): for i, text in enumerate(texts): self.label_1.setText(text) # Font random_font = random.randrange(0, len(fonts)) fontfamily, bold = fonts[random_font].split(',') self.currentfont.setFamily(fontfamily) self.currentfont.setBold(int(bold)) # Letter spacing random_spacing = random.randrange(start=85, stop=120, step=5) self.currentfont.setLetterSpacing(QFont.PercentageSpacing, random_spacing) # Font size random_font = random.randrange(start=16, stop=22, step=2) self.currentfont.setPointSize(random_font) self.label_1.setFont(self.currentfont) self.label_1.adjustSize() # Margin random_margin = random.randrange(start=4, stop=16, step=4) width = self.label_1.width() + random_margin height = self.label_1.height() + random_margin self.label_1.resize(width, height) image = ImageQt.fromqpixmap(self.label_1.grab()) #RGB # Noise if noise_true: original_img = np.array(image) noise = np.random.normal(0, 2, original_img.shape) image = Image.fromarray((original_img + noise).astype('uint8')) # Quality random_quality = random.randrange(start=85, stop=100, step=5) image_fname = f'{self.thread_num}_{i}.jpg' image.save(os.path.join('train', image_fname), quality = random_quality) class Window(QMainWindow): def __init__(self): super().__init__() self.thread_count = 0 self.finish_count = 0 self.initUI() def initUI(self): self.thread_list = [] self.start_time = time.time() for i in range(repeat_n): self.thread_list.append(MakeImage(i)) self.thread_count += 1 for i in range(repeat_n): self.thread_list[i].finish_signal.connect(self.pyside_finish) for i in range(repeat_n): self.thread_list[i].start() self.makingJsonThread = MakeJson() self.thread_count += 1 self.makingJsonThread.makingJson_finish_signal.connect(self.makingJson_finish) self.makingJsonThread.start() if trdg_true: self.trdg_thread_list = [] for i in range(repeat_n): self.trdg_thread_list.append(MakeImageTRDG(i)) self.thread_count += 1 for i in range(repeat_n): self.trdg_thread_list[i].trdg_finish_signal.connect(self.trdg_finish) for i in range(repeat_n): self.trdg_thread_list[i].start() @Slot(int) def pyside_finish(self, recieved_signal): self.thread_list[recieved_signal].quit() self.finish_count += 1 self.all_finish() @Slot(int) def trdg_finish(self, recieved_signal): self.trdg_thread_list[recieved_signal].quit() self.finish_count += 1 self.all_finish() @Slot(bool) def makingJson_finish(self, recieved_signal): if recieved_signal: self.makingJsonThread.quit() self.finish_count += 1 self.all_finish() def all_finish(self): if self.finish_count == self.thread_count: collapsed = time.time() - self.start_time print(f'{collapsed} sec') sys.exit() if __name__ == "__main__": os.makedirs('train', exist_ok=True) with open('fonts.txt', 'r', encoding='utf-8') as f: lines = f.readlines() fonts = [x.strip() for x in lines] with open('texts.txt', 'r', encoding='utf-8') as f: lines = f.readlines() texts = [x.strip() for x in lines] if trdg_true: os.makedirs('trainTRDG', exist_ok=True) fontsTRDG = glob.glob('fonts/*.ttf') app = QApplication([]) ex =Window() app.exec()
実行環境
Windows 11 Python 3.11.0
pip install pyside6==6.4.0.1 pip install openpyxl==3.0.10 pip install trdg==1.8.0
準備するもの
テキストファイルと辞書ファイル
学習する文字列が書かれたテキストファイル(texts.txt)と使用している文字が書かれたテキストファイル(dicts.txt)を準備する必要があります。今回は「お薬手帳」から薬名を抽出することを目的とします。厚生労働省のこちらのページから薬価基準収載品目リストのExcelファイルをダウンロードさせて頂き以下のスクリプトで読み込みました。import glob import pandas as pd import random max_len = 25 ZEN = "".join(chr(0xff01 + i) for i in range(94)) HAN = "".join(chr(0x21 + i) for i in range(94)) HAN2ZEN = str.maketrans(HAN, ZEN) excel_files = glob.glob('*.xlsx') product_name = set([]) for excel_file in excel_files: df = pd.read_excel(excel_file) product_name = product_name.union(set(df['品名'])) product_name = set([x.replace(' ', '').replace(' ', '') for x in product_name]) new_product_name = [] for each in product_name: if len(each) < (max_len -3): new_product_name.append(each) random_nums = random.sample(range(20), 3) for num in random_nums: with_header = f'[{str(num).translate(HAN2ZEN)}]{each}' new_product_name.append(with_header) elif len(each) > max_len: new_product_name.append(each[:max_len]) new_product_name.append(each[-max_len:]) else: new_product_name.append(each) all_names = set(new_product_name) with open('texts.txt', 'w', encoding='utf-8') as f: f.write('\n'.join(all_names)) with open('dicts.txt', 'w', encoding='utf-8') as f: f.write('\n'.join(set(''.join(all_names))))
PySide6で使用するフォントファイル
使用するフォントを記入したテキストファイルです。Arial,0 Arial,1 Courier New,0 Courier New,1 Consolas,0 Consolas,1 BIZ UDPゴシック,0 BIZ UDPゴシック,1 BIZ UDP明朝 Medium,0 Lucida Console,0 UD デジタル 教科書体 N-R,0 UD デジタル 教科書体 NK-R,0 メイリオ,0 メイリオ,1 游明朝,0 游ゴシック,0 游ゴシック,1 MS Pゴシック,0 MS P明朝,0 HGS創英角ゴシックUB,0
TRDGで使用するフォントファイル
こちらの方法で収集しました。テスト用データも作る
from PySide6.QtWidgets import QMainWindow, QApplication, QLabel from PySide6.QtGui import QFont from PySide6.QtCore import Qt, Signal, Slot, QThread from trdg.generators import GeneratorFromStrings import sys import os import glob import json import random from PIL import ImageQt, Image import time import numpy as np from argparse import ArgumentParser parser = ArgumentParser() parser.add_argument('--sample', type=int, default= 500, help='count of samples') parser.add_argument('--repeat', type=int, default=1, help='count of repeat' ) parser.add_argument('--noise', action='store_true', help="add Gaussian Noise") args = parser.parse_args() sample_n = args.sample repeat_n = args.repeat noise_true = args.noise class MakeJson(QThread): makingJson_finish_signal = Signal(bool) def __init__(self): super().__init__() def run(self): data_list = [] for i in range(repeat_n): for text_i, text in enumerate(texts): image_fname = f'{i}_{text_i}.jpg' data = { 'img_path': image_fname, 'instances':[{'text':text}] } data_list.append(data) result = { 'metainfo':{ 'dataset_type':'TextRecogDataset', 'task_name':'textrecog' }, 'data_list':data_list } with open('test_labels.json', 'w', encoding='utf-8') as f: json.dump(result, f, indent=2, ensure_ascii=False) self.makingJson_finish_signal.emit(True) class MakeImage(QThread): finish_signal = Signal(int) def __init__(self, thread_num): super().__init__() self.currentfont = QFont() self.thread_num = thread_num def run(self): self.label_1 = QLabel() self.label_1.setAlignment(Qt.AlignmentFlag.AlignCenter | Qt.AlignmentFlag.AlignVCenter) self.saveimage() self.finish_signal.emit(self.thread_num) def saveimage(self): for i, text in enumerate(texts): self.label_1.setText(text) # Font random_font = random.randrange(0, len(fonts)) fontfamily, bold = fonts[random_font].split(',') self.currentfont.setFamily(fontfamily) self.currentfont.setBold(int(bold)) # Letter spacing random_spacing = random.randrange(start=85, stop=120, step=5) self.currentfont.setLetterSpacing(QFont.PercentageSpacing, random_spacing) # Font size random_font = random.randrange(start=16, stop=22, step=2) self.currentfont.setPointSize(random_font) self.label_1.setFont(self.currentfont) self.label_1.adjustSize() # Margin random_margin = random.randrange(start=4, stop=16, step=4) width = self.label_1.width() + random_margin height = self.label_1.height() + random_margin self.label_1.resize(width, height) image = ImageQt.fromqpixmap(self.label_1.grab()) #RGB # Noise if noise_true: original_img = np.array(image) noise = np.random.normal(0, 2, original_img.shape) image = Image.fromarray((original_img + noise).astype('uint8')) # Quality random_quality = random.randrange(start=85, stop=100, step=5) image_fname = f'{self.thread_num}_{i}.jpg' image.save(os.path.join('test', image_fname), quality = random_quality) class Window(QMainWindow): def __init__(self): super().__init__() self.thread_count = 0 self.finish_count = 0 self.initUI() def initUI(self): self.thread_list = [] self.start_time = time.time() for i in range(repeat_n): self.thread_list.append(MakeImage(i)) self.thread_count += 1 for i in range(repeat_n): self.thread_list[i].finish_signal.connect(self.pyside_finish) for i in range(repeat_n): self.thread_list[i].start() self.makingJsonThread = MakeJson() self.thread_count += 1 self.makingJsonThread.makingJson_finish_signal.connect(self.makingJson_finish) self.makingJsonThread.start() @Slot(int) def pyside_finish(self, recieved_signal): self.thread_list[recieved_signal].quit() self.finish_count += 1 self.all_finish() @Slot(bool) def makingJson_finish(self, recieved_signal): if recieved_signal: self.makingJsonThread.quit() self.finish_count += 1 self.all_finish() def all_finish(self): if self.finish_count == self.thread_count: collapsed = time.time() - self.start_time print(f'{collapsed} sec') sys.exit() if __name__ == "__main__": os.makedirs('test', exist_ok=True) with open('fonts.txt', 'r', encoding='utf-8') as f: lines = f.readlines() fonts = [x.strip() for x in lines] with open('texts.txt', 'r', encoding='utf-8') as f: lines = f.readlines() texts = [x.strip() for x in lines] texts = random.sample(texts, sample_n) app = QApplication([]) ex =Window() app.exec()
学習する
学習データを作成する時と学習する時では必要となるライブラリが全然違いますので別の仮想環境を用意しました。Windows 11 Python 3.10.7
pip install torch==1.12.1 torchvision==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu116 pip install openmim==0.3.3 pip install mmengine==0.3.1 pip install mmcv==2.0.0rc2 -f https://download.openmmlab.com/mmcv/dist/cu116/torch1.12.0/index.html pip install mmdet==3.0.0rc3 pip install mmocr==1.0.0rc3 pip install albumentations==1.3.0
学習するには以下のconfigファイル(satrn_japanese_cfg.py)を用意したうえでPythonスクリプト(satrn_train.py)を実行します。
python satrn_train.py
satrn_japanese_cfg.py
train0 = dict( type='OCRDataset', data_prefix=dict(img_path='train'), ann_file='train_labels.json', test_mode=False, pipeline=None) train1 = dict( type='OCRDataset', data_prefix=dict(img_path='trainTRDG'), ann_file='trainTRDG_labels.json', test_mode=False, pipeline=None) test = dict( type='OCRDataset', data_prefix=dict(img_path='test'), ann_file='test_labels.json', test_mode=True, pipeline=None) default_scope = 'mmocr' env_cfg = dict( cudnn_benchmark=True, mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), dist_cfg=dict(backend='nccl')) randomness = dict(seed=None) default_hooks = dict( timer=dict(type='IterTimerHook'), logger=dict(type='LoggerHook', interval=100), param_scheduler=dict(type='ParamSchedulerHook'), checkpoint=dict(type='CheckpointHook', interval=1), sampler_seed=dict(type='DistSamplerSeedHook'), sync_buffer=dict(type='SyncBuffersHook'), visualization=dict( type='VisualizationHook', interval=1, enable=False, show=False, draw_gt=False, draw_pred=False)) log_level = 'INFO' log_processor = dict(type='LogProcessor', window_size=10, by_epoch=True) load_from = None resume = False val_evaluator = dict( type='Evaluator', metrics=[ dict( type='WordMetric', mode=['exact', 'ignore_case', 'ignore_case_symbol']), dict(type='CharMetric') ]) test_evaluator = dict( type='Evaluator', metrics=[ dict( type='WordMetric', mode=['exact', 'ignore_case', 'ignore_case_symbol']), dict(type='CharMetric') ]) vis_backends = [dict(type='LocalVisBackend')] visualizer = dict( type='TextRecogLocalVisualizer', name='visualizer', vis_backends=[dict(type='LocalVisBackend')]) optim_wrapper = dict( type='OptimWrapper', optimizer=dict(type='Adam', lr=0.0003)) train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=5, val_interval=1) val_cfg = dict(type='ValLoop') test_cfg = dict(type='TestLoop') param_scheduler = [dict(type='MultiStepLR', milestones=[3, 4], end=5)] file_client_args = dict(backend='disk') dictionary = dict( type='Dictionary', dict_file='dicts.txt', with_padding=True, with_unknown=True, same_start_end=True, with_start=True, with_end=True) model = dict( type='SATRN', backbone=dict(type='ShallowCNN', input_channels=3, hidden_dim=512), encoder=dict( type='SATRNEncoder', n_layers=12, n_head=8, d_k=64, d_v=64, d_model=512, n_position=150, d_inner=2048, dropout=0.1), decoder=dict( type='NRTRDecoder', n_layers=6, d_embedding=512, n_head=8, d_model=512, d_inner=2048, d_k=64, d_v=64, module_loss=dict( type='CEModuleLoss', flatten=True, ignore_first_char=True), dictionary=dictionary, max_seq_len=25, postprocessor=dict(type='AttentionPostprocessor')), data_preprocessor=dict( type='TextRecogDataPreprocessor', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375])) train_pipeline = [ dict( type='LoadImageFromFile', file_client_args=dict(backend='disk'), ignore_empty=True, min_size=2), dict(type='LoadOCRAnnotations', with_text=True), dict(type='Resize', scale=(150, 32), keep_ratio=False), dict( type='RandomApply', prob=0.5, transforms=[ dict( type='RandomChoice', transforms=[ dict( type='RandomRotate', max_angle=5, ), ]) ], ), dict( type='RandomApply', prob=0.25, transforms=[ dict(type='PyramidRescale'), dict( type='mmdet.Albu', transforms=[ dict(type='GaussNoise', var_limit=(20, 20), p=0.5), dict(type='MotionBlur', blur_limit=5, p=0.5), ]), ]), dict( type='RandomApply', prob=0.25, transforms=[ dict( type='TorchVisionWrapper', op='ColorJitter', brightness=0.5, saturation=0.5, contrast=0.5, hue=0.1), ]), dict( type='PackTextRecogInputs', meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio')) ] test_pipeline = [ dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')), dict(type='Resize', scale=(150, 32), keep_ratio=False), dict(type='LoadOCRAnnotations', with_text=True), dict( type='PackTextRecogInputs', meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio')) ] train_dataset = dict( type='ConcatDataset', datasets=[train0, train1], pipeline=train_pipeline) test_dataset = dict( type='ConcatDataset', datasets=[test], pipeline=test_pipeline) train_dataloader = dict( batch_size=16, num_workers=16, persistent_workers=True, sampler=dict(type='DefaultSampler', shuffle=True), dataset=train_dataset) test_dataloader = dict( batch_size=1, num_workers=4, persistent_workers=True, drop_last=False, sampler=dict(type='DefaultSampler', shuffle=False), dataset=test_dataset) val_dataloader = dict( batch_size=1, num_workers=4, persistent_workers=True, drop_last=False, sampler=dict(type='DefaultSampler', shuffle=False), dataset=test_dataset) auto_scale_lr = dict(base_batch_size=512)
satrn_train.py
import os from mmengine.config import Config from mmengine.registry import RUNNERS from mmengine.runner import Runner def main(): cfg = Config.fromfile('satrn_japanese_cfg.py') os.makedirs('satrn_output', exist_ok=True) #### ## modify configuration file #### # Set output dir cfg.work_dir = 'satrn_output' # Modify cuda setting cfg.gpu_ids = range(1) cfg.device = 'cuda' # Others cfg.train_cfg.max_epochs = 3 # default 5 cfg.default_hooks.logger.interval = 2000 # Build the runner from config if 'runner_type' not in cfg: # build the default runner runner = Runner.from_cfg(cfg) else: # build customized runner from the registry # if 'runner_type' is set in the cfg runner = RUNNERS.build(cfg) # Start training runner.train() if __name__ == '__main__': main()
結果
[4]ヒルドイドソフト軟膏0.3% [3]メサデルム軟膏0.1% [2]ロコイド軟膏0.1% [1]ザイザルシロップ0.05%
完璧に薬名のみ抽出しています。大成功です。
テキスト検出モデルはMMOCRで公開されている学習済みモデル(textsnake_resnet50_fpn-unet_1200e_ctw1500)を使用しています。
下から上に検出する仕様のようです。
from mmocr.ocr import MMOCR from mim.commands.download import download import os import sys import numpy as np import cv2 img = sys.argv[1] os.makedirs('models', exist_ok=True) det_checkpoint_name = 'textsnake_resnet50_fpn-unet_1200e_ctw1500' checkpoint = download(package='mmocr', configs=[det_checkpoint_name], dest_root="models") config_paths =os.path.join('models', det_checkpoint_name + '.py') checkpoint_paths = os.path.join('models', checkpoint[0]) det_model = MMOCR( det_config = config_paths, det_ckpt = checkpoint_paths, recog = None, device = 'cuda' ) recog_cfg = 'satrn_output/satrn_japanese_cfg.py' recog_checkpoint = 'satrn_output/epoch_3.pth' recog_model = MMOCR( det = None, recog_config = recog_cfg, recog_ckpt = recog_checkpoint, device = 'cuda' ) det_result = det_model.readtext(img) # -> dict(key:['det_polygons', 'det_scores']) polygons = det_result['det_polygons'] # -> list (len: number of bboxes) original_image = cv2.imread(img) for each_array in polygons: poly = np.array(each_array).reshape(-1, 1, 2).astype(np.float32) x, y, width, height = cv2.boundingRect(poly) trim_image = original_image[y:y+height, x:x+width, :] recog_result = recog_model.readtext(trim_image, print_result=False, show=False) score = np.mean(np.array(recog_result['rec_scores'])) if score > 0.95: print(recog_result['rec_texts'][0])