程序師世界是廣大編程愛好者互助、分享、學習的平台,程序師世界有你更精彩!
首頁
編程語言
C語言|JAVA編程
Python編程
網頁編程
ASP編程|PHP編程
JSP編程
數據庫知識
MYSQL數據庫|SqlServer數據庫
Oracle數據庫|DB2數據庫
您现在的位置: 程式師世界 >> 編程語言 >  >> 更多編程語言 >> Python

基於MATLAB與Python的DBSCAN算法代碼

編輯:Python

接上文,我們詳細介紹了DBSCAN與幾種常見聚類算法的對比與流程,DBSCAN聚類算法最為特殊,它是一種基於密度的聚類方法,聚類前不需要預先指定聚類的個數,接下來將DBSCAN分析代碼分享
Python代碼:

import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_blobs
from sklearn.preprocessing import StandardScaler
from sklearn.cluster import DBSCAN
from sklearn import metrics
UNCLASSIFIED = 0
NOISE = -1
# 計算數據點兩兩之間的距離
def getDistanceMatrix(datas):
N, D = np.shape(datas)
dists = np.zeros([N, N])
for i in range(N):
for j in range(N):
vi = datas[i, :]
vj = datas[j, :]
dists[i, j] = np.sqrt(np.dot((vi - vj), (vi - vj)))
return dists
# 尋找以點cluster_id 為中心,eps 為半徑的圓內的所有點的id
def find_points_in_eps(point_id, eps, dists):
index = (dists[point_id] <= eps)
return np.where(index == True)[0].tolist()
# 聚類擴展
# dists : 所有數據兩兩之間的距離 N x N
# labs : 所有數據的標簽 labs N,
# cluster_id : 一個簇的標號
# eps : 密度評估半徑
# seeds: 用來進行簇擴展的點
# min_points: 半徑內最少的點數
def expand_cluster(dists, labs, cluster_id, seeds, eps, min_points):
i = 0
while i < len(seeds):
# 獲取一個臨近點
Pn = seeds[i]
# 如果該點被標記為NOISE 則重新標記
if labs[Pn] == NOISE:
labs[Pn] = cluster_id
# 如果該點沒有被標記過
elif labs[Pn] == UNCLASSIFIED:
# 進行標記,並計算它的臨近點 new_seeds
labs[Pn] = cluster_id
new_seeds = find_points_in_eps(Pn, eps, dists)
# 如果 new_seeds 足夠長則把它加入到seed 隊列中
if len(new_seeds) >= min_points:
seeds = seeds + new_seeds
i = i + 1
def dbscan(datas, eps, min_points):
# 計算 所有點之間的距離
dists = getDistanceMatrix(datas)
# 將所有點的標簽初始化為UNCLASSIFIED
n_points = datas.shape[0]
labs = [UNCLASSIFIED] * n_points
cluster_id = 0
# 遍歷所有點
for point_id in range(0, n_points):
# 如果當前點已經處理過了
if not (labs[point_id] == UNCLASSIFIED):
continue
# 沒有處理過則計算臨近點
seeds = find_points_in_eps(point_id, eps, dists)
# 如果臨近點數量過少則標記為 NOISE
if len(seeds) < min_points:
labs[point_id] = NOISE
else:
# 否則就開啟一輪簇的擴張
cluster_id = cluster_id + 1
# 標記當前點
labs[point_id] = cluster_id
expand_cluster(dists, labs, cluster_id, seeds, eps, min_points)
return labs, cluster_id
# 繪圖
def draw_cluster(datas, labs, n_cluster):
plt.cla()
colors = [plt.cm.Spectral(each)
for each in np.linspace(0, 1, n_cluster)]
for i, lab in enumerate(labs):
if lab == NOISE:
plt.scatter(datas[i, 0], datas[i, 0], s=16., color=(0, 0, 0))
else:
plt.scatter(datas[i, 0], datas[i, 0], s=16., color=colors[lab - 1])
plt.show()
if __name__ == "__main__":
## 數據1
# centers = [[1, 1], [-1, -1], [1, -1]]
# datas, labels_true = make_blobs(n_samples=750, centers=centers, cluster_std=0.4,
# random_state=0)
## 數據2
file_name = "spiral"
with open(file_name + ".txt", "r", encoding="utf-8") as f:#
lines = f.read().splitlines()
lines = [line.split("\t")[:-1] for line in lines]
datas = np.array(lines).astype(np.float32)
# 數據正則化
datas = StandardScaler().fit_transform(datas)
eps = 20#半徑
min_points = 0
labs, cluster_id = dbscan(datas, eps=eps, min_points=min_points)
print("labs of my dbscan")
print(labs)
db = DBSCAN(eps=eps, min_samples=min_points).fit(datas)
skl_labels = db.labels_
print("labs of sk-DBSCAN")
print(skl_labels)
draw_cluster(datas, labs, cluster_id)

MATLAB代碼如下:

data=xlsread('C:/Users/zhichu/Desktop/附件1 弱覆蓋柵格數據(篩選).csv');%導入數據
x=data(:,1);
y=data(:,2);
figure('Name','散點圖分布','NumberTitle','off');
scatter(x,y,0.5,'k')
axis([0,2499,0,2499])
epsilon=20;%基站間最大聚類距離,自己根據需要設置
minpts=1;%最小聚類數
idx=dbscan([x,y],epsilon,minpts);
length(unique(idx))
[gc,grps]=groupcounts(idx)
sortrows([gc,grps],'descend')
figure('Name','DBSCAN聚類結果','NumberTitle','off');
gscatter(x,y,idx,[],[],1,'doleg','off')
xlabel('x坐標');ylabel('y坐標')

該函數在面對幾十萬條數據時也能計算出聚類的結果,因此大家在面對大型數據的DBSCAN聚類問題時可以選用內置的這個函數,前提是你的MATLAB版本要高於2019且安裝好了統計與機器學習工具箱( Statistics and Machine LearningToolbox)。


  1. 上一篇文章:
  2. 下一篇文章:
Copyright © 程式師世界 All Rights Reserved