书中的第二个例子是利用朴素贝叶斯算法判断垃圾短信。

首先载入需要用到的包:

library(tidyverse) # 清洗数据
library(here) # 设置数据文件路径
library(tidytext) # 分词及创建稀疏矩阵
library(e1071) # 建模
library(gmodels) # 评估模型

在清洗数据的时候遇到一定的困难,因为书中是用tm包进行文本处理的,而我完全没有用过这个包(甚至也没有装这个包),所以看书中的代码就只能凭感觉脑补了。不过,还好,最后还是成功写出了tidyverse化的数据清洗代码,如下:

sms <- read_csv(here('content', 'post', 'data', '02-sms_spam.csv')) %>% 
  mutate(type = factor(type),
         row = row_number()) %>% 
  unnest_tokens(word, text) %>% 
  anti_join(stop_words) %>% 
  filter(!str_detect(word, '\\d')) %>% 
  cast_sparse(row, word) %>% 
  as.matrix() %>% 
  as_tibble() %>% 
  select(which(colSums(.) > 4)) %>% 
  bind_cols(read_csv(here('data', '02-sms_spam.csv')) %>% 
              mutate(type = factor(type),
                     row = row_number()) %>% 
              unnest_tokens(word, text) %>% 
              anti_join(stop_words) %>% 
              filter(!str_detect(word, '\\d')) %>%
              select(-3) %>% 
              distinct()) %>% 
  mutate_if(is.numeric, factor, levels = c(0, 1), labels = c('No', 'Yes'))

虽然是很长一串,但还是要比书中的代码少10来行的,而且连贯性和可读性也更高,最重要的是,只需要命名一个变量。

分解一下:

原始数据是这样的:

(sms <- read_csv(here('content', 'post', 'data', '02-sms_spam.csv')))
## # A tibble: 5,574 x 2
##    type  text                                                              
##    <chr> <chr>                                                             
##  1 ham   Go until jurong point, crazy.. Available only in bugis n great wo~
##  2 ham   Ok lar... Joking wif u oni...                                     
##  3 spam  Free entry in 2 a wkly comp to win FA Cup final tkts 21st May 200~
##  4 ham   U dun say so early hor... U c already then say...                 
##  5 ham   Nah I don't think he goes to usf, he lives around here though     
##  6 spam  FreeMsg Hey there darling it's been 3 week's now and no word back~
##  7 ham   Even my brother is not like to speak with me. They treat me like ~
##  8 ham   As per your request 'Melle Melle (Oru Minnaminunginte Nurungu Vet~
##  9 spam  WINNER!! As a valued network customer you have been selected to r~
## 10 spam  Had your mobile 11 months or more? U R entitled to Update to the ~
## # ... with 5,564 more rows

随后将标签变量type变为因子型,并新增row变量,记录行数:

sms <- read_csv(here('content', 'post', 'data', '02-sms_spam.csv')) %>% 
  mutate(type = factor(type),
         row = row_number())

然后利用tidytext包中的unnest_token函数进行分词,利用anti_join函数去掉停用词,再利用filterstr_detect的组合去掉数字。此时的数据是这样的:

(sms <- read_csv(here('content', 'post', 'data', '02-sms_spam.csv')) %>% 
  mutate(type = factor(type),
         row = row_number()) %>% 
  unnest_tokens(word, text) %>% 
  anti_join(stop_words) %>% 
  filter(!str_detect(word, '\\d')))
## # A tibble: 34,390 x 3
##    type    row word  
##    <fct> <int> <chr> 
##  1 ham       1 jurong
##  2 ham       1 crazy 
##  3 ham       1 bugis 
##  4 ham       1 world 
##  5 ham       1 la    
##  6 ham       1 buffet
##  7 ham       1 cine  
##  8 ham       1 amore 
##  9 ham       1 wat   
## 10 ham       2 lar   
## # ... with 34,380 more rows

其中共涉及到7440个词汇。

这时遇到了困难,因为需要把数据整成稀疏矩阵,也就是要做到每个词自成一列,假如某条短信内出现了该词,则记为1,没有出现的话,则记为0。数据一共5000多行,而词汇共有7000多个,即要整理出一个5000*7000的矩阵或数据框。一开始想尝试用tidyr包来解决这个问题,结果发现生成了一个5GB的数据框,虽然也能把问题解决,但这个方法太慢了。看书里的方法,tm包中是有相关的函数来进行这一步转换的;去网上查,发现Matrix包也能解决这个问题,但它们都会破坏代码的完整性。后来想到,tidytext应该不会没有处理这种问题的函数,看了下,果然有个cast_sparse函数,可以调用Matrix包中的sparseMatrix函数。此时问题还没有完全解决,以为cast_sparse函数生成的矩阵是一个class为dgCMatrix的矩阵,没法直接转为数据框。又在网上查了下,发现可以先将其转为矩阵,然后再转为数据框。此时的部分数据是这样的:

(sms <- read_csv(here('content', 'post', 'data', '02-sms_spam.csv')) %>% 
  mutate(type = factor(type),
         row = row_number()) %>% 
  unnest_tokens(word, text) %>% 
  anti_join(stop_words) %>% 
  filter(!str_detect(word, '\\d')) %>% 
  cast_sparse(row, word) %>% 
  as.matrix() %>% 
  as_tibble())
## # A tibble: 5,454 x 7,440
##    jurong crazy bugis world    la buffet  cine amore   wat   lar joking
##     <dbl> <dbl> <dbl> <dbl> <dbl>  <dbl> <dbl> <dbl> <dbl> <dbl>  <dbl>
##  1      1     1     1     1     1      1     1     1     1     0      0
##  2      0     0     0     0     0      0     0     0     0     1      1
##  3      0     0     0     0     0      0     0     0     0     0      0
##  4      0     0     0     0     0      0     0     0     0     0      0
##  5      0     0     0     0     0      0     0     0     0     0      0
##  6      0     0     0     0     0      0     0     0     0     0      0
##  7      0     0     0     0     0      0     0     0     0     0      0
##  8      0     0     0     0     0      0     0     0     0     0      0
##  9      0     0     0     0     0      0     0     0     0     0      0
## 10      0     0     0     0     0      0     0     0     0     0      0
## # ... with 5,444 more rows, and 7,429 more variables: wif <dbl>,
## #   oni <dbl>, free <dbl>, entry <dbl>, wkly <dbl>, comp <dbl>, win <dbl>,
## #   fa <dbl>, cup <dbl>, final <dbl>, tkts <dbl>, text <dbl>,
## #   receive <dbl>, question <dbl>, std <dbl>, txt <dbl>, rate <dbl>,
## #   apply <dbl>, dun <dbl>, hor <dbl>, nah <dbl>, usf <dbl>, lives <dbl>,
## #   freemsg <dbl>, hey <dbl>, darling <dbl>, `week's` <dbl>, word <dbl>,
## #   fun <dbl>, tb <dbl>, xxx <dbl>, chgs <dbl>, send <dbl>, rcv <dbl>,
## #   brother <dbl>, speak <dbl>, treat <dbl>, aids <dbl>, patent <dbl>,
## #   request <dbl>, melle <dbl>, oru <dbl>, minnaminunginte <dbl>,
## #   nurungu <dbl>, vettam <dbl>, set <dbl>, callertune <dbl>,
## #   callers <dbl>, press <dbl>, copy <dbl>, friends <dbl>, winner <dbl>,
## #   valued <dbl>, network <dbl>, customer <dbl>, selected <dbl>,
## #   receivea <dbl>, prize <dbl>, reward <dbl>, claim <dbl>, call <dbl>,
## #   code <dbl>, valid <dbl>, hours <dbl>, mobile <dbl>, months <dbl>,
## #   entitled <dbl>, update <dbl>, colour <dbl>, mobiles <dbl>,
## #   camera <dbl>, gonna <dbl>, home <dbl>, talk <dbl>, stuff <dbl>,
## #   anymore <dbl>, tonight <dbl>, cried <dbl>, chances <dbl>, cash <dbl>,
## #   pounds <dbl>, cost <dbl>, day <dbl>, tsandcs <dbl>, reply <dbl>,
## #   hl <dbl>, info <dbl>, urgent <dbl>, won <dbl>, week <dbl>,
## #   membership <dbl>, jackpot <dbl>, www.dbuk.net <dbl>, lccltd <dbl>,
## #   pobox <dbl>, searching <dbl>, words <dbl>, breather <dbl>,
## #   promise <dbl>, wont <dbl>, ...

下一步是按着书里的标准,去掉出现频次较低的词汇,仅保留至少出现过5次的词汇。这里也遇到点小困难,本来是想用select_*系列的函数剔除低频词汇的,但没有成功,最后在网上查到了更为简单的方式。这时变量就从7440变成了1312,数据框的大小也从300多MB减少到了50多MB:

(sms <- read_csv(here('content', 'post', 'data', '02-sms_spam.csv')) %>% 
  mutate(type = factor(type),
         row = row_number()) %>% 
  unnest_tokens(word, text) %>% 
  anti_join(stop_words) %>% 
  filter(!str_detect(word, '\\d')) %>% 
  cast_sparse(row, word) %>% 
  as.matrix() %>% 
  as_tibble() %>% 
  select(which(colSums(.) > 4)))
## # A tibble: 5,454 x 1,312
##    crazy bugis world    la  cine   wat   lar joking   wif  free entry  wkly
##    <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>  <dbl> <dbl> <dbl> <dbl> <dbl>
##  1     1     1     1     1     1     1     0      0     0     0     0     0
##  2     0     0     0     0     0     0     1      1     1     0     0     0
##  3     0     0     0     0     0     0     0      0     0     1     1     1
##  4     0     0     0     0     0     0     0      0     0     0     0     0
##  5     0     0     0     0     0     0     0      0     0     0     0     0
##  6     0     0     0     0     0     0     0      0     0     0     0     0
##  7     0     0     0     0     0     0     0      0     0     0     0     0
##  8     0     0     0     0     0     0     0      0     0     0     0     0
##  9     0     0     0     0     0     0     0      0     0     0     0     0
## 10     0     0     0     0     0     0     0      0     0     1     0     0
## # ... with 5,444 more rows, and 1,300 more variables: comp <dbl>,
## #   win <dbl>, cup <dbl>, final <dbl>, text <dbl>, receive <dbl>,
## #   question <dbl>, std <dbl>, txt <dbl>, rate <dbl>, apply <dbl>,
## #   dun <dbl>, nah <dbl>, usf <dbl>, freemsg <dbl>, hey <dbl>,
## #   darling <dbl>, word <dbl>, fun <dbl>, xxx <dbl>, send <dbl>,
## #   brother <dbl>, speak <dbl>, treat <dbl>, request <dbl>, set <dbl>,
## #   callertune <dbl>, callers <dbl>, press <dbl>, copy <dbl>,
## #   friends <dbl>, winner <dbl>, valued <dbl>, network <dbl>,
## #   customer <dbl>, selected <dbl>, prize <dbl>, reward <dbl>,
## #   claim <dbl>, call <dbl>, code <dbl>, valid <dbl>, hours <dbl>,
## #   mobile <dbl>, months <dbl>, entitled <dbl>, update <dbl>,
## #   colour <dbl>, mobiles <dbl>, camera <dbl>, gonna <dbl>, home <dbl>,
## #   talk <dbl>, stuff <dbl>, anymore <dbl>, tonight <dbl>, cash <dbl>,
## #   pounds <dbl>, cost <dbl>, day <dbl>, reply <dbl>, hl <dbl>,
## #   info <dbl>, urgent <dbl>, won <dbl>, week <dbl>, pobox <dbl>,
## #   searching <dbl>, words <dbl>, promise <dbl>, wont <dbl>,
## #   wonderful <dbl>, times <dbl>, date <dbl>, sunday <dbl>, credit <dbl>,
## #   click <dbl>, wap <dbl>, link <dbl>, message <dbl>, http <dbl>,
## #   watching <dbl>, eh <dbl>, remember <dbl>, naughty <dbl>, wet <dbl>,
## #   fine <dbl>, feel <dbl>, england <dbl>, dont <dbl>, miss <dbl>,
## #   team <dbl>, news <dbl>, ur <dbl>, national <dbl>, `i‘m` <dbl>,
## #   ha <dbl>, ü <dbl>, pay <dbl>, da <dbl>, ...

这时一个稀疏矩阵就建好了,但数据中还没有标签,所以我又用一大段重复的代码把行数和标签并了进去。暂时没想到更简单的方式可以在不打断代码的前提下完成同样的事情。最后一步是按照书中讲到的,把所有预测变量变为因子型:

sms <- read_csv(here('content', 'post', 'data', '02-sms_spam.csv')) %>% 
  mutate(type = factor(type),
         row = row_number()) %>% 
  unnest_tokens(word, text) %>% 
  anti_join(stop_words) %>% 
  filter(!str_detect(word, '\\d')) %>% 
  cast_sparse(row, word) %>% 
  as.matrix() %>% 
  as_tibble() %>% 
  select(which(colSums(.) > 4)) %>% 
  bind_cols(read_csv(here('content', 'post', 'data', '02-sms_spam.csv')) %>% 
              mutate(type = factor(type),
                     row = row_number()) %>% 
              unnest_tokens(word, text) %>% 
              anti_join(stop_words) %>% 
              filter(!str_detect(word, '\\d')) %>%
              select(-3) %>% 
              distinct()) %>% 
  mutate_if(is.numeric, factor, levels = c(0, 1), labels = c('No', 'Yes'))

数据已经清洗好了,可以创建训练数据集和测试数据集了:

set.seed(0424)
sms_train <- sms %>% sample_n(4169)
sms_test <- sms %>% setdiff(sms_train)

这里也遇到个问题。在去掉频次较少的词汇前,也就是有7000多列时,sample_n函数会报错,但去掉那些词汇后,就没有问题了。猜测使用sample_n函数时,数据的变量数不能大于参数n的值?

然后就可以建模了。模型里的训练数据去掉了最后两列(行数和标签),而且需要注意的是,因为词汇中包括type这个词,所以本来的type变量名被自动变更为type1了。

sms_class <- naiveBayes(sms_train[, -1313:-1314], sms_train$type1)
sms_pred <- predict(sms_class, sms_test)

用测试数据评估一下模型:

CrossTable(sms_pred, sms_test$type1, 
           prop.chisq = FALSE, prop.t = FALSE,
           dnn = c('predicted', 'actual'))
## 
##  
##    Cell Contents
## |-------------------------|
## |                       N |
## |           N / Row Total |
## |           N / Col Total |
## |-------------------------|
## 
##  
## Total Observations in Table:  890 
## 
##  
##              | actual 
##    predicted |       ham |      spam | Row Total | 
## -------------|-----------|-----------|-----------|
##          ham |       789 |        14 |       803 | 
##              |     0.983 |     0.017 |     0.902 | 
##              |     0.996 |     0.143 |           | 
## -------------|-----------|-----------|-----------|
##         spam |         3 |        84 |        87 | 
##              |     0.034 |     0.966 |     0.098 | 
##              |     0.004 |     0.857 |           | 
## -------------|-----------|-----------|-----------|
## Column Total |       792 |        98 |       890 | 
##              |     0.890 |     0.110 |           | 
## -------------|-----------|-----------|-----------|
## 
## 

应该还挺不错的。

按书里的方式,更改laplace参数再试一下:

sms_class1 <- naiveBayes(sms_train[, -1313:-1314], sms_train$type1, laplace = 1)
sms_pred1 <- predict(sms_class1, sms_test)

CrossTable(sms_pred1, sms_test$type1, 
           prop.chisq = FALSE, prop.t = FALSE,
           dnn = c('predicted', 'actual'))
## 
##  
##    Cell Contents
## |-------------------------|
## |                       N |
## |           N / Row Total |
## |           N / Col Total |
## |-------------------------|
## 
##  
## Total Observations in Table:  890 
## 
##  
##              | actual 
##    predicted |       ham |      spam | Row Total | 
## -------------|-----------|-----------|-----------|
##          ham |       790 |        16 |       806 | 
##              |     0.980 |     0.020 |     0.906 | 
##              |     0.997 |     0.163 |           | 
## -------------|-----------|-----------|-----------|
##         spam |         2 |        82 |        84 | 
##              |     0.024 |     0.976 |     0.094 | 
##              |     0.003 |     0.837 |           | 
## -------------|-----------|-----------|-----------|
## Column Total |       792 |        98 |       890 | 
##              |     0.890 |     0.110 |           | 
## -------------|-----------|-----------|-----------|
## 
## 

此时模型确实得到了一定程度的优化,因为虽然多看了两条垃圾短信,但少错过了一条非垃圾短信。