程序師世界是廣大編程愛好者互助、分享、學習的平台,程序師世界有你更精彩!
首頁
編程語言
C語言|JAVA編程
Python編程
網頁編程
ASP編程|PHP編程
JSP編程
數據庫知識
MYSQL數據庫|SqlServer數據庫
Oracle數據庫|DB2數據庫
 程式師世界 >> 編程語言 >> JAVA編程 >> JAVA綜合教程 >> Kmeans算法java實現

Kmeans算法java實現

編輯:JAVA綜合教程

Kmeans算法java實現


K-means算法是機器學習的基本算法,也是很簡單的一個了。最近剛剛仔細學習了這個算法,參照網上現有的資源,自己重新寫了一遍,

自己獨立寫真是錯誤百出,找bug找的好辛苦哭哭哭整理如下微笑微笑微笑大哭大哭大哭

package algorithm;

import java.util.*;

public class K_means {
	private ArrayList dataSet; // 數據集鏈表
	private ArrayList center; // 中心點鏈表
	private ArrayList> cluster; // 聚類鏈表
	private int k; // 類數
	private int m; // 迭代次數
	private int dataSetLength; // 數據集長度
	private ArrayList wc; // 每次迭代的誤差鏈表

	public K_means(int k) { // 構造函數
		if (k < 1)
			k = 1;
		this.k = k;
		dataSet = new ArrayList();
		center = new ArrayList();
		cluster = new ArrayList>();
		m = 0;
		dataSetLength = 0;
		wc = new ArrayList();
	}

	private ArrayList> getCluster() { // 獲取聚類鏈表
		return cluster;
	}

	private void init() { // 初始化
		if (dataSet == null || dataSet.size() == 0) {
			double[][] dataSetArray = new double[][] { { 8, 2 }, { 3, 4 },
					{ 2, 5 }, { 4, 2 }, { 7, 3 }, { 6, 2 }, { 4, 7 }, { 6, 3 },
					{ 5, 3 }, { 6, 3 }, { 6, 9 }, { 1, 6 }, { 3, 9 }, { 4, 1 },
					{ 8, 6 } };
			for (int i = 0; i < dataSetArray.length; i++) {
				dataSet.add(dataSetArray[i]);
			}
		}
		dataSetLength = dataSet.size();
		if (k > dataSetLength) {
			k = dataSetLength;
		}
		center = initCenter();
		cluster = initCluster();
	}

	private ArrayList initCenter() {//初始化中心點,防止找到數據集中的同一個點,但是有一個問題是也不保證數據集中存在相同的點
		ArrayList center = new ArrayList();
		int[] randoms = new int[k];
		boolean flag;
		Random random = new Random();
		int temp = random.nextInt(dataSetLength);
		randoms[0] = temp;
		for (int i = 1; i < k; i++) {
			flag = true;
			while (flag) {
				temp = random.nextInt(dataSetLength);
				int j = 0;
				while (j < i) {
					if (temp == randoms[j]) {
						break;
					}
					j++;
				}
				if (j == i) {
					flag = false;
				}
			}
			randoms[i] = temp;
		}

		

		for (int i = 0; i < k; i++) {

			center.add(dataSet.get(randoms[i]));// 生成初始化中心鏈表
			// System.out.println(center.get(i)[0]+" "+center.get(i)[1]);
		}
		return center;

	}

	private ArrayList> initCluster() {初始化空的聚類鏈表
		for (int i = 0; i < k; i++) {
			ArrayList clusters = new ArrayList();
			cluster.add(clusters);
		}
		return cluster;
	}

	private double distance(double[] point, double[] center) {
		double x = point[0] - center[0];
		double y = point[1] - center[1];
		double distance = x * x + y * y;
		return distance;
	}

	private int minDistance(double[] distance) {
		double minDistance = distance[0];
		int minLocation = 0;
		for (int i = 1; i < k; i++) {
			if (minDistance > distance[i]) {
				minDistance = distance[i];
				minLocation = i;
			} else if (distance[i] == minDistance) // 如果相等,隨機返回一個位置
			{
				Random random = new Random();
				if (random.nextInt(10) < 5) {
					minLocation = i;
				}
			}
		}
		return minLocation;
	}

	private void setCluster() { //重新設置
		double[] dist = new double[k];
		for (int i = 0; i < dataSetLength; i++) {
			for (int j = 0; j < k; j++) {
				dist[j] = distance(dataSet.get(i), center.get(j));
			}
			cluster.get(minDistance(dist)).add(dataSet.get(i));
		}
	}

	private ArrayList updateCenter() {//更新中心點
		
		
		for (int i = 0; i < k; i++) {
			double[] newCenter = new double[2];
			int n = cluster.get(i).size();
			if (n != 0) {
				for (int j = 0; j < n; j++) {
					newCenter[0] += cluster.get(i).get(j)[0];
					newCenter[1] += cluster.get(i).get(j)[1];
				}
				newCenter[0] = newCenter[0]/n;
				newCenter[1] = newCenter[1]/n;
				center.set(i, newCenter);
			}
		}
		return center;
	
	}

	private void errorSquare() {   //計算每次迭代後,所有點與其對應中心點的距離誤差值
		double errorValue = 0;
		for (int i = 0; i < k; i++) {
			for (int j = 0; j < cluster.get(i).size(); j++) {

				errorValue += distance(cluster.get(i).get(j), center.get(i));
			}
		}
		wc.add(errorValue);
	}

	private void kmeans() {
		init();

		while (true) {
			// System.out.println(cluster.size());
			setCluster();
			errorSquare();
			if (m != 0) {
				if (wc.get(m) - wc.get(m - 1) == 0)
					System.out.println(m);
				break;
			}
			center = updateCenter();
			m++;
			cluster.clear();
			cluster = initCluster();
		}
	}

	public void printDataArray(ArrayList dataArray,
			String dataArrayName) {
		for (int i = 0; i < dataArray.size(); i++) {
			System.out.println("print:" + dataArrayName + "[" + i + "]={"
					+ dataArray.get(i)[0] + "," + dataArray.get(i)[1] + "}");
		}
		System.out.println("===================================");
	}

	private void execute() {
		long startTime = System.currentTimeMillis();
		System.out.println("kmeans begins");
		kmeans();
		long endTime = System.currentTimeMillis();
		System.out.println("kmeans running time=" + (endTime - startTime)
				+ "ms");
		System.out.println("kmeans ends");
		System.out.println();
	}

	public static void main(String[] args) {
		K_means kl = new K_means(10);
		kl.execute();
		// System.out.println(kl.center.get(9)[0]+" "+kl.center.get(9)[1]);
		ArrayList> cluster = kl.getCluster();
		
		for (int i = 0; i < cluster.size(); i++) {
			kl.printDataArray(cluster.get(i), "cluster[" + i + "]");
		}
	}
}

  1. 上一頁:
  2. 下一頁:
Copyright © 程式師世界 All Rights Reserved