Döntési fa R - -ben Osztályozási fa & Kód R-ben a példával

Tartalomjegyzék:

Anonim

Mik a döntési fák?

A döntési fák sokoldalú gépi tanulási algoritmusok, amelyek képesek mind osztályozási, mind regressziós feladatokat végrehajtani. Nagyon hatékony algoritmusok, amelyek képesek összetett adatkészletek illesztésére. Ezenkívül a döntési fák a véletlenszerű erdők alapvető elemei, amelyek a ma elérhető leghatékonyabb Machine Learning algoritmusok közé tartoznak.

Képzés és a döntési fák megjelenítése

Az első döntési fa felépítéséhez az R példában a következőképpen járunk el:

  • 1. lépés: Importálja az adatokat
  • 2. lépés: Tisztítsa meg az adatkészletet
  • 3. lépés: Hozzon létre vonatot / tesztkészletet
  • 4. lépés: Készítse el a modellt
  • 5. lépés: Tippeljen
  • 6. lépés: A teljesítmény mérése
  • 7. lépés: Hangolja be a hiperparamétereket

1. lépés: Importálja az adatokat

Ha kíváncsi a titanic sorsára, megnézheti ezt a videót a Youtube-on. Ennek az adatkészletnek az a célja, hogy megjósolja, melyik ember él nagyobb valószínűséggel a jégheggyel ütközés után. Az adatkészlet 13 változót és 1309 megfigyelést tartalmaz. Az adatsort az X változó rendezi.

set.seed(678)path <- 'https://raw.githubusercontent.com/guru99-edu/R-Programming/master/titanic_data.csv'titanic <-read.csv(path)head(titanic)

Kimenet:

## X pclass survived name sex## 1 1 1 1 Allen, Miss. Elisabeth Walton female## 2 2 1 1 Allison, Master. Hudson Trevor male## 3 3 1 0 Allison, Miss. Helen Loraine female## 4 4 1 0 Allison, Mr. Hudson Joshua Creighton male## 5 5 1 0 Allison, Mrs. Hudson J C (Bessie Waldo Daniels) female## 6 6 1 1 Anderson, Mr. Harry male## age sibsp parch ticket fare cabin embarked## 1 29.0000 0 0 24160 211.3375 B5 S## 2 0.9167 1 2 113781 151.5500 C22 C26 S## 3 2.0000 1 2 113781 151.5500 C22 C26 S## 4 30.0000 1 2 113781 151.5500 C22 C26 S## 5 25.0000 1 2 113781 151.5500 C22 C26 S## 6 48.0000 0 0 19952 26.5500 E12 S## home.dest## 1 St Louis, MO## 2 Montreal, PQ / Chesterville, ON## 3 Montreal, PQ / Chesterville, ON## 4 Montreal, PQ / Chesterville, ON## 5 Montreal, PQ / Chesterville, ON## 6 New York, NY
tail(titanic)

Kimenet:

## X pclass survived name sex age sibsp## 1304 1304 3 0 Yousseff, Mr. Gerious male NA 0## 1305 1305 3 0 Zabour, Miss. Hileni female 14.5 1## 1306 1306 3 0 Zabour, Miss. Thamine female NA 1## 1307 1307 3 0 Zakarian, Mr. Mapriededer male 26.5 0## 1308 1308 3 0 Zakarian, Mr. Ortin male 27.0 0## 1309 1309 3 0 Zimmerman, Mr. Leo male 29.0 0## parch ticket fare cabin embarked home.dest## 1304 0 2627 14.4583 C## 1305 0 2665 14.4542 C## 1306 0 2665 14.4542 C## 1307 0 2656 7.2250 C## 1308 0 2670 7.2250 C## 1309 0 315082 7.8750 S

A fej és a farok kimenetéből észreveheti, hogy az adatok nem keveredtek. Ez nagy kérdés! Amikor felosztja adatait a vonatkészlet és a tesztkészlet között, csak az utast választja ki az 1. és 2. osztályból (a 3. osztályból egyetlen utas sem szerepel a megfigyelések 80 százalékában), ami azt jelenti, hogy az algoritmus soha nem fogja látni A 3. osztályú utas jellemzői. Ez a hiba rossz előrejelzéshez vezet.

A probléma kiküszöbölésére használhatja a () függvénymintát.

shuffle_index <- sample(1:nrow(titanic))head(shuffle_index)

Döntési fa R kód Magyarázat

  • minta (1: nrow (titanic)): Készítsen véletlenszerű indexlistát 1 és 1309 között (azaz a sorok maximális száma).

Kimenet:

## [1] 288 874 1078 633 887 992 

Ezt az indexet fogja használni a titán adatkészlet keveréséhez.

titanic <- titanic[shuffle_index, ]head(titanic)

Kimenet:

## X pclass survived## 288 288 1 0## 874 874 3 0## 1078 1078 3 1## 633 633 3 0## 887 887 3 1## 992 992 3 1## name sex age## 288 Sutton, Mr. Frederick male 61## 874 Humblen, Mr. Adolf Mathias Nicolai Olsen male 42## 1078 O'Driscoll, Miss. Bridget female NA## 633 Andersson, Mrs. Anders Johan (Alfrida Konstantia Brogren) female 39## 887 Jermyn, Miss. Annie female NA## 992 Mamee, Mr. Hanna male NA## sibsp parch ticket fare cabin embarked home.dest## 288 0 0 36963 32.3208 D50 S Haddenfield, NJ## 874 0 0 348121 7.6500 F G63 S## 1078 0 0 14311 7.7500 Q## 633 1 5 347082 31.2750 S Sweden Winnipeg, MN## 887 0 0 14313 7.7500 Q## 992 0 0 2677 7.2292 C

2. lépés: Tisztítsa meg az adatkészletet

Az adatok felépítése azt mutatja, hogy néhány változó NA-val rendelkezik. Az adatok tisztítását a következőképpen kell elvégezni

  • Drop változók home.dest, kabin, név, X és jegy
  • Hozzon létre faktorváltozókat a pclass számára, és túlélte
  • Dobd el az NA-t
library(dplyr)# Drop variablesclean_titanic <- titanic % > %select(-c(home.dest, cabin, name, X, ticket)) % > %#Convert to factor levelmutate(pclass = factor(pclass, levels = c(1, 2, 3), labels = c('Upper', 'Middle', 'Lower')),survived = factor(survived, levels = c(0, 1), labels = c('No', 'Yes'))) % > %na.omit()glimpse(clean_titanic)

Kód Magyarázat

  • select (-c (home.dest, kabin, név, X, jegy)): Dobd felesleges változókat
  • pclass = faktor (pclass, szint = c (1,2,3), feliratok = c ('Felső', 'Közép', 'Alsó')): Adjon címkét a pclass változóhoz. 1-ből Felső, 2-ből MIddle és 3-ból alacsonyabb lesz
  • faktor (túlélte, szintek = c (0,1), címkék = c ('Nem', 'Igen')): Adjon hozzá címkét a túlélő változóhoz. 1 Nem lesz, 2 pedig Igen lesz
  • na.omit (): Távolítsa el az NA megfigyeléseit

Kimenet:

## Observations: 1,045## Variables: 8## $ pclass  Upper, Lower, Lower, Upper, Middle, Upper, Middle, U… ## $ survived  No, No, No, Yes, No, Yes, Yes, No, No, No, No, No, Y… ## $ sex  male, male, female, female, male, male, female, male… ## $ age  61.0, 42.0, 39.0, 49.0, 29.0, 37.0, 20.0, 54.0, 2.0,… ## $ sibsp  0, 0, 1, 0, 0, 1, 0, 0, 4, 0, 0, 1, 1, 0, 0, 0, 1, 1,… ## $ parch  0, 0, 5, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 2, 0, 4, 0,… ## $ fare  32.3208, 7.6500, 31.2750, 25.9292, 10.5000, 52.5542,… ## $ embarked  S, S, S, S, S, S, S, S, S, C, S, S, S, Q, C, S, S, C… 

3. lépés: Hozzon létre vonatot / tesztkészletet

A modell edzése előtt két lépést kell végrehajtania:

  • Vonat és tesztkészlet létrehozása: A modellt a vonatkészleten képzi ki, és a tesztkészleten teszteli a jóslatot (azaz nem látott adatokat).
  • Telepítse az rpart.plot alkalmazást a konzolról

Az általános gyakorlat az adatok 80/20-os felosztása, az adatok 80 százaléka a modell kiképzésére szolgál, 20 százalékuk pedig előrejelzésekre. Két külön adatkeretet kell létrehoznia. Addig nem szabad hozzányúlni a tesztkészlethez, amíg be nem fejezi a modell felépítését. Létrehozhat egy create_train_test () függvénynevet, amely három argumentumot tartalmaz.

create_train_test(df, size = 0.8, train = TRUE)arguments:-df: Dataset used to train the model.-size: Size of the split. By default, 0.8. Numerical value-train: If set to `TRUE`, the function creates the train set, otherwise the test set. Default value sets to `TRUE`. Boolean value.You need to add a Boolean parameter because R does not allow to return two data frames simultaneously.
create_train_test <- function(data, size = 0.8, train = TRUE) {n_row = nrow(data)total_row = size * n_rowtrain_sample < - 1: total_rowif (train == TRUE) {return (data[train_sample, ])} else {return (data[-train_sample, ])}}

Kód Magyarázat

  • függvény (adatok, méret = 0,8, vonat = IGAZ): Adja hozzá az argumentumokat a függvénybe
  • n_row = nrow (adatok): Számolja meg az adatkészlet sorainak számát
  • total_row = méret * n_row: Visszatérve az n-edik sorra a vonatkészlet összeállításához
  • vonat_minta <- 1: összesen_sor: Válassza ki az első sort az n-edik sorig
  • if (vonat == IGAZ) {} else {}: Ha a feltétel igazra állítja, adja vissza a vonatkészletet, különben a tesztkészletet.

Kipróbálhatja a funkcióját és ellenőrizheti a méretet.

data_train <- create_train_test(clean_titanic, 0.8, train = TRUE)data_test <- create_train_test(clean_titanic, 0.8, train = FALSE)dim(data_train)

Kimenet:

## [1] 836 8
dim(data_test)

Kimenet:

## [1] 209 8 

A vonat adatkészlet 1046, míg a teszt adatkészlet 262 sorral rendelkezik.

A prop.table () függvényt a () táblával kombinálva ellenőrizheti, hogy a véletlenszerűsítés folyamata helyes-e.

prop.table(table(data_train$survived))

Kimenet:

#### No Yes## 0.5944976 0.4055024
prop.table(table(data_test$survived))

Kimenet:

#### No Yes## 0.5789474 0.4210526

Mindkét adatkészletben a túlélők száma azonos, körülbelül 40 százalék.

Telepítse az rpart.plot fájlt

Az rpart.plot nem érhető el a conda könyvtárakból. Telepítheti a konzolról:

install.packages("rpart.plot") 

4. lépés: Készítse el a modellt

Készen áll a modell elkészítésére. Az Rpart döntési fa függvényének szintaxisa:

rpart(formula, data=, method='')arguments:- formula: The function to predict- data: Specifies the data frame- method:- "class" for a classification tree- "anova" for a regression tree

Az osztály módszert használja, mert megjósolja az osztályt.

library(rpart)library(rpart.plot)fit <- rpart(survived~., data = data_train, method = 'class')rpart.plot(fit, extra = 106

Kód Magyarázat

  • rpart (): Funkció, amely illeszkedik a modellhez. Az érvek a következők:
    • túlélte ~ .: A döntési fák képlete
    • data = data_train: Adatkészlet
    • method = 'class': Illesszen be egy bináris modellt
  • rpart.plot (fit, extra = 106): Ábrázolja a fát. Az extra funkciók 101-re vannak állítva a 2. osztály valószínűségének megjelenítésére (hasznos bináris válaszok esetén). A többi lehetőségről további információt a matricán talál.

Kimenet:

A gyökércsomópontból indul (0 mélység 3 felett, a grafikon teteje):

  1. A csúcson ez a túlélés teljes valószínűsége. Megmutatja az utasok arányát, akik túlélték a balesetet. Az utasok 41 százaléka életben maradt.
  2. Ez a csomópont azt kérdezi, hogy az utas neme-e férfi. Ha igen, akkor menjen le a gyökér bal gyermekcsomópontjához (2. mélység). 63 százaléka olyan férfi, amelynek túlélési valószínűsége 21 százalék.
  3. A második csomópontban azt kérdezi, hogy a férfi utas 3,5 évesnél idősebb-e. Ha igen, akkor a túlélés esélye 19 százalék.
  4. Így folytatja, hogy megértse, milyen tulajdonságok befolyásolják a túlélés valószínűségét.

Ne feledje, hogy a döntési fák sok tulajdonságának egyike az, hogy nagyon kevés adat-előkészítést igényelnek. Különösen nem igényelnek funkciók méretezését vagy központosítását.

Alapértelmezés szerint az rpart () függvény a Gini szennyeződés mértékét használja a hang felosztásához. Minél magasabb a Gini-együttható, annál több különböző példány van a csomóponton belül.

5. lépés: Tegyen jóslatot

Megjósolhatja a tesztadatkészletet. Előrejelzéshez használhatja a prediktív () függvényt. Az R döntési fa előrejelzésének alapvető szintaxisa:

predict(fitted_model, df, type = 'class')arguments:- fitted_model: This is the object stored after model estimation.- df: Data frame used to make the prediction- type: Type of prediction- 'class': for classification- 'prob': to compute the probability of each class- 'vector': Predict the mean response at the node level

Meg akarja jósolni, hogy melyik utasok élnek túl nagy valószínűséggel az ütközés után a tesztkészletből. Ez azt jelenti, hogy a 209 utas közül tudni fogja, melyik marad életben vagy sem.

predict_unseen <-predict(fit, data_test, type = 'class')

Kód Magyarázat

  • megjósolni (fit, data_test, type = 'class'): Megjósolja a tesztkészlet osztályát (0/1)

Az utast, aki nem jutott el, és azokat, akiknek sikerült.

table_mat <- table(data_test$survived, predict_unseen)table_mat

Kód Magyarázat

  • táblázat (data_test $ survived, pred__seen láthatatlan): Hozzon létre egy táblázatot, amely megszámolja, hogy hány utas van túlélőnek és elhunytnak, összehasonlítva a helyes döntési fa besorolásával R

Kimenet:

## predict_unseen## No Yes## No 106 15## Yes 30 58

A modell 106 halott utast jósolt meg, de 15 túlélőt halottnak minősített. Analógia alapján a modell 30 utast minősített túlélőként, miközben kiderült, hogy meghaltak.

6. lépés) Mérje meg a teljesítményt

Kiszámíthatja az osztályozási feladat pontossági mértékét a zavaros mátrix segítségével :

A zavartsági mátrix jobb választás az osztályozási teljesítmény értékelésére. Az általános elképzelés az, hogy megszámoljuk az igaz példányok hamis besorolását.

A zavaros mátrix minden sora egy tényleges célt képvisel, míg minden oszlop egy előre jelzett célt. Ennek a mátrixnak az első sora a halott utasokat veszi figyelembe (a hamis osztály): 106-ot helyesen minősítettek halottnak ( True negatív ), míg a maradékot tévesen túlélőnek ( hamis pozitív ). A második sor a túlélőket veszi figyelembe, a pozitív osztály 58 volt ( igaz pozitív ), míg az igaz negatív 30 volt.

A pontossági tesztet a zavaros mátrixból számíthatja ki :

Ez az igaz pozitív és igaz negatív aránya a mátrix összegéhez viszonyítva. Az R-vel az alábbiak szerint kódolhat:

accuracy_Test <- sum(diag(table_mat)) / sum(table_mat)

Kód Magyarázat

  • sum (diag (table_mat)): Az átló összege
  • sum (table_mat): A mátrix összege.

A tesztkészlet pontosságát kinyomtathatja:

print(paste('Accuracy for test', accuracy_Test))

Kimenet:

## [1] "Accuracy for test 0.784688995215311" 

A tesztkészlet 78 százalékos pontszámmal rendelkezik. Megismételheti ugyanazt a gyakorlatot az edzésadatkészlettel.

7. lépés: Hangolja be a hiperparamétereket

Az R döntési fának különböző paraméterei vannak, amelyek szabályozzák az illesztés szempontjait. Az rpart döntésfa könyvtárban a paramétereket az rpart.control () függvény segítségével vezérelheti. A következő kódban bevezeti a hangolni kívánt paramétereket. A matricán más paraméterek találhatók.

rpart.control(minsplit = 20, minbucket = round(minsplit/3), maxdepth = 30)Arguments:-minsplit: Set the minimum number of observations in the node before the algorithm perform a split-minbucket: Set the minimum number of observations in the final note i.e. the leaf-maxdepth: Set the maximum depth of any node of the final tree. The root node is treated a depth 0

A következőképpen fogunk eljárni:

  • A függvény szerkesztése a pontosság visszaadásához
  • Hangolja be a maximális mélységet
  • Állítsa be a csomópont minimális számú mintáját, mielőtt fel tudna osztódni
  • Állítsa be a levél csomópontjának minimális számú mintáját

Írhat egy függvényt a pontosság megjelenítéséhez. Egyszerűen be kell csomagolnia a korábban használt kódot:

  1. megjósolni: megjósolni_láthatatlan <- megjósolni (illeszkedés, adat_teszt, type = 'osztály')
  2. Táblázat előállítása: table_mat <- tábla ($_test $ túlélte, pred__láthatatlan)
  3. Számítási pontosság: pontossági_teszt <- összeg (átló (tábla_mat)) / összeg (tábla_mat)
accuracy_tune <- function(fit) {predict_unseen <- predict(fit, data_test, type = 'class')table_mat <- table(data_test$survived, predict_unseen)accuracy_Test <- sum(diag(table_mat)) / sum(table_mat)accuracy_Test}

Megpróbálhatja hangolni a paramétereket, és megnézheti, hogy javíthatja-e a modellt az alapértelmezett érték fölött. Emlékeztetőül: 0,78-nál nagyobb pontosságot kell elérni

control <- rpart.control(minsplit = 4,minbucket = round(5 / 3),maxdepth = 3,cp = 0)tune_fit <- rpart(survived~., data = data_train, method = 'class', control = control)accuracy_tune(tune_fit)

Kimenet:

## [1] 0.7990431 

A következő paraméterrel:

minsplit = 4minbucket= round(5/3)maxdepth = 3cp=0 

Nagyobb teljesítményt ér el, mint az előző modell. Gratulálok!

Összegzés

Összefoglalhatjuk a döntési fa algoritmus R-ben történő kiképzésének függvényeit

Könyvtár

Célkitűzés

funkció

osztály

paraméterek

részletek

rpart

Vonat besorolási fa R-ben

rpart ()

osztály

formula, df, módszer

rpart

Vonat regressziós fa

rpart ()

anova

formula, df, módszer

rpart

Ábrázold a fákat

rpart.plot ()

felszerelt modell

bázis

megjósolni

megjósolni ()

osztály

felszerelt modell, típus

bázis

megjósolni

megjósolni ()

prob

felszerelt modell, típus

bázis

megjósolni

megjósolni ()

vektor

felszerelt modell, típus

rpart

Ellenőrzési paraméterek

rpart.control ()

minsplit

Állítsa be a megfigyelések minimális számát a csomópontban, mielőtt az algoritmus végrehajtaná a felosztást

minvödör

Állítsa be a megfigyelések minimális számát az utolsó jegyzetben, azaz a levélben

maximális mélység

Állítsa be a végső fa bármelyik csomópontjának maximális mélységét. A gyökércsomópont 0 mélységű

rpart

Vonatmodell vezérlő paraméterrel

rpart ()

formula, df, módszer, kontroll

Megjegyzés: Képezze a modellt edzésadatokra, és tesztelje a teljesítményt egy láthatatlan adatkészleten, azaz tesztkészleten.