EM 알고리즘 구현 – 오류 수정본 –

저번에 구현한 EM 알고리즘에 오류가 있어서 수정해봤다.

각 클러스터 중심을 구할 때 k-means 와 비슷하게 아예 확률값이 큰 클러스터에 포인트를 할당해 해당 클러스터 파라메터 계산시에만 확률값을 사용했는데 이런식으로 하면 안되고 한 포인트가 가지는 각 클러스터의 확률값을 가지고 이를 기반으로 각 클러스터 파라메터 재계산을 해야한다.

사실 코딩할 때 애매모호 해서 예전에 들었던 k-means와 거의 비슷하다는 강의록을 기억해 이런식으로 코딩했었다. 하지만 다시 확인해보고 여러 책을 뒤져보니 잘못 코딩한게 명백해졌다.

그래서 코드 수정을 했다.
게다가 k-means와 거의 100% 비슷하다는 언급을 취소하고 약 90%정도 비슷하다 정도로 이야기 하고 싶다.

/*
EM clustering algorithm
Created by gogamza 2009.06.28
You can see full description in http://www.freesearch.pe.kr/1262
*/
#include 
#include 
#include 
#include 
#include 
#include 
#include 

using boost::math::normal;
using namespace std;



//data structure of data
typedef struct{
    double datapoint;
    int clusterid;
    //double prob;
    map clusterProb;
} t_data;

//data structure of each cluster
typedef struct{
    double mean;
    double stddev;
    int clusterid;
} t_cluster;

//cluster vector
vector cluster;


//clacualte bayes probability 
double getProb(const int x, const int clusterid){
    if(cluster.size() == 0)
        assert(0);
    double priorP = 1.0/cluster.size();
    
    vector::const_iterator iter;
    double numer = 0.0;    
    double denorm = 0.0;
    for(iter = cluster.begin();iter != cluster.end();iter++){
        if(clusterid == iter->clusterid)
            numer = priorP *  pdf(normal(iter->mean, iter->stddev), x);
        denorm += priorP * pdf(normal(iter->mean, iter->stddev), x);
    }
    return numer/denorm;
}


int expectation(t_data& data){
    vector::iterator iter;
    double x = data.datapoint;
    double max = -1;
    for(iter = cluster.begin();iter != cluster.end();iter++){
        double prob = getProb(x, iter->clusterid);
        data.clusterProb[iter->clusterid] = prob;
        if(max < prob){
            max =  prob;
            data.clusterid = iter->clusterid;
        }
        //cout << iter->clusterid << prob << endl;
    }
    return 0;
}


void maximization(std::vector& datas){
    vector::iterator iter;
    for(iter = cluster.begin();iter != cluster.end();iter++){
        vector::iterator iter2;
        double numer = 0.0;
        double denorm = 0.0;
        double mean = 0.0;
        double stddev = 0.0;
        double numer1 = 0.0;
        double denorm1 = 0.0;
        for(iter2 = datas.begin();iter2 != datas.end();iter2++){
            map::iterator ele = iter2->clusterProb.find(iter->clusterid);
            if(ele !=  iter2->clusterProb.end()){
                numer += iter2->datapoint * ele->second;
                denorm += ele->second;
            }
        }
        mean = numer/denorm;
        vector::iterator iter3 = datas.begin();
        while(iter3 != datas.end()){
            map::iterator ele = iter3->clusterProb.find(iter->clusterid);
            if(ele !=  iter3->clusterProb.end()){
                numer1 += ele->second * pow(mean - iter3->datapoint, 2);
                denorm1 += ele->second;
            }
            iter3++;
        }

        stddev = sqrt(numer1/denorm1);
        iter->mean = mean;
        iter->stddev = stddev;
            
    }
}


int main( int argc, char **argv ){
    //init
    boost::mt19937    Random4;

    vector datas;

    boost::normal_distribution<>  norm_gen1(6,3);
    boost::normal_distribution<>  norm_gen2(8,3);
    boost::normal_distribution<>  norm_gen3(20,3);

    boost::variate_generator< boost::mt19937, boost::normal_distribution<> > genG1( Random4, norm_gen1 );

    boost::variate_generator< boost::mt19937, boost::normal_distribution<> > genG2(Random4, norm_gen2);

    boost::variate_generator< boost::mt19937, boost::normal_distribution<> > genG3(Random4, norm_gen3);

    
    t_cluster init_cluster1 = {2,3,0};
    cluster.push_back(init_cluster1);
    t_cluster init_cluster2 = {9,3,1};
    cluster.push_back(init_cluster2);
    t_cluster init_cluster3 = {18,3,2};
    cluster.push_back(init_cluster3);
    //random number generation    
    pair initProb[] = {make_pair(0,0.0), make_pair(1,0.0), make_pair(2,0.0)};
    for( int i = 0; i < 1000; ++i ){
        t_data data1 = {0, 0.0, map(initProb, initProb + sizeof(initProb)/sizeof(initProb[0]))};
        data1.datapoint = genG1();
        datas.push_back(data1);

        t_data data2 = {0, 0.0, map(initProb, initProb + sizeof(initProb)/sizeof(initProb[0]))};
        data2.datapoint = genG2();
        datas.push_back(data2);
        
        t_data data3 = {0, 0.0, map(initProb, initProb + sizeof(initProb)/sizeof(initProb[0]))};
        data3.datapoint = genG3();
        datas.push_back(data3);
    }

    //run
    int runcnt = 0;
    //while(runcnt < 50){
    while(1){
        vector::iterator iter = datas.begin();
        //expectation
        for(;iter != datas.end(); iter++){
            if(expectation(*iter))
                cerr << "error occur!\n" << endl;
        }
        //maximization
        maximization(datas);
        cout << "iter : " << runcnt << endl;
        vector::iterator iter2 = cluster.begin();
        for(; iter2 != cluster.end(); iter2++){
            cout << "cluster id : " << iter2->clusterid << '\n'
                 << "mean : " << iter2->mean << '\n'
                 << "stddev :" << iter2->stddev << "\n\n" << endl;
        }
        
        runcnt++;
    }
    
    return 0;
}

위 알고리즘을 돌려보면 두 클러스터는 이전 클러스터의 중간값으로 근접해 가고, 나머지 한개의 클러스터는 멀리 떨어진 평균 20을 가지는 점으로 수렴하는걸 볼 수 있다.

세 가지 다른 정규분포에 의해서 나온 점들을 제너레이션 했으나 결국 2개의 클러스터로 클러스터링 되었다는 재밋는 사실을 알 수 있었다.
넘버 제너러이션 시점에서 표준편차를 적은값으로 주면 정확하게 3개의 클러스터로 만들어질 거란  생각을 해본다.

CC BY-NC 4.0 EM 알고리즘 구현 – 오류 수정본 – by from __future__ import dream is licensed under a Creative Commons Attribution-NonCommercial 4.0 International License.