聚類算法入門:k-means

一、聚類定義

  • 聚類分析(cluster analysis)就是給你一堆雜七雜八的樣本數據把它們分成幾個組,組內成員有一定的相似,不同組之間成員有一定的差別。
  • 區別與分類分析(classification analysis) 你事先並不知道有哪幾類、劃分每個類別的標準。
  • 比如垃圾分類就是分類算法,你知道豬能吃的是溼垃圾,不能吃的是幹垃圾……;打掃房間時你把雜物都分分類,這是聚類,你事先不知道每個類別的標準。

二、劃分聚類方法: K-means:

對於給定的樣本集,按照樣本之間的距離(也就是相似程度)大小,將樣本集劃分為K個簇(即類別)。讓簇內的點儘量緊密的連在一起,而讓簇間的距離儘量的大。

  • 步驟1:隨機取k個初始中心點
  • 步驟2:對於每個樣本點計算到這k箇中心點的距離,將樣本點歸到與之距離最小的那個中心點的簇。這樣每個樣本都有自己的簇了
  • 步驟3:對於每個簇,根據裡面的所有樣本點重新計算得到一個新的中心點,如果中心點發生變化回到步驟2,未發生變化轉到步驟4
  • 步驟4:得出結果
    就像這樣


聚類算法入門:k-means


缺點:
初始值敏感、採用迭代方法,得到的結果只是局部最優、K值的選取不好把握、對於不是凸的數據集比較難收斂
如何衡量Kmeans 算法的精確度?
SSE(Sum of Square Error) 誤差平方和, SSE越小,精確度越高。

三、改進算法-二分Kmeans

  • 首先將所有點作為一個簇,然後將其一分為二。
  • 每次選擇一個簇一分為二,選取簇的依據取決於其是否能最大程度降低SSE即選取聚類後SSE最小的一個簇進行劃分。
  • 直至有k個簇

四、Kmeans Code

import numpy as np
import matplotlib.pyplot as plt
import scipy.io as scio
# %matplotlib inline

def K_Means(X, sp, K):

# 計算臨近點
def near(p):
dis = [np.sum(np.square(x-p)) for x in sp]
return dis.index(min(dis))
# 打印結果

def print_result(sp_list):
#打印中心點迭代軌跡
sp_list = [np.array([x[k] for x in sp_list]) for k in range(K)]
for k in range(K):
plt.plot(sp_list[k][:,0], sp_list[k][:,1], 'k->', label='type{}'.format(k))

#分類打印其他點
p_list = [[] for k in range(K)]
for p in X:
i = near(p)
p_list[i].append(p)
p_list = [np.array(x) for x in p_list]

color = ['r','g','b']
for i in range(K):
plt.plot(p_list[i][:,0], p_list[i][:,1],color[i]+'o')

plt.title('K-Means Result')
plt.xlabel('X')
plt.ylabel('Y')
plt.legend('123')
plt.show()

# 迭代中心點
sp_list = []
sp_list.append(sp)
while True:
count = np.zeros(K)
sp_t = np.zeros((K,2))
for p in X:
i = near(p)
count[i] += 1
sp_t[i] += p
sp_t = np.array([sp_t[i]/count[i] for i in range(K)])
SSE = np.sum(np.square(sp-sp_t))
if SSE < 0.001:
break
sp = sp_t
sp_list.append(sp)
print_result(sp_list)
print('聚類中心:')
for p in sp:
print(p, end=',')

if __name__ == '__main__':
data = scio.loadmat('ex7data2.mat')
X = data['X']
K = 3

sp = np.array([[3, 3], [6, 2], [8, 5]]) # starting point
K_Means(X, sp, K)
聚類算法入門:k-means

kmeans聚類結果

K為3聚類中心: [1.95399466 5.02557006],[3.04367119 1.01541041],[6.03366736 3.00052511]
如需要測試數據請留言

本文由作者授權轉載並稍加修改:https://tawn0000.github.io


分享到:


相關文章: