Kerasでファインチューニング

最終更新: 2017-02-27 00:42

Kerasでファインチューニング

VGG-16の場合

import argparse
import numpy as np
from keras import backend as K
from keras.models import Model, Sequential
from keras.layers import Deconvolution2D, Dense, Flatten, Input
from keras.applications.vgg16 import VGG16
from keras.applications.vgg16 import preprocess_input
from keras.preprocessing import image

parser = argparse.ArgumentParser()
parser.add_argument('filename')
args = parser.parse_args()

input_tensor = Input(shape=(224, 224, 3))
vgg16 = VGG16(weights='imagenet', include_top=False,
              input_tensor=input_tensor)

top_model = Sequential()
top_model.add(Flatten(input_shape=vgg16.output_shape[1:]))
top_model.add(Dense(256, activation='relu'))
top_model.add(Dense(10, activation='softmax'))
model = Model(input=vgg16.input, output=top_model(vgg16.output))

img = image.load_img(args.filename, target_size=(224, 224))
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)

pred = model.predict(x)

ポイントとしてはinput_tensorを指定しないとFlattenのサイズが確定しないからエラーがいっぱいでることに注意すること

参考ページ