MENU

敵ロボット検出スクリプトをGitHubで公開します

以前Autogluonを使ったObject Detectionでロボットパネルを検出する課題に取り組みました。
touch-sp.hatenablog.com

学習データ作成のために動画を撮影しました。
その動画ファイルを含めてスクリプトGitHubで公開します。
github.com
良かったら使用してみて下さい。
簡単に再現できます。

Pandasのfactorizeが便利すぎる!

はじめに

最近Pandasを勉強しています。
factorizeというのを初めて知りました。
AutoGluonを使ったObject Detectionでは結果がPandas DataFrameで返ってきます。その時にfactorizeが非常に役に立ちます。

   predict_class  predict_score                                       predict_rois
0           head       0.988660  {'xmin': 0.3798920810222626, 'ymin': 0.2936957...  
1           head       0.981958  {'xmin': 0.16018624603748322, 'ymin': 0.251445...  
2           head       0.964868  {'xmin': 0.7895624041557312, 'ymin': 0.1375325...  
3           head       0.962894  {'xmin': 0.5869008302688599, 'ymin': 0.2351982...  
4           head       0.299219  {'xmin': 0.5046889185905457, 'ymin': 0.4765846...  
..           ...            ...                                                ...  
84          head       0.058102  {'xmin': 0.1765645295381546, 'ymin': 0.3947232...  
85          head       0.057993  {'xmin': 0.8565279245376587, 'ymin': 0.1774100...  
86          head       0.057855  {'xmin': 0.2825169563293457, 'ymin': 0.3654884...  
87          head       0.057760  {'xmin': 0.2772541344165802, 'ymin': 0.5180094...  
88          head       0.057700  {'xmin': 0.06226371228694916, 'ymin': 0.014346...  

以前の方法

class_names = list(set(selected_result['predict_class']))
class_ids = np.array([class_names.index(i) for i in list(selected_result['predict_class'])])

新たに習得した方法

class_ids , class_names = selected_result['predict_class'].factorize()

これで以前の方法と同じ結果が得られます。返ってくるclass_namesがリストかnumpy arrayかの違いはありますが。

感想

いままでなんとなくスクリプト書いていましたがPandas、Numpy、pillow、matplotlibなど基本から勉強するべきと痛感しました。

One-Hotベクトル化はPandasのget_dummiesを使えば簡単だった!

はじめに

最近Pandasを勉強しています。
get_dummiesを使えばOne-Hotベクトル化が簡単に実現できました。

以前の方法

GluonTS 0.6.4 が公開された - パソコン関連もろもろ
わざわざ関数を定義してOne-Hotベクトル化を行っていました。

def one_hot(x, start_zero = True):

    if not start_zero:
        x = x-1 
    category_n = x.max() + 1
    one_hot_vec = np.identity(category_n)[x]

    return one_hot_vec.transpose(1,0)

feat4 = one_hot(df.weekday, start_zero=True)

新たに習得した方法

feat4 = np.array(pd.get_dummies(df.weekday)).transpose((1,0))

これで以前の方法と同じ結果が得られます。
また、データフレームに文字列が入っていてもそのまま実行可能です。

感想

いままでなんとなくスクリプト書いていましたがPandas、Numpy、pillow、matplotlibなど基本から勉強するべきと痛感しました。

WSL2で画像やグラフの表示

はじめに

WSL2では今のところGUIが標準では使えないので画像やグラフの表示ができないと思っていました。
しかし、VS code 内でJupyter Notebookを使用すると簡単にできました。

結果

f:id:touch-sp:20210220152525p:plain:w400
pillowを使って表示
f:id:touch-sp:20210220152545p:plain:w400
matplotlibを使って表示

(画像はクリックすると拡大できます)

方法

VS codeのインストールはWindows環境だけで問題ありません。

  • WSL2内のPython環境にnotebookをインストール(pipで可能)
pip install notebook


これだけでVS code内でJupyter Notebookが使えるようになり画像やグラフが表示できました。

【GluonCV】【物体検出】胸部レントゲンの結節影を検出せよ!

はじめに

今回は前回の続きです。
touch-sp.hatenablog.com
肺の結節影を検出することにチャレンジしました。
前回ダウンロードさせて頂いた154枚の結節影を有する胸部レントゲン写真を使いました。
(学習用:146枚、テスト用:8枚)
非常に少ないですがそれしか手に入らないので仕方がありません。

結果

先に結果を示します。
良い結果は得られませんでした。今後の課題です。
左画像が正解画像、右画像が今回の学習で検出したものです。
8枚のうち正しく検出できたのは2枚だけでした。1枚は違う部位を結節と認識、残りは未検出です。
f:id:touch-sp:20210213124039p:plainf:id:touch-sp:20210213124111p:plain
f:id:touch-sp:20210213124124p:plainf:id:touch-sp:20210213124133p:plain
f:id:touch-sp:20210213124144p:plainf:id:touch-sp:20210213124152p:plain
f:id:touch-sp:20210213124203p:plainf:id:touch-sp:20210213124214p:plain
f:id:touch-sp:20210213124224p:plainf:id:touch-sp:20210213124231p:plain
f:id:touch-sp:20210213124244p:plainf:id:touch-sp:20210213124254p:plain
f:id:touch-sp:20210213124302p:plainf:id:touch-sp:20210213124312p:plain
f:id:touch-sp:20210213124319p:plainf:id:touch-sp:20210213124326p:plain

方法

以下に方法を示します。

学習データとテストデータに分ける

from sklearn.model_selection import train_test_split
import pandas as pd 

data = pd.read_csv('CLNDAT_EN.TXT', sep='\t', header=None)

train, test = train_test_split(data, train_size=0.95)

train.to_csv('train.csv')
test.to_csv('test.csv')

'''
読み込む時は
train_data = pd.read_csv('train.csv', index_col=0)
test_data = pd.read_csv('test.csv', index_col=0)
'''

データ標準化のために平均、標準偏差を求める

この部分に限り、今回使用する画像に加えて結節影がうつっていない正常画像93枚も使用しています。

import numpy as np
import pydicom
import glob

all_array = []

all_path = glob.glob('./Nodule154images/*.dcm')

for path in all_path:

    a = pydicom.read_file(path)
    b = a.pixel_array

    all_array.append(b)

all_path = glob.glob('./NonNodule93images/*.dcm')

for path in all_path:

    a = pydicom.read_file(path)
    b = a.pixel_array

    all_array.append(b)

all_np = np.stack(all_array)/4095

mean = np.mean(all_np)
std = np.std(all_np)

mean_std = {'mean': mean, 'std': std}

import pickle
with open('mean_std.pkl', 'wb') as f:
    pickle.dump(mean_std, f)

Data Augmentation

以下のスクリプトを「grey_augmentation.py」という名前で保存しました。

import numpy as np
import mxnet as mx
from gluoncv.data.transforms import image as timage
from gluoncv.data.transforms import bbox as tbbox

def random_augmentation(src, label, brightness_delta=512, contrast_low=0.5, contrast_high=1.5, img_mean=0.5, img_std=0.25):
    
    """
    Note that input image should in original range [0, 255].

    Parameters
    ----------
    src : mxnet.nd.NDArray
        Input image as HWC format.
    label : numpy.ndarray
        shape N x 6 (N = number of bbox)
    brightness_delta : int
        Maximum brightness delta. Defaults to 32.
    contrast_low : float
        Lowest contrast. Defaults to 0.5.
    contrast_high : float
        Highest contrast. Defaults to 1.5.

    Returns
    -------
    mxnet.nd.NDArray
        Distorted image in HWC format.
    numpy.ndarray
        new bounding box

    """

    def brightness(src, delta):
        if np.random.uniform(0, 1) > 0.5:
            delta = np.random.uniform(-delta, delta)
            src += delta
            src.clip(0, 4095)
        return src

    def contrast(src, low, high):
        if np.random.uniform(0, 1) > 0.5:
            alpha = np.random.uniform(low, high)
            src *= alpha
            src.clip(0, 4095)
        return src
    
    def random_flip(src, label):
        h, w, _ = src.shape
        src, flips = timage.random_flip(src, px=0.5, py=0.5)
        new_bbox = tbbox.flip(label, (w, h), flip_x=flips[0], flip_y=flips[1])
        return src, new_bbox

    src = src.astype('float32')

    src = brightness(src, brightness_delta)
    src = contrast(src, contrast_low, contrast_high)
    src, label = random_flip(src, label)

    src = src/4095
    src = (src-img_mean)/img_std
    src = src.transpose((2,0,1))
    src = mx.nd.concat(src, src, src, dim=0)
    
    return src, label

学習

import numpy as np
import os
import pickle
import time
import pandas as pd
import pydicom
import mxnet as mx
from mxnet import gluon, autograd

from gluoncv import model_zoo
from gluoncv.loss import SSDMultiBoxLoss

from grey_augmentation import random_augmentation

from gluoncv.model_zoo.ssd.target import SSDTargetGenerator
target_generator = SSDTargetGenerator(
                iou_thresh=0.5, stds=(0.1, 0.1, 0.2, 0.2), negative_mining_ratio=-1)

ctx = [mx.gpu()]

classes = ['nodule']
net = model_zoo.get_model('ssd_512_mobilenet1.0_voc', pretrained=True, ctx = ctx[0], root='./models')
net.reset_class(classes)
net.hybridize()

x = mx.nd.zeros(shape=(1, 3, 512, 512),ctx=ctx[0])
with autograd.train_mode():
    _, _, anchors = net(x)

anchors = anchors.as_in_context(mx.cpu())

with open('mean_std.pkl', 'rb') as f:
    mean_std = pickle.load(f)

img_mean = mean_std['mean']
img_std = mean_std['std']

csv_data = pd.read_csv('train.csv', index_col=0)

def make_train_dataset():

    img_list = []
    cls_list = []
    box_list = []

    for i in range(len(csv_data)):

        file_name = os.path.splitext(csv_data.iloc[i,0])[0] + '.dcm'

        img = pydicom.read_file(os.path.join('Nodule154images', file_name))
        img = img.pixel_array

        position_x = csv_data.iloc[i,5]
        position_y = csv_data.iloc[i,6]
        size = csv_data.iloc[i,2]
        nodule_size = int(size/(0.175*2))

        x_min = position_x - nodule_size
        x_max = position_x + nodule_size

        y_min = position_y - nodule_size
        y_max = position_y + nodule_size

        if position_x < 1024:
            crop_xmin = np.random.randint(max(3, x_max-450), x_min-3)
        else:
            crop_xmin = np.random.randint(x_max+3, min(2044, x_min+450)) - 512
        
        if position_y < 1024:

            crop_ymin = np.random.randint(max(3, y_max-450), y_min-3)
        else:
            crop_ymin = np.random.randint(y_max+3, min(2044, y_min+450)) - 512 

        x_min -= crop_xmin
        x_max -= crop_xmin
        y_min -= crop_ymin
        y_max -= crop_ymin

        img = img[crop_ymin:crop_ymin+512, crop_xmin:crop_xmin+512]

        img = mx.nd.array(img).expand_dims(2)
        label = np.array([x_min, y_min, x_max, y_max, 0, 0]).reshape((1, -1)).astype('float32')

        img, label = random_augmentation(img, label, img_mean=img_mean, img_std=img_std)

        gt_bboxes = mx.nd.array(np.expand_dims(label[:,:4], 0))
        gt_ids = mx.nd.array(np.expand_dims(label[:,4:5], 0))
        cls_targets, box_targets, _ = target_generator(anchors, None, gt_bboxes, gt_ids)

        img_list.append(img)
        cls_list.append(cls_targets[0])
        box_list.append(box_targets[0])
        
    return img_list, cls_list, box_list

#hyperparameters
train_epochs = 100
train_batch = 4
num_workers = 0

mbox_loss = SSDMultiBoxLoss()
ce_metric = mx.metric.Loss('CrossEntropy')
smoothl1_metric = mx.metric.Loss('SmoothL1')

trainer = gluon.Trainer(
    net.collect_params(), 'sgd',
    {'learning_rate': 0.001, 'wd': 0.0005, 'momentum': 0.9})

for epoch in range(train_epochs):

    img_list, cls_list, box_list = make_train_dataset()
    epoch_dataset = gluon.data.dataset.ArrayDataset(img_list, cls_list, box_list)
    train_loader = gluon.data.DataLoader(epoch_dataset, batch_size=train_batch, 
            shuffle=True, last_batch='rollover', num_workers=num_workers)

    ce_metric.reset()
    smoothl1_metric.reset()
    
    tic = time.time()
    btic = time.time()
    
    for i, batch in enumerate(train_loader):
        batch_size = batch[0].shape[0]
        data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0)
        cls_targets = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0)
        box_targets = gluon.utils.split_and_load(batch[2], ctx_list=ctx, batch_axis=0)
        with autograd.record():
            cls_preds = []
            box_preds = []
            for x in data:
                cls_pred, box_pred, _ = net(x)
                cls_preds.append(cls_pred)
                box_preds.append(box_pred)
            sum_loss, cls_loss, box_loss = mbox_loss(
                cls_preds, box_preds, cls_targets, box_targets)
            autograd.backward(sum_loss)
        trainer.step(1)
        ce_metric.update(0, [l * batch_size for l in cls_loss])
        smoothl1_metric.update(0, [l * batch_size for l in box_loss])
        name1, loss1 = ce_metric.get()
        name2, loss2 = smoothl1_metric.get()
        if i % 20 == 0:
            print('[Epoch {}][Batch {}], Speed: {:.3f} samples/sec, {}={:.3f}, {}={:.3f}'.format(
                epoch, i, batch_size/(time.time()-btic), name1, loss1, name2, loss2))

net.save_parameters('ssd_512_mobilenet1.0_nodule_%d.params'%train_epochs)

テストデータを変形

テストデータはモデルに代入するために変形が必要です。縮小ではなく切り取りを行いました。切り取りの際には結節影が必ず含まれるように工夫しています。何度も使えるようにpickleで保存しています。

import numpy as np
import os
import pickle
import pandas as pd
import pydicom

csv_data = pd.read_csv('test.csv', index_col=0)

def make_test_dataset():

    original_img_list = []
    test_label_list = []

    for i in range(len(csv_data)):

        file_name = os.path.splitext(csv_data.iloc[i,0])[0] + '.dcm'

        img = pydicom.read_file(os.path.join('Nodule154images', file_name))
        img = img.pixel_array

        position_x = csv_data.iloc[i,5]
        position_y = csv_data.iloc[i,6]
        size = csv_data.iloc[i,2]
        nodule_size = int(size/(0.175*2))

        x_min = position_x - nodule_size
        x_max = position_x + nodule_size

        y_min = position_y - nodule_size
        y_max = position_y + nodule_size

        if position_x < 1024:
            crop_xmin = np.random.randint(max(3, x_max-450), x_min-3)
        else:
            crop_xmin = np.random.randint(x_max+3, min(2044, x_min+450)) - 512
        
        if position_y < 1024:

            crop_ymin = np.random.randint(max(3, y_max-450), y_min-3)
        else:
            crop_ymin = np.random.randint(y_max+3, min(2044, y_min+450)) - 512 

        x_min -= crop_xmin
        x_max -= crop_xmin
        y_min -= crop_ymin
        y_max -= crop_ymin

        img = img[crop_ymin:crop_ymin+512, crop_xmin:crop_xmin+512]
        
        label = np.array([x_min, y_min, x_max, y_max])

        original_img_list.append(img)
        test_label_list.append(label)
        
    return original_img_list, test_label_list

original_img_list, label_list = make_test_dataset()

with open('test_data.pkl', 'wb') as f:
    pickle.dump([original_img_list, label_list], f)

テストデータに対してモデルを適用

import os
import pickle
import mxnet as mx
from PIL import Image, ImageDraw

from gluoncv import model_zoo

os.mkdir('result')

ctx = mx.gpu()

classes = ['nodule']
net = model_zoo.get_model('ssd_512_mobilenet1.0_voc', pretrained=True, ctx = ctx, root='./models')
net.reset_class(classes)
net.load_parameters('ssd_512_mobilenet1.0_nodule.params')
net.hybridize()

with open('mean_std.pkl', 'rb') as f:
    mean_std = pickle.load(f)

with open('test_data.pkl', 'rb') as f:
    test_dataset = pickle.load(f)

img_mean = mean_std['mean']
img_std = mean_std['std']

for i in range(len(test_dataset[0])):

    original_img = test_dataset[0][i]
    label = test_dataset[1][i]

    src = mx.nd.array(original_img).expand_dims(2)
    src = src/4095
    src = (src-img_mean)/img_std
    src = src.transpose((2,0,1))
    src = mx.nd.concat(src, src, src, dim=0)
        
    #正解表示
    img = Image.fromarray((original_img/16).astype('uint8'), mode='L')
    draw = ImageDraw.Draw(img)
    draw.rectangle(list(label), outline=255, width=4)
    
    img.save(os.path.join('result', 'correct_position_%d.png'%i))

    #検出結果表示
    normalize_img = src.expand_dims(0)
    class_IDs, scores, bounding_boxs = net(normalize_img.as_in_context(ctx))

    result_num = int(mx.nd.sum(scores>0.7).asscalar())

    img = Image.fromarray((original_img/16).astype('uint8'), mode='L')
    draw = ImageDraw.Draw(img)

    for x in range(result_num):
        draw.rectangle([int(x.asscalar()) for x in bounding_boxs[0][x]], outline=255, width=4)

    img.save(os.path.join('result', 'predict_reslt_%d.png'%i))

考察

今回良い結果は得られませんでした。
もう少しどうにかなるのではないかと考えています。
この課題にもう少し取り組んでいこうと思います。いつか良い結果がでればまた記事にします。

このエントリーをはてなブックマークに追加

PythonでDICOMファイルを扱う

はじめに

データは
標準ディジタル画像データベース(DICOM版) | 日本放射線技術学会 画像部会から『Nodule154images (806MB)』をダウンロードさせて頂きました。

Shiraishi J, Katsuragawa S, Ikezoe J, Matsumoto T, Kobayashi T, Komatsu K, Matsui M, Fujita H, Kodera Y, and Doi K.: Development of a digital image database for chest radiographs with and without a lung nodule: Receiver operating characteristic analysis of radiologists’ detection of pulmonary nodules. AJR 174; 71-74, 2000

データは12bitグレースケールですが今回は8bitに落としてJPEGで保存しました。それによって情報の一部は失われてしまいます。
「pydicom」を使用しましたがpipを使って簡単にインストールできました。

JPEGに変換するPythonスクリプト

import pydicom
from PIL import Image
import os
import glob

all_path = glob.glob('./Nodule154images/*.dcm')

os.mkdir('nodule')

for path in all_path:

    file_name = os.path.splitext(os.path.basename(path))[0]

    a = pydicom.read_file(path)

    b = a.pixel_array

    e = Image.fromarray((b/16).astype('uint8'), mode = 'L')

    e.save(os.path.join('nodule', '%s.jpg'%file_name)

病変を表示するPythonスクリプト

DICOMファイルと同時に臨床情報のテキストファイルもダウンロードさせて頂きました。(アカウントの作成が必要です。)

from PIL import Image, ImageDraw
import numpy as np
import os
import pandas as pd

data = pd.read_csv('CLNDAT_EN.TXT', sep='\t', header=None)

i = 2

file_name = os.path.splitext(data.iloc[i,0])[0] + '.jpg'

position_x = data.iloc[i,5]
position_y = data.iloc[i,6]

size = data.iloc[i,2]

nodule_size = size/0.175

half_width = int(nodule_size/2)

img = Image.open(os.path.join('nodule', file_name))
draw = ImageDraw.Draw(img)

draw.rectangle((position_x - half_width, position_y - half_width, position_x + half_width, position_y + half_width), outline=255, width=3)

img.show()

f:id:touch-sp:20210210191941p:plain
病変部位が四角で囲われています。

GluonCVのplot_bboxを使う場合は以下のようになります。

import os
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
from mxnet import image
from gluoncv.utils.viz import plot_bbox

data = pd.read_csv('CLNDAT_EN.TXT', sep='\t', header=None)

i = 2

file_name = os.path.splitext(data.iloc[i,0])[0] + '.jpg'
img = image.imread(os.path.join('nodule', file_name))

position_x = data.iloc[i,5]
position_y = data.iloc[i,6]
size = data.iloc[i,2]
nodule_size = size/0.175
half_width = int(nodule_size/2)

x_min = position_x - half_width
x_max = position_x + half_width

y_min = position_y - half_width
y_max = position_y + half_width

bounding_boxes = np.array([x_min, y_min, x_max, y_max]).reshape(1,-1)
class_ids = np.zeros(shape=(1,1))

plot_bbox(img, bounding_boxes, scores=None, labels=class_ids, class_names=['nodule'])

plt.axis("off")
plt.show()

f:id:touch-sp:20210209132034p:plain:w320

【GluonCV】胸部X線写真に写っている肺をセグメンテーション【改】

はじめに

touch-sp.hatenablog.com
前回からさらに良い結果を求めてスクリプトを書き換えました。

今回使用したデータ

miniJSRT_database | 日本放射線技術学会 画像部会から「Segmentation > >Segmentation01(256×256,RGB Color:24bit)」をダウンロードさせて頂きました。
学習データ50画像、テストデータ10画像です。

結果

左が前回の結果、右が今回の結果です。
f:id:touch-sp:20210205223024p:plainf:id:touch-sp:20210207224655p:plain
f:id:touch-sp:20210205223038p:plainf:id:touch-sp:20210207224725p:plain
f:id:touch-sp:20210205223054p:plainf:id:touch-sp:20210207224736p:plain
f:id:touch-sp:20210205223118p:plainf:id:touch-sp:20210207224745p:plain
f:id:touch-sp:20210205223131p:plainf:id:touch-sp:20210207224754p:plain
f:id:touch-sp:20210205223140p:plainf:id:touch-sp:20210207224805p:plain
f:id:touch-sp:20210205223157p:plainf:id:touch-sp:20210207224815p:plain
f:id:touch-sp:20210205223207p:plainf:id:touch-sp:20210207224827p:plain
f:id:touch-sp:20210205223216p:plainf:id:touch-sp:20210207224835p:plain
f:id:touch-sp:20210205223225p:plainf:id:touch-sp:20210207224848p:plain

学習スクリプト

学習用画像が50枚しかないので「切り取り」と「拡大」のAugmentationを追加しました。
MXNetではfixed_cropを使えば同時に実行できます。
明らかにセグメンテーションがおかしい2枚の画像は排除しました。結果的に学習に使用した画像は48枚です。

import os
import glob
import random
import mxnet as mx
from mxnet import gluon, image, autograd
from mxnet.gluon.data.vision import transforms
import gluoncv

#hyperparameters
epochs = 400
train_path = 'train_image0'

ctx = mx.gpu() if mx.context.num_gpus() >0 else mx.cpu()

jitter_param = 0.4
lighting_param = 0.1
input_transform = transforms.Compose([
    transforms.RandomColorJitter(brightness=jitter_param, contrast=jitter_param,
                                 saturation=jitter_param),
    transforms.RandomLighting(lighting_param),
    transforms.ToTensor(),
    transforms.Normalize([.485, .456, .406], [.229, .224, .225]),
])

all_path = glob.glob(train_path + '/train/label/*.png')
all_img = [os.path.split(x)[1] for x in all_path]

def make_dataset(all_image_file):
    img_list = []
    label_list = []

    for i in all_image_file:

        x  = random.randint(6, 10)
        width = int(256 * x /10)

        y1 = random.randint(0, 256- width)
        y2 = random.randint(0, 256- width)

        img = image.imread(train_path + '/train/org/' + i)
        post_image = image.fixed_crop(img, y1, y2, width, width, size=(256,256))
        img_list.append(post_image)

        label = image.imread(train_path + '/train/label/' + i, flag=0)
        post_label = image.fixed_crop(label, y1, y2, width, width, size=(256,256), interp=0)
        post_label = post_label/255
        label_list.append(mx.nd.squeeze(post_label))

    train_dataset = gluon.data.dataset.ArrayDataset(img_list, label_list)

    return train_dataset

model = gluoncv.model_zoo.FCN(
    root = './models',
    nclass = 2, 
    backbone = 'resnet50', 
    aux = True, 
    ctx = ctx, 
    pretrained_base=True, 
    crop_size=256)

lr_scheduler = gluoncv.utils.LRScheduler('poly', base_lr=0.001,
                                         nepochs=50, 
                                         iters_per_epoch=len(all_img), 
                                         power=0.9)

criterion = gluoncv.loss.MixSoftmaxCrossEntropyLoss(aux=True)

trainer = gluon.Trainer(model.collect_params(), 'sgd',
                          {'lr_scheduler': lr_scheduler,
                           'wd':0.0001,
                           'momentum': 0.9,
                           'multi_precision': True})

for epoch in range(epochs):

    train_dataset = make_dataset(all_img)

    train_dataloader = gluon.data.DataLoader(
        train_dataset.transform_first(input_transform), batch_size=4 , shuffle=True)
    
    train_loss = 0.0
    data_count = 0

    for i, (data, target) in enumerate(train_dataloader):
        with autograd.record():
            outputs = model(data.as_in_context(ctx))
            losses = criterion(outputs[0], outputs[1], target.as_in_context(ctx))
            losses.backward()
        trainer.step(data.shape[0])

        data_count += data.shape[0]
        train_loss += mx.nd.sum(losses).asscalar()
        
        print('Epoch %d, batch %d[%d/%d], training loss %.3f'%(epoch+1, i+1, data_count, len(train_dataset), train_loss/data_count))

model.save_parameters('seg.params')


このエントリーをはてなブックマークに追加