勉強などのメモ

勉強用のメモ

機械学習:Rでやってみる決定木

機械学習をJuliaでやろうと思ったものの、新しい言語を覚えるのもそれなりに労力がいる。
最近ようやくRに慣れてきたところなので、これを定着させることを目指してRでやることにした。

機械学習の勉強に役立つサイトとして有名な「kaggle」を活用させてもらうことにして、以下のデータセットを使用して決定木を試してみる。

Titanic: Machine Learning from Disaster | Kaggle

まず、データの読み込み。

#csvデータ読み込み
train <- read.csv('../all/train.csv')
test <- read.csv('../all/test.csv')

データの中身を軽く確認してみる。

summary(train)

>  PassengerId       Survived          Pclass     
> Min.   :  1.0   Min.   :0.0000   Min.   :1.000  
> 1st Qu.:223.5   1st Qu.:0.0000   1st Qu.:2.000  
> Median :446.0   Median :0.0000   Median :3.000  
> Mean   :446.0   Mean   :0.3838   Mean   :2.309  
> 3rd Qu.:668.5   3rd Qu.:1.0000   3rd Qu.:3.000  
> Max.   :891.0   Max.   :1.0000   Max.   :3.000  
>                                                 
>                                    Name         Sex           Age       
> Abbing, Mr. Anthony                  :  1   female:314   Min.   : 0.42  
> Abbott, Mr. Rossmore Edward          :  1   male  :577   1st Qu.:20.12  
> Abbott, Mrs. Stanton (Rosa Hunt)     :  1                Median :28.00  
> Abelson, Mr. Samuel                  :  1                Mean   :29.70  
> Abelson, Mrs. Samuel (Hannah Wizosky):  1                3rd Qu.:38.00  
> Adahl, Mr. Mauritz Nils Martin       :  1                Max.   :80.00  
> (Other)                              :885                NA's   :177    
>     SibSp           Parch             Ticket         Fare       
> Min.   :0.000   Min.   :0.0000   1601    :  7   Min.   :  0.00  
> 1st Qu.:0.000   1st Qu.:0.0000   347082  :  7   1st Qu.:  7.91  
> Median :0.000   Median :0.0000   CA. 2343:  7   Median : 14.45  
> Mean   :0.523   Mean   :0.3816   3101295 :  6   Mean   : 32.20  
> 3rd Qu.:1.000   3rd Qu.:0.0000   347088  :  6   3rd Qu.: 31.00  
> Max.   :8.000   Max.   :6.0000   CA 2144 :  6   Max.   :512.33  
>                                  (Other) :852                   
>         Cabin     Embarked
>            :687    :  2   
> B96 B98    :  4   C:168   
> C23 C25 C27:  4   Q: 77   
> G6         :  4   S:644   
> C22 C26    :  3           
> D          :  3           
> (Other)    :186           

年齢にNAがある模様。見にくいので欠損値の数のみ確認する。

#欠損値の数確認
is_na_train <- sapply(train, function(y) sum(is.na(y)))
is_na_train

>PassengerId    Survived      Pclass        Name         Sex         Age 
>          0           0           0           0           0         177 
>      SibSp       Parch      Ticket        Fare       Cabin    Embarked 
>          0           0           0           0           0           0 

欠損値があるのはAgeだけなのかな?と思って、データを目視確認したところ、ところどころブランク(空白)がある。これはNA判定されていない模様。そこで、空白はNAにするようにデータ読み込みし直すことに。

#ブランクもNAにして読み込む
train <- read.csv('../all/train.csv', stringsAsFactors=F, na.strings = c('NA', ''))
test <- read.csv('../all/test.csv', stringsAsFactors=F, na.strings = c('NA', ''))

「stringsAsFactors=F」というのは「文字列型をFactor型にしないように」というオプション。

再び欠損値数の確認。

#欠損値の数確認(再)
is_na_train <- sapply(train, function(y) sum(is.na(y)))
is_na_train

>PassengerId    Survived      Pclass        Name         Sex         Age 
>          0           0           0           0           0         177 
>      SibSp       Parch      Ticket        Fare       Cabin    Embarked 
>          0           0           0           0         687           2 

Age以外にも欠損値がありました。Cabinにかなりの数の欠損が。でも、モデルに使わないので無視。Embarkedの欠損が2件。
テストデータの方も確認しておく。

is_na_test <-  sapply(test, function(y) sum(is.na(y)))
is_na_test

>PassengerId      Pclass        Name         Sex         Age       SibSp 
>          0           0           0           0          86           0 
>      Parch      Ticket        Fare       Cabin    Embarked 
>          0           0           1         327           0 

モデルを作る前にデータのクリーニングをしなくてはならない。まず欠損値を埋める処理を行う。

#Ageの欠損値を埋める。とりあえずmedianを入れる
train$Age[is.na(train$Age)] <- median(train$Age, na.rm=T)
test$Age[is.na(test$Age)] <- median(test$Age, na.rm=T)

#Embarkedの欠損値を埋める。とりあえず数が最も多いSを入れる
#testには欠損値なし
train$Embarked[is.na(train$Embarked)] <- 'S'

#Fareの欠損値を埋める。medianを入れる
#trainには欠損値なし
test$Fare[is.na(test$Fare)] <- median(test$Fare, na.rm=T)

欠損値がなくなったかどうか確認。

#欠損値の数確認(再々)
is_na_train <- sapply(train, function(y) sum(is.na(y)))
is_na_train

>PassengerId    Survived      Pclass        Name         Sex         Age 
>          0           0           0           0           0           0 
>      SibSp       Parch      Ticket        Fare       Cabin    Embarked 
>          0           0           0           0         687           0 

Age、Embarkedの欠損値は0になった。テストデータの確認は省略。

次に、文字列型のデータをカテゴリ値に修正する。

#文字列をカテゴリ値に
train$Sex[train$Sex=='male'] <- 0
train$Sex[train$Sex=='female'] <- 1
train$Embarked[train$Embarked=='S'] <- 0
train$Embarked[train$Embarked=='C'] <- 1
train$Embarked[train$Embarked=='Q'] <- 2

test$Sex[test$Sex=='male'] <- 0
test$Sex[test$Sex=='female'] <- 1
test$Embarked[test$Embarked=='S'] <- 0
test$Embarked[test$Embarked=='C'] <- 1
test$Embarked[test$Embarked=='Q'] <- 2

これでデータクリーニングが終わったので、いよいよモデルを作る。

#決定木でモデルを作る
library(rpart)

model1 <- rpart(Survived ~ Pclass + Age + Sex + Fare, data=train)
model1

>n= 891 
>
>node), split, n, deviance, yval
>      * denotes terminal node
>
> 1) root 891 210.727300 0.3838384  
>   2) Sex=0 577  88.409010 0.1889081  
>     4) Age>=6.5 553  77.359860 0.1681736  
>       8) Pclass>=1.5 433  44.226330 0.1154734 *
>       9) Pclass< 1.5 120  27.591670 0.3583333 *
>     5) Age< 6.5 24   5.333333 0.6666667 *
>   3) Sex=1 314  60.105100 0.7420382  
>     6) Pclass>=2.5 144  36.000000 0.5000000  
>      12) Fare>=23.35 27   2.666667 0.1111111 *
>      13) Fare< 23.35 117  28.307690 0.5897436 *
>     7) Pclass< 2.5 170   8.523529 0.9470588 *

うーん、よくわからない。なので、図示してみる。

library(rpart.plot)

#図示してみる
rpart.plot(model1, extra = 1)

出力結果は以下の図。

f:id:prism0081:20180919125511j:plain

これの評価はさておき、作ったモデルをテストデータにあてはめて予測してみる。

#モデルで予測する
pred1 <- predict(model1, test)
pred1

>        1         2         3         4         5         6         7 
>0.1154734 0.5897436 0.1154734 0.1154734 0.5897436 0.1154734 0.5897436 
>        8         9        10        11        12        13        14 
>0.1154734 0.5897436 0.1154734 0.1154734 0.3583333 0.9470588 0.1154734 
(以下略)

0,1での出力じゃなくて、確率で出してくれるのね。kaggeleへの提出は0,1での予測が必要なので変換する。

#0,1で表現する
pred1 <- round(pred1)
pred1

>  1   2   3   4   5   6   7   8   9  10  11  12  13  14  15  16  17  18  19 
>  0   1   0   0   1   0   1   0   1   0   0   0   1   0   1   1   0   0   1 
> 20  21  22  23  24  25  26  27  28  29  30  31  32  33  34  35  36  37  38 
>  1   0   0   1   0   1   0   1   0   0   0   0   0   1   0   0   0   1   1 
(以下略)

予測結果をkaggleで評価してもらうために、csvファイルに書き出す。

# 予測結果とPassengerIdをデータフレームへ
my_solution1 = data.frame(
  PassengerId = test$PassengerId,
  Survived = pred1
)

#結果のCSV書き出し
write.csv(my_solution1, file='./model1_solution.csv', row.names = F)

出力したcsvファイルをkaggleでアップロードして判定してもらったところ・・・

f:id:prism0081:20180919130530j:plain

まあまあの数字。
まったく同じ事をPythonでやってみたけど、Rの方がスコアは高かった。ライブラリの違いでしょうかね。

少しモデルの説明変数を増やしてみる。

#モデルの説明変数を増やしてみる
model2 <- rpart(Survived ~ Pclass + Age + Sex + Fare + SibSp + Parch + Embarked, data=train)

#図示してみる
rpart.plot(model2, extra = 1)

f:id:prism0081:20180919132255j:plain

大きくは変わってない印象・・・モデル1と同様に予測してみる。

#モデルで予測する
pred2 <- predict(model2, test)

#0,1で表現する
pred2 <- round(pred2)

# 予測結果とPassengerIdをデータフレームへ
my_solution2 = data.frame(
  PassengerId = test$PassengerId,
  Survived = pred2
)

#結果のCSV書き出し
write.csv(my_solution2, file='./model2_solution.csv', row.names = F)

kaggleにアップロードして評価してもらうと

f:id:prism0081:20180919132239j:plain

少しだけ改善しましたが、もっとチューニングが必要です。