U-Netをtensorflow 2で実装して、画像セグメンテーションをする。

U-Netをtensorflow2で実装して、oxford_iiit_petデータセットで画像セグメンテーションをしてみましたので紹介します。

Tensorflow 2.8.0
Python 3.7.13 (default, Mar 16 2022, 17:37:17) \n[GCC 7.5.0]

U-Net

U-Netは、downsamplingとupsamplingからなるU字型をしたネットワークです。医療画像など、サンプル数が少ない問題での画像セグメンテーションで効果的です。

 

U-Net: Convolutional Networks for Biomedical Image Segmentation

Olaf Ronneberger, Philipp Fischer, Thomas Brox

Medical Image Computing and Computer-Assisted Intervention (MICCAI), Springer, LNCS, Vol.9351: 234–241, 2015, available at arXiv:1505.04597

Tensorflow 2で実装

Tensorflowで実装しました。

 

論文と異なる点

  • 入力画像サイズを128x128x3と小さくした。そのためネットワークの深さは浅めに。
  • upsamplingではConv2Dではなく、Conv2DTransposeを使って対称的なネットワークにした。
  • upsampingで画像はcropせずに、paddingを適宜いれて画像サイズもあわせて、そのままconcatenate()する。
  • パラメータは全部フリー。Trainable params: 5,460,163

試す

データセットoxford_iiit_pet/3.2.0で試しました。猫よりも犬のサンプルのが多いのが残念な感じです。

 

import tensorflow_datasets as tfds
dataset, info = tfds.load('oxford_iiit_pet:3.*.*', with_info=True)

未学習のネットワークにしては、それなりに犬猫を認識してフィットできている感じです。

トレーニング前

10 epoch トレーニング後

73 epoch後

予測

accuracy

57/57 [==============================] – 32s 556ms/step – loss: 0.3852 – accuracy: 0.8417 – val_loss: 0.3825 – val_accuracy: 0.8431

gpuはgoogle colabのT4です。

 

追加で63 epoch走らせました。

20~30 epoch以降は過学習がみられます。

 

layers.BatchNormalization()を入れると、過学習は抑えられましたが、どっちにせよ、validationのaccuracyで0.875くらい、lossで0.33くらいがこのネットワークの性能のようです。

コード

downsamplingは次のようにしました。

from tensorflow import keras
from tensorflow.keras import layers

downsampling_input = keras.Input(shape=(128, 128, 3), name="img")
x = layers.Conv2D(64, 3, activation='relu')(downsampling_input)
layer_64 = layers.Conv2D(64, 3, activation='relu')(x)
#layers.BatchNormalization()
x = layers.MaxPooling2D(2)(layer_64)

x = layers.Conv2D(128, 3, activation='relu')(x)
layer_128 = layers.Conv2D(128, 3, activation='relu', padding="same")(x)
x = layers.MaxPooling2D(2)(layer_128)

x = layers.Conv2D(256, 3, activation='relu')(x)
layer_256 = layers.Conv2D(256, 3, activation='relu')(x)
x = layers.MaxPooling2D(2)(layer_256)

downsampling_output = layers.Conv2D(512, 3, activation="relu")(x)

#x = layers.Conv2D(512, 3, activation='relu')(x)
#x = layers.Conv2D(512, 3, activation='relu')(x)
#x = layers.MaxPooling2D(2)(x)

#x = layers.Conv2D(1024, 3, activation='relu')(x)
#x = layers.Conv2D(1024, 3, activation='relu')(x)
#encoder_output = layers.MaxPooling2D(2)(x)

downsampling = keras.Model(downsampling_input, downsampling_output, name="downsampling")
downsampling.summary()

upsamplingは次のようにしました。


x = layers.Conv2DTranspose(256, 3, activation="relu")(downsampling_output)
x = layers.UpSampling2D(2)(x)

x = layers.concatenate([x, layer_256])
x = layers.Conv2DTranspose(256, 3, activation="relu")(x)
x = layers.Conv2DTranspose(128, 3, activation="relu")(x)
x = layers.UpSampling2D(2)(x)

x = layers.concatenate([x, layer_128])
x = layers.Conv2DTranspose(128, 3, activation="relu", padding="same")(x)
x = layers.Conv2DTranspose(64, 3, activation="relu")(x)
x = layers.UpSampling2D(2)(x)

x = layers.concatenate([x, layer_64])
x = layers.Conv2DTranspose(64, 3, activation="relu")(x)
x = layers.Conv2DTranspose(64, 3, activation="relu")(x)
OUTPUT_CLASSES = 3
upsampling_output = layers.Conv2D(OUTPUT_CLASSES, 1, activation="relu")(x) #conv1x1

unet = keras.Model(downsampling_input, upsampling_output, name="unet")
unet.summary()

参考