回答編集履歴
3
おかしい日本語を修正
test
CHANGED
@@ -1,6 +1,10 @@
|
|
1
|
-
ソースコードを見る限り、y_train, y_testの使い方を間違えているため
|
1
|
+
ソースコードを見る限り、y_train, y_testの使い方を間違えているため
|
2
2
|
|
3
|
+
エラーが発生していると予測しました。
|
4
|
+
|
5
|
+
|
6
|
+
|
3
|
-
|
7
|
+
`tf_correct.npy`に格納されているy_train.shape=(1470,21),y_test.shape=(630,21)を見る限り、もう既にto_categoricalされている状態に見えます。
|
4
8
|
|
5
9
|
|
6
10
|
|
@@ -52,7 +56,7 @@
|
|
52
56
|
|
53
57
|
```
|
54
58
|
|
55
|
-
|
59
|
+
したがってy_train, y_testをto_categorical関数を利用して変換するのではなく
|
56
60
|
|
57
61
|
そのままy_train, y_testを利用すればエラーは解消されるのではないでしょうか。
|
58
62
|
|
2
説明追記
test
CHANGED
@@ -1,12 +1,16 @@
|
|
1
|
+
ソースコードを見る限り、y_train, y_testの使い方を間違えているためエラーが発生していると予測しました。
|
2
|
+
|
1
|
-
`tf_correct.npy`に格納されているy_train.shape=(1470,21),y_test.shape=(630,21)を見る限り、もう既にto_categoricalされている状態に見えます。
|
3
|
+
理由としては`tf_correct.npy`に格納されているy_train.shape=(1470,21),y_test.shape=(630,21)を見る限り、もう既にto_categoricalされている状態に見えます。
|
2
4
|
|
3
5
|
|
4
6
|
|
5
|
-
|
7
|
+
単なる数値データ、クラス数が3だったら
|
6
8
|
|
7
9
|
0, 1, 2を [1, 0, 0], [0, 1, 0], [0, 0, 1]
|
8
10
|
|
9
11
|
に変換するのがto_categoricalの本来の使い方のはずです。
|
12
|
+
|
13
|
+
|
10
14
|
|
11
15
|
以下に例を用意してみました。
|
12
16
|
|
@@ -48,6 +52,10 @@
|
|
48
52
|
|
49
53
|
```
|
50
54
|
|
55
|
+
質問への直接的な回答としてはy_train, y_testをto_categorical関数を利用して変換するのではなく
|
56
|
+
|
57
|
+
そのままy_train, y_testを利用すればエラーは解消されるのではないでしょうか。
|
58
|
+
|
51
59
|
|
52
60
|
|
53
61
|
### 参考
|
1
説明追記
test
CHANGED
@@ -1,6 +1,14 @@
|
|
1
1
|
`tf_correct.npy`に格納されているy_train.shape=(1470,21),y_test.shape=(630,21)を見る限り、もう既にto_categoricalされている状態に見えます。
|
2
2
|
|
3
3
|
|
4
|
+
|
5
|
+
は単なる数値データ、クラス数が3だったら
|
6
|
+
|
7
|
+
0, 1, 2を [1, 0, 0], [0, 1, 0], [0, 0, 1]
|
8
|
+
|
9
|
+
に変換するのがto_categoricalの本来の使い方のはずです。
|
10
|
+
|
11
|
+
以下に例を用意してみました。
|
4
12
|
|
5
13
|
```python
|
6
14
|
|
@@ -22,6 +30,18 @@
|
|
22
30
|
|
23
31
|
categorical_labels = to_categorical(label, classes)
|
24
32
|
|
33
|
+
print(categorical_labels)
|
34
|
+
|
35
|
+
"""
|
36
|
+
|
37
|
+
array([[1., 0., 0.],
|
38
|
+
|
39
|
+
[0., 1., 0.],
|
40
|
+
|
41
|
+
[0., 0., 1.]])
|
42
|
+
|
43
|
+
"""
|
44
|
+
|
25
45
|
print(categorical_labels.shape)
|
26
46
|
|
27
47
|
# (3, 3)
|