回答編集履歴
1
修正
answer
CHANGED
@@ -9,4 +9,64 @@
|
|
9
9
|
```
|
10
10
|
|
11
11
|
キーの意味を知りたければ、以下を参考に。
|
12
|
-
http://scikit-learn.org/stable/modules/generated/sklearn.model_selection.train_test_split.html
|
12
|
+
http://scikit-learn.org/stable/modules/generated/sklearn.model_selection.train_test_split.html
|
13
|
+
|
14
|
+
---
|
15
|
+
追記:
|
16
|
+
コメントの内容を全部まとめてコードを書き直すと、以下の通り。
|
17
|
+
走らせる環境がないので、タイプミス・インデントミスがあるかもしれません。あしからず。
|
18
|
+
```python
|
19
|
+
# coding:utf-8
|
20
|
+
from sklearn import svm, cross_validation
|
21
|
+
from sklearn.metrics import classification_report, accuracy_score
|
22
|
+
import sys
|
23
|
+
from mfcc import *
|
24
|
+
import glob
|
25
|
+
import csv
|
26
|
+
import random
|
27
|
+
import numpy as np
|
28
|
+
import os
|
29
|
+
from sklearn.model_selection import train_test_split
|
30
|
+
|
31
|
+
def get_data(files, nfft, nceps):
|
32
|
+
data = None
|
33
|
+
label = np.array([])
|
34
|
+
for file_name in files:
|
35
|
+
feature = get_feature(file_name, nfft, nceps)
|
36
|
+
if data is None:
|
37
|
+
data = feature
|
38
|
+
else:
|
39
|
+
data = np.vstack((data, feature))
|
40
|
+
|
41
|
+
if file_name.split('/')[-1].startswith('dog'):
|
42
|
+
label = np.append(label, 0)
|
43
|
+
else:
|
44
|
+
label = np.append(label, 1)
|
45
|
+
return data, label
|
46
|
+
|
47
|
+
if __name__ == "__main__":
|
48
|
+
nfft = 2048 # FFTのサンプル数
|
49
|
+
nceps = 12 # MFCCの次元数
|
50
|
+
|
51
|
+
basedir = '/sound_animal/sounds'
|
52
|
+
files = glob.glob(os.path.join(basedir, '*.wav'))
|
53
|
+
data, label = get_data(files, nfft, nceps)
|
54
|
+
|
55
|
+
train_data, test_data, train_label, test_label = train_test_split(data, label, test_size=0.33, random_state=0, stratify=label)
|
56
|
+
|
57
|
+
feature_train_data = np.hstack((train_label.reshape(-1, 1), train_data))
|
58
|
+
feature_test_data = np.hstack((test_label.reshape(-1, 1), test_data))
|
59
|
+
|
60
|
+
with open("feature_data/train_data.txt", "w") as f:
|
61
|
+
writer = csv.writer(f)
|
62
|
+
writer.writerows(feature_train_data)
|
63
|
+
with open("feature_data/test_data.txt", "w") as f:
|
64
|
+
writer = csv.writer(f)
|
65
|
+
writer.writerows(feature_test_data)
|
66
|
+
|
67
|
+
clf = svm.SVC(kernel='linear', C=1)
|
68
|
+
clf.fit(train_data, train_label)
|
69
|
+
score = clf.score(test_data, test_label)
|
70
|
+
|
71
|
+
print(score)
|
72
|
+
```
|