回答編集履歴

1

修正

2017/11/13 16:05

投稿

mkgrei
mkgrei

スコア8560

test CHANGED
@@ -21,3 +21,123 @@
21
21
  キーの意味を知りたければ、以下を参考に。
22
22
 
23
23
  http://scikit-learn.org/stable/modules/generated/sklearn.model_selection.train_test_split.html
24
+
25
+
26
+
27
+ ---
28
+
29
+ 追記:
30
+
31
+ コメントの内容を全部まとめてコードを書き直すと、以下の通り。
32
+
33
+ 走らせる環境がないので、タイプミス・インデントミスがあるかもしれません。あしからず。
34
+
35
+ ```python
36
+
37
+ # coding:utf-8
38
+
39
+ from sklearn import svm, cross_validation
40
+
41
+ from sklearn.metrics import classification_report, accuracy_score
42
+
43
+ import sys
44
+
45
+ from mfcc import *
46
+
47
+ import glob
48
+
49
+ import csv
50
+
51
+ import random
52
+
53
+ import numpy as np
54
+
55
+ import os
56
+
57
+ from sklearn.model_selection import train_test_split
58
+
59
+
60
+
61
+ def get_data(files, nfft, nceps):
62
+
63
+ data = None
64
+
65
+ label = np.array([])
66
+
67
+ for file_name in files:
68
+
69
+ feature = get_feature(file_name, nfft, nceps)
70
+
71
+ if data is None:
72
+
73
+ data = feature
74
+
75
+ else:
76
+
77
+ data = np.vstack((data, feature))
78
+
79
+
80
+
81
+ if file_name.split('/')[-1].startswith('dog'):
82
+
83
+ label = np.append(label, 0)
84
+
85
+ else:
86
+
87
+ label = np.append(label, 1)
88
+
89
+ return data, label
90
+
91
+
92
+
93
+ if __name__ == "__main__":
94
+
95
+ nfft = 2048 # FFTのサンプル数
96
+
97
+ nceps = 12 # MFCCの次元数
98
+
99
+
100
+
101
+ basedir = '/sound_animal/sounds'
102
+
103
+ files = glob.glob(os.path.join(basedir, '*.wav'))
104
+
105
+ data, label = get_data(files, nfft, nceps)
106
+
107
+
108
+
109
+ train_data, test_data, train_label, test_label = train_test_split(data, label, test_size=0.33, random_state=0, stratify=label)
110
+
111
+
112
+
113
+ feature_train_data = np.hstack((train_label.reshape(-1, 1), train_data))
114
+
115
+ feature_test_data = np.hstack((test_label.reshape(-1, 1), test_data))
116
+
117
+
118
+
119
+ with open("feature_data/train_data.txt", "w") as f:
120
+
121
+ writer = csv.writer(f)
122
+
123
+ writer.writerows(feature_train_data)
124
+
125
+ with open("feature_data/test_data.txt", "w") as f:
126
+
127
+ writer = csv.writer(f)
128
+
129
+ writer.writerows(feature_test_data)
130
+
131
+
132
+
133
+ clf = svm.SVC(kernel='linear', C=1)
134
+
135
+ clf.fit(train_data, train_label)
136
+
137
+ score = clf.score(test_data, test_label)
138
+
139
+
140
+
141
+ print(score)
142
+
143
+ ```