公開日:2022年10月25日
最終更新日:2023年7月18日
PC環境
Windows 11 CUDA 11.6.2 Python 3.10.7
Python環境構築
pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116 pip install mmcv-full==1.7.1 -f https://download.openmmlab.com/mmcv/dist/cu116/torch1.13.0/index.html pip install -r https://raw.githubusercontent.com/dai-ichiro/trainMMOCR/main/new_requirements.txt
学習データの作成
こちらを使わせて頂きました。github.com
インストールは先の環境構築ですでに完了しています。
テキストファイルを準備する
スペースを含まない文字列を記入したテキストファイルを作成します。東京 有楽町 新橋 浜松町 田町 品川 大崎 五反田
辞書ファイルを作成する
辞書ファイルとは使用する文字を羅列したテキストファイルです。反 五 橋 有 楽 浜 川 崎 町 京 松 東 新 品 田 大
最初に準備したテキストファイルが「station.txt」という名前であれば以下のようにすれば辞書ファイル作成が可能です。
with open('station.txt', 'r', encoding='utf-8') as f: all_lines = f.readlines() characters = ''.join(all_lines).replace('\n','') with open('dicts.txt', 'w', encoding='utf-8') as f: f.write('\n'.join(set(characters)))
フォントを準備する
以下のスクリプトで数種類のフォントを準備しました。import os import glob import shutil from torchvision.datasets.utils import download_url os.makedirs('fonts', exist_ok=True) # Copy from Windows 11 font_index = [84, 85, 179, 180, 181] fonts = glob.glob('c://Windows/Fonts/*.ttf') for i in font_index: fname = os.path.basename(fonts[i]) fname = fname.replace('.TTF', '.ttf') shutil.copy(fonts[i], os.path.join('fonts', fname)) # Download from mmocr font_url = 'https://download.openmmlab.com/mmocr/data/font.TTF' font_fname = 'mmocr.ttf' download_url(font_url, root = 'fonts', filename = font_fname) # Download from TextRecognitionDataGenerator font_url = 'https://raw.githubusercontent.com/Belval/TextRecognitionDataGenerator/master/trdg/fonts/ja/TakaoMincho.ttf' font_fname = font_url.split('/')[-1] download_url(font_url, root = 'fonts', filename = font_fname)
ここで出てくる「font_index = [84, 85, 179, 180, 181]」は各自のPCで異なります。
使用できるフォントを見つけるためには以下のスクリプトを実行してみて下さい。
import os import glob from PIL import Image, ImageDraw, ImageFont os.makedirs('testresults', exist_ok=True) fonts = glob.glob('c://Windows/Fonts/*.ttf') text = '「アート」\n【あーと】\n東京駅' for i, _font in enumerate(fonts): im = Image.new("RGB", (256, 256), (255, 255, 255)) draw = ImageDraw.Draw(im) font = ImageFont.truetype(_font, 48) x = 20 y = 20 draw.multiline_text((x, y), text, fill=(0, 0, 255), font=font) im.save(os.path.join('testresults', f'{i}.jpg'))
TextRecognitionDataGeneratorを使用する
trdg -l ja -c 3000 -k 1 -rk -bl 1 -rbl -fd fonts -dt text.txt -na 2 --output_dir train trdg -l ja -c 100 -k 1 -rk -bl 1 -rbl -fd fonts -dt text.txt -na 2 --output_dir test
これで「train」フォルダと「test」フォルダが作成されます。
Configファイルを作成する
以下のようなファイルを作成しました。log_config = dict(interval=500, hooks=[dict(type='TextLoggerHook')]) dist_params = dict(backend='nccl') log_level = 'INFO' load_from = None resume_from = None workflow = [('train', 1)] opencv_num_threads = 0 mp_start_method = 'fork' optimizer = dict(type='Adam', lr=0.000125) optimizer_config = dict(grad_clip=None) lr_config = dict(policy='step', step=[3, 4]) runner = dict(type='EpochBasedRunner', max_epochs=10) checkpoint_config = dict(interval=1) label_convertor = dict( type='AttnConvertor', dict_file='dicts.txt', with_unknown=True, max_seq_len=35) model = dict( type='SARNet', backbone=dict(type='ResNet31OCR'), encoder=dict( type='SAREncoder', enc_bi_rnn=False, enc_do_rnn=0.1, enc_gru=False), decoder=dict( type='ParallelSARDecoder', enc_bi_rnn=False, dec_bi_rnn=False, dec_do_rnn=0, dec_gru=False, pred_dropout=0.1, d_k=512, pred_concat=True), loss=dict(type='SARLoss'), label_convertor=label_convertor) img_norm_cfg = dict(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) train = dict( type='OCRDataset', img_prefix='img', ann_file='train_label.txt', loader=dict( type='HardDiskLoader', repeat=1, parser=dict( type='LineStrParser', keys=['filename', 'text'], keys_idx=[0, 1], separator=' ')), pipeline=None, test_mode=False) test = dict( type='OCRDataset', img_prefix='img', ann_file='test_label.txt', loader=dict( type='HardDiskLoader', repeat=1, parser=dict( type='LineStrParser', keys=['filename', 'text'], keys_idx=[0, 1], separator=' ')), pipeline=None, test_mode=False) data = dict( samples_per_gpu=8, workers_per_gpu=2, val_dataloader=dict(samples_per_gpu=1), test_dataloader=dict(samples_per_gpu=1), train=dict( type='UniformConcatDataset', datasets=[ dict( type='OCRDataset', img_prefix='img', ann_file='train_label.txt', loader=dict( type='HardDiskLoader', repeat=1, parser=dict( type='LineStrParser', keys=['filename', 'text'], keys_idx=[0, 1], separator=' ')), pipeline=None, test_mode=False) ], pipeline=[ dict(type='LoadImageFromFile'), dict( type='ResizeOCR', height=48, min_width=48, max_width=256, keep_aspect_ratio=True, width_downsample_ratio=0.25), dict(type='ToTensorOCR'), dict( type='NormalizeOCR', mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), dict( type='Collect', keys=['img'], meta_keys=[ 'filename', 'ori_shape', 'resize_shape', 'text', 'valid_ratio' ]) ]), val=dict( type='UniformConcatDataset', datasets=[ dict( type='OCRDataset', img_prefix='img', ann_file='test_label.txt', loader=dict( type='HardDiskLoader', repeat=1, parser=dict( type='LineStrParser', keys=['filename', 'text'], keys_idx=[0, 1], separator=' ')), pipeline=None, test_mode=False) ], pipeline=[ dict(type='LoadImageFromFile'), dict( type='MultiRotateAugOCR', rotate_degrees=[0, 90, 270], transforms=[ dict( type='ResizeOCR', height=48, min_width=48, max_width=256, keep_aspect_ratio=True, width_downsample_ratio=0.25), dict(type='ToTensorOCR'), dict( type='NormalizeOCR', mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), dict( type='Collect', keys=['img'], meta_keys=[ 'filename', 'ori_shape', 'resize_shape', 'valid_ratio' ]) ]) ]), test=dict( type='UniformConcatDataset', datasets=[ dict( type='OCRDataset', img_prefix='img', ann_file='test_label.txt', loader=dict( type='HardDiskLoader', repeat=1, parser=dict( type='LineStrParser', keys=['filename', 'text'], keys_idx=[0, 1], separator=' ')), pipeline=None, test_mode=False) ], pipeline=[ dict(type='LoadImageFromFile'), dict( type='MultiRotateAugOCR', rotate_degrees=[0, 90, 270], transforms=[ dict( type='ResizeOCR', height=48, min_width=48, max_width=256, keep_aspect_ratio=True, width_downsample_ratio=0.25), dict(type='ToTensorOCR'), dict( type='NormalizeOCR', mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), dict( type='Collect', keys=['img'], meta_keys=[ 'filename', 'ori_shape', 'resize_shape', 'valid_ratio' ]) ]) ])) evaluation = dict(interval=1, metric='acc')
学習用ファイルを実行する
次のPythonファイルを実行します。import os from mmcv import Config from mmocr.datasets import build_dataset from mmocr.models import build_detector from mmocr.apis import train_detector def main(): cfg = Config.fromfile('SAR_japanese_cfg.py') os.makedirs('sar_output', exist_ok=True) #### ## modify configuration file #### # set output dir cfg.work_dir = 'sar_output' # Path to annotation file cfg.train.ann_file= 'train/labels.txt' cfg.test.ann_file = 'test/labels.txt' # Paht to image folder cfg.train.img_prefix = 'train' cfg.test.img_prefix = 'test' # Modify label_convertor cfg.label_convertor.dict_file='dicts.txt' cfg.label_convertor.max_seq_len = 40 cfg.model.label_convertor = cfg.label_convertor # Modify data cfg.data.train.datasets = [cfg.train] cfg.data.val.datasets = [cfg.test] cfg.data.test.datasets = [cfg.test] # modify cuda setting cfg.gpu_ids = range(1) cfg.device = 'cuda' # Others cfg.optimizer.lr = 0.001 /8 cfg.seed = 0 cfg.runner.max_epochs = 1 # default 5 cfg.data.samples_per_gpu = 16 cfg.log_config.interval = 1000 cfg.dump('new_SAR_cfg.py') # Build dataset datasets = [build_dataset(cfg.data.train)] model = build_detector(cfg.model) model.CLASSES = datasets[0].CLASSES model.init_weights() train_detector(model, datasets, cfg, validate=True) if __name__ == '__main__': main()
補足(WSL2からフォントファイルを取得する)
Ubuntu 20.04 on WSL2からフォントファイル(.ttf)をコピーしたい時は以下のようにすればできます。import os import glob import shutil os.makedirs('/mnt/e/fonts2004', exist_ok=True) fonts = glob.glob('/usr/share/fonts/truetype/*/*.ttf') for font in fonts: fname = os.path.basename(font) shutil.copy(font, os.path.join('/mnt/e/fonts2004', fname))