質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

ただいまの
回答率

88.61%

tensorflowにおけるモデル構築Lambda文について

解決済

回答 1

投稿

  • 評価
  • クリップ 0
  • VIEW 1,058
退会済みユーザー

退会済みユーザー

 質問

tensorflowのimage内の関数random_cropを使うと以下のエラーが出ます。
tf.random_cropでは次元の数は変わらないと思うのですが。
どのようにすれば解消できますか。

 エラーメッセージ

以下はモデルを初期化する際に発生するメッセージです。

Dimensions must be equal, but are 4 and 3 for 'lambda_20/random_crop/GreaterEqual' (op: 'GreaterEqual') with input shapes: [4], [3].

 ソースコード

def multiresolution_model():
    inputs = Input(shape=(entire_x, entire_y, 3))

    high = Lambda(lambda image: tf.image.resize_images(image, (img_width, img_height)))(inputs) #こちらは問題なく通る
    x = Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv1')(high)
    x = BatchNormalization()(x)
    x = Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv2')(x)
    x = BatchNormalization()(x)
    x = MaxPooling2D((2, 2), strides=(2, 2), padding='same', name='block1_pool')(x)
    x = Conv2D(128, (3, 3), activation='relu', padding='same', name='block2_conv1')(x)
    x = BatchNormalization()(x)
    x = Conv2D(128, (3, 3), activation='relu', padding='same', name='block2_conv2')(x)
    x = BatchNormalization()(x)
    x = MaxPooling2D((2, 2), strides=(2, 2), padding='same', name='block2_pool')(x)
    x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv1')(x)
    x = BatchNormalization()(x)
    x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv2')(x)
    x = BatchNormalization()(x)
    x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv3')(x)
    x = BatchNormalization()(x)
    x = MaxPooling2D((2, 2), strides=(2, 2), padding='same', name='block3_pool')(x)
    x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv1')(x)
    x = BatchNormalization()(x)
    x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv2')(x)
    x = BatchNormalization()(x)
    x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv3')(x)
    x = BatchNormalization()(x)
    x = MaxPooling2D((2, 2), strides=(2, 2), padding='same', name='block4_pool')(x)
    x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv1')(x)
    x = BatchNormalization()(x)
    x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv2')(x)
    x = BatchNormalization()(x)
    x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv3')(x)
    x = BatchNormalization()(x)
    x = MaxPooling2D((2, 2), strides=(2, 2), padding='same', name='block5_pool')(x)
    flattened_high = Flatten(name='flatten')(x)

    #ここが問題の文
    low = Lambda(lambda image: tf.random_crop(image, [img_height, img_width, 3]))(inputs)#次元が違うとして止まる
    ######
    x = Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv1-2')(low)
    x = BatchNormalization()(x)
    x = Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv2-2')(x)
    x = BatchNormalization()(x)
    x = MaxPooling2D((2, 2), strides=(2, 2), padding='same', name='block1_pool-2')(x)
    x = Conv2D(128, (3, 3), activation='relu', padding='same', name='block2_conv1-2')(x)
    x = BatchNormalization()(x)
    x = Conv2D(128, (3, 3), activation='relu', padding='same', name='block2_conv2-2')(x)
    x = BatchNormalization()(x)
    x = MaxPooling2D((2, 2), strides=(2, 2), padding='same', name='block2_pool-2')(x)
    x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv1-2')(x)
    x = BatchNormalization()(x)
    x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv2-2')(x)
    x = BatchNormalization()(x)
    x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv3-2')(x)
    x = BatchNormalization()(x)
    x = MaxPooling2D((2, 2), strides=(2, 2), padding='same', name='block3_pool-2')(x)
    x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv1-2')(x)
    x = BatchNormalization()(x)
    x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv2-2')(x)
    x = BatchNormalization()(x)
    x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv3-2')(x)
    x = BatchNormalization()(x)
    x = MaxPooling2D((2, 2), strides=(2, 2), padding='same', name='block4_pool-2')(x)
    x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv1-2')(x)
    x = BatchNormalization()(x)
    x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv2-2')(x)
    x = BatchNormalization()(x)
    x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv3-2')(x)
    x = BatchNormalization()(x)
    x = MaxPooling2D((2, 2), strides=(2, 2), padding='same', name='block5_pool-2')(x)
    flattened_low = Flatten(name='flatten-2')(x)

    merge = concatenate([flattened_low, flattened_high])    
    x = Dense(4096, activation='relu', name='fc1')(merge)
    x = Dropout(0.5, name='dropout1')(x)
    x = Dense(4096, activation='relu', name='fc2')(x)
    x = Dropout(0.5, name='dropout2')(x)
    predictions = Dense(nb_classes, activation='softmax', name='predictions')(x)
    model = Model(inputs=inputs, outputs=predictions)

    return model

 補足

入力層でクロップ画像とリサイズ画像の2パターンを並列に扱うモデルを組み立てる際にcropがモデル中で定義できなかったので質問させていただきました。

  • 気になる質問をクリップする

    クリップした質問は、後からいつでもマイページで確認できます。

    またクリップした質問に回答があった際、通知やメールを受け取ることができます。

    クリップを取り消します

  • 良い質問の評価を上げる

    以下のような質問は評価を上げましょう

    • 質問内容が明確
    • 自分も答えを知りたい
    • 質問者以外のユーザにも役立つ

    評価が高い質問は、TOPページの「注目」タブのフィードに表示されやすくなります。

    質問の評価を上げたことを取り消します

  • 評価を下げられる数の上限に達しました

    評価を下げることができません

    • 1日5回まで評価を下げられます
    • 1日に1ユーザに対して2回まで評価を下げられます

    質問の評価を下げる

    teratailでは下記のような質問を「具体的に困っていることがない質問」、「サイトポリシーに違反する質問」と定義し、推奨していません。

    • プログラミングに関係のない質問
    • やってほしいことだけを記載した丸投げの質問
    • 問題・課題が含まれていない質問
    • 意図的に内容が抹消された質問
    • 過去に投稿した質問と同じ内容の質問
    • 広告と受け取られるような投稿

    評価が下がると、TOPページの「アクティブ」「注目」タブのフィードに表示されにくくなります。

    質問の評価を下げたことを取り消します

    この機能は開放されていません

    評価を下げる条件を満たしてません

    評価を下げる理由を選択してください

    詳細な説明はこちら

    上記に当てはまらず、質問内容が明確になっていない質問には「情報の追加・修正依頼」機能からコメントをしてください。

    質問の評価を下げる機能の利用条件

    この機能を利用するためには、以下の事項を行う必要があります。

回答 1

checkベストアンサー

0

Tensorflow の画像処理系の関数は画像1枚だけを想定しているので、ミニバッチに対して、まとめて処理する場合は、tf.map_fn() を使いましょう。

map on the list of tensors unpacked from elems on dimension 0.

Python の map() 関数同様に以下のように使います。

tf.map_fn(lambda img: tf.random_crop(img, [500, 500, 3]), input_tensor)

 サンプルコード

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from PIL import Image

# 画像を読み込む。
img = np.array(Image.open('test.jpg'))

# 画像を複製する。
imgs = np.tile(img, (4, 1, 1, 1))
print(imgs.shape)  # (4, 1960, 1960, 3)

# 計算グラフを作成する。
h, w = img.shape[:2]
input_tensor = tf.placeholder(tf.uint8, [None, h, w, 3])
output = tf.map_fn(lambda img: tf.random_crop(img, [500, 500, 3]), input_tensor)

# 実行する。
with tf.Session() as sess:
    cropped_imgs = sess.run(output, feed_dict={input_tensor: imgs})
print(cropped_imgs.shape)  # (4, 500, 500, 3)

# 描画する。
fig, ax_list = plt.subplots(2, 2, figsize=(8, 8))
for cropped, ax in zip(cropped_imgs, ax_list.ravel()):
    ax.set_axis_off()
    ax.imshow(cropped)
plt.show()

イメージ説明
入力

イメージ説明
出力

 追記

次元は変わらないですが、crop しているので、形状は変わります。

import tensorflow as tf

h, w = 1000, 1000
input_tensor = tf.placeholder(tf.uint8, [None, h, w, 3])
output = tf.map_fn(lambda img: tf.random_crop(img, [500, 500, 3]), input_tensor)

print(input_tensor.shape)
print(output.shape)
(?, 1000, 1000, 3)
(?, 500, 500, 3)

投稿

編集

  • 回答の評価を上げる

    以下のような回答は評価を上げましょう

    • 正しい回答
    • わかりやすい回答
    • ためになる回答

    評価が高い回答ほどページの上位に表示されます。

  • 回答の評価を下げる

    下記のような回答は推奨されていません。

    • 間違っている回答
    • 質問の回答になっていない投稿
    • スパムや攻撃的な表現を用いた投稿

    評価を下げる際はその理由を明確に伝え、適切な回答に修正してもらいましょう。

  • 2018/11/06 13:38

    おそらくこのエラーコードがデータの次元が変更されてしまっていることから来ているものだと思ったのです。tf.mat_fnを通す前後で軸が変更されてしまうことはありますか?

    キャンセル

  • 2018/11/06 14:19

    次元や軸の位置は変わりません。Crop しているので、形状は変わります。

    キャンセル

  • 2018/11/06 15:21

    ありがとうございました。解決しました。

    キャンセル

15分調べてもわからないことは、teratailで質問しよう!

  • ただいまの回答率 88.61%
  • 質問をまとめることで、思考を整理して素早く解決
  • テンプレート機能で、簡単に質問をまとめられる

関連した質問

同じタグがついた質問を見る