質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

ただいまの
回答率

91.37%

  • Python 3.x

    2398questions

    Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。

randomforestの結果であるpredict_probaをcsvに出力出来ず、苦慮しております。

解決済

回答 1

投稿 2017/11/27 19:33

  • 評価
  • クリップ 0
  • VIEW 37

akakage13

score 74

randomforestで二値分類の勉強をしております初心者でございます。

randomforestの結果であるpredict_probaをcsvに出力出来ず、苦慮しております。

今回のソースコードは以下の通りでございます。

# -*- coding: utf-8 -*-

import numpy as np
import pandas as pd
from sklearn.ensemble import RandomForestClassifier

#訓練データ
gakusyuu_data = pd.read_csv("g_1.csv" , sep=",")

# 特徴データとラベルデータを取り出す
gakusyuu_data_except_arrival = gakusyuu_data.drop("result", axis=1)

features = gakusyuu_data_except_arrival.as_matrix()
targets = gakusyuu_data['result'].as_matrix()

##############モデル設定

model = RandomForestClassifier()

model.fit(features, targets)

# テストデータ読み込み
test_df = pd.read_csv("test_1.csv", sep=",")
# テストデータ作成
test_data = test_df[['ratio_1','ratio_2','ratio_3','ratio_4','ratio_5','ratio_6']].as_matrix()
# 予測
result = model.predict(test_data)

res_ratio = model.predict_proba(test_data)

print(result)
print(res_ratio)

# 結果をテストデータに反映
test_df['result'] = result


print(test_df)

test_df.to_csv('result.csv', index=None)

#res_ratio.to_csv('result_1.csv', index=None) #この部分を改変出来ず、苦慮しております。

g1_csvの内容は以下の通りでございます。

ratio_1    ratio_2    ratio_3    ratio_4    ratio_5    ratio_6    result
2    6    1    10    5    8    1
9    6    2    2    6    16    0
6    9    1    2    9    16    1
8    6    2    6    6    9    0
1    6    1    1    6    16    0
4    5    4    10    4    1    1
2    5    5    5    5    16    0
1    5    1    7    5    2    0
1    5    9    4    6    3    1


test_1.csvの内容は以下の通りでございます。

ratio_1    ratio_2    ratio_3    ratio_4    ratio_5    ratio_6
7    7    4    11    4    18

上記のソースコードを動かしますと

こ [1] [[ 0.4  0.6]]    ratio_1  ratio_2  ratio_3  ratio_4  ratio_5  ratio_6  result 0        7        7        4       11        4       18       1

このように、分類結果と、分類の根拠の割合であります predict_proba が

[[ 0.4  0.6]]と出力されます。

この割合を、result.csvに併せて、一緒にcsvに出力することが目的でございます。

例えば、

|ratio_1    ratio_2    ratio_3    ratio_4    ratio_5    ratio_6    result    answer_1    answer_2
7    7    4    11    4    18    1    0.4    0.6

このように、csvファイルとして、出力することが出来れば理想でございます。

numpyについて、いろいろ調べましたが、どうしても解決出来ません。

ヒント等でも頂けますと幸いです。

先輩方の御教示をよろしくお願いいたします。

   

  • 気になる質問をクリップする

    クリップした質問は、後からいつでもマイページで確認できます。

    またクリップした質問に回答があった際、通知やメールを受け取ることができます。

    クリップを取り消します

  • 良い質問の評価を上げる

    以下のような質問は評価を上げましょう

    • 質問内容が明確
    • 自分も答えを知りたい
    • 質問者以外のユーザにも役立つ

    評価が高い質問は、TOPページの「注目」タブのフィードに表示されやすくなります。

    質問の評価を上げたことを取り消します

  • 評価を下げられる数の上限に達しました

    評価を下げることができません

    • 1日5回まで評価を下げられます
    • 1日に1ユーザに対して2回まで評価を下げられます

    質問の評価を下げる

    teratailでは下記のような質問を「具体的に困っていることがない質問」、「サイトポリシーに違反する質問」と定義し、推奨していません。

    • プログラミングに関係のない質問
    • やってほしいことだけを記載した丸投げの質問
    • 問題・課題が含まれていない質問
    • 意図的に内容が抹消された質問
    • 広告と受け取られるような投稿

    評価が下がると、TOPページの「アクティブ」「注目」タブのフィードに表示されにくくなります。

    質問の評価を下げたことを取り消します

    この機能は開放されていません

    評価を下げる条件を満たしてません

    評価を下げる理由を選択してください

    詳細な説明はこちら

    上記に当てはまらず、質問内容が明確になっていない質問には「情報の追加・修正依頼」機能からコメントをしてください。

    質問の評価を下げる機能の利用条件

    この機能を利用するためには、以下の事項を行う必要があります。

回答 1

checkベストアンサー

+1

こんな感じでよろしいでしょうか?

# 予測
result = model.predict(test_data)
res_ratio = model.predict_proba(test_data)

print(result)
print(res_ratio)

# 結果をテストデータに反映
test_df['result'] = result

# res_ratio を結合
test_df = pd.concat([test_df, pd.DataFrame(res_ratio, columns=['answer_1','answer_2'])], axis=1)

# 結果をCSVファイルに出力
test_df.to_csv('result.csv', index=None)

投稿 2017/11/27 20:11

  • 回答の評価を上げる

    以下のような回答は評価を上げましょう

    • 正しい回答
    • わかりやすい回答
    • ためになる回答

    評価が高い回答ほどページの上位に表示されます。

  • 回答の評価を下げる

    下記のような回答は推奨されていません。

    • 間違っている回答
    • 質問の回答になっていない投稿
    • スパムや攻撃的な表現を用いた投稿

    評価を下げる際はその理由を明確に伝え、適切な回答に修正してもらいましょう。

  • 2017/11/27 20:40

    magichan様、早々の御教示、恐縮すると同時に、感謝の気持ちでいっぱいでございます。

    うまく動きました。ありがとうございました。

    今後ともよろしくお願いいたします。

    キャンセル

15分調べてもわからないことは、teratailで質問しよう!

ただいまの回答率

91.37%

関連した質問

同じタグがついた質問を見る

  • Python 3.x

    2398questions

    Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。