Source code for quantecon._gridtools

"""
Implements cartesian products and regular cartesian grids, and provides
a function that constructs a grid for a simplex as well as one that
determines the index of a point in the simplex.

"""
import numpy as np
import scipy.special
from numba import jit, njit
from .util.numba import comb_jit


[docs]def cartesian(nodes, order='C'): ''' Cartesian product of a list of arrays Parameters ---------- nodes : list(array_like(ndim=1)) order : str, optional(default='C') ('C' or 'F') order in which the product is enumerated Returns ------- out : ndarray(ndim=2) each line corresponds to one point of the product space ''' nodes = [np.asarray(e) for e in nodes] shapes = [e.shape[0] for e in nodes] dtype = np.result_type(*nodes) n = len(nodes) l = np.prod(shapes) out = np.zeros((l, n), dtype=dtype) if order == 'C': repetitions = np.cumprod([1] + shapes[:-1]) else: shapes.reverse() sh = [1] + shapes[:-1] repetitions = np.cumprod(sh) repetitions = repetitions.tolist() repetitions.reverse() for i in range(n): _repeat_1d(nodes[i], repetitions[i], out[:, i]) return out
[docs]def mlinspace(a, b, nums, order='C'): ''' Constructs a regular cartesian grid Parameters ---------- a : array_like(ndim=1) lower bounds in each dimension b : array_like(ndim=1) upper bounds in each dimension nums : array_like(ndim=1) number of nodes along each dimension order : str, optional(default='C') ('C' or 'F') order in which the product is enumerated Returns ------- out : ndarray(ndim=2) each line corresponds to one point of the product space ''' a = np.asarray(a, dtype='float64') b = np.asarray(b, dtype='float64') nums = np.asarray(nums, dtype='int64') nodes = [np.linspace(a[i], b[i], nums[i]) for i in range(len(nums))] return cartesian(nodes, order=order)
@njit def _repeat_1d(x, K, out): ''' Repeats each element of a vector many times and repeats the whole result many times Parameters ---------- x : ndarray(ndim=1) vector to be repeated K : scalar(int) number of times each element of x is repeated (inner iterations) out : ndarray(ndim=1) placeholder for the result Returns ------- None ''' N = x.shape[0] L = out.shape[0] // (K*N) # number of outer iterations # K # number of inner iterations # the result out should enumerate in C-order the elements # of a 3-dimensional array T of dimensions (K,N,L) # such that for all k,n,l, we have T[k,n,l] == x[n] for n in range(N): val = x[n] for k in range(K): for l in range(L): ind = k*N*L + n*L + l out[ind] = val
[docs]def cartesian_nearest_index(x, nodes, order='C'): """ Return the index of the point closest to `x` within the cartesian product generated by `nodes`. Each array in `nodes` must be sorted in ascending order. Parameters ---------- x : array_like(ndim=1 or 2) Point(s) to search the closest point(s) for. nodes : array_like(array_like(ndim=1)) Array of sorted arrays. order : str, optional(default='C') ('C' or 'F') order in which the product is enumerated. Returns ------- scalar(int) or ndarray(int, ndim=1) Index (indices) of the closest point(s) to `x`. Examples -------- >>> nodes = (np.arange(3), np.arange(2)) >>> prod = qe.cartesian(nodes) >>> print(prod) [[0 0] [0 1] [1 0] [1 1] [2 0] [2 1]] Among the 6 points in the cartesian product `prod`, the closest to the point (0.6, 0.4) is `prod[2]`: >>> x = (0.6, 0.4) >>> qe.cartesian_nearest_index(x, nodes) # Pass `nodes`, not `prod` 2 The closest to (-0.1, 1.2) and (2, 0) are `prod[1]` and `prod[4]`, respectively: >>> x = [(-0.1, 1.2), (2, 0)] >>> qe.cartesian_nearest_index(x, nodes) array([1, 4]) Internally, the index in each dimension is searched by binary search and then the index in the cartesian product is calculated (*not* by constructing the cartesian product and then searching linearly over it). """ x = np.asarray(x) is_1d = False shape = x.shape if len(shape) == 1: is_1d = True x = x[np.newaxis] types = [type(e[0]) for e in nodes] dtype = np.result_type(*types) nodes = tuple(np.asarray(e, dtype=dtype) for e in nodes) n = shape[1-is_1d] if len(nodes) != n: msg = 'point `x`' if is_1d else 'points in `x`' msg += ' must have same length as `nodes`' raise ValueError(msg) out = _cartesian_nearest_indices(x, nodes, order=order) if is_1d: return out[0] return out
@njit(cache=True) def _cartesian_nearest_indices(X, nodes, order='C'): """ The main body of `cartesian_nearest_index`, jit-complied by Numba. Note that `X` must be a 2-dim ndarray, and a Python list is not accepted for `nodes`. Parameters ---------- X : ndarray(ndim=2) Points to search the closest points for. nodes : tuple(ndarray(ndim=1)) Tuple of sorted ndarrays of same dtype. order : str, optional(default='C') ('C' or 'F') order in which the product is enumerated. Returns ------- ndarray(int, ndim=1) Indices of the closest points to the points in `X`. """ m, n = X.shape # m vectors of length n nums_grids = np.empty(n, dtype=np.intp) for i in range(n): nums_grids[i] = len(nodes[i]) ind = np.empty(n, dtype=np.intp) out = np.empty(m, dtype=np.intp) step = -1 if order == 'F' else 1 slice_ = slice(None, None, step) for t in range(m): for i in range(n): if X[t, i] <= nodes[i][0]: ind[i] = 0 elif X[t, i] >= nodes[i][-1]: ind[i] = nums_grids[i] - 1 else: k = np.searchsorted(nodes[i], X[t, i]) ind[i] = ( k if nodes[i][k] - X[t, i] < X[t, i] - nodes[i][k-1] else k - 1 ) out[t] = _cartesian_index(ind[slice_], nums_grids[slice_]) return out @njit(cache=True) def _cartesian_index(indices, nums_grids): n = len(indices) idx = 0 de_cumprod = 1 for i in range(1,n+1): idx += de_cumprod * indices[n-i] de_cumprod *= nums_grids[n-i] return idx _msg_max_size_exceeded = 'Maximum allowed size exceeded'
[docs]@jit(nopython=True, cache=True) def simplex_grid(m, n): r""" Construct an array consisting of the integer points in the (m-1)-dimensional simplex :math:`\{x \mid x_0 + \cdots + x_{m-1} = n \}`, or equivalently, the m-part compositions of n, which are listed in lexicographic order. The total number of the points (hence the length of the output array) is L = (n+m-1)!/(n!*(m-1)!) (i.e., (n+m-1) choose (m-1)). Parameters ---------- m : scalar(int) Dimension of each point. Must be a positive integer. n : scalar(int) Number which the coordinates of each point sum to. Must be a nonnegative integer. Returns ------- out : ndarray(int, ndim=2) Array of shape (L, m) containing the integer points in the simplex, aligned in lexicographic order. Notes ----- A grid of the (m-1)-dimensional *unit* simplex with n subdivisions along each dimension can be obtained by `simplex_grid(m, n) / n`. Examples -------- >>> simplex_grid(3, 4) array([[0, 0, 4], [0, 1, 3], [0, 2, 2], [0, 3, 1], [0, 4, 0], [1, 0, 3], [1, 1, 2], [1, 2, 1], [1, 3, 0], [2, 0, 2], [2, 1, 1], [2, 2, 0], [3, 0, 1], [3, 1, 0], [4, 0, 0]]) >>> simplex_grid(3, 4) / 4 array([[ 0. , 0. , 1. ], [ 0. , 0.25, 0.75], [ 0. , 0.5 , 0.5 ], [ 0. , 0.75, 0.25], [ 0. , 1. , 0. ], [ 0.25, 0. , 0.75], [ 0.25, 0.25, 0.5 ], [ 0.25, 0.5 , 0.25], [ 0.25, 0.75, 0. ], [ 0.5 , 0. , 0.5 ], [ 0.5 , 0.25, 0.25], [ 0.5 , 0.5 , 0. ], [ 0.75, 0. , 0.25], [ 0.75, 0.25, 0. ], [ 1. , 0. , 0. ]]) References ---------- A. Nijenhuis and H. S. Wilf, Combinatorial Algorithms, Chapter 5, Academic Press, 1978. """ L = num_compositions_jit(m, n) if L == 0: # Overflow occured raise ValueError(_msg_max_size_exceeded) out = np.empty((L, m), dtype=np.int_) x = np.zeros(m, dtype=np.int_) x[m-1] = n for j in range(m): out[0, j] = x[j] h = m for i in range(1, L): h -= 1 val = x[h] x[h] = 0 x[m-1] = val - 1 x[h-1] += 1 for j in range(m): out[i, j] = x[j] if val != 1: h = m return out
[docs]def simplex_index(x, m, n): r""" Return the index of the point x in the lexicographic order of the integer points of the (m-1)-dimensional simplex :math:`\{x \mid x_0 + \cdots + x_{m-1} = n\}`. Parameters ---------- x : array_like(int, ndim=1) Integer point in the simplex, i.e., an array of m nonnegative itegers that sum to n. m : scalar(int) Dimension of each point. Must be a positive integer. n : scalar(int) Number which the coordinates of each point sum to. Must be a nonnegative integer. Returns ------- idx : scalar(int) Index of x. """ if m == 1: return 0 decumsum = np.cumsum(x[-1:0:-1])[::-1] idx = num_compositions(m, n) - 1 for i in range(m-1): if decumsum[i] == 0: break idx -= num_compositions(m-i, decumsum[i]-1) return idx
[docs]def num_compositions(m, n): """ The total number of m-part compositions of n, which is equal to (n+m-1) choose (m-1). Parameters ---------- m : scalar(int) Number of parts of composition. n : scalar(int) Integer to decompose. Returns ------- scalar(int) Total number of m-part compositions of n. """ # docs.scipy.org/doc/scipy/reference/generated/scipy.special.comb.html return scipy.special.comb(n+m-1, m-1, exact=True)
[docs]@jit(nopython=True, cache=True) def num_compositions_jit(m, n): """ Numba jit version of `num_compositions`. Return `0` if the outcome exceeds the maximum value of `np.intp`. """ return comb_jit(n+m-1, m-1)