iris 데이터에 대한 베이지언 분류기 토이 예제

패턴인식 겨울 학교 첫날에 정규분포를 가정한 데이터에 대한 베이지언 분류기를 만들어 보는 토이 예제가 나와서 간단하게 R로 구현해 봤다.

몇몇 함수를 먼저 정의할 필요가 있다. 특히나 공분산 행렬을 만드는…

물론 R은 cov()라는 공분산 행렬을 만드는 함수가 있으나 목적상 직접 만들어 보자.

# 학습셋과 테스트셋을 구분
set.seed(1234)
idx <- sample(1:50, size = 45)
sub_set_train <- subset(iris, Species == "setosa")[idx, ]
sub_ver_train <- subset(iris, Species == "versicolor")[idx, ]
sub_vir_train <- subset(iris, Species == "virginica")[idx, ]

sub_set_test <- subset(iris, Species == "setosa")[-idx, ]
sub_ver_test <- subset(iris, Species == "versicolor")[-idx, ]
sub_vir_test <- subset(iris, Species == "virginica")[-idx, ]

# cov 함수 정의

user_cov <- function(x) {
    nc <- ncol(x)
    cov_set <- matrix(0, nrow = nc, ncol = nc)
    mean_set <- colMeans(x)
    for (i in 1:nrow(x)) {
        y <- x[i, ]
        cov_set <- cov_set + (t(as.matrix(y - mean_set)) %*% as.matrix(y - mean_set))
    }

    # 공분산 행렬
    cov_set <- 1/(nrow(x) - 1) * cov_set
    return(cov_set)
}


# 우도 함수
likelihood <- function(means, cov, x) {
    1/((2 * pi)^(4/2) * det(cov)^(1/2)) * exp(-(1/2) * (as.matrix(x - means) %*% 
        solve(cov) %*% t(as.matrix(x - means))))
}

필요한 준비는 다 되었는데….

이제 각 클래스의 평균과 covariance matrix를 구하고 unseen데이터들을 이 우도함수에 태워서 가장 큰 확률 값을 가지는 것에 클래스를 할당한다. 모든 클래스의 prior확률값도 같게 맞춰줬고 전확률도 동일하다고 가정했다.

set.means <- colMeans(sub_set_train[, -5])
set.cov <- user_cov(sub_set_train[, -5])
# setosa의 covariance matrix
set.cov
##              Sepal.Length Sepal.Width Petal.Length Petal.Width
## Sepal.Length      0.13146     0.10851     0.018939    0.011566
## Sepal.Width       0.10851     0.15674     0.012394    0.011157
## Petal.Length      0.01894     0.01239     0.026182    0.005652
## Petal.Width       0.01157     0.01116     0.005652    0.010677


ver.means <- colMeans(sub_ver_train[, -5])
ver.cov <- user_cov(sub_ver_train[, -5])
# versicolor의 covariance matrix
ver.cov
##              Sepal.Length Sepal.Width Petal.Length Petal.Width
## Sepal.Length      0.27219     0.08283      0.17370     0.05588
## Sepal.Width       0.08283     0.09164      0.07521     0.03735
## Petal.Length      0.17370     0.07521      0.19968     0.07199
## Petal.Width       0.05588     0.03735      0.07199     0.03859

vir.means <- colMeans(sub_vir_train[, -5])
vir.cov <- user_cov(sub_vir_train[, -5])
# virginicad의 covariance matrix
vir.cov
##              Sepal.Length Sepal.Width Petal.Length Petal.Width
## Sepal.Length      0.35118     0.08673      0.27689     0.03603
## Sepal.Width       0.08673     0.10245      0.06417     0.04476
## Petal.Length      0.27689     0.06417      0.29586     0.04056
## Petal.Width       0.03603     0.04476      0.04056     0.07689

sub_set_test$lik_set <- apply(t(sub_set_test[, c(1, 2, 3, 4)]), 2, function(x) {
    likelihood(set.means, set.cov, t(x))
})


sub_set_test$lik_ver <- apply(t(sub_set_test[, c(1, 2, 3, 4)]), 2, function(x) {
    likelihood(ver.means, ver.cov, t(x))
})


sub_set_test$lik_vir <- apply(t(sub_set_test[, c(1, 2, 3, 4)]), 2, function(x) {
    likelihood(vir.means, vir.cov, t(x))
})

# setosa 테스트셋 각 클래스 확률
sub_set_test
##    Sepal.Length Sepal.Width Petal.Length Petal.Width Species  lik_set
## 7           4.6         3.4          1.4         0.3  setosa  3.26678
## 24          5.1         3.3          1.7         0.5  setosa  0.33225
## 25          4.8         3.4          1.9         0.2  setosa  0.04426
## 36          5.0         3.2          1.2         0.2  setosa  2.93948
## 49          5.3         3.7          1.5         0.2  setosa 10.30504
##      lik_ver   lik_vir
## 7  6.000e-21 7.931e-32
## 24 7.732e-17 1.728e-28
## 25 1.690e-19 1.493e-28
## 36 1.054e-22 4.453e-37
## 49 3.221e-27 8.949e-40



sub_ver_test$lik_set <- apply(t(sub_ver_test[, c(1, 2, 3, 4)]), 2, function(x) {
    likelihood(set.means, set.cov, t(x))
})


sub_ver_test$lik_ver <- apply(t(sub_ver_test[, c(1, 2, 3, 4)]), 2, function(x) {
    likelihood(ver.means, ver.cov, t(x))
})


sub_ver_test$lik_vir <- apply(t(sub_ver_test[, c(1, 2, 3, 4)]), 2, function(x) {
    likelihood(vir.means, vir.cov, t(x))
})

# versicolor 테스트셋 각 클래스 확률
sub_ver_test
##    Sepal.Length Sepal.Width Petal.Length Petal.Width    Species    lik_set
## 57          6.3         3.3          4.7         1.6 versicolor 1.440e-104
## 74          6.1         2.8          4.7         1.2 versicolor  3.780e-97
## 75          6.4         2.9          4.3         1.3 versicolor  6.768e-80
## 86          6.0         3.4          4.5         1.6 versicolor  4.403e-95
## 99          5.1         2.5          3.0         1.1 versicolor  1.644e-32
##     lik_ver   lik_vir
## 57 1.036871 9.549e-03
## 74 0.073948 6.406e-03
## 75 2.745005 8.117e-05
## 86 0.238806 1.947e-03
## 99 0.007279 4.059e-08



sub_vir_test$lik_set <- apply(t(sub_vir_test[, c(1, 2, 3, 4)]), 2, function(x) {
    likelihood(set.means, set.cov, t(x))
})


sub_vir_test$lik_ver <- apply(t(sub_vir_test[, c(1, 2, 3, 4)]), 2, function(x) {
    likelihood(ver.means, ver.cov, t(x))
})


sub_vir_test$lik_vir <- apply(t(sub_vir_test[, c(1, 2, 3, 4)]), 2, function(x) {
    likelihood(vir.means, vir.cov, t(x))
})

# virginica 테스트셋 각 클래스 확률
sub_vir_test
##     Sepal.Length Sepal.Width Petal.Length Petal.Width   Species    lik_set
## 107          4.9         2.5          4.5         1.7 virginica 2.275e-109
## 124          6.3         2.7          4.9         1.8 virginica 8.717e-128
## 125          6.7         3.3          5.7         2.1 virginica 4.988e-183
## 136          7.7         3.0          6.1         2.3 virginica 3.121e-226
## 149          6.2         3.4          5.4         2.3 virginica 1.744e-177
##       lik_ver lik_vir
## 107 3.636e-04 0.01635
## 124 2.158e-02 0.44757
## 125 1.315e-03 1.17220
## 136 5.520e-08 0.04965
## 149 2.787e-07 0.25756

결과적으로 iris는 꽤 정규분포에 잘 적합되는 데이터가 아닐 수 없다.

CC BY-NC 4.0 iris 데이터에 대한 베이지언 분류기 토이 예제 by from __future__ import dream is licensed under a Creative Commons Attribution-NonCommercial 4.0 International License.