"""
Implements cartesian products and regular cartesian grids, and provides
a function that constructs a grid for a simplex as well as one that
determines the index of a point in the simplex.
"""
import numpy as np
import scipy.special
from numba import jit, njit
from .util.numba import comb_jit
[docs]def cartesian(nodes, order='C'):
'''
Cartesian product of a list of arrays
Parameters
----------
nodes : list(array_like(ndim=1))
order : str, optional(default='C')
('C' or 'F') order in which the product is enumerated
Returns
-------
out : ndarray(ndim=2)
each line corresponds to one point of the product space
'''
nodes = [np.asarray(e) for e in nodes]
shapes = [e.shape[0] for e in nodes]
dtype = np.result_type(*nodes)
n = len(nodes)
l = np.prod(shapes)
out = np.zeros((l, n), dtype=dtype)
if order == 'C':
repetitions = np.cumprod([1] + shapes[:-1])
else:
shapes.reverse()
sh = [1] + shapes[:-1]
repetitions = np.cumprod(sh)
repetitions = repetitions.tolist()
repetitions.reverse()
for i in range(n):
_repeat_1d(nodes[i], repetitions[i], out[:, i])
return out
[docs]def mlinspace(a, b, nums, order='C'):
'''
Constructs a regular cartesian grid
Parameters
----------
a : array_like(ndim=1)
lower bounds in each dimension
b : array_like(ndim=1)
upper bounds in each dimension
nums : array_like(ndim=1)
number of nodes along each dimension
order : str, optional(default='C')
('C' or 'F') order in which the product is enumerated
Returns
-------
out : ndarray(ndim=2)
each line corresponds to one point of the product space
'''
a = np.asarray(a, dtype='float64')
b = np.asarray(b, dtype='float64')
nums = np.asarray(nums, dtype='int64')
nodes = [np.linspace(a[i], b[i], nums[i]) for i in range(len(nums))]
return cartesian(nodes, order=order)
@njit
def _repeat_1d(x, K, out):
'''
Repeats each element of a vector many times and repeats the whole
result many times
Parameters
----------
x : ndarray(ndim=1)
vector to be repeated
K : scalar(int)
number of times each element of x is repeated (inner iterations)
out : ndarray(ndim=1)
placeholder for the result
Returns
-------
None
'''
N = x.shape[0]
L = out.shape[0] // (K*N) # number of outer iterations
# K # number of inner iterations
# the result out should enumerate in C-order the elements
# of a 3-dimensional array T of dimensions (K,N,L)
# such that for all k,n,l, we have T[k,n,l] == x[n]
for n in range(N):
val = x[n]
for k in range(K):
for l in range(L):
ind = k*N*L + n*L + l
out[ind] = val
[docs]def cartesian_nearest_index(x, nodes, order='C'):
"""
Return the index of the point closest to `x` within the cartesian
product generated by `nodes`. Each array in `nodes` must be sorted
in ascending order.
Parameters
----------
x : array_like(ndim=1 or 2)
Point(s) to search the closest point(s) for.
nodes : array_like(array_like(ndim=1))
Array of sorted arrays.
order : str, optional(default='C')
('C' or 'F') order in which the product is enumerated.
Returns
-------
scalar(int) or ndarray(int, ndim=1)
Index (indices) of the closest point(s) to `x`.
Examples
--------
>>> nodes = (np.arange(3), np.arange(2))
>>> prod = qe.cartesian(nodes)
>>> print(prod)
[[0 0]
[0 1]
[1 0]
[1 1]
[2 0]
[2 1]]
Among the 6 points in the cartesian product `prod`, the closest to
the point (0.6, 0.4) is `prod[2]`:
>>> x = (0.6, 0.4)
>>> qe.cartesian_nearest_index(x, nodes) # Pass `nodes`, not `prod`
2
The closest to (-0.1, 1.2) and (2, 0) are `prod[1]` and `prod[4]`,
respectively:
>>> x = [(-0.1, 1.2), (2, 0)]
>>> qe.cartesian_nearest_index(x, nodes)
array([1, 4])
Internally, the index in each dimension is searched by binary search
and then the index in the cartesian product is calculated (*not* by
constructing the cartesian product and then searching linearly over
it).
"""
x = np.asarray(x)
is_1d = False
shape = x.shape
if len(shape) == 1:
is_1d = True
x = x[np.newaxis]
types = [type(e[0]) for e in nodes]
dtype = np.result_type(*types)
nodes = tuple(np.asarray(e, dtype=dtype) for e in nodes)
n = shape[1-is_1d]
if len(nodes) != n:
msg = 'point `x`' if is_1d else 'points in `x`'
msg += ' must have same length as `nodes`'
raise ValueError(msg)
out = _cartesian_nearest_indices(x, nodes, order=order)
if is_1d:
return out[0]
return out
@njit(cache=True)
def _cartesian_nearest_indices(X, nodes, order='C'):
"""
The main body of `cartesian_nearest_index`, jit-complied by Numba.
Note that `X` must be a 2-dim ndarray, and a Python list is not
accepted for `nodes`.
Parameters
----------
X : ndarray(ndim=2)
Points to search the closest points for.
nodes : tuple(ndarray(ndim=1))
Tuple of sorted ndarrays of same dtype.
order : str, optional(default='C')
('C' or 'F') order in which the product is enumerated.
Returns
-------
ndarray(int, ndim=1)
Indices of the closest points to the points in `X`.
"""
m, n = X.shape # m vectors of length n
nums_grids = np.empty(n, dtype=np.intp)
for i in range(n):
nums_grids[i] = len(nodes[i])
ind = np.empty(n, dtype=np.intp)
out = np.empty(m, dtype=np.intp)
step = -1 if order == 'F' else 1
slice_ = slice(None, None, step)
for t in range(m):
for i in range(n):
if X[t, i] <= nodes[i][0]:
ind[i] = 0
elif X[t, i] >= nodes[i][-1]:
ind[i] = nums_grids[i] - 1
else:
k = np.searchsorted(nodes[i], X[t, i])
ind[i] = (
k if nodes[i][k] - X[t, i] < X[t, i] - nodes[i][k-1]
else k - 1
)
out[t] = _cartesian_index(ind[slice_], nums_grids[slice_])
return out
@njit(cache=True)
def _cartesian_index(indices, nums_grids):
n = len(indices)
idx = 0
de_cumprod = 1
for i in range(1,n+1):
idx += de_cumprod * indices[n-i]
de_cumprod *= nums_grids[n-i]
return idx
_msg_max_size_exceeded = 'Maximum allowed size exceeded'
[docs]@jit(nopython=True, cache=True)
def simplex_grid(m, n):
r"""
Construct an array consisting of the integer points in the
(m-1)-dimensional simplex :math:`\{x \mid x_0 + \cdots + x_{m-1} = n
\}`, or equivalently, the m-part compositions of n, which are listed
in lexicographic order. The total number of the points (hence the
length of the output array) is L = (n+m-1)!/(n!*(m-1)!) (i.e.,
(n+m-1) choose (m-1)).
Parameters
----------
m : scalar(int)
Dimension of each point. Must be a positive integer.
n : scalar(int)
Number which the coordinates of each point sum to. Must be a
nonnegative integer.
Returns
-------
out : ndarray(int, ndim=2)
Array of shape (L, m) containing the integer points in the
simplex, aligned in lexicographic order.
Notes
-----
A grid of the (m-1)-dimensional *unit* simplex with n subdivisions
along each dimension can be obtained by `simplex_grid(m, n) / n`.
Examples
--------
>>> simplex_grid(3, 4)
array([[0, 0, 4],
[0, 1, 3],
[0, 2, 2],
[0, 3, 1],
[0, 4, 0],
[1, 0, 3],
[1, 1, 2],
[1, 2, 1],
[1, 3, 0],
[2, 0, 2],
[2, 1, 1],
[2, 2, 0],
[3, 0, 1],
[3, 1, 0],
[4, 0, 0]])
>>> simplex_grid(3, 4) / 4
array([[ 0. , 0. , 1. ],
[ 0. , 0.25, 0.75],
[ 0. , 0.5 , 0.5 ],
[ 0. , 0.75, 0.25],
[ 0. , 1. , 0. ],
[ 0.25, 0. , 0.75],
[ 0.25, 0.25, 0.5 ],
[ 0.25, 0.5 , 0.25],
[ 0.25, 0.75, 0. ],
[ 0.5 , 0. , 0.5 ],
[ 0.5 , 0.25, 0.25],
[ 0.5 , 0.5 , 0. ],
[ 0.75, 0. , 0.25],
[ 0.75, 0.25, 0. ],
[ 1. , 0. , 0. ]])
References
----------
A. Nijenhuis and H. S. Wilf, Combinatorial Algorithms, Chapter 5,
Academic Press, 1978.
"""
L = num_compositions_jit(m, n)
if L == 0: # Overflow occured
raise ValueError(_msg_max_size_exceeded)
out = np.empty((L, m), dtype=np.int_)
x = np.zeros(m, dtype=np.int_)
x[m-1] = n
for j in range(m):
out[0, j] = x[j]
h = m
for i in range(1, L):
h -= 1
val = x[h]
x[h] = 0
x[m-1] = val - 1
x[h-1] += 1
for j in range(m):
out[i, j] = x[j]
if val != 1:
h = m
return out
[docs]def simplex_index(x, m, n):
r"""
Return the index of the point x in the lexicographic order of the
integer points of the (m-1)-dimensional simplex :math:`\{x \mid x_0
+ \cdots + x_{m-1} = n\}`.
Parameters
----------
x : array_like(int, ndim=1)
Integer point in the simplex, i.e., an array of m nonnegative
itegers that sum to n.
m : scalar(int)
Dimension of each point. Must be a positive integer.
n : scalar(int)
Number which the coordinates of each point sum to. Must be a
nonnegative integer.
Returns
-------
idx : scalar(int)
Index of x.
"""
if m == 1:
return 0
decumsum = np.cumsum(x[-1:0:-1])[::-1]
idx = num_compositions(m, n) - 1
for i in range(m-1):
if decumsum[i] == 0:
break
idx -= num_compositions(m-i, decumsum[i]-1)
return idx
[docs]def num_compositions(m, n):
"""
The total number of m-part compositions of n, which is equal to
(n+m-1) choose (m-1).
Parameters
----------
m : scalar(int)
Number of parts of composition.
n : scalar(int)
Integer to decompose.
Returns
-------
scalar(int)
Total number of m-part compositions of n.
"""
# docs.scipy.org/doc/scipy/reference/generated/scipy.special.comb.html
return scipy.special.comb(n+m-1, m-1, exact=True)
[docs]@jit(nopython=True, cache=True)
def num_compositions_jit(m, n):
"""
Numba jit version of `num_compositions`. Return `0` if the outcome
exceeds the maximum value of `np.intp`.
"""
return comb_jit(n+m-1, m-1)