Categorical Variables in Trees I








I find it remarkable that very few of the current implementations of tree algorithms in R exploit an important “trick” which was mentioned in the original seminal paper in 1984!

“For a two-class problem, the search for the best categorical split can be reduced to M steps using the Gini criterion”

Breiman, L., Friedman, J., Olshen, R. and Stone, C. (1984). Classification and Regression Trees , Wadsworth International Group

See also http://www.math.ccu.edu.tw/~yshih/papers/spl.pdf

and p. 11 in http://www.cs.ubbcluj.ro/~csatol/mestint/pdfs/Gehrke_Loh_DecTree.pdf

(The effect of dummifying categorical variables on performance is nicely elaborated upon in this post Are categorical variables getting lost in your random forests?)

This implies that instead of having to search \(2^k-1\) combinations of the \(k\) levels it is sufficient to try just k of them !

The practical consequences of this exponential vs. linear scaling are quite grave, especially with modern datasets that often contain lots of categorical variables with many (not rare to see \(k> 1000\)) levels.

Of the popular packages that I tried, only rpart, gbm and RWeka avoid the “exponential trap” and pose (almost) no limit on the number of levels:

#read in the Rossman data set from this kaggle competition:
#  https://www.kaggle.com/c/rossmann-store-sales
#train = read.csv('H:/kaggle/Rossmann/input/train.csv.gz',as.is=T)

#  oh well, fake data serve to illustrate the point just as well:
N=5*10^5
train = cbind.data.frame(Sales= 0, Store=sample(1:1000,N,replace=TRUE))
train$Sales = 0.1*train$Store +rnorm(N)
train$Store = factor(train$Store)

cat ("There are ", length(levels(train$Store)),"stores/levels in this dataset.\n")
## There are  1000 stores/levels in this dataset.

Libraries that fail to exploit the linear search strategy:

The tree package

library(tree, quietly=TRUE)
try({fit = tree(Sales ~ Store, data = train)})

The party package

library(party, quietly=TRUE)
try({fit = ctree(Sales ~ Store, data = train)})
detach("package:party", unload=TRUE)
## Warning: Error in matrix(0, nrow = mi, ncol = nl) : 
##   invalid 'nrow' value (too large or NA)
## In addition: Warning message:
## In matrix(0, nrow = mi, ncol = nl) :
##   NAs introduced by coercion to integer range

The partykit package

library(partykit, quietly=TRUE)
try({fit = lmtree(Sales ~ Store, data = train)})
detach("package:partykit", unload=TRUE)
## Warning: Error in matrix(0, nrow = mi, ncol = nl) : 
##   invalid 'nrow' value (too large or NA)
## In addition: Warning message:
## In matrix(0, nrow = mi, ncol = nl) :
##   NAs introduced by coercion to integer range

The randomForest package

library(randomForest, quietly=TRUE)
try({fit = randomForest(Sales ~ Store, data = train,  ntree=1)})
## Warning: Error in randomForest.default(m, y, ...) : 
##   Can not handle categorical predictors with more than 53 categories.

base lm runs out of memory

try({fit0 = lm(Sales ~ Store, data = train)})
## Warning: Cannot allocate vector of size...

Libraries that shine:

The rpart package

library(rpart, quietly=TRUE)
fit = rpart(Sales ~ Store, data = train)

The gbm package

Well, almost. At least the max number of levels is very high (1024):

library(gbm, quietly=TRUE)
fit = gbm(Sales ~ Store, data = train, interaction.depth = 8, n.trees=1)
## Distribution not specified, assuming gaussian ...

(Not clear to me why there is a limit at all)

The RWeka package

library(RWeka, quietly=TRUE)
#handles only classification problems:
train$Sales2 = train$Sales > mean(train$Sales)
fit = J48(Sales2 ~ Store, data = train)

The h2o package



One thought on “Categorical Variables in Trees I”

  1. Markus, I was just pointed to your blog through the following discussion on SO:

    https://stackoverflow.com/questions/44455090/how-to-deal-with-large-number-of-factors-categories-within-partykit

    And we had discussed this privately before: The main reason for “partykit” not to implement this shortcut is that it is only available for certain special cases of the very general frameworks provided by ctree() and mob(). We’re currently working on modularizing our internal partykit infrastructure so that such special cases can – in principle – be added. We need more work though for which models/scores/loss functions the trick actually works and where it fails. This is not at the top of our to-do list but at least it’s on it 😉

Leave a Reply