質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

ただいまの
回答率

90.61%

  • Python

    7516questions

    Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

  • Python 3.x

    5926questions

    Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。

Python mnistを畳み込みで学習させてTSNE表示をさせたい

受付中

回答 1

投稿

  • 評価
  • クリップ 0
  • VIEW 152

python_man

score 3

 前提・実現したいこと

mnistの畳み込みニューラルネットワークでのTSNE表示を行いたいがエラーが出てしまう。
対処法を教えていただきたいです。

 発生している問題・エラーメッセージ

Found array with dim 4. Estimator expected <= 2.

 該当のソースコード

ソースコード

import numpy as np
import matplotlib.pyplot as plt
import keras
import tensorflow as tf

from keras.models import Sequential
from keras.datasets import mnist,fashion_mnist
from keras.layers import Dense, Dropout, Activation, Conv2D, MaxPooling2D, Flatten
from keras.optimizers import rmsprop
from keras.utils.np_utils import to_categorical

from sklearn.manifold import TSNE

np.random.seed(7)

session_conf = tf.ConfigProto(
intra_op_parallelism_threads=1,
inter_op_parallelism_threads=1
)

from keras import backend as K
tf.set_random_seed(7)

sess = tf.Session(graph=tf.get_default_graph(),config=session_conf)
K.set_session(sess)

in_n = 784
class_n = 10

hidden_layers = 2
hidden_units = 10

(x_train, y_train), (x_test, y_test) = mnist.load_data()

x_train  = x_train.reshape([-1, 28, 28, 1])
x_test   = x_test.reshape([-1, 28, 28, 1])
x_train  = x_train.astype('float32')
x_test   = x_test.astype('float32')
x_train /= 255
x_test  /= 255
y_train  = keras.utils.to_categorical(y_train, 10)
y_test   = keras.utils.to_categorical(y_test, 10)

warna = []
bentuk = []
warna_code = ['red','blue','green','black','magenta','cyan','grey','aqua','springgreen','salmon']
bentuk_code = ['o','s','D','*','v','p','8','h','+','X']

pattern_n = range((int)(x_train.size/in_n))

print(pattern_n)

num_classes = 10

model = Sequential()

model.add(Conv2D(32, (3, 3), 
input_shape=(28,28,1)))
model.add(Activation('relu'))
model.add(Conv2D(32, (3, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))

model.add(Conv2D(64, (3, 3),))
model.add(Activation('relu'))
model.add(Conv2D(64, (3, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))

model.add(Flatten())
model.add(Dense(512))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(num_classes))
model.add(Activation('softmax'))

print(model.summary())

model.compile(
loss='categorical_crossentropy',
optimizer=rmsprop(lr=0.0001, decay=1e-6),
metrics=['accuracy'])

fit = model.fit(x_train, y_train,
batch_size=32,
epochs=1, #shouldn't be raised to 100, because the overfitting occurs.
verbose=2,
validation_split=0.1
)

score = model.evaluate(x_test, y_test,
verbose=0
)
print('Test score:', score[0])
print('Test accuracy:', score[1])

output = model.predict(x_train)

warna_output = []
for i in pattern_n:
warna_output.append(warna_code[np.argmax(output[i,:])])

X_reduced_input = TSNE(n_components=2, random_state=0).fit_transform(x_train)
for i in pattern_n:
plt.scatter(X_reduced_input[i,0], X_reduced_input[i,1], c=warna[i], marker=bentuk[i])
plt.show()

for vis_layer in range(2, hidden_layers*2,2):
get_layer_output = K.function([model.layers[0].input],[model.layers[vis_layer].output])
hidden_output = get_layer_output([x_train])[0]
X_reduced = TSNE(n_components=2, random_state=0).fit_transform(hidden_output)
for i in pattern_n:
plt.scatter(X_reduced[i,0], X_reduced[i,1], c=warna[i], marker=bentuk[i])
plt.show()

K.clear_session()

 試したこと

ここに問題に対して試したことを記載してください。

 補足情報(FW/ツールのバージョンなど)

ここにより詳細な情報を記載してください。

  • 気になる質問をクリップする

    クリップした質問は、後からいつでもマイページで確認できます。

    またクリップした質問に回答があった際、通知やメールを受け取ることができます。

    クリップを取り消します

  • 良い質問の評価を上げる

    以下のような質問は評価を上げましょう

    • 質問内容が明確
    • 自分も答えを知りたい
    • 質問者以外のユーザにも役立つ

    評価が高い質問は、TOPページの「注目」タブのフィードに表示されやすくなります。

    質問の評価を上げたことを取り消します

  • 評価を下げられる数の上限に達しました

    評価を下げることができません

    • 1日5回まで評価を下げられます
    • 1日に1ユーザに対して2回まで評価を下げられます

    質問の評価を下げる

    teratailでは下記のような質問を「具体的に困っていることがない質問」、「サイトポリシーに違反する質問」と定義し、推奨していません。

    • プログラミングに関係のない質問
    • やってほしいことだけを記載した丸投げの質問
    • 問題・課題が含まれていない質問
    • 意図的に内容が抹消された質問
    • 広告と受け取られるような投稿

    評価が下がると、TOPページの「アクティブ」「注目」タブのフィードに表示されにくくなります。

    質問の評価を下げたことを取り消します

    この機能は開放されていません

    評価を下げる条件を満たしてません

    評価を下げる理由を選択してください

    詳細な説明はこちら

    上記に当てはまらず、質問内容が明確になっていない質問には「情報の追加・修正依頼」機能からコメントをしてください。

    質問の評価を下げる機能の利用条件

    この機能を利用するためには、以下の事項を行う必要があります。

回答 1

0

tsneに入れるときにXの次元を落としてください。

.reshape(X.shape[0], -1)みたく。

投稿

  • 回答の評価を上げる

    以下のような回答は評価を上げましょう

    • 正しい回答
    • わかりやすい回答
    • ためになる回答

    評価が高い回答ほどページの上位に表示されます。

  • 回答の評価を下げる

    下記のような回答は推奨されていません。

    • 間違っている回答
    • 質問の回答になっていない投稿
    • スパムや攻撃的な表現を用いた投稿

    評価を下げる際はその理由を明確に伝え、適切な回答に修正してもらいましょう。

15分調べてもわからないことは、teratailで質問しよう!

  • ただいまの回答率 90.61%
  • 質問をまとめることで、思考を整理して素早く解決
  • テンプレート機能で、簡単に質問をまとめられる

関連した質問

同じタグがついた質問を見る

  • Python

    7516questions

    Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

  • Python 3.x

    5926questions

    Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。