1. <strong id="7actg"></strong>
    2. <table id="7actg"></table>

    3. <address id="7actg"></address>
      <address id="7actg"></address>
      1. <object id="7actg"><tt id="7actg"></tt></object>

        機(jī)器學(xué)習(xí)算法-隨機(jī)森林之決策樹R 代碼從頭暴力實(shí)現(xiàn)(2)

        共 468字,需瀏覽 1分鐘

         ·

        2021-01-13 16:22

        前文(機(jī)器學(xué)習(xí)算法 - 隨機(jī)森林之決策樹初探(1))講述了決策樹的基本概念、決策評價(jià)標(biāo)準(zhǔn)并手算了單個(gè)變量單個(gè)分組Gini impurity。是一個(gè)基本概念學(xué)習(xí)的過程,如果不了解,建議先讀一下再繼續(xù)。

        本篇通過 R 代碼(希望感興趣的朋友能夠投稿這個(gè)代碼的Python實(shí)現(xiàn))從頭暴力方式自寫函數(shù)訓(xùn)練決策樹。之前計(jì)算的結(jié)果,可以作為正對照,確定后續(xù)函數(shù)結(jié)果的準(zhǔn)確性。

        訓(xùn)練決策樹 - 確定根節(jié)點(diǎn)的分類閾值

        Gini impurity可以用來判斷每一步最合適的決策分類方式,那么怎么確定最優(yōu)的分類變量和分類閾值呢?

        最粗暴的方式是,我們用每個(gè)變量的每個(gè)可能得閾值來進(jìn)行決策分類,選擇具有最低Gini impurity值的分類組合。這不是最快速的解決問題的方式,但是最容易理解的方式。

        定義計(jì)算Gini impurity的函數(shù)

        data <- data.frame(x=c(0,0.5,1.1,1.8,1.9,2,2.5,3,3.6,3.7),
        y=c(1,0.5,1.5,2.1,2.8,2,2.2,3,3.3,3.5),
        color=c(rep('blue',3),rep('red',2),rep('green',5)))

        data

        ## x y color
        ## 1 0.0 1.0 blue
        ## 2 0.5 0.5 blue
        ## 3 1.1 1.5 blue
        ## 4 1.8 2.1 red
        ## 5 1.9 2.8 red
        ## 6 2.0 2.0 green
        ## 7 2.5 2.2 green
        ## 8 3.0 3.0 green
        ## 9 3.6 3.3 green
        ## 10 3.7 3.5 green

        首先定義個(gè)函數(shù)計(jì)算每個(gè)分支的Gini_impurity。

        Gini_impurity <- function(branch){
        # print(branch)
        len_branch <- length(branch)
        if(len_branch==0){
        return(0)
        }
        table_branch <- table(branch)
        wrong_probability <- function(x, total) (x/total*(1-x/total))
        return(sum(sapply(table_branch, wrong_probability, total=len_branch)))
        }

        測試下,沒問題。

        Gini_impurity(c(rep('a',2),rep('b',3)))

        ## [1] 0.48

        再定義一個(gè)函數(shù),計(jì)算每次決策的總Gini impurity.

        Gini_impurity_for_split_branch <- function(threshold, data, variable_column, 
        class_column, Init_gini_impurity=NULL){
        total = nrow(data)
        left <- data[data[variable_column] left_len = length(left)
        left_table = table(left)
        left_gini <- Gini_impurity(left)

        right <- data[data[variable_column]>=threshold,][[class_column]]
        right_len = length(right)
        right_table = table(right)
        right_gini <- Gini_impurity(right)
        total_gini <- left_gini * left_len / total + right_gini * right_len /total

        result = c(variable_column,threshold,
        paste(names(left_table), left_table, collapse="; ", sep=" x "),
        paste(names(right_table), right_table, collapse="; ", sep=" x "),
        total_gini)

        names(result) <- c("Variable", "Threshold", "Left_branch", "Right_branch", "Gini_impurity")

        if(!is.null(Init_gini_impurity)){
        Gini_gain <- Init_gini_impurity - total_gini
        result = c(variable_column, threshold,
        paste(names(left_table), left_table, collapse="; ", sep=" x "),
        paste(names(right_table), right_table, collapse="; ", sep=" x "),
        Gini_gain)

        names(result) <- c("Variable", "Threshold", "Left_branch", "Right_branch", "Gini_gain")
        }

        return(result)
        }

        測試下,跟之前計(jì)算的結(jié)果一致:

        as.data.frame(rbind(Gini_impurity_for_split_branch(2, data, 'x', 'color'), 
        Gini_impurity_for_split_branch(2, data, 'y', 'color')))

        ## Variable Threshold Left_branch Right_branch Gini_impurity
        ## 1 x 2 blue x 3; red x 2 green x 5 0.24
        ## 2 y 2 blue x 3 green x 5; red x 2 0.285714285714286

        暴力決策根節(jié)點(diǎn)和閾值

        基于前面定義的函數(shù),遍歷每一個(gè)可能的變量和閾值。

        首先看下基于變量x的計(jì)算方法:

        uniq_x <- sort(unique(data$x))
        delimiter_x <- zoo::rollmean(uniq_x,2)
        impurity_x <- as.data.frame(do.call(rbind, lapply(delimiter_x, Gini_impurity_for_split_branch,
        data=data, variable_column='x', class_column='color')))
        print(impurity_x)

        ## Variable Threshold Left_branch Right_branch Gini_impurity
        ## 1 x 0.25 blue x 1 blue x 2; green x 5; red x 2 0.533333333333333
        ## 2 x 0.8 blue x 2 blue x 1; green x 5; red x 2 0.425
        ## 3 x 1.45 blue x 3 green x 5; red x 2 0.285714285714286
        ## 4 x 1.85 blue x 3; red x 1 green x 5; red x 1 0.316666666666667
        ## 5 x 1.95 blue x 3; red x 2 green x 5 0.24
        ## 6 x 2.25 blue x 3; green x 1; red x 2 green x 4 0.366666666666667
        ## 7 x 2.75 blue x 3; green x 2; red x 2 green x 3 0.457142857142857
        ## 8 x 3.3 blue x 3; green x 3; red x 2 green x 2 0.525
        ## 9 x 3.65 blue x 3; green x 4; red x 2 green x 1 0.577777777777778

        再包裝2個(gè)函數(shù),一個(gè)計(jì)算單個(gè)變量為決策節(jié)點(diǎn)的各種可能決策的Gini impurity, 另一個(gè)計(jì)算所有變量依次作為決策節(jié)點(diǎn)的各種可能決策的Gini impurity。

        Gini_impurity_for_all_possible_branches_of_one_variable <- function(data, variable, class, Init_gini_impurity=NULL){
        uniq_value <- sort(unique(data[[variable]]))
        delimiter_value <- zoo::rollmean(uniq_value,2)
        impurity <- as.data.frame(do.call(rbind, lapply(delimiter_value,
        Gini_impurity_for_split_branch, data=data,
        variable_column=variable,
        class_column=class,
        Init_gini_impurity=Init_gini_impurity)))
        if(is.null(Init_gini_impurity)){
        decreasing = F
        } else {
        decreasing = T
        }
        impurity <- impurity[order(impurity[[colnames(impurity)[5]]], decreasing = decreasing),]
        return(impurity)
        }

        Gini_impurity_for_all_possible_branches_of_all_variables <- function(data, variables, class, Init_gini_impurity=NULL){
        one_split_gini <- do.call(rbind, lapply(variables,
        Gini_impurity_for_all_possible_branches_of_one_variable,
        data=data, class=class,
        Init_gini_impurity=Init_gini_impurity))
        if(is.null(Init_gini_impurity)){
        decreasing = F
        } else {
        decreasing = T
        }
        one_split_gini[order(one_split_gini[[colnames(one_split_gini)[5]]], decreasing = decreasing),]
        }

        測試下:

        Gini_impurity_for_all_possible_branches_of_one_variable(data, 'x', 'color')

        ## Variable Threshold Left_branch Right_branch Gini_impurity
        ## 5 x 1.95 blue x 3; red x 2 green x 5 0.24
        ## 3 x 1.45 blue x 3 green x 5; red x 2 0.285714285714286
        ## 4 x 1.85 blue x 3; red x 1 green x 5; red x 1 0.316666666666667
        ## 6 x 2.25 blue x 3; green x 1; red x 2 green x 4 0.366666666666667
        ## 2 x 0.8 blue x 2 blue x 1; green x 5; red x 2 0.425
        ## 7 x 2.75 blue x 3; green x 2; red x 2 green x 3 0.457142857142857
        ## 8 x 3.3 blue x 3; green x 3; red x 2 green x 2 0.525
        ## 1 x 0.25 blue x 1 blue x 2; green x 5; red x 2 0.533333333333333
        ## 9 x 3.65 blue x 3; green x 4; red x 2 green x 1 0.577777777777778

        兩個(gè)變量的各個(gè)閾值分別進(jìn)行決策,并計(jì)算Gini impurity,輸出按Gini impurity由小到大排序后的結(jié)果。根據(jù)變量x和閾值1.95(與上面選擇的閾值2獲得的決策結(jié)果一致)的決策可以獲得本步?jīng)Q策的最好結(jié)果。

        variables <- c('x', 'y')
        Gini_impurity_for_all_possible_branches_of_all_variables(data, variables, class="color")

        ## Variable Threshold Left_branch Right_branch Gini_impurity
        ## 5 x 1.95 blue x 3; red x 2 green x 5 0.24
        ## 3 x 1.45 blue x 3 green x 5; red x 2 0.285714285714286
        ## 31 y 1.75 blue x 3 green x 5; red x 2 0.285714285714286
        ## 4 x 1.85 blue x 3; red x 1 green x 5; red x 1 0.316666666666667
        ## 6 x 2.25 blue x 3; green x 1; red x 2 green x 4 0.366666666666667
        ## 41 y 2.05 blue x 3; green x 1 green x 4; red x 2 0.416666666666667
        ## 2 x 0.8 blue x 2 blue x 1; green x 5; red x 2 0.425
        ## 21 y 1.25 blue x 2 blue x 1; green x 5; red x 2 0.425
        ## 51 y 2.15 blue x 3; green x 1; red x 1 green x 4; red x 1 0.44
        ## 7 x 2.75 blue x 3; green x 2; red x 2 green x 3 0.457142857142857
        ## 71 y 2.9 blue x 3; green x 2; red x 2 green x 3 0.457142857142857
        ## 61 y 2.5 blue x 3; green x 2; red x 1 green x 3; red x 1 0.516666666666667
        ## 8 x 3.3 blue x 3; green x 3; red x 2 green x 2 0.525
        ## 81 y 3.15 blue x 3; green x 3; red x 2 green x 2 0.525
        ## 1 x 0.25 blue x 1 blue x 2; green x 5; red x 2 0.533333333333333
        ## 11 y 0.75 blue x 1 blue x 2; green x 5; red x 2 0.533333333333333
        ## 9 x 3.65 blue x 3; green x 4; red x 2 green x 1 0.577777777777778
        ## 91 y 3.4 blue x 3; green x 4; red x 2 green x 1 0.577777777777778

        • https://victorzhou.com/blog/intro-to-random-forests/

        • https://victorzhou.com/blog/gini-impurity/

        • https://stats.stackexchange.com/questions/192310/is-random-forest-suitable-for-very-small-data-sets

        • https://towardsdatascience.com/understanding-random-forest-58381e0602d2

        • https://www.stat.berkeley.edu/~breiman/RandomForests/reg_philosophy.html

        • https://medium.com/@williamkoehrsen/random-forest-simple-explanation-377895a60d2d


        往期精品(點(diǎn)擊圖片直達(dá)文字對應(yīng)教程)


        后臺回復(fù)“生信寶典福利第一波”或點(diǎn)擊閱讀原文獲取教程合集

        ?

        (請備注姓名-學(xué)校/企業(yè)-職務(wù)等)


        瀏覽 51
        點(diǎn)贊
        評論
        收藏
        分享

        手機(jī)掃一掃分享

        分享
        舉報(bào)
        評論
        圖片
        表情
        推薦
        點(diǎn)贊
        評論
        收藏
        分享

        手機(jī)掃一掃分享

        分享
        舉報(bào)
        1. <strong id="7actg"></strong>
        2. <table id="7actg"></table>

        3. <address id="7actg"></address>
          <address id="7actg"></address>
          1. <object id="7actg"><tt id="7actg"></tt></object>
            天天碰天天操 | 国产女人18毛片 | 伊人青青 | 豆花视频18 成人入口 | 欧美极品jiizzhd欧美暴力 | 免费看黄毛片 | 日本日韩中文字幕波多野吉衣 | 黄片靠逼 | 国产成人精品无码区在线 | 欧美色图15p |