用(n-1)d数组索引n维数组

什么是最优雅的方式来访问一个n维数组与一个给定维度(n-1)维数组,如虚拟示例

a = np.random.random_sample((3,4,4)) b = np.random.random_sample((3,4,4)) idx = np.argmax(a, axis=0) 

我现在如何访问idx a来获取最大值,就好像我已经使用了a.max(axis=0) ? 或者如何检索由idx中的idx指定的值?

我想过使用np.meshgrid但我认为这是一个矫枉过正的问题。 请注意,维度axis可以是任何有用的轴(0,1,2),并且事先不知道。 有没有一个优雅的方式来做到这一点?

利用advanced-indexing

 m,n = a.shape[1:] I,J = np.ogrid[:m,:n] a_max_values = a[idx, I, J] b_max_values = b[idx, I, J] 

一般情况下:

 def argmax_to_max(arr, argmax, axis): """argmax_to_max(arr, arr.argmax(axis), axis) == arr.max(axis)""" new_shape = list(arr.shape) del new_shape[axis] grid = np.ogrid[tuple(map(slice, new_shape))] grid.insert(axis, argmax) return arr[tuple(grid)] 

不幸的是,这样的自然操作应该会更尴尬。

为了用一个(n-1) dim数组索引一个n dimarrays,我们可以简化一下,为我们提供所有轴的索引网格,

 def all_idx(idx, axis): grid = np.ogrid[tuple(map(slice, idx.shape))] grid.insert(axis, idx) return tuple(grid) 

因此,使用它来索引input数组 –

 axis = 0 a_max_values = a[all_idx(idx, axis=axis)] b_max_values = b[all_idx(idx, axis=axis)]