U-Netをtensorflow2で実装して、oxford_iiit_petデータセットで画像セグメンテーションをしてみましたので紹介します。
Tensorflow 2.8.0
Python 3.7.13 (default, Mar 16 2022, 17:37:17) \n[GCC 7.5.0]
U-Net
U-Netは、downsamplingとupsamplingからなるU字型をしたネットワークです。医療画像など、サンプル数が少ない問題での画像セグメンテーションで効果的です。
U-Net: Convolutional Networks for Biomedical Image Segmentation
Olaf Ronneberger, Philipp Fischer, Thomas Brox
Medical Image Computing and Computer-Assisted Intervention (MICCAI), Springer, LNCS, Vol.9351: 234–241, 2015, available at arXiv:1505.04597
Tensorflow 2で実装
Tensorflowで実装しました。
論文と異なる点
- 入力画像サイズを128x128x3と小さくした。そのためネットワークの深さは浅めに。
- upsamplingではConv2Dではなく、Conv2DTransposeを使って対称的なネットワークにした。
- upsampingで画像はcropせずに、paddingを適宜いれて画像サイズもあわせて、そのままconcatenate()する。
- パラメータは全部フリー。Trainable params: 5,460,163
試す
データセットoxford_iiit_pet/3.2.0で試しました。猫よりも犬のサンプルのが多いのが残念な感じです。
import tensorflow_datasets as tfds
dataset, info = tfds.load('oxford_iiit_pet:3.*.*', with_info=True)
未学習のネットワークにしては、それなりに犬猫を認識してフィットできている感じです。
トレーニング前
10 epoch トレーニング後
73 epoch後
予測
accuracy
57/57 [==============================] – 32s 556ms/step – loss: 0.3852 – accuracy: 0.8417 – val_loss: 0.3825 – val_accuracy: 0.8431
gpuはgoogle colabのT4です。
追加で63 epoch走らせました。
20~30 epoch以降は過学習がみられます。
layers.BatchNormalization()を入れると、過学習は抑えられましたが、どっちにせよ、validationのaccuracyで0.875くらい、lossで0.33くらいがこのネットワークの性能のようです。
コード
downsamplingは次のようにしました。
from tensorflow import keras
from tensorflow.keras import layers
downsampling_input = keras.Input(shape=(128, 128, 3), name="img")
x = layers.Conv2D(64, 3, activation='relu')(downsampling_input)
layer_64 = layers.Conv2D(64, 3, activation='relu')(x)
#layers.BatchNormalization()
x = layers.MaxPooling2D(2)(layer_64)
x = layers.Conv2D(128, 3, activation='relu')(x)
layer_128 = layers.Conv2D(128, 3, activation='relu', padding="same")(x)
x = layers.MaxPooling2D(2)(layer_128)
x = layers.Conv2D(256, 3, activation='relu')(x)
layer_256 = layers.Conv2D(256, 3, activation='relu')(x)
x = layers.MaxPooling2D(2)(layer_256)
downsampling_output = layers.Conv2D(512, 3, activation="relu")(x)
#x = layers.Conv2D(512, 3, activation='relu')(x)
#x = layers.Conv2D(512, 3, activation='relu')(x)
#x = layers.MaxPooling2D(2)(x)
#x = layers.Conv2D(1024, 3, activation='relu')(x)
#x = layers.Conv2D(1024, 3, activation='relu')(x)
#encoder_output = layers.MaxPooling2D(2)(x)
downsampling = keras.Model(downsampling_input, downsampling_output, name="downsampling")
downsampling.summary()
upsamplingは次のようにしました。
x = layers.Conv2DTranspose(256, 3, activation="relu")(downsampling_output)
x = layers.UpSampling2D(2)(x)
x = layers.concatenate([x, layer_256])
x = layers.Conv2DTranspose(256, 3, activation="relu")(x)
x = layers.Conv2DTranspose(128, 3, activation="relu")(x)
x = layers.UpSampling2D(2)(x)
x = layers.concatenate([x, layer_128])
x = layers.Conv2DTranspose(128, 3, activation="relu", padding="same")(x)
x = layers.Conv2DTranspose(64, 3, activation="relu")(x)
x = layers.UpSampling2D(2)(x)
x = layers.concatenate([x, layer_64])
x = layers.Conv2DTranspose(64, 3, activation="relu")(x)
x = layers.Conv2DTranspose(64, 3, activation="relu")(x)
OUTPUT_CLASSES = 3
upsampling_output = layers.Conv2D(OUTPUT_CLASSES, 1, activation="relu")(x) #conv1x1
unet = keras.Model(downsampling_input, upsampling_output, name="unet")
unet.summary()