PythonNote030---sklearn近邻api使用

网友投稿 691 2022-10-08

PythonNote030---sklearn近邻api使用

PythonNote030---sklearn近邻api使用

Intro

近邻相关计算的api,底层用了kdtree,速度更快。简单整理总结下,备查。

Case1

import numpy as npimport pandas as pd from sklearn.neighbors import

samples = [[0., 0., 0.], [0., .5, 0.], [1., 1., .5], [2, 0, 0], [2, 1, 0]]df = pd.DataFrame(np.array(samples), columns=["x", "y", "z"])

x

y

z

0

0.0

0.0

0.0

1

0.0

0.5

0.0

2

1.0

1.0

0.5

3

2.0

0.0

0.0

4

2.0

1.0

0.0

df.values

array([[0. , 0. , 0. ], [0. , 0.5, 0. ], [1. , 1. , 0.5], [2. , 0. , 0. ], [2. , 1. , 0. ]])

# 半径or距离为1.6neigh = NearestNeighbors(radius=1.6)neigh.fit(df.values)

NearestNeighbors(algorithm='auto', leaf_size=30, metric='minkowski', metric_params=None, n_jobs=None, n_neighbors=5, p=2, radius=1.6)

计算samples里的点和点x的距离,并且筛选出距离小于radius(1.6)的点

第一个array返回具体距离第二个array返回index

x = [1., 1., 1.]rng = neigh.radius_neighbors(X=[x],return_distance=True)

(array([array([1.5 , 0.5 , 1.41421356])], dtype=object), array([array([1, 2, 4], dtype=int64)], dtype=object))

可以看到samples的5个点,只有index为1、2、4这三个点距离点x在1.6以内。以第4个位置的点[2,1,0]为例,距离为​​sqrt((2-1)^2+(1-1)^2+(0-1)^2)=sqrt(2)​​

radius_neighbors还有其他几个参数:

radius=None,半径应该可以重新制定return_distance=True 是否返回距离sort_results=False 是否排序

Case2

多个点时,返回多个array,每个array是满足条件的array。举个实际案例,方便理解。 滴滴打车,有5辆车,3个用户,分别找出每个用户附近满足距离的车辆。

car_list = [[0., 0., 0.], [0., .5, 0.], [1., 1., .5], [2, 0, 0], [2, 1, 0]]car_df = pd.DataFrame(np.array(car_list), columns=["x", "y", "z"])

x

y

z

0

0.0

0.0

0.0

1

0.0

0.5

0.0

2

1.0

1.0

0.5

3

2.0

0.0

0.0

4

2.0

1.0

0.0

user_list = [[1., 1., 1.],[0,0,0],[3,4,5]]user_df = pd.DataFrame(np.array(user_list), columns=["x", "y", "z"])

x

y

z

0

1.0

1.0

1.0

1

0.0

0.0

0.0

2

3.0

4.0

5.0

# 半径or距离为1.6neigh = NearestNeighbors(radius=1.6)# car_df建立kdtreeneigh.fit(car_df.values)

NearestNeighbors(algorithm='auto', leaf_size=30, metric='minkowski', metric_params=None, n_jobs=None, n_neighbors=5, p=2, radius=1.6)

result = neigh.radius_neighbors(X=user_df, return_distance=True, sort_results=True)

(array([array([1.5 , 0.5 , 1.41421356]), array([0. , 0.5, 1.5]), array([], dtype=float64)], dtype=object), array([array([1, 2, 4], dtype=int64), array([0, 1, 2], dtype=int64), array([], dtype=int64)], dtype=object))

排序并没有什么用~

def convert_ndarray_list(nd_array): result = [] for a in nd_array: result.append(a.tolist()) return

pd.concat([user_df,pd.DataFrame({"car_index":convert_ndarray_list(result[0]),"distance":convert_ndarray_list(result[1])})],axis=1)

x

y

z

car_index

distance

0

1.0

1.0

1.0

[1.5, 0.5, 1.4142135623730951]

[1, 2, 4]

1

0.0

0.0

0.0

[0.0, 0.5, 1.5]

[0, 1, 2]

2

3.0

4.0

5.0

[]

[]

Case3

找到每个乘客最近的一辆车

result = neigh.kneighbors(X=user_df, return_distance=True)

(array([[0.5 , 1.41421356, 1.5 , 1.73205081, 1.73205081], [0. , 0.5 , 1.5 , 2. , 2.23606798], [5.7662813 , 5.91607978, 6.4807407 , 6.80073525, 7.07106781]]), array([[2, 4, 1, 0, 3], [0, 1, 2, 3, 4], [2, 4, 3, 1, 0]], dtype=int64))

上面计算了乘客到各个车辆的距离,并且做了排序,如果只取一个,n_neighbors=1即可

result = neigh.kneighbors(X=user_df, n_neighbors=1, return_distance=True)

(array([[0.5 ], [0. ], [5.7662813]]), array([[2], [0], [2]], dtype=int64))

pd.concat([user_df,pd.DataFrame({"car_index":convert_ndarray_list(result[0]),"distance":convert_ndarray_list(result[1])})],axis=1)

x

y

z

car_index

distance

0

1.0

1.0

1.0

[0.5]

[2]

1

0.0

0.0

0.0

[0.0]

[0]

2

3.0

4.0

5.0

[5.766281297335398]

[2]

Ref

​​[1] 于南京市江宁区九龙湖

版权声明:本文内容由网络用户投稿,版权归原作者所有,本站不拥有其著作权,亦不承担相应法律责任。如果您发现本站中有涉嫌抄袭或描述失实的内容,请联系我们jiasou666@gmail.com 处理,核实后本网站将在24小时内删除侵权内容。

上一篇:微信小程序的轮播图swiper问题(swiper实现轮播图)
下一篇:MeanShift聚类-02python案例
相关文章

 发表评论

暂时没有评论,来抢沙发吧~