実現したいこと
- 学習用データを用いて作成した決定木モデルを、そのモデルと検証用データを入力としたpredict関数を用いて正解率の評価をしたい。
- 作成した決定木モデルの図示をしたい。
前提
R言語を用いて、データ「cust1.csv」の決定木分析を行いたいです。cust1.csvはこちらのリンクからダウンロードできます。
https://ux.getuploader.com/subtest/download/1
cust1.csvのデータ項目は以下の通りで、目的変数がhotel、説明変数がidを除くそれ以外となっています。
gender 性別 1 : 男性 2 : 女性 age 年齢 income 収入 1 : 100万円未満 4 : 100万円以上 5 : 300万円以上 6 : 500万円以上 7 : 700万円以上 8 :1000万円以上 marriage 婚姻 1 : 既婚 2 : 未婚 usage クレジットカード年間総利用額(以下からジャンルごとの金額) retail 小売 rest 飲食 enter エンタメ trans 交通 other その他 hotel 宿泊利用
目的変数であるhotelは0-1変数で、0が宿泊未利用、1が宿泊利用を表しています。ここでは説明変数から目的変数を予測する決定木モデルを作成し、未知のデータに対してどれだけ予測精度が良いのかを検証していくことが目標です。
このデータを変数cust.1に格納し、学習用データをcustTrainに、検証用データをcustTestにそれぞれ格納しております。(後述のソースコードに記載しております)
発生している問題・エラーメッセージ
発生している問題点は、以下の2点です。
- ライブラリpartykitを用いた決定木の箱ひげ図の表示がすべて同じであり、かつ、データの四分位範囲が極端に狭いことであることからも結果がおかしいこと。
R
1install.packages("rpart.plot") 2library(rpart.plot) 3install.packages("partykit") 4library(partykit) 5 6model.rp <- rpart(hotel~., data = custTrain) # すべての説明変数からhotelの値を予測する 7rpart.plot(model.rp) 8model.rpo <- as.party(model.rp, type='simple') 9plot(model.rpo) #以下の画像がこの行の実行結果です
2. 作成した決定木モデルと未知のデータcustTestを入力としてpredict関数で予測を行った結果の値が0.03865975と0.11043432の2つのみになってしまうこと。
理想としては、すべてのデータに対する予測値を表示させるのではなく、混同行列のような形式で正解率を表示させたいのですが、調べてもわかりませんでした。
R
1# 未知データ(検証用データ)に対して予測を行う 2model.pred <- predict(model.rp, custTest) 3model.pred #この行の実行結果が以下です
・・・(前略) 65975 0.03865975 0.03865975 0.03865975 0.03865975 0.17823882 0.11043432 125737 125740 125743 125746 125749 125752 125755 0.11043432 0.11043432 0.03865975 0.11043432 0.03865975 0.03865975 0.03865975 125758 125761 125764 125767 125770 125773 125776 0.03865975 0.11043432 0.11043432 0.03865975 0.03865975 0.03865975 0.03865975 125779 125782 125785 125788 125791 125794 125797 0.03865975 0.03865975 0.17823882 0.03865975 0.11043432 0.03865975 0.03865975 125800 125803 125806 125809 125812 125815 125818 0.03865975 0.03865975 0.17823882 0.11043432 0.03865975 0.03865975 0.03865975 125821 125824 125827 125830 125833 125836 125839 0.03865975 0.17823882 0.11043432 0.03865975 0.03865975 0.11043432 0.17823882 125842 125845 125848 125851 125854 125858 125861 0.03865975 0.03865975 0.03865975 0.03865975 0.03865975 0.03865975 0.03865975 125864 125867 125870 125874 125877 125880 125885 0.03865975 0.03865975 0.03865975 0.03865975 0.03865975 0.17823882 0.03865975 125888 125891 125894 125897 125900 125903 125906 0.17823882 0.03865975 0.03865975 0.11043432 0.03865975 0.03865975 0.17823882 125909 125912 125915 125919 125922 125925 125928 0.03865975 0.03865975 0.03865975 0.03865975 0.03865975 0.03865975 0.03865975 125931 125934 125937 125940 125943 125946 125949 0.03865975 0.03865975 0.03865975 0.17823882 0.03865975 0.03865975 0.11043432 125952 125955 125958 125961 125964 125967 125970 0.03865975 0.03865975 0.03865975 0.11043432 0.03865975 0.03865975 0.03865975 125973 125976 125979 125982 125985 125989 125992 0.03865975 0.11043432 0.03865975 0.03865975 0.17823882 0.03865975 0.03865975 125996 125999 126002 126005 126008 126011 126014 0.03865975 0.03865975 0.03865975 0.03865975 0.03865975 0.11043432 0.17823882 126017 126020 126023 126026 126029 126032 126035 0.03865975 0.03865975 0.03865975 0.03865975 0.03865975 0.03865975 0.03865975 126038 126041 126044 126047 126050 126053 126056 0.17823882 0.03865975 0.03865975 0.11043432 0.03865975 0.03865975 0.17823882 126059 126062 126065 126068 126071 126074 126077 0.17823882 0.03865975 0.17823882 0.11043432 0.11043432 0.03865975 0.11043432 126080 126083 126086 126090 126095 126098 126101 0.03865975 0.11043432 0.03865975 0.03865975 0.03865975 0.03865975 0.11043432 126104 126108 126111 126114 126117 126120 126123 0.17823882 0.11043432 0.03865975 0.03865975 0.11043432 0.11043432 0.03865975 126126 126130 126133 126136 126139 126142 126145 0.17823882 0.03865975 0.03865975 0.11043432 0.03865975 0.11043432 0.17823882 126148 126151 126154 126157 126160 126163 126166 0.03865975 0.03865975 0.03865975 0.03865975 0.03865975 0.11043432 0.03865975 126169 126172 126175 126178 126181 126184 126187 0.11043432 0.03865975 0.17823882 0.17823882 0.17823882 0.17823882 0.17823882 126190 126193 126196 126199 126202 126205 126208 0.03865975 0.03865975 0.03865975 0.03865975 0.03865975 0.11043432 0.17823882 126211 126214 126217 126220 126223 126226 126229 0.11043432 0.03865975 0.03865975 0.11043432 0.03865975 0.03865975 0.03865975 126232 126235 126238 126241 126244 126247 126250 0.11043432 0.03865975 0.03865975 0.03865975 0.03865975 0.03865975 0.03865975 126253 126256 126259 126262 126265 126268 126272 0.03865975 0.03865975 0.03865975 0.11043432 0.03865975 0.03865975 0.11043432
該当のソースコード
ソースコードの全体を掲載いたします。
R
1install.packages("rpart") 2library(rpart) 3install.packages("rpart.plot") 4library(rpart.plot) 5install.packages("partykit") 6library(partykit) 7 8# データのインポート 9getwd() 10cust <- read.csv(file = "cust1.csv") 11head(cust) 12 13# id列を説明変数から除外 14cust$id = NULL 15 16# データの各列のサマリーの確認 17summary(cust) 18 19# 異常値の除外 20cust.1 <- cust 21cust.1$age <- ifelse(cust.1$age>122, NA, cust.1$age) 22cust.1$retail <- ifelse(cust.1$retail<0, NA, cust.1$retail) 23cust.1$rest <- ifelse(cust.1$rest<0, NA, cust.1$rest) 24cust.1$trans <- ifelse(cust.1$trans<0, NA, cust.1$trans) 25cust.1$other <- ifelse(cust.1$other<0, NA, cust.1$other) 26cust.1$usage <- ifelse(cust.1$usage<0, NA, cust.1$usage) 27summary(cust.1) 28 29# 欠損値の削除 30cust.1 <- na.omit(cust.1) 31 32# 箱ひげ図を用いて、外れ値を確認する 33boxplot(cust.1$age) 34boxplot(cust.1$retail) 35boxplot(cust.1$rest) 36boxplot(cust.1$enter) 37boxplot(cust.1$trans) 38boxplot(cust.1$other) 39boxplot(cust.1$usage) 40boxplot(cust.1[,c("age","retail","rest","enter","trans","other","usage")]) 41 42# 外れ値を除外する 43higher.retail = quantile(cust.1$retail, probs=0.99) 44cust.1$retail <- ifelse(cust.1$retail>higher.retail, NA, cust.1$retail) 45higher.rest = quantile(cust.1$rest, probs=0.99) 46cust.1$rest <- ifelse(cust.1$rest>higher.rest, NA, cust.1$rest) 47higher.trans = quantile(cust.1$trans, probs=0.99) 48cust.1$trans <- ifelse(cust.1$trans>higher.trans, NA, cust.1$trans) 49higher.other = quantile(cust.1$other, probs=0.99) 50cust.1$other <- ifelse(cust.1$other>higher.other, NA, cust.1$other) 51higher.usage = quantile(cust.1$usage, probs=0.99) 52cust.1$usage <- ifelse(cust.1$usage>higher.usage, NA, cust.1$usage) 53 54# 欠損値の削除 55cust.1 <- na.omit(cust.1) 56summary(cust.1) 57 58# custを学習用データと検証用データの2つに分類しておく 59index <- which(1:nrow(cust.1)%%3 == 0) 60custTrain <- cust.1[-index,] #3の倍数以外を学習用データにする 61custTest <- cust.1[index,] #3の倍数は検証用データ 62head(custTrain) 63head(custTest) 64 65# 学習用データを用いて決定木モデルを作成 66model.rp <- rpart(hotel~., data = custTrain) 67model.rp 68rpart.plot(model.rp) 69 70model.rpo <- as.party(model.rp, type='simple') 71plot(model.rpo) 72 73# 未知データ(検証用データ)に対して予測を行う 74model.pred <- predict(model.rp, custTest) 75head(model.pred) 76
試したこと
上記ソースコードの17行目「summary(cust)」の出力結果から、「cust1.csv」には異常値があること、また、33~40行目にある箱ひげ図の出力結果から、外れ値があることが判明しました。
ageに122歳以上のデータがあることや、ジャンルごとのクレジットカードの利用金額に負の値があることから、これらのデータをいったん欠損値NAに変え、「na.omit(cust.1)」によってNAの行を除外しています。
R
1# 異常値の除外 2cust.1 <- cust 3cust.1$age <- ifelse(cust.1$age>122, NA, cust.1$age) 4cust.1$retail <- ifelse(cust.1$retail<0, NA, cust.1$retail) 5cust.1$rest <- ifelse(cust.1$rest<0, NA, cust.1$rest) 6cust.1$trans <- ifelse(cust.1$trans<0, NA, cust.1$trans) 7cust.1$other <- ifelse(cust.1$other<0, NA, cust.1$other) 8cust.1$usage <- ifelse(cust.1$usage<0, NA, cust.1$usage) 9summary(cust.1)
また、箱ひげ図の結果から異常に大きな値(外れ値)があることを確認したため、データの上位0.01%を除外してみました。
R
1# 外れ値を除外する 2higher.retail = quantile(cust.1$retail, probs=0.99) 3cust.1$retail <- ifelse(cust.1$retail>higher.retail, NA, cust.1$retail) 4higher.rest = quantile(cust.1$rest, probs=0.99) 5cust.1$rest <- ifelse(cust.1$rest>higher.rest, NA, cust.1$rest) 6higher.trans = quantile(cust.1$trans, probs=0.99) 7cust.1$trans <- ifelse(cust.1$trans>higher.trans, NA, cust.1$trans) 8higher.other = quantile(cust.1$other, probs=0.99) 9cust.1$other <- ifelse(cust.1$other>higher.other, NA, cust.1$other) 10higher.usage = quantile(cust.1$usage, probs=0.99) 11cust.1$usage <- ifelse(cust.!1$usage>higher.usage, NA, cust.1$usage) 12# 欠損値の削除 13cust.1 <- na.omit(cust.1) 14summary(cust.1)
Rによる決定木分析についてネットで何時間も調べていましたが、結局問題解決には至りませんでした。どなたかお力をお借りしたいです。よろしくお願いいたします。
補足情報(FW/ツールのバージョンなど)
R Gui 4.3.2 (Windows版)
追記2(検証用データによるモデルの評価について)
R
1#学習用データを用いて決定木モデルを作成 2install.packages("rpart") 3library(rpart) 4install.packages("partykit") 5library(partykit) 6install.packages("dplyr") 7library(dplyr) 8model < 9rpart(hotel~., data = custTrain, cp=0.002, method=' anova') 10plot(as.party(model), type='simple') 11 12# 未知データ(検証用データ)に対して予測を行う 13threshold <- 0.1 14model.pred <- predict(model, newdata = custTest, type = "vector") >= threshold 15table(model.pred, custTest[,4])
実行結果
model.pred 0 1 FALSE 29647 1434 TRUE 7825 1526
そもそも閾値を0.5にした時点でFALSEの判定(hote=0)しか表示されない点で何かがおかしいと感じております。type = "vector"としているのは、以下のサイトを参考にしたためです。
https://www.appsloveworld.com/r/100/148/invalid-prediction-for-rpart-object-error
回答1件
あなたの回答
tips
プレビュー