MeanShift聚类-02python案例

网友投稿 894 2022-10-08

MeanShift聚类-02python案例

MeanShift聚类-02python案例

Intro

Meanshift的使用案例~

数据引入

from sklearn.cluster import MeanShift, estimate_bandwidthimport matplotlib.pyplot as pltfrom itertools import cycleimport numpy as npimport warningswarnings.filterwarnings("ignore")%matplotlib inline

from sklearn.datasets import load_irisimport pandas as pdpd.set_option('display.max_rows', 500) # 打印最大行数pd.set_option('display.max_columns', 500) # 打印最大列数

# 检查是否是array格式,如果不是,转换成arrayfrom sklearn.utils import check_arrayfrom sklearn.utils import check_random_statefrom sklearn.neighbors import NearestNeighborsfrom sklearn.utils import

iris_df = pd.DataFrame( load_iris()["data"], columns=["sepal_length", "sepal_width", "petal_length", "petal_width"])iris_df["target"] = load_iris()["target"]iris_df.head()

sepal_length

sepal_width

petal_length

petal_width

target

0

5.1

3.5

1.4

0.2

0

1

4.9

3.0

1.4

0.2

0

2

4.7

3.2

1.3

0.2

0

3

4.6

3.1

1.5

0.2

0

4

5.0

3.6

1.4

0.2

0

iris_df.groupby(by="target").describe()

sepal_length

sepal_width

petal_length

petal_width

count

mean

std

min

25%

50%

75%

max

count

mean

std

min

25%

50%

75%

max

count

mean

std

min

25%

50%

75%

max

count

mean

std

min

25%

50%

75%

max

target

0

50.0

5.006

0.352490

4.3

4.800

5.0

5.2

5.8

50.0

3.428

0.379064

2.3

3.200

3.4

3.675

4.4

50.0

1.462

0.173664

1.0

1.4

1.50

1.575

1.9

50.0

0.246

0.105386

0.1

0.2

0.2

0.3

0.6

1

50.0

5.936

0.516171

4.9

5.600

5.9

6.3

7.0

50.0

2.770

0.313798

2.0

2.525

2.8

3.000

3.4

50.0

4.260

0.469911

3.0

4.0

4.35

4.600

5.1

50.0

1.326

0.197753

1.0

1.2

1.3

1.5

1.8

2

50.0

6.588

0.635880

4.9

6.225

6.5

6.9

7.9

50.0

2.974

0.322497

2.2

2.800

3.0

3.175

3.8

50.0

5.552

0.551895

4.5

5.1

5.55

5.875

6.9

50.0

2.026

0.274650

1.4

1.8

2.0

2.3

2.5

从数据上看,三个种类之间,petal_length和petal_width的差异比较大,用它来画图。

# colors = cycle('bgrcmykbgrcmykbgrcmykbgrcmyk')colors =["red","yellow","blue"]marker = ["o","*","+"]for k, col,mark in zip(range(3), colors,marker): sub_data = iris_df.query("target==%s"%k) plt.plot(sub_data.petal_length, sub_data.petal_width,"o", markerfacecolor=col, markeredgecolor='k', markersize=5)plt.show()

可以看到红色点和其余点相差很多,蓝色和黄色有部分点交错在一起

默认参数进行聚类

# ms = MeanShift( bin_seeding=True,cluster_all=False)bandwidth = 0.726ms = MeanShift(bandwidth=bandwidth, bin_seeding=True)ms.fit(iris_df[["petal_length", "petal_width"]])labels = ms.labels_cluster_centers = ms.cluster_centers_labels_unique = np.unique(labels)n_clusters_ = len(labels_unique)print("number of estimated clusters : %d" % n_clusters_)# ############################################################################## Plot resultimport matplotlib.pyplot as pltfrom itertools import cycleplt.figure(1)plt.clf()# colors = cycle('bgrcmykbgrcmykbgrcmykbgrcmyk')colors = ["yellow", "red", "blue"]marker = ["o", "*", "+"]for k, col, mark in zip(range(n_clusters_), colors, marker): my_members = labels == k cluster_center = cluster_centers[k] plt.plot(iris_df[my_members].petal_length, iris_df[my_members].petal_width, ".", markerfacecolor=col, markeredgecolor='k', markersize=6) plt.plot(cluster_center[0], cluster_center[1], 'o', markerfacecolor=col, markeredgecolor='k', markersize=14) circle = plt.Circle((cluster_center[0], cluster_center[1]), bandwidth, color='black', fill=False) plt.gcf().gca().add_artist(circle)plt.title('Estimated number of clusters: %d' % n_clusters_)plt.show()

number of estimated clusters : 3

从图上看,红色部分自成一派,聚类效果就好,蓝黄两类互有交叉,以最靠近的类别中心来打label.

estimate_bandwidth方法

根据聚类的原始数据,生成建议的bandwidth,基础逻辑:

先抽样,获取部分样本计算这样样本和所有点的最大距离对距离求平均

从逻辑上看,更像是找一个较大的距离,使得能涵盖更多的点

estimate_bandwidth(iris_df[["petal_length", "petal_width"]])

0.7266371274126329

计算距离,check下

from sklearn.neighbors import

nbrs = NearestNeighbors(n_neighbors=len(iris_df), n_jobs=-1)nbrs.fit(iris_df.iloc[:,[2,3]])

NearestNeighbors(algorithm='auto', leaf_size=30, metric='minkowski', metric_params=None, n_jobs=-1, n_neighbors=150, p=2, radius=1.0)

d, index = nbrs.kneighbors(iris_df.iloc[:,[2,3]],return_distance=True)

from functools import reduce #python 3total_distance = reduce(lambda x,y: x+y,np.array(pd.DataFrame(d).iloc[:,1:150]).tolist())

from scipy import

stats.describe(total_distance)

DescribeResult(nobs=22350, minmax=(0.0, 6.262587324740471), mean=2.185682454621745, variance=2.6174775533104904, skewness=0.3422940721262964, kurtosis=-1.1637573960810108)

pd.DataFrame({"total_distance":total_distance}).describe()

total_distance

count

22350.000000

mean

2.185682

std

1.617862

min

0.000000

25%

0.640312

50%

1.941649

75%

3.544009

max

6.262587

从数据上看,有点接近25%分位数。

meanshift的简单介绍到此为止,有些业务场景下,这个算法还是很好用的。需要具体问题具体分析。

2021-03-31 于南京市江宁区九龙湖

版权声明:本文内容由网络用户投稿,版权归原作者所有,本站不拥有其著作权,亦不承担相应法律责任。如果您发现本站中有涉嫌抄袭或描述失实的内容,请联系我们jiasou666@gmail.com 处理,核实后本网站将在24小时内删除侵权内容。

上一篇:PythonNote030---sklearn近邻api使用
下一篇:EFX小程序框架
相关文章

 发表评论

暂时没有评论,来抢沙发吧~