[Algorithm] K-means

k-meansUnsupervised Learning 에 속하는 Clustering Algorithm 이다.

Unsupervised Learning 은 Training Data에 Label 이 없기 때문에, 비슷한 군집 끼리 분류를 해야한다.

클러스터를 정의하는 방법에는 많은 방법이 있지만, 여기서 소개할 k-means는 클러스터와 데이터가 가까운것을 찾는다.

이 문제는 NP-hard 로써 각 클러스터와 데이터간의 유클리드 거리를 최소로 하는 집합을 찾는것이 목표이다.

알고리즘은 간단하다.

1. 먼저 임의의 K개의 클러스터의 위치를 설정한다.
2. 모든 점들은 자기와 가장 가까운 클러스터에 속하게 된다.
3. i 번째 클러스터에 속한 점들의 무게중심으로 클러스터가 이동한다.
4. 모든 클러스터가 더이상 이동하지 않으면 종료한다. 그렇지 않으면 2번으로 되돌아간다.

  • 이를 수식으로 적으면 아래와 같다.

위 수식을 설명하면, K개의 클러스터와 모든 데이터의 거리의 최소값을 찾는다는 뜻이다.

where뒤의 식은 mj는 클러스터의 중심점인데, 무게중심을 개선한다는 뜻이다.

실제로 k-means 자체는 간단하다.

이를 간단하게 구현해 보았다.

processing source code

kmeans.pde
/*
* @Project  : k-means visualization
 * @Architecture : Kim Bom
 * kmeans.pde
 *
 * @Created by KimBom On 2016. 06. 12...
 * @Copyright (C) 2016 KimBom. All rights reserved.
 */
import java.util.LinkedList;  
import java.awt.Point;  
LinkedList<Point> list=new LinkedList<Point>();  //point set  
final int padding=50;  //for drawing  
final int radius=3;  //for drawing  
final int pt_size=5000;  
final int K=7;  //K-means, num of cluster  
Cluster[] cluster=new Cluster[K];  
//maximan num of cluster is 7
color[] rgbs=new color[]{  
  color(241, 95, 95), 
  color(250, 237, 125), 
  color(134, 229, 127), 
  color(92, 209, 229), 
  color(67, 116, 217), 
  color(217, 65, 197), 
  color(255, 255, 255)
};
boolean isEnd=false;  
void setup() {  
  size(900, 900);
  //init point position
  for (int i=0; i<pt_size; i++) {
    int x=(int)random(padding, width-padding);
    int y=(int)random(padding, width-padding);
    Point p=new Point(x, y);
    list.add(p);
  }
  //set cluster;
  try {
    for (int i=0; i<K; i++) {
      int x=(int)random(padding, width-padding);
      int y=(int)random(padding, width-padding);
      cluster[i]=new Cluster(x, y, rgbs[i]);
    }
  }
  catch(ArrayIndexOutOfBoundsException e) {
    e.printStackTrace();
    isEnd=true;
  }
  frameRate(5);
}

void draw() {  
  background(33, 33, 33);
  if (!isEnd) {

    SetCluster();
  } else {
    fill(255,255,255);
    textSize(20);
    text("k-means end",20,30);
  }
  DrawClusterPoint();
  if (!isEnd && ResetCluster()) {
    isEnd=true;
  }
}
void DrawClusterPoint() {  
  for (int i=0; i<K; i++) {
    fill(cluster[i].rgb);
    cluster[i].draw_x(radius);
    for (Point p : cluster[i].set) {
      ellipse(p.x, p.y, radius, radius);
    }
  }
}
//calculate Euclidean distance
float getDistance(float x, float y) {  
  return pow(x, 2)+pow(y, 2);
}
void SetCluster() {  
  for (Cluster c : cluster) {
    c.set.clear();
  }
  for (Point p : list) {
    int min_idx=0;
    float min_value=Float.MAX_VALUE;
    for (int i=0; i<cluster.length; i++) {
      float d=getDistance(p.x-cluster[i].base.x, p.y-cluster[i].base.y);
      if (d<min_value) {
        min_idx=i;
        min_value=d;
      }
    }
    cluster[min_idx].set.add(p);
  }
}
boolean ResetCluster() {  
  //center position : sum(x),sum(y) div size(x,y)
  boolean b=true;
  for (int i=0; i<K; i++) {
    int sum_x=0;
    int sum_y=0;
    for (Point p : cluster[i].set) {
      sum_x+=p.x;
      sum_y+=p.y;
    }
    sum_x/=cluster[i].set.size();
    sum_y/=cluster[i].set.size();
    if (sum_x!=cluster[i].base.x || sum_y!=cluster[i].base.y) {
      cluster[i].base=new Point(sum_x, sum_y);
      b&=false;
    } else {
      b&=true;
    }
  }
  return b;
}
cluster.pde
/*
* @Project  : k-means visualization
* @Architecture : Kim Bom
* cluster.pde
*
* @Created by KimBom On 2016. 06. 12...
* @Copyright (C) 2016 KimBom. All rights reserved.
*/
import java.util.LinkedList;  
import java.awt.Point;  
class Cluster {  
  public LinkedList<Point> set;
  public Point base;
  public color rgb;
  public Cluster(int x, int y,color c) {
    this.base=new Point(x,y);
    this.set=new LinkedList<Point>();
    rgb=c;
  }
  public void draw_x(int radius){
    stroke(red(rgb),green(rgb),blue(rgb));
     line(this.base.x-radius,this.base.y-radius,this.base.x+radius,this.base.y+radius);
     line(this.base.x+radius,this.base.y-radius,this.base.x-radius,this.base.y+radius);
  }
};