# Source code for quantecon.util.numba

```"""
Utilities to support Numba jitted functions

"""
import numpy as np
from numba import jit, types
from numba.np.linalg import _LAPACK

# BLAS kinds as letters
_blas_kinds = {
types.float32: 's',
types.float64: 'd',
types.complex64: 'c',
types.complex128: 'z',
}

def _numba_linalg_solve(a, b):  # pragma: no cover
pass

def _numba_linalg_solve_ol(a, b):
"""
Solve the linear equation ax = b directly calling a Numba internal
function. The data in `a` and `b` are interpreted in Fortran order,
and dtype of `a` and `b` must be the same, one of {float32, float64,
complex64, complex128}. `a` and `b` are modified in place, and the
solution is stored in `b`. *No error check is made for the inputs.*
Only work in a Numba-jitted function.

Parameters
----------
a : ndarray(ndim=2)
2-dimensional ndarray of shape (n, n).

b : ndarray(ndim=1 or 2)
1-dimensional ndarray of shape (n,) or 2-dimensional ndarray of
shape (n, nrhs).

Returns
-------
r : scalar(int)
r = 0 if successful.

Notes
-----
From github.com/numba/numba/blob/main/numba/np/linalg.py

"""
numba_xgesv = _LAPACK().numba_xgesv(a.dtype)
kind = ord(_blas_kinds[a.dtype])

def _numba_linalg_solve_impl(a, b):  # pragma: no cover
n = a.shape[-1]
if b.ndim == 1:
nrhs = 1
else:  # b.ndim == 2
nrhs = b.shape[-1]
F_INT_nptype = np.int32
ipiv = np.empty(n, dtype=F_INT_nptype)

r = numba_xgesv(
kind,         # kind
n,            # n
nrhs,         # nhrs
a.ctypes,     # a
n,            # lda
ipiv.ctypes,  # ipiv
b.ctypes,     # b
n             # ldb
)
return r

return _numba_linalg_solve_impl

[docs]@jit(types.intp(types.intp, types.intp), nopython=True, cache=True)
def comb_jit(N, k):
"""
Numba jitted function that computes N choose k. Return `0` if the
outcome exceeds the maximum value of `np.intp` or if N < 0, k < 0,
or k > N.

Parameters
----------
N : scalar(int)

k : scalar(int)

Returns
-------
val : scalar(int)

"""
# From scipy.special._comb_int_long
# github.com/scipy/scipy/blob/v1.0.0/scipy/special/_comb.pyx
INTP_MAX = np.iinfo(np.intp).max
if N < 0 or k < 0 or k > N:
return 0
if k == 0:
return 1
if k == 1:
return N
if N == INTP_MAX:
return 0

M = N + 1
nterms = min(k, N - k)

val = 1

for j in range(1, nterms + 1):
# Overflow check
if val > INTP_MAX // (M - j):
return 0

val *= M - j
val //= j

return val
```