K-means算法是機器學習的基本算法,也是很簡單的一個了。最近剛剛仔細學習了這個算法,參照網上現有的資源,自己重新寫了一遍,
自己獨立寫真是錯誤百出,找bug找的好辛苦

整理如下

,


package algorithm;
import java.util.*;
public class K_means {
private ArrayList dataSet; // 數據集鏈表
private ArrayList center; // 中心點鏈表
private ArrayList> cluster; // 聚類鏈表
private int k; // 類數
private int m; // 迭代次數
private int dataSetLength; // 數據集長度
private ArrayList wc; // 每次迭代的誤差鏈表
public K_means(int k) { // 構造函數
if (k < 1)
k = 1;
this.k = k;
dataSet = new ArrayList();
center = new ArrayList();
cluster = new ArrayList>();
m = 0;
dataSetLength = 0;
wc = new ArrayList();
}
private ArrayList> getCluster() { // 獲取聚類鏈表
return cluster;
}
private void init() { // 初始化
if (dataSet == null || dataSet.size() == 0) {
double[][] dataSetArray = new double[][] { { 8, 2 }, { 3, 4 },
{ 2, 5 }, { 4, 2 }, { 7, 3 }, { 6, 2 }, { 4, 7 }, { 6, 3 },
{ 5, 3 }, { 6, 3 }, { 6, 9 }, { 1, 6 }, { 3, 9 }, { 4, 1 },
{ 8, 6 } };
for (int i = 0; i < dataSetArray.length; i++) {
dataSet.add(dataSetArray[i]);
}
}
dataSetLength = dataSet.size();
if (k > dataSetLength) {
k = dataSetLength;
}
center = initCenter();
cluster = initCluster();
}
private ArrayList initCenter() {//初始化中心點,防止找到數據集中的同一個點,但是有一個問題是也不保證數據集中存在相同的點
ArrayList center = new ArrayList();
int[] randoms = new int[k];
boolean flag;
Random random = new Random();
int temp = random.nextInt(dataSetLength);
randoms[0] = temp;
for (int i = 1; i < k; i++) {
flag = true;
while (flag) {
temp = random.nextInt(dataSetLength);
int j = 0;
while (j < i) {
if (temp == randoms[j]) {
break;
}
j++;
}
if (j == i) {
flag = false;
}
}
randoms[i] = temp;
}
for (int i = 0; i < k; i++) {
center.add(dataSet.get(randoms[i]));// 生成初始化中心鏈表
// System.out.println(center.get(i)[0]+" "+center.get(i)[1]);
}
return center;
}
private ArrayList> initCluster() {初始化空的聚類鏈表
for (int i = 0; i < k; i++) {
ArrayList clusters = new ArrayList();
cluster.add(clusters);
}
return cluster;
}
private double distance(double[] point, double[] center) {
double x = point[0] - center[0];
double y = point[1] - center[1];
double distance = x * x + y * y;
return distance;
}
private int minDistance(double[] distance) {
double minDistance = distance[0];
int minLocation = 0;
for (int i = 1; i < k; i++) {
if (minDistance > distance[i]) {
minDistance = distance[i];
minLocation = i;
} else if (distance[i] == minDistance) // 如果相等,隨機返回一個位置
{
Random random = new Random();
if (random.nextInt(10) < 5) {
minLocation = i;
}
}
}
return minLocation;
}
private void setCluster() { //重新設置
double[] dist = new double[k];
for (int i = 0; i < dataSetLength; i++) {
for (int j = 0; j < k; j++) {
dist[j] = distance(dataSet.get(i), center.get(j));
}
cluster.get(minDistance(dist)).add(dataSet.get(i));
}
}
private ArrayList updateCenter() {//更新中心點
for (int i = 0; i < k; i++) {
double[] newCenter = new double[2];
int n = cluster.get(i).size();
if (n != 0) {
for (int j = 0; j < n; j++) {
newCenter[0] += cluster.get(i).get(j)[0];
newCenter[1] += cluster.get(i).get(j)[1];
}
newCenter[0] = newCenter[0]/n;
newCenter[1] = newCenter[1]/n;
center.set(i, newCenter);
}
}
return center;
}
private void errorSquare() { //計算每次迭代後,所有點與其對應中心點的距離誤差值
double errorValue = 0;
for (int i = 0; i < k; i++) {
for (int j = 0; j < cluster.get(i).size(); j++) {
errorValue += distance(cluster.get(i).get(j), center.get(i));
}
}
wc.add(errorValue);
}
private void kmeans() {
init();
while (true) {
// System.out.println(cluster.size());
setCluster();
errorSquare();
if (m != 0) {
if (wc.get(m) - wc.get(m - 1) == 0)
System.out.println(m);
break;
}
center = updateCenter();
m++;
cluster.clear();
cluster = initCluster();
}
}
public void printDataArray(ArrayList dataArray,
String dataArrayName) {
for (int i = 0; i < dataArray.size(); i++) {
System.out.println("print:" + dataArrayName + "[" + i + "]={"
+ dataArray.get(i)[0] + "," + dataArray.get(i)[1] + "}");
}
System.out.println("===================================");
}
private void execute() {
long startTime = System.currentTimeMillis();
System.out.println("kmeans begins");
kmeans();
long endTime = System.currentTimeMillis();
System.out.println("kmeans running time=" + (endTime - startTime)
+ "ms");
System.out.println("kmeans ends");
System.out.println();
}
public static void main(String[] args) {
K_means kl = new K_means(10);
kl.execute();
// System.out.println(kl.center.get(9)[0]+" "+kl.center.get(9)[1]);
ArrayList> cluster = kl.getCluster();
for (int i = 0; i < cluster.size(); i++) {
kl.printDataArray(cluster.get(i), "cluster[" + i + "]");
}
}
}