【GluonCV】胸部X線写真に写っている肺をセグメンテーション

今回使用したデータ

miniJSRT_database | 日本放射線技術学会 画像部会から「Segmentation > >Segmentation01(256×256,RGB Color:24bit)」をダウンロードさせて頂きました。
学習データ50画像、テストデータ10画像です。
学習データが少ないですがAugmentationで増やせということでしょうか?

学習スクリプト

ラベル画像に影響しないAugmentation(RandomColorJitterとRandomLighting)だけ使用しました。
回転や上下反転、左右反転などはラベル画像にも変換を加える必要があるため使用しませんでした。
学習率はデフォルトのままです。
auxiliary lossを使用しています。

import mxnet as mx
from mxnet import gluon, image, autograd
from mxnet.gluon.data.vision import transforms
import gluoncv

ctx = mx.gpu() if mx.context.num_gpus() >0 else mx.cpu()

jitter_param = 0.4
lighting_param = 0.1
input_transform = transforms.Compose([
    transforms.RandomColorJitter(brightness=jitter_param, contrast=jitter_param,
                                 saturation=jitter_param),
    transforms.RandomLighting(lighting_param),
    transforms.ToTensor(),
    transforms.Normalize([.485, .456, .406], [.229, .224, .225]),
])

img_list = []
label_list = []

for i in range(1, 51):
    img = image.imread('seg_image/train/org/%d.png'%i)
    img_list.append(img)

    label = image.imread('seg_image/train/label/%d.png'%i, flag=0)
    label = label/255
    label_list.append(mx.nd.squeeze(label))

train_dataset = gluon.data.dataset.ArrayDataset(img_list, label_list)

train_dataloader = gluon.data.DataLoader(
    train_dataset.transform_first(input_transform), batch_size=4 , shuffle=True)

model = gluoncv.model_zoo.FCN(
    root = './models',
    nclass = 2, 
    backbone = 'resnet50', 
    aux = True, 
    ctx = ctx, 
    pretrained_base=True, 
    crop_size=256)

criterion = gluoncv.loss.MixSoftmaxCrossEntropyLoss(aux=True)

trainer = gluon.Trainer(model.collect_params(), 'sgd')

epochs = 50

for epoch in range(1, epochs + 1):

    train_loss = 0.0
    data_count = 0

    for i, (data, target) in enumerate(train_dataloader):
        with autograd.record():
            outputs = model(data.as_in_context(ctx))
            losses = criterion(outputs[0], outputs[1], target.as_in_context(ctx))
            losses.backward()
        trainer.step(data.shape[0])

        data_count += data.shape[0]
        train_loss += mx.nd.sum(losses).asscalar()
        
        print('Epoch %d, batch %d[%d/%d], training loss %.3f'%(epoch, i+1, data_count, len(train_dataset), train_loss/data_count))

model.save_parameters('seg.params')

結果

Epoch 1, batch 1[4/50], training loss 1.014
Epoch 1, batch 2[8/50], training loss 0.873
Epoch 1, batch 3[12/50], training loss 0.804
Epoch 1, batch 4[16/50], training loss 0.767
Epoch 1, batch 5[20/50], training loss 0.743
・
・
・
Epoch 50, batch 9[36/50], training loss 0.104
Epoch 50, batch 10[40/50], training loss 0.103
Epoch 50, batch 11[44/50], training loss 0.102
Epoch 50, batch 12[48/50], training loss 0.102
Epoch 50, batch 13[50/50], training loss 0.102

テスト画像を用いた評価

import glob
import numpy as np
from PIL import Image
import mxnet as mx
import gluoncv
from mxnet import image
from mxnet.gluon.data.vision import transforms

ctx = mx.gpu() if mx.context.num_gpus() >0 else mx.cpu()

input_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([.485, .456, .406], [.229, .224, .225]),
])

model = gluoncv.model_zoo.FCN(
    root = './models',
    nclass = 2, 
    backbone = 'resnet50', 
    aux = True, 
    ctx = ctx, 
    pretrained_base=True, 
    crop_size=256)

model.load_parameters('seg.params')

img_files = glob.glob('./seg_image/test/org/*.png')

for i, img_file in enumerate(img_files):
    img = image.imread(img_file)
    img_PIL = Image.fromarray(img.asnumpy())

    img = input_transform(img)

    output = model.predict(img.expand_dims(0).as_in_context(ctx))

    output = mx.nd.squeeze(output)

    prediction = np.argmax(output.asnumpy(), axis=0).astype('int8')
    prediction = prediction * 255

    mask = Image.fromarray(prediction).convert('RGB')
    mask.putalpha(60)
    img_PIL.paste(mask, mask)

    img_PIL.save('result_%d.png'%i)


f:id:touch-sp:20210205223024p:plain
f:id:touch-sp:20210205223038p:plain
f:id:touch-sp:20210205223054p:plain
f:id:touch-sp:20210205223118p:plain
f:id:touch-sp:20210205223131p:plain
f:id:touch-sp:20210205223140p:plain
f:id:touch-sp:20210205223157p:plain
f:id:touch-sp:20210205223207p:plain
f:id:touch-sp:20210205223216p:plain
f:id:touch-sp:20210205223225p:plain

感想

たった50画像を学習しただけですがなかなかの結果になったと思います。

動作環境

Windows 10 with NVIDIA GTX 1080
Python 3.7.9

atomicwrites==1.4.0
attrs==20.3.0
autocfg==0.0.6
autogluon.core==0.0.16b20210204
autograd==1.3
bcrypt==3.2.0
boto3==1.17.1
botocore==1.20.1
certifi==2020.12.5
cffi==1.14.4
chardet==3.0.4
click==7.1.2
cloudpickle==1.6.0
colorama==0.4.4
ConfigSpace==0.4.16
cryptography==3.3.1
cycler==0.10.0
Cython==0.29.21
dask==2021.1.1
decord==0.4.2
dill==0.3.3
distributed==2021.1.1
future==0.18.2
gluoncv==0.9.2
graphviz==0.8.4
HeapDict==1.0.1
idna==2.6
importlib-metadata==3.4.0
iniconfig==1.1.1
jmespath==0.10.0
joblib==1.0.0
kiwisolver==1.3.1
matplotlib==3.3.4
msgpack==1.0.2
mxnet-cu102==1.7.0
numpy==1.16.6
opencv-python==4.5.1.48
packaging==20.9
pandas==1.2.1
paramiko==2.7.2
Pillow==8.1.0
pluggy==0.13.1
portalocker==2.2.1
protobuf==3.14.0
psutil==5.8.0
py==1.10.0
pyaml==20.4.0
pycparser==2.20
PyNaCl==1.4.0
pyparsing==2.4.7
pytest==6.2.2
python-dateutil==2.8.1
pytz==2021.1
pywin32==300
PyYAML==5.4.1
requests==2.25.1
s3transfer==0.3.4
scikit-learn==0.23.2
scikit-optimize==0.8.1
scipy==1.4.1
six==1.15.0
sortedcontainers==2.3.0
tblib==1.7.0
tensorboardX==2.1
threadpoolctl==2.1.0
toml==0.10.2
toolz==0.11.1
tornado==6.1
tqdm==4.56.0
typing-extensions==3.7.4.3
urllib3==1.26.3
yacs==0.1.8
zict==2.0.0
zipp==3.4.0

2021年2月7日追記

さらに良い結果を求めたのがこちらになります。
touch-sp.hatenablog.com

このエントリーをはてなブックマークに追加