Cereal data- analisi con gli alberi di classificazione Emanuele Taufer file:///c:/users/emanuele.taufer/google%20drive/2%20corsi/3%20sqg/labs/l8-cereal-tree.html#(1) 1/32
Cereal.dat Per migliorare la commercializzazione dei propri prodotti per la prima colazione, una società interviste 880 persone, registrando la loro età, il sesso, lo stato civile e se hanno un stile di vita attivo (sulla base del fatto che essi praticano sport almeno due volte a settimana). Ogni partecipante assaggia 3 diversi prodotti per la colazione e indica quello che preferisce. file:///c:/users/emanuele.taufer/google%20drive/2%20corsi/3%20sqg/labs/l8-cereal-tree.html#(1) 2/32
Variabili BFAST : (Y) 1. se «Breakfast bar» 2. se «Oatmeal» 3 se «Cereal» AGECAT: «Under 31», «31-45», «46-60», «Over 60» GENDER: 1 se F, 0 se M ; MARITAL: 1 se sposato; 0 se no. ACTIVE: 1 se attivo ; 0 se no file:///c:/users/emanuele.taufer/google%20drive/2%20corsi/3%20sqg/labs/l8-cereal-tree.html#(1) 3/32
Caricare i dati Cereal<-read.table("http://www.cs.unitn.it/~taufer/Data/Cereal.dat",header=T,sep="") head(cereal) ## AGECAT GENDER MARITAL ACTIVE BFAST ## 1 1 0 1 1 3 ## 2 3 0 1 0 1 ## 3 4 0 1 0 2 ## 4 2 1 1 1 2 ## 5 3 0 1 0 2 ## 6 4 0 1 0 3 file:///c:/users/emanuele.taufer/google%20drive/2%20corsi/3%20sqg/labs/l8-cereal-tree.html#(1) 4/32
Codificare le variabili come factors Le variabili del dataset sono codificate numericamente. È necessario (e più facile per l interpretazione dei risultati) codificarle come factors e introdurre etichette per i livelli Cereal$BFAST=factor(Cereal$BFAST,levels=c(1,2,3), labels=c("bf-bar","oatmeal","cereal")) Cereal$AGECAT=factor(Cereal$AGECAT,levels=c(1,2,3,4), labels=c("under 31","[31-45]","[46-60]","Over 60")) Cereal$GENDER=factor(Cereal$GENDER,levels=c(0,1), labels=c("m","f")) Cereal$MARITAL=factor(Cereal$MARITAL,levels=c(0,1), labels=c("non sposato","sposato")) Cereal$ACTIVE=factor(Cereal$ACTIVE,levels=c(0,1), labels=c("no","yes")) str(cereal) ## 'data.frame': 880 obs. of 5 variables: ## $ AGECAT : Factor w/ 4 levels "Under 31","[31-45]",..: 1 3 4 2 3 4 2 4 2 2... ## $ GENDER : Factor w/ 2 levels "M","F": 1 1 1 2 1 1 2 2 2 2... ## $ MARITAL: Factor w/ 2 levels "Non sposato",..: 2 2 2 2 2 2 2 1 2 2... ## $ ACTIVE : Factor w/ 2 levels "No","Yes": 2 1 1 2 1 1 1 1 2 1... ## $ BFAST : Factor w/ 3 levels "BF-bar","Oatmeal",..: 3 1 2 2 2 3 1 2 2 1... head(cereal) ## AGECAT GENDER MARITAL ACTIVE BFAST ## 1 Under 31 M Sposato Yes Cereal ## 2 [46-60] M Sposato No BF-bar ## 3 Over 60 M Sposato No Oatmeal ## 4 [31-45] F Sposato Yes Oatmeal ## 5 [46-60] M Sposato No Oatmeal ## 6 Over 60 M Sposato No Cereal file:///c:/users/emanuele.taufer/google%20drive/2%20corsi/3%20sqg/labs/l8-cereal-tree.html#(1) 5/32
Validation set approach Dividiamo le unità in due parti: training set e test set. Stimiamo l albero di classificazione utilizzando il training set, e valutiamo la sua performance usando il test set. Ci sono 880 unità nel data set, prendiamone 280 per il test set Il seguente codice seleziona, in modo casuale, 600 unità (la loro posizione) da tutto il set di dati set.seed(1) train=sample(nrow(cereal), 600) train[1:10] ## [1] 234 328 503 797 177 787 826 577 549 54 Con gli indici ottenuti sopra possiamo suddividere i dati in due parti: Cereal.train=Cereal[train,] # take only units that correspond to train Cereal.test=Cereal[-train,] # take only units that DO NOT correspond to train file:///c:/users/emanuele.taufer/google%20drive/2%20corsi/3%20sqg/labs/l8-cereal-tree.html#(1) 6/32
Costruzione dell albero di classificazione Per adattare un albero di classificazione (o regressione) possiamo usare la funzione tree() dalla libreria tree. L input minimo è molto semplice poiché è sufficiente indicare l equazione di regressione (o classificazione) ed i dati. (analogo a quanto già imparato per la funzione lm()) library(tree) Cereal.tree<-tree(BFAST~.,data=Cereal.train) Nota: BFAST~. indica che BFAST è la variabile dipendente e il punto che tutte le altre variabili nel set di dati devono essere utilizzate come variabili indipendenti file:///c:/users/emanuele.taufer/google%20drive/2%20corsi/3%20sqg/labs/l8-cereal-tree.html#(1) 7/32
Summary L oggetto Cereal.tree contiene i risultati della stima summary(cereal.tree) ## ## Classification tree: ## tree(formula = BFAST ~., data = Cereal.train) ## Variables actually used in tree construction: ## [1] "AGECAT" ## Number of terminal nodes: 3 ## Residual mean deviance: 1.802 = 1076 / 597 ## Misclassification error rate: 0.4617 = 277 / 600 Per gli alberi di classificazione, la deviance riportata nell output di summary() è data da 2 m k n mk dove n mk è il numero di unità del nodo terminale m che appartiene alla classe k. log p^mk Una piccola deviance indica un albero che fornisce un buon adattamento ai dati (training). La devianza residua è la devianza divisa per 600 3 = 597. n T 0, in questo caso file:///c:/users/emanuele.taufer/google%20drive/2%20corsi/3%20sqg/labs/l8-cereal-tree.html#(1) 8/32
Plot dell albero plot(cereal.tree,lwd=2) text(cereal.tree,pretty=0,cex=1.5,col="blue") file:///c:/users/emanuele.taufer/google%20drive/2%20corsi/3%20sqg/labs/l8-cereal-tree.html#(1) 9/32
Descrizione dell albero Se il grafico non è leggibile, si può utilizzare la sua descrizione Cereal.tree ## node), split, n, deviance, yval, (yprob) ## * denotes terminal node ## ## 1) root 600 1306.0 Cereal ( 0.27000 0.35000 0.38000 ) ## 2) AGECAT: Under 31,[31-45] 266 486.6 Cereal ( 0.45865 0.07895 0.46241 ) * ## 3) AGECAT: [46-60],Over 60 334 628.0 Oatmeal ( 0.11976 0.56587 0.31437 ) ## 6) AGECAT: [46-60] 150 307.5 Cereal ( 0.16667 0.38000 0.45333 ) * ## 7) AGECAT: Over 60 184 281.6 Oatmeal ( 0.08152 0.71739 0.20109 ) * Ad esempio, il nodo 2) è terminale, contiene 266 unità e classifica l unità come Cereal (prob=0.4624). La probabilità di BF-Bar è 0.4586. L albero utilizza solo l età come variabile per la classificazione. file:///c:/users/emanuele.taufer/google%20drive/2%20corsi/3%20sqg/labs/l8-cereal-tree.html#(1) 10/32
Sima del test error rate con i dati Cereal.test La funzione predict() può essere utilizzata per questo scopo. Nel caso di un albero di classificazione, l argomento type="class" dice ad R di fornire la classe di Y dell unità. Il codice sotto, utilizzando l albero stimato, fornisce le previsioni utilizzando il test set. Cereal.tree.pred=predict(Cereal.tree, Cereal.test, type="class") Cereal.tree.pred[1:10] ## [1] Cereal Cereal Cereal Oatmeal Cereal Cereal Oatmeal Cereal ## [9] Cereal Cereal ## Levels: BF-bar Oatmeal Cereal Per calcolare il tasso di errore di test abbiamo bisogno di confrontare le previsioni con la classe osservata nel test set. file:///c:/users/emanuele.taufer/google%20drive/2%20corsi/3%20sqg/labs/l8-cereal-tree.html#(1) 11/32
Per questo, si definisca la variabile BFAST.test che contiene solo le classi osservate in Cereal.test e si costruisca una tavola di classificazione. BFAST.test=Cereal$BFAST[-train] table(cereal.tree.pred,bfast.test) ## BFAST.test ## Cereal.tree.pred BF-bar Oatmeal Cereal ## BF-bar 0 0 0 ## Oatmeal 3 53 22 ## Cereal 66 47 89 Dalla tavola possiamo stimare il test error rate (3+22+66+47)/nrow(Cereal.test) ## [1] 0.4928571 file:///c:/users/emanuele.taufer/google%20drive/2%20corsi/3%20sqg/labs/l8-cereal-tree.html#(1) 12/32
Costruire un albero più grande Si definisca maggior sensibilità utilizzando la funzione tree.control() setup1=tree.control(nrow(cereal), mincut = 5, minsize = 10, mindev = 0.001) E quindi si usi tree()con l opzione control=setup1 Cereal.tree=tree(BFAST~.,data=Cereal.train,control=setup1) summary(cereal.tree) ## ## Classification tree: ## tree(formula = BFAST ~., data = Cereal.train, control = setup1) ## Number of terminal nodes: 19 ## Residual mean deviance: 1.718 = 998.1 / 581 ## Misclassification error rate: 0.4117 = 247 / 600 file:///c:/users/emanuele.taufer/google%20drive/2%20corsi/3%20sqg/labs/l8-cereal-tree.html#(1) 13/32
Plot plot(cereal.tree,type="uniform",lwd=2) text(cereal.tree,pretty=0,cex=1.2,col="blue") file:///c:/users/emanuele.taufer/google%20drive/2%20corsi/3%20sqg/labs/l8-cereal-tree.html#(1) 14/32
Tavola di classificazione (osservazioni test) Cereal.tree.pred=predict(Cereal.tree, Cereal.test, type="class") table(cereal.tree.pred,bfast.test) ## BFAST.test ## Cereal.tree.pred BF-bar Oatmeal Cereal ## BF-bar 23 4 17 ## Oatmeal 3 53 22 ## Cereal 43 43 72 file:///c:/users/emanuele.taufer/google%20drive/2%20corsi/3%20sqg/labs/l8-cereal-tree.html#(1) 15/32
Test error rate stimato Albero con 19 nodi (4+17+3+22+43+43)/nrow(Cereal.test) ## [1] 0.4714286 Compariamolo con il valore precedente (Albero a 3 nodi) (3+22+66+47)/nrow(Cereal.test) ## [1] 0.4928571 file:///c:/users/emanuele.taufer/google%20drive/2%20corsi/3%20sqg/labs/l8-cereal-tree.html#(1) 16/32
Potare l albero La funzione cv.tree() effettua una cross-validazione per ottenere il livello ottimale di complessità dell albero. Per adattare alberi di classificazione si usi l argomento FUN=prune.misclass (nel qual caso il tasso di errore sarà il criterio guida) L output di cv.tree() riporterà: il numero di nodi terminali di ogni albero considerato (size) il corrispondente tasso di errore (dev) altri parametri (non discussi a lezione). file:///c:/users/emanuele.taufer/google%20drive/2%20corsi/3%20sqg/labs/l8-cereal-tree.html#(1) 17/32
set.seed (3) cv.cereal =cv.tree(cereal.tree,fun=prune.misclass ) cv.cereal ## $size ## [1] 19 10 8 7 6 3 2 1 ## ## $dev ## [1] 311 311 311 311 311 318 323 380 ## ## $k ## [1] -Inf 0.000000 1.000000 2.000000 4.000000 7.333333 11.000000 ## [8] 84.000000 ## ## $method ## [1] "misclass" ## ## attr(,"class") ## [1] "prune" "tree.sequence" dev in questo caso corrisponde alla stima di cross-validazione dell errore test. file:///c:/users/emanuele.taufer/google%20drive/2%20corsi/3%20sqg/labs/l8-cereal-tree.html#(1) 18/32
Plot Riportiamo in un grafico i risultati per size e dev plot(cv.cereal$size,cv.cereal$dev,type="b", lwd=3,col="blue", xlab="terminal nodes", ylab="rss",main="cost complexity pruning" ) 6 nodi sembrano ottimali file:///c:/users/emanuele.taufer/google%20drive/2%20corsi/3%20sqg/labs/l8-cereal-tree.html#(1) 19/32
L albero potato A questo punto si utilizzi la funzione prune.misclass() per potare l albero iniziale al numero di nodi scelti in base ai risultati ottenuti dalla cross-validazione. prune.cereal=prune.misclass(cereal.tree, best =6) plot(prune.cereal,lwd=2) text(prune.cereal,pretty =0,cex=1.3,col="blue") file:///c:/users/emanuele.taufer/google%20drive/2%20corsi/3%20sqg/labs/l8-cereal-tree.html#(1) 20/32
plot(prune.cereal,lwd=2,type="uniform") text(prune.cereal,pretty =0,cex=1.3,col="blue") file:///c:/users/emanuele.taufer/google%20drive/2%20corsi/3%20sqg/labs/l8-cereal-tree.html#(1) 21/32
Stima del test error tree.pred=predict(prune.cereal, Cereal.test, type="class") table(tree.pred, BFAST.test) ## BFAST.test ## tree.pred BF-bar Oatmeal Cereal ## BF-bar 29 3 23 ## Oatmeal 3 53 22 ## Cereal 37 44 66 (3+23+3+22+37+44)/nrow(Cereal.test) ## [1] 0.4714286 file:///c:/users/emanuele.taufer/google%20drive/2%20corsi/3%20sqg/labs/l8-cereal-tree.html#(1) 22/32
Bagging Il Bagging è un caso specifico di una Foresta Casuale con m = p. Quindi la funzione randomforest() dalla libreria randomforest può essere usata in entrambi i casi. library (randomforest) set.seed (1) bag.cereal=randomforest(bfast~.,data=cereal,subset =train, mtry=4, importance =TRUE, ntree=1000) bag.cereal ## ## Call: ## randomforest(formula = BFAST ~., data = Cereal, mtry = 4, importance = TRUE, n ## Type of random forest: classification ## Number of trees: 1000 ## No. of variables tried at each split: 4 ## ## OOB estimate of error rate: 45.67% ## Confusion matrix: ## BF-bar Oatmeal Cereal class.error ## BF-bar 61 20 81 0.6234568 ## Oatmeal 7 141 62 0.3285714 ## Cereal 37 67 124 0.4561404 L argomento mtry = 4 indica che tutti e 4 i predittori devono essere considerati ad ogni split dell albero; in altre parole, che stiamo facendo bagging. file:///c:/users/emanuele.taufer/google%20drive/2%20corsi/3%20sqg/labs/l8-cereal-tree.html#(1) 23/32
Stima del test error tree.pred=predict(bag.cereal, Cereal.test, type="class") table(tree.pred, BFAST.test) ## BFAST.test ## tree.pred BF-bar Oatmeal Cereal ## BF-bar 26 4 19 ## Oatmeal 7 70 37 ## Cereal 36 26 55 (4+9+7+37+36+26)/nrow(Cereal.test) ## [1] 0.425 La stima dell errore test è migliore rispetto all albero potato file:///c:/users/emanuele.taufer/google%20drive/2%20corsi/3%20sqg/labs/l8-cereal-tree.html#(1) 24/32
Foresta casuale La crescita di una foresta casuale procede esattamente come nel bagging, tranne che usiamo un valore inferiore dell argomento mtry. Per impostazione predefinita, random.forest() utilizza p/3 variabili per la costruzione di una foresta casuale di alberi di regressione, e p variabili quando si costruisce una foresta casuale di alberi di classificazione. Qua impostiamo mtry = 2. set.seed (1) rf.cereal =randomforest(bfast~.,data=cereal, subset =train, mtry=2, importance =TRUE) rf.cereal ## ## Call: ## randomforest(formula = BFAST ~., data = Cereal, mtry = 2, importance = TRUE, s ## Type of random forest: classification ## Number of trees: 500 ## No. of variables tried at each split: 2 ## ## OOB estimate of error rate: 45.5% ## Confusion matrix: ## BF-bar Oatmeal Cereal class.error ## BF-bar 63 20 79 0.6111111 ## Oatmeal 7 132 71 0.3714286 ## Cereal 37 59 132 0.4210526 file:///c:/users/emanuele.taufer/google%20drive/2%20corsi/3%20sqg/labs/l8-cereal-tree.html#(1) 25/32
Stima del test error tree.pred=predict(rf.cereal, Cereal.test, type="class") table(tree.pred, BFAST.test) ## BFAST.test ## tree.pred BF-bar Oatmeal Cereal ## BF-bar 25 5 17 ## Oatmeal 5 58 30 ## Cereal 39 37 64 (5+17+5+30+39+37)/nrow(Cereal.test) ## [1] 0.475 peggiore rispetto al bagging file:///c:/users/emanuele.taufer/google%20drive/2%20corsi/3%20sqg/labs/l8-cereal-tree.html#(1) 26/32
Importance statistics importance(rf.cereal) ## BF-bar Oatmeal Cereal MeanDecreaseAccuracy ## AGECAT 32.5288675 75.0068568 16.613496 73.445403 ## GENDER 0.3980847 0.5369111-5.132621-2.791728 ## MARITAL 13.0030759 1.6831823-2.940249 6.299546 ## ACTIVE 26.1612175 1.8601322-7.631455 13.980660 ## MeanDecreaseGini ## AGECAT 65.787271 ## GENDER 4.835760 ## MARITAL 6.640673 ## ACTIVE 11.613277 Ci sono due misure di importanza delle variabili. La prima si basa sulla diminuzione media della precisione nelle previsioni sui campioni out-of-bag quando una data variabile viene esclusa dal modello. La seconda è una misura della diminuzione totale dell impurità del nodo che deriva dallo split su una certa variabile. Nel caso di alberi di regressione, l impurità del nodo viene misurata attraverso il training RSS. Per gli alberi di classificazione attraverso l indice di Gini. file:///c:/users/emanuele.taufer/google%20drive/2%20corsi/3%20sqg/labs/l8-cereal-tree.html#(1) 27/32
Importance plot I plot delle misure di importanza possono essere prodotti con la funzione varimpplot(). varimpplot(rf.cereal,pch=19,ce=1.5,col="blue",lwd=2) file:///c:/users/emanuele.taufer/google%20drive/2%20corsi/3%20sqg/labs/l8-cereal-tree.html#(1) 28/32
Ottenere le probabilità di previsione con le foreste casuali Per fare previsioni dato un insieme di predittori, possiamo usare la funzione predict(). L opzione type="prob" fornisce le probabilità di appartenenza a ciascun livello della variabile Y. 1. Selezioniamo 5 intervistati casualmente set.seed(5) R.units<-Cereal[sample(nrow(Cereal),5),] R.units ## AGECAT GENDER MARITAL ACTIVE BFAST ## 177 [46-60] M Sposato Yes Cereal ## 603 Over 60 M Sposato Yes Oatmeal ## 806 [31-45] F Sposato Yes Cereal ## 250 Over 60 M Sposato Yes Oatmeal ## 92 [46-60] F Sposato No Cereal file:///c:/users/emanuele.taufer/google%20drive/2%20corsi/3%20sqg/labs/l8-cereal-tree.html#(1) 29/32
2. Facciamo le previsioni Prob.Pr<-predict(rf.Cereal,R.units,type="prob") Prob.Pr ## BF-bar Oatmeal Cereal ## 177 0.028 0.338 0.634 ## 603 0.010 0.922 0.068 ## 806 0.722 0.002 0.276 ## 250 0.010 0.922 0.068 ## 92 0.000 0.450 0.550 ## attr(,"class") ## [1] "matrix" "votes" file:///c:/users/emanuele.taufer/google%20drive/2%20corsi/3%20sqg/labs/l8-cereal-tree.html#(1) 30/32
3. Mettiamo Prob.Pr e R.units in un data.frame e confrontiamo df<-data.frame(r.units,prob.pr) df AGECAT GENDER MARITAL ACTIVE BFAST BF.bar Oatmeal Cereal 177 [46-60] M Sposato Yes Cereal 0.028 0.338 0.634 603 Over 60 M Sposato Yes Oatmeal 0.010 0.922 0.068 806 [31-45] F Sposato Yes Cereal 0.722 0.002 0.276 250 Over 60 M Sposato Yes Oatmeal 0.010 0.922 0.068 92 [46-60] F Sposato No Cereal 0.000 0.450 0.550 file:///c:/users/emanuele.taufer/google%20drive/2%20corsi/3%20sqg/labs/l8-cereal-tree.html#(1) 31/32
Considerazioni conclusive Tutti i modelli testati hanno un test error rate stimato piuttosto elevato. L albero ottenuto col Bagging è quello con test error stimato minore (42.5%) Come già notato, tuttavia, l accento per questo probema è su un analisi di tipo inferenziale. Si noti che l albero potato suggerisce, in modo molto più chiaro rispetto all analisi con LDA e GLM, dei profili di consumo cui prestare attenzione Il bagging migliora la precisione ma si perde in capacità interpretativa. Tuttavia l utilizzo delle statistiche di importanza e l analisi in dettaglio dei profili di consumo ci permettono di interpretare correttamente il fenomeno. Vale la pena completare l analisi calcolando le probabilità di scelta per alcuni profili di interesse per capire più a fondo la precisione del modello per sottogruppi di consumatori file:///c:/users/emanuele.taufer/google%20drive/2%20corsi/3%20sqg/labs/l8-cereal-tree.html#(1) 32/32