[Algorithm] Least Square Method

[Math] Least Square Method

최소 자승법(or 최소 제곱법) 이라고 부르며 이는 여러개의 X,Y 에 대해 이를 만족시키는 함수 f(x)를 찾는 과정이다.

이 f(x)는 1차식일수도 있고, 고차함수일수도 있다.

해당 P(x) 를 보고 대략 이런 함수일것이다 를 유추한후 그 식에 맞춰 함수의 상수를 구하면 된다. (P(x)는 실험으로 얻은 결과 이산 함수)

만일 P(x) = {(1,2),(3,4)} 이라면 y=ax+b 의 식으로 유추할 때 a=1 , b=0 이 되겠다.

여러개의 매칭 쌍이 존재할때 이를 가장 근사할 수 있는 함수식 f(x) 를 찾아야 한다.

일단 P(x) 의 형태가 직선과 유사하다고 판별될 때를 알아보자. 즉 우리가 예상한 추정함수의 형태는 y=ax+b 가 된다.

실험 이산함수 P(x) 와 우리가 추측한 함수 f(x) 의 오차가 가장 적게 하는 a,b를 찾으면 된다.

그래서 아래의 제곱에러 식을 사용한다.

위 식에서 E가 최소가 되도록 하는 a,b 를 찾는다.

이러한 최소 제곱직선의 최소값을 구하려면 a,b에 대한 편미분이 0 이 되어야 한다.

따라서, 식을 a,b로 각각 편미분 하면 아래와 같다.

이를 행렬을 통해 적고 역행렬을 이용해 a,b를 구할수 있다.

이제 역행렬을 구하면 아래와 같이 된다.

마지막 식으로의 유도가 힘들었는데, 정리하면 아래와 같다.


이를 직접 processing으로 확인한 결과는 아래와 같다.

lease square method.pde

//uses
import java.util.ArrayList;  
import java.awt.Point;  
//var
int baseX;  
int baseY;  
ArrayList<Point> pts=new ArrayList<Point>();  
float a, b;  
boolean ldw=false;

//implement
void setup() {  
  size(600, 600);
  baseX=width/2;
  baseY=height/2;
}
void draw() {  
  background(255, 255, 255);
  line(0, baseY, width, baseY);
  line(baseX, 0, baseX, height);
  fill(103, 153, 255);
  noStroke();
  for (int i=0; i<pts.size(); i++) {
    ellipse(pts.get(i).x+baseX, baseY-pts.get(i).y, 5, 5);
  }
  stroke(1);
  if (ldw==true) {
    DrawLinearGraph();
    DrawExpression();
  }
}
boolean once_flag=false;  
void DrawLinearGraph() {  
  float x1=-baseX;
  float x2=baseX;
  float y1=a*x1+b;
  float y2=a*x2+b;

  y1*=-1;  //pixel coordinate => math coordinate
  y2*=-1;
  y1+=baseY;
  y2+=baseY;
  x1+=baseX;
  x2+=baseX;
  if (once_flag==false) {
    println(x1, y1, x2, y2);
    once_flag=true;
  }
  line(x1, y1, x2, y2);
}
void DrawExpression() {  
  String exp="y="+Float.toString(a)+"x";
  if (b<0) {
    exp+=Float.toString(b);
  } else {
    exp+="+"+Float.toString(b);
  }
  textSize(16);
  fill(0, 0, 0);
  text(exp, 10, 20);
}
void keyPressed() {  
  if (key=='A' || key=='a') {
    float sx2=SX2();
    float sxy=SXY();
    float sx=SX();
    float sy=SY();
    float n=pts.size();
    float detM=(n*sx2-(sx*sx));
    a=(n*sxy-sx*sy)/detM;
    b=(sx2*sy - sx*sxy)/detM;
    ldw=true;
    /*print info*/
    for (int i=0; i<pts.size(); i++) {
      println("(", pts.get(i).x, ",", pts.get(i).y, ")");
    }
    println("a : ", a, " b : ", b);
  } else if (key=='S' || key=='s') {
    Point pt=new Point();
    pt.x=mouseX-baseX;
    pt.y=baseY-mouseY;
    pts.add(pt);
  } else if (key=='d' || key=='D') {
    pts.clear();
    ldw=false;
  }
}
float SX2() {  
  float r=0;
  for (int i=0; i<pts.size(); i++) {
    r+=pts.get(i).x*pts.get(i).x;
  }
  return r;
}
float SXY() {  
  float r=0;
  for (int i=0; i<pts.size(); i++) {
    r+=pts.get(i).x*pts.get(i).y;
  }
  return r;
}
float SX() {  
  float r=0;
  for (int i=0; i<pts.size(); i++) {
    r+=pts.get(i).x;
  }
  return r;
}
float SY() {  
  float r=0;
  for (int i=0; i<pts.size(); i++) {
    r+=pts.get(i).y;
  }
  return r;
}

References

변경이력

  • 2016-09-03 첫 글 작성