Fully Convolutional Network in Keras

最終更新: 2017-03-23 12:47

Fully Convolutional Network in Keras

from keras.models import Model
from keras.layers import Input, Conv2D, Conv2DTranspose
from keras.initializers import Constant
from keras.applications.vgg16 import VGG16, preprocess_input, decode_predictions
from keras.callbacks import ModelCheckpoint
from PIL import Image
import numpy as np
import argparse
import copy
import os


nb_classes = 21
# Bilinear interpolation (reference: https://github.com/warmspringwinds/tf-image-segmentation/blob/master/tf_image_segmentation/utils/upsampling.py)
def bilinear_upsample_weights(factor, number_of_classes):
    filter_size = factor*2 - factor%2
    factor = (filter_size + 1) // 2
    if filter_size % 2 == 1:
        center = factor - 1
    else:
        center = factor - 0.5
    og = np.ogrid[:filter_size, :filter_size]
    upsample_kernel = (1 - abs(og[0] - center) / factor) * (1 - abs(og[1] - center) / factor)
    weights = np.zeros((filter_size, filter_size, number_of_classes, number_of_classes),
                       dtype=np.float32)
    for i in xrange(number_of_classes):
        weights[:, :, i, i] = upsample_kernel
    return weights

def fcn_32s():
    inputs = Input(shape=(None, None, 3))
    vgg16 = VGG16(weights='imagenet', include_top=False, input_tensor=inputs)
    x = Conv2D(filters=nb_classes, 
               kernel_size=(1, 1))(vgg16.output)
    x = Conv2DTranspose(filters=nb_classes, 
                        kernel_size=(64, 64),
                        strides=(32, 32),
                        padding='same',
                        activation='sigmoid',
                        kernel_initializer=Constant(bilinear_upsample_weights(32, nb_classes)))(x)
    model = Model(inputs=inputs, outputs=x)
    for layer in model.layers[:15]:
        layer.trainable = False
    return model

def load_image(path):
    img_org = Image.open(path)
    w, h = img_org.size
    img = img_org.resize(((w//32)*32, (h//32)*32))
    img = np.array(img, dtype=np.float32)
    x = np.expand_dims(img, axis=0)
    x = preprocess_input(x)
    return x

def load_label(path):
    img_org = Image.open(path)
    w, h = img_org.size
    img = img_org.resize(((w//32)*32, (h//32)*32))
    img = np.array(img, dtype=np.uint8)
    img[img==255] = 0
    y = np.zeros((1, img.shape[0], img.shape[1], nb_classes), dtype=np.float32)
    for i in range(img.shape[0]):
        for j in range(img.shape[1]):
            y[0, i, j, img[i][j]] = 1
    return y

def generate_arrays_from_file(path, image_dir, label_dir):
    while 1:
        f = open(path)
        for line in f:
            filename = line.rstrip('\n')
            path_image = os.path.join(image_dir, filename+'.jpg')
            path_label = os.path.join(label_dir, filename+'.png')
            x = load_image(path_image)
            y = load_label(path_label)
            yield (x, y)
        f.close()

def model_predict(model, input_path, output_path):
    img_org = Image.open(input_path)
    w, h = img_org.size
    img = img_org.resize(((w//32)*32, (h//32)*32))
    img = np.array(img, dtype=np.float32)
    x = np.expand_dims(img, axis=0)
    x = preprocess_input(x)
    pred = model.predict(x)
    pred = pred[0].argmax(axis=-1).astype(np.uint8)
    img = Image.fromarray(pred, mode='P')
    img = img.resize((w, h))
    palette_im = Image.open('palette.png')
    img.palette = copy.copy(palette_im.palette)
    img.save(output_path)

parser = argparse.ArgumentParser()
parser.add_argument('train_data')
parser.add_argument('image_dir')
parser.add_argument('label_dir')
args = parser.parse_args()

nb_data = sum(1 for line in open(args.train_data))

model = fcn_32s()
model.compile(loss="binary_crossentropy", optimizer='sgd')
for epoch in range(100):
    model.fit_generator(
        generate_arrays_from_file(args.train_data, args.image_dir, args.label_dir),
        steps_per_epoch=nb_data, 
        epochs=1)
    model_predict(model, 'test.jpg', 'predict-{}.png'.format(epoch))

evaluate

http://stackoverflow.com/questions/31653576/how-to-calculate-the-mean-iu-score-in-image-segmentation

https://gist.github.com/meetshah1995/6a5ad112559ef1536d0191f8b9fe8d1e

For some applications, e.g. in the graphics domain, the contour quality significantly contributes to the perceived segmentation quality
http://www.bmva.org/bmvc/2013/Papers/paper0032/paper0032.pdf