Source code for quantecon.gridtools

"""
Implements cartesian products and regular cartesian grids, and provides
a function that constructs a grid for a simplex as well as one that
determines the index of a point in the simplex.

"""
import numpy as np
import scipy.special
from numba import jit, njit
from .util.numba import comb_jit


[docs]def cartesian(nodes, order='C'): ''' Cartesian product of a list of arrays Parameters ---------- nodes : list(array_like(ndim=1)) order : str, optional(default='C') ('C' or 'F') order in which the product is enumerated Returns ------- out : ndarray(ndim=2) each line corresponds to one point of the product space ''' nodes = [np.array(e) for e in nodes] shapes = [e.shape[0] for e in nodes] dtype = nodes[0].dtype n = len(nodes) l = np.prod(shapes) out = np.zeros((l, n), dtype=dtype) if order == 'C': repetitions = np.cumprod([1] + shapes[:-1]) else: shapes.reverse() sh = [1] + shapes[:-1] repetitions = np.cumprod(sh) repetitions = repetitions.tolist() repetitions.reverse() for i in range(n): _repeat_1d(nodes[i], repetitions[i], out[:, i]) return out
[docs]def mlinspace(a, b, nums, order='C'): ''' Constructs a regular cartesian grid Parameters ---------- a : array_like(ndim=1) lower bounds in each dimension b : array_like(ndim=1) upper bounds in each dimension nums : array_like(ndim=1) number of nodes along each dimension order : str, optional(default='C') ('C' or 'F') order in which the product is enumerated Returns ------- out : ndarray(ndim=2) each line corresponds to one point of the product space ''' a = np.array(a, dtype='float64') b = np.array(b, dtype='float64') nums = np.array(nums, dtype='int64') nodes = [np.linspace(a[i], b[i], nums[i]) for i in range(len(nums))] return cartesian(nodes, order=order)
@njit def _repeat_1d(x, K, out): ''' Repeats each element of a vector many times and repeats the whole result many times Parameters ---------- x : ndarray(ndim=1) vector to be repeated K : scalar(int) number of times each element of x is repeated (inner iterations) out : ndarray(ndim=1) placeholder for the result Returns ------- None ''' N = x.shape[0] L = out.shape[0] // (K*N) # number of outer iterations # K # number of inner iterations # the result out should enumerate in C-order the elements # of a 3-dimensional array T of dimensions (K,N,L) # such that for all k,n,l, we have T[k,n,l] == x[n] for n in range(N): val = x[n] for k in range(K): for l in range(L): ind = k*N*L + n*L + l out[ind] = val _msg_max_size_exceeded = 'Maximum allowed size exceeded'
[docs]@jit(nopython=True, cache=True) def simplex_grid(m, n): r""" Construct an array consisting of the integer points in the (m-1)-dimensional simplex :math:`\{x \mid x_0 + \cdots + x_{m-1} = n \}`, or equivalently, the m-part compositions of n, which are listed in lexicographic order. The total number of the points (hence the length of the output array) is L = (n+m-1)!/(n!*(m-1)!) (i.e., (n+m-1) choose (m-1)). Parameters ---------- m : scalar(int) Dimension of each point. Must be a positive integer. n : scalar(int) Number which the coordinates of each point sum to. Must be a nonnegative integer. Returns ------- out : ndarray(int, ndim=2) Array of shape (L, m) containing the integer points in the simplex, aligned in lexicographic order. Notes ----- A grid of the (m-1)-dimensional *unit* simplex with n subdivisions along each dimension can be obtained by `simplex_grid(m, n) / n`. Examples -------- >>> simplex_grid(3, 4) array([[0, 0, 4], [0, 1, 3], [0, 2, 2], [0, 3, 1], [0, 4, 0], [1, 0, 3], [1, 1, 2], [1, 2, 1], [1, 3, 0], [2, 0, 2], [2, 1, 1], [2, 2, 0], [3, 0, 1], [3, 1, 0], [4, 0, 0]]) >>> simplex_grid(3, 4) / 4 array([[ 0. , 0. , 1. ], [ 0. , 0.25, 0.75], [ 0. , 0.5 , 0.5 ], [ 0. , 0.75, 0.25], [ 0. , 1. , 0. ], [ 0.25, 0. , 0.75], [ 0.25, 0.25, 0.5 ], [ 0.25, 0.5 , 0.25], [ 0.25, 0.75, 0. ], [ 0.5 , 0. , 0.5 ], [ 0.5 , 0.25, 0.25], [ 0.5 , 0.5 , 0. ], [ 0.75, 0. , 0.25], [ 0.75, 0.25, 0. ], [ 1. , 0. , 0. ]]) References ---------- A. Nijenhuis and H. S. Wilf, Combinatorial Algorithms, Chapter 5, Academic Press, 1978. """ L = num_compositions_jit(m, n) if L == 0: # Overflow occured raise ValueError(_msg_max_size_exceeded) out = np.empty((L, m), dtype=np.int_) x = np.zeros(m, dtype=np.int_) x[m-1] = n for j in range(m): out[0, j] = x[j] h = m for i in range(1, L): h -= 1 val = x[h] x[h] = 0 x[m-1] = val - 1 x[h-1] += 1 for j in range(m): out[i, j] = x[j] if val != 1: h = m return out
[docs]def simplex_index(x, m, n): r""" Return the index of the point x in the lexicographic order of the integer points of the (m-1)-dimensional simplex :math:`\{x \mid x_0 + \cdots + x_{m-1} = n\}`. Parameters ---------- x : array_like(int, ndim=1) Integer point in the simplex, i.e., an array of m nonnegative itegers that sum to n. m : scalar(int) Dimension of each point. Must be a positive integer. n : scalar(int) Number which the coordinates of each point sum to. Must be a nonnegative integer. Returns ------- idx : scalar(int) Index of x. """ if m == 1: return 0 decumsum = np.cumsum(x[-1:0:-1])[::-1] idx = num_compositions(m, n) - 1 for i in range(m-1): if decumsum[i] == 0: break idx -= num_compositions(m-i, decumsum[i]-1) return idx
[docs]def num_compositions(m, n): """ The total number of m-part compositions of n, which is equal to (n+m-1) choose (m-1). Parameters ---------- m : scalar(int) Number of parts of composition. n : scalar(int) Integer to decompose. Returns ------- scalar(int) Total number of m-part compositions of n. """ # docs.scipy.org/doc/scipy/reference/generated/scipy.special.comb.html return scipy.special.comb(n+m-1, m-1, exact=True)
[docs]@jit(nopython=True, cache=True) def num_compositions_jit(m, n): """ Numba jit version of `num_compositions`. Return `0` if the outcome exceeds the maximum value of `np.intp`. """ return comb_jit(n+m-1, m-1)