OpenMMLab の MMOCR に日本語の学習をさせたい【v0.6.2】

公開日:2022年10月25日
最終更新日:2023年7月18日

PC環境

Windows 11
CUDA 11.6.2
Python 3.10.7

Python環境構築

pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116
pip install mmcv-full==1.7.1 -f https://download.openmmlab.com/mmcv/dist/cu116/torch1.13.0/index.html
pip install -r https://raw.githubusercontent.com/dai-ichiro/trainMMOCR/main/new_requirements.txt

学習データの作成

こちらを使わせて頂きました。
github.com
インストールは先の環境構築ですでに完了しています。

テキストファイルを準備する

スペースを含まない文字列を記入したテキストファイルを作成します。

東京
有楽町
新橋
浜松町
田町
品川
大崎
五反田

辞書ファイルを作成する

辞書ファイルとは使用する文字を羅列したテキストファイルです。

反
五
橋
有
楽
浜
川
崎
町
京
松
東
新
品
田
大

最初に準備したテキストファイルが「station.txt」という名前であれば以下のようにすれば辞書ファイル作成が可能です。

with open('station.txt', 'r', encoding='utf-8') as f:
    all_lines = f.readlines()
characters = ''.join(all_lines).replace('\n','')
with open('dicts.txt', 'w', encoding='utf-8') as f:
    f.write('\n'.join(set(characters)))

フォントを準備する

以下のスクリプトで数種類のフォントを準備しました。

import os
import glob
import shutil
from torchvision.datasets.utils import download_url

os.makedirs('fonts', exist_ok=True)

# Copy from Windows 11
font_index = [84, 85, 179, 180, 181]
fonts = glob.glob('c://Windows/Fonts/*.ttf')

for i in font_index:
    fname = os.path.basename(fonts[i])
    fname = fname.replace('.TTF', '.ttf')
    shutil.copy(fonts[i], os.path.join('fonts', fname))
   
# Download from mmocr
font_url = 'https://download.openmmlab.com/mmocr/data/font.TTF'
font_fname = 'mmocr.ttf'
download_url(font_url, root = 'fonts', filename = font_fname)

# Download from TextRecognitionDataGenerator
font_url = 'https://raw.githubusercontent.com/Belval/TextRecognitionDataGenerator/master/trdg/fonts/ja/TakaoMincho.ttf'
font_fname = font_url.split('/')[-1]
download_url(font_url, root = 'fonts', filename = font_fname)



ここで出てくる「font_index = [84, 85, 179, 180, 181]」は各自のPCで異なります。
使用できるフォントを見つけるためには以下のスクリプトを実行してみて下さい。

import os
import glob
from PIL import Image, ImageDraw, ImageFont

os.makedirs('testresults', exist_ok=True)

fonts = glob.glob('c://Windows/Fonts/*.ttf')

text = '「アート」\n【あーと】\n東京駅'

for i, _font in enumerate(fonts):

    im = Image.new("RGB", (256, 256), (255, 255, 255))
    draw = ImageDraw.Draw(im)

    font = ImageFont.truetype(_font, 48)

    x = 20
    y = 20

    draw.multiline_text((x, y), text, fill=(0, 0, 255), font=font)
        
    im.save(os.path.join('testresults', f'{i}.jpg'))

TextRecognitionDataGeneratorを使用する

trdg -l ja -c 3000 -k 1 -rk -bl 1 -rbl -fd fonts -dt text.txt -na 2 --output_dir train
trdg -l ja -c 100 -k 1 -rk -bl 1 -rbl -fd fonts -dt text.txt -na 2 --output_dir test



これで「train」フォルダと「test」フォルダが作成されます。

Configファイルを作成する

以下のようなファイルを作成しました。

log_config = dict(interval=500, hooks=[dict(type='TextLoggerHook')])
dist_params = dict(backend='nccl')
log_level = 'INFO'
load_from = None
resume_from = None
workflow = [('train', 1)]
opencv_num_threads = 0
mp_start_method = 'fork'
optimizer = dict(type='Adam', lr=0.000125)
optimizer_config = dict(grad_clip=None)
lr_config = dict(policy='step', step=[3, 4])
runner = dict(type='EpochBasedRunner', max_epochs=10)
checkpoint_config = dict(interval=1)

label_convertor = dict(
    type='AttnConvertor',
    dict_file='dicts.txt',
    with_unknown=True,
    max_seq_len=35)

model = dict(
    type='SARNet',
    backbone=dict(type='ResNet31OCR'),
    encoder=dict(
        type='SAREncoder', enc_bi_rnn=False, enc_do_rnn=0.1, enc_gru=False),
    decoder=dict(
        type='ParallelSARDecoder',
        enc_bi_rnn=False,
        dec_bi_rnn=False,
        dec_do_rnn=0,
        dec_gru=False,
        pred_dropout=0.1,
        d_k=512,
        pred_concat=True),
    loss=dict(type='SARLoss'),
    label_convertor=label_convertor)
img_norm_cfg = dict(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])

train = dict(
    type='OCRDataset',
    img_prefix='img',
    ann_file='train_label.txt',
    loader=dict(
        type='HardDiskLoader',
        repeat=1,
        parser=dict(
            type='LineStrParser',
            keys=['filename', 'text'],
            keys_idx=[0, 1],
            separator=' ')),
    pipeline=None,
    test_mode=False)
test = dict(
    type='OCRDataset',
    img_prefix='img',
    ann_file='test_label.txt',
    loader=dict(
        type='HardDiskLoader',
        repeat=1,
        parser=dict(
            type='LineStrParser',
            keys=['filename', 'text'],
            keys_idx=[0, 1],
            separator=' ')),
    pipeline=None,
    test_mode=False)

data = dict(
    samples_per_gpu=8,
    workers_per_gpu=2,
    val_dataloader=dict(samples_per_gpu=1),
    test_dataloader=dict(samples_per_gpu=1),
    train=dict(
        type='UniformConcatDataset',
        datasets=[
            dict(
                type='OCRDataset',
                img_prefix='img',
                ann_file='train_label.txt',
                loader=dict(
                    type='HardDiskLoader',
                    repeat=1,
                    parser=dict(
                        type='LineStrParser',
                        keys=['filename', 'text'],
                        keys_idx=[0, 1],
                        separator=' ')),
                pipeline=None,
                test_mode=False)
        ],
        pipeline=[
            dict(type='LoadImageFromFile'),
            dict(
                type='ResizeOCR',
                height=48,
                min_width=48,
                max_width=256,
                keep_aspect_ratio=True,
                width_downsample_ratio=0.25),
            dict(type='ToTensorOCR'),
            dict(
                type='NormalizeOCR', mean=[0.5, 0.5, 0.5], std=[0.5, 0.5,
                                                                0.5]),
            dict(
                type='Collect',
                keys=['img'],
                meta_keys=[
                    'filename', 'ori_shape', 'resize_shape', 'text',
                    'valid_ratio'
                ])
        ]),
    val=dict(
        type='UniformConcatDataset',
        datasets=[
            dict(
                type='OCRDataset',
                img_prefix='img',
                ann_file='test_label.txt',
                loader=dict(
                    type='HardDiskLoader',
                    repeat=1,
                    parser=dict(
                        type='LineStrParser',
                        keys=['filename', 'text'],
                        keys_idx=[0, 1],
                        separator=' ')),
                pipeline=None,
                test_mode=False)
        ],
        pipeline=[
            dict(type='LoadImageFromFile'),
            dict(
                type='MultiRotateAugOCR',
                rotate_degrees=[0, 90, 270],
                transforms=[
                    dict(
                        type='ResizeOCR',
                        height=48,
                        min_width=48,
                        max_width=256,
                        keep_aspect_ratio=True,
                        width_downsample_ratio=0.25),
                    dict(type='ToTensorOCR'),
                    dict(
                        type='NormalizeOCR',
                        mean=[0.5, 0.5, 0.5],
                        std=[0.5, 0.5, 0.5]),
                    dict(
                        type='Collect',
                        keys=['img'],
                        meta_keys=[
                            'filename', 'ori_shape', 'resize_shape',
                            'valid_ratio'
                        ])
                ])
        ]),
    test=dict(
        type='UniformConcatDataset',
        datasets=[
            dict(
                type='OCRDataset',
                img_prefix='img',
                ann_file='test_label.txt',
                loader=dict(
                    type='HardDiskLoader',
                    repeat=1,
                    parser=dict(
                        type='LineStrParser',
                        keys=['filename', 'text'],
                        keys_idx=[0, 1],
                        separator=' ')),
                pipeline=None,
                test_mode=False)
        ],
        pipeline=[
            dict(type='LoadImageFromFile'),
            dict(
                type='MultiRotateAugOCR',
                rotate_degrees=[0, 90, 270],
                transforms=[
                    dict(
                        type='ResizeOCR',
                        height=48,
                        min_width=48,
                        max_width=256,
                        keep_aspect_ratio=True,
                        width_downsample_ratio=0.25),
                    dict(type='ToTensorOCR'),
                    dict(
                        type='NormalizeOCR',
                        mean=[0.5, 0.5, 0.5],
                        std=[0.5, 0.5, 0.5]),
                    dict(
                        type='Collect',
                        keys=['img'],
                        meta_keys=[
                            'filename', 'ori_shape', 'resize_shape',
                            'valid_ratio'
                        ])
                ])
        ]))

evaluation = dict(interval=1, metric='acc')

学習用ファイルを実行する

次のPythonファイルを実行します。

import os
from mmcv import Config
from mmocr.datasets import build_dataset
from mmocr.models import build_detector
from mmocr.apis import train_detector

def main():
    cfg = Config.fromfile('SAR_japanese_cfg.py')

    os.makedirs('sar_output', exist_ok=True)

    ####
    ## modify configuration file
    ####

    # set output dir
    cfg.work_dir = 'sar_output'

    # Path to annotation file
    cfg.train.ann_file= 'train/labels.txt'
    cfg.test.ann_file = 'test/labels.txt'

    # Paht to image folder
    cfg.train.img_prefix = 'train'
    cfg.test.img_prefix = 'test'

    # Modify label_convertor
    cfg.label_convertor.dict_file='dicts.txt'
    cfg.label_convertor.max_seq_len = 40
    cfg.model.label_convertor = cfg.label_convertor
    
    # Modify data
    cfg.data.train.datasets = [cfg.train]
    cfg.data.val.datasets = [cfg.test]
    cfg.data.test.datasets = [cfg.test]

    # modify cuda setting
    cfg.gpu_ids = range(1)
    cfg.device = 'cuda'

    # Others
    cfg.optimizer.lr = 0.001 /8
    cfg.seed = 0
    cfg.runner.max_epochs = 1 # default 5 
    cfg.data.samples_per_gpu = 16
    cfg.log_config.interval = 1000

    cfg.dump('new_SAR_cfg.py')

    # Build dataset
    datasets = [build_dataset(cfg.data.train)]

    model = build_detector(cfg.model)
    model.CLASSES = datasets[0].CLASSES
    model.init_weights()

    train_detector(model, datasets, cfg, validate=True)

if __name__ == '__main__':
    main()

補足(WSL2からフォントファイルを取得する)

Ubuntu 20.04 on WSL2からフォントファイル(.ttf)をコピーしたい時は以下のようにすればできます。

import os
import glob
import shutil

os.makedirs('/mnt/e/fonts2004', exist_ok=True)

fonts = glob.glob('/usr/share/fonts/truetype/*/*.ttf')

for font in fonts:
    fname = os.path.basename(font)
    shutil.copy(font, os.path.join('/mnt/e/fonts2004', fname))