Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
nrows, ncols = a0.shape
nbins = [len(b) for b in bins]
hist_shapes = [nb+1 for nb in nbins]
# a marginally faster implementation would be to use searchsorted,
# like numpy histogram itself does
# https://github.com/numpy/numpy/blob/9c98662ee2f7daca3f9fae9d5144a9a8d3cabe8c/numpy/lib/histograms.py#L864-L882
# for now we stick with `digitize` because it's easy to understand how it works
# the maximum possible value of of digitize is nbins
# for right=False:
# - 0 corresponds to a < b[0]
# - i corresponds to bins[i-1] <= a < b[i]
# - nbins corresponds to a a >= b[1]
each_bin_indices = [digitize(a, b) for a, b in zip(args, bins)]
# product of the bins gives the joint distribution
if N_inputs > 1:
bin_indices = ravel_multi_index(each_bin_indices, hist_shapes)
else:
bin_indices = each_bin_indices[0]
# total number of unique bin indices
N = reduce(lambda x, y: x*y, hist_shapes)
bin_counts = _dispatch_bincount(bin_indices, weights, N, hist_shapes,
block_size=block_size)
# just throw out everything outside of the bins, as np.histogram does
# TODO: make this optional?
slices = (slice(None),) + (N_inputs * (slice(1, -1),))
return bin_counts[slices]