回答編集履歴
3
おかしい日本語を修正
answer
CHANGED
@@ -1,6 +1,8 @@
|
|
1
|
-
ソースコードを見る限り、y_train, y_testの使い方を間違えているため
|
1
|
+
ソースコードを見る限り、y_train, y_testの使い方を間違えているため
|
2
|
-
|
2
|
+
エラーが発生していると予測しました。
|
3
3
|
|
4
|
+
`tf_correct.npy`に格納されているy_train.shape=(1470,21),y_test.shape=(630,21)を見る限り、もう既にto_categoricalされている状態に見えます。
|
5
|
+
|
4
6
|
単なる数値データ、クラス数が3だったら
|
5
7
|
0, 1, 2を [1, 0, 0], [0, 1, 0], [0, 0, 1]
|
6
8
|
に変換するのがto_categoricalの本来の使い方のはずです。
|
@@ -25,7 +27,7 @@
|
|
25
27
|
print(categorical_labels.shape)
|
26
28
|
# (3, 3)
|
27
29
|
```
|
28
|
-
|
30
|
+
したがってy_train, y_testをto_categorical関数を利用して変換するのではなく
|
29
31
|
そのままy_train, y_testを利用すればエラーは解消されるのではないでしょうか。
|
30
32
|
|
31
33
|
### 参考
|
2
説明追記
answer
CHANGED
@@ -1,8 +1,10 @@
|
|
1
|
+
ソースコードを見る限り、y_train, y_testの使い方を間違えているためエラーが発生していると予測しました。
|
1
|
-
`tf_correct.npy`に格納されているy_train.shape=(1470,21),y_test.shape=(630,21)を見る限り、もう既にto_categoricalされている状態に見えます。
|
2
|
+
理由としては`tf_correct.npy`に格納されているy_train.shape=(1470,21),y_test.shape=(630,21)を見る限り、もう既にto_categoricalされている状態に見えます。
|
2
3
|
|
3
|
-
|
4
|
+
単なる数値データ、クラス数が3だったら
|
4
5
|
0, 1, 2を [1, 0, 0], [0, 1, 0], [0, 0, 1]
|
5
6
|
に変換するのがto_categoricalの本来の使い方のはずです。
|
7
|
+
|
6
8
|
以下に例を用意してみました。
|
7
9
|
```python
|
8
10
|
from tflearn.data_utils import to_categorical
|
@@ -23,6 +25,8 @@
|
|
23
25
|
print(categorical_labels.shape)
|
24
26
|
# (3, 3)
|
25
27
|
```
|
28
|
+
質問への直接的な回答としてはy_train, y_testをto_categorical関数を利用して変換するのではなく
|
29
|
+
そのままy_train, y_testを利用すればエラーは解消されるのではないでしょうか。
|
26
30
|
|
27
31
|
### 参考
|
28
32
|
- [tflearn > data_utils - to_categorical](http://tflearn.org/data_utils/#to_categorical)
|
1
説明追記
answer
CHANGED
@@ -1,5 +1,9 @@
|
|
1
1
|
`tf_correct.npy`に格納されているy_train.shape=(1470,21),y_test.shape=(630,21)を見る限り、もう既にto_categoricalされている状態に見えます。
|
2
2
|
|
3
|
+
は単なる数値データ、クラス数が3だったら
|
4
|
+
0, 1, 2を [1, 0, 0], [0, 1, 0], [0, 0, 1]
|
5
|
+
に変換するのがto_categoricalの本来の使い方のはずです。
|
6
|
+
以下に例を用意してみました。
|
3
7
|
```python
|
4
8
|
from tflearn.data_utils import to_categorical
|
5
9
|
import numpy as np
|
@@ -10,6 +14,12 @@
|
|
10
14
|
#(3,)
|
11
15
|
|
12
16
|
categorical_labels = to_categorical(label, classes)
|
17
|
+
print(categorical_labels)
|
18
|
+
"""
|
19
|
+
array([[1., 0., 0.],
|
20
|
+
[0., 1., 0.],
|
21
|
+
[0., 0., 1.]])
|
22
|
+
"""
|
13
23
|
print(categorical_labels.shape)
|
14
24
|
# (3, 3)
|
15
25
|
```
|