在x和y坐标的numpy数组中寻找最近点的索引

我有两个2D numpy数组:x_array包含在x方向的位置信息,y_array包含在y方向的位置。

然后我有一个很长的x,y点列表。

对于列表中的每个点,我需要find最接近该点的位置(在数组中指定)的数组索引。

我天真地产生了一些代码工作,基于这个问题: 在numpy数组中find最接近的值

import time import numpy def find_index_of_nearest_xy(y_array, x_array, y_point, x_point): distance = (y_array-y_point)**2 + (x_array-x_point)**2 idy,idx = numpy.where(distance==distance.min()) return idy[0],idx[0] def do_all(y_array, x_array, points): store = [] for i in xrange(points.shape[1]): store.append(find_index_of_nearest_xy(y_array,x_array,points[0,i],points[1,i])) return store # Create some dummy data y_array = numpy.random.random(10000).reshape(100,100) x_array = numpy.random.random(10000).reshape(100,100) points = numpy.random.random(10000).reshape(2,5000) # Time how long it takes to run start = time.time() results = do_all(y_array, x_array, points) end = time.time() print 'Completed in: ',end-start 

我正在通过一个大的数据集来做这件事,并且真的想加快一点。 任何人都可以优化吗?

谢谢。


更新:解决scheme按照@silvado和@justin(下面)的build议

 # Shoe-horn existing data for entry into KDTree routines combined_x_y_arrays = numpy.dstack([y_array.ravel(),x_array.ravel()])[0] points_list = list(points.transpose()) def do_kdtree(combined_x_y_arrays,points): mytree = scipy.spatial.cKDTree(combined_x_y_arrays) dist, indexes = mytree.query(points) return indexes start = time.time() results2 = do_kdtree(combined_x_y_arrays,points_list) end = time.time() print 'Completed in: ',end-start 

上面的代码加快了我的代码(在100x100matrix中search5000个点)100倍。 有趣的是,使用scipy.spatial.KDTree(而不是scipy.spatial.cKDTree)给我的天真的解决scheme相当的时间,所以它是绝对值得使用的cKDTree版本…

scipy.spatial也有一个kd树实现: scipy.spatial.KDTree

该方法通常是首先使用点数据来构buildkd树。 其计算复杂度为N log N,其中N是数据点的数量。 范围查询和最近邻search可以用日志N复杂度来完成。 这比简单循环遍历所有点(复杂度N)要高效得多。

因此,如果您有重复范围或最近邻居查询,强烈build议使用kd树。

这是一个scipy.spatial.KDTree示例

 In [1]: from scipy import spatial In [2]: import numpy as np In [3]: A = np.random.random((10,2))*100 In [4]: A Out[4]: array([[ 68.83402637, 38.07632221], [ 76.84704074, 24.9395109 ], [ 16.26715795, 98.52763827], [ 70.99411985, 67.31740151], [ 71.72452181, 24.13516764], [ 17.22707611, 20.65425362], [ 43.85122458, 21.50624882], [ 76.71987125, 44.95031274], [ 63.77341073, 78.87417774], [ 8.45828909, 30.18426696]]) In [5]: pt = [6, 30] # <-- the point to find In [6]: A[spatial.KDTree(A).query(pt)[1]] # <-- the nearest point Out[6]: array([ 8.45828909, 30.18426696]) #how it works! In [7]: distance,index = spatial.KDTree(A).query(pt) In [8]: distance # <-- The distances to the nearest neighbors Out[8]: 2.4651855048258393 In [9]: index # <-- The locations of the neighbors Out[9]: 9 #then In [10]: A[index] Out[10]: array([ 8.45828909, 30.18426696]) 

如果您可以将数据按照正确的格式进行处理,那么scipy.spatial.distance的方法就是使用scipy.spatial.distance的方法:

http://docs.scipy.org/doc/scipy/reference/spatial.distance.html

特别是pdistcdist提供了快速计算配对距离的方法。