Source code for sisl.linalg.base

from functools import partial as _partial

# Create a _copy_ of the scipy.linalg.solve routine and implement
# our own refine keyword.
from numpy import atleast_1d, atleast_2d
from scipy.linalg.lapack import get_lapack_funcs, _compute_lwork
from scipy.linalg.misc import LinAlgError, _datacopied
from scipy._lib._util import _asarray_validated

import scipy.linalg as sl
import scipy.sparse.linalg as ssl


__all__ = []


[docs]def inv(a, overwrite_a=False): """ Inverts a matrix Parameters ---------- a : (N, N) array_like the matrix to be inverted. overwrite_a : bool, optional whether we are allowed to overwrite the matrix `a` Returns ------- x : (N, N) ndarray The inverted matrix """ a1 = atleast_2d(_asarray_validated(a, check_finite=False)) overwrite_a = overwrite_a or _datacopied(a1, a) if a1.shape[0] != a1.shape[1]: raise ValueError('Input a needs to be a square matrix.') getrf, getri, getri_lwork = get_lapack_funcs(('getrf', 'getri', 'getri_lwork'), (a1,)) lu, piv, info = getrf(a1, overwrite_a=overwrite_a) if info == 0: lwork = _compute_lwork(getri_lwork, a1.shape[0]) lwork = int(1.01 * lwork) x, info = getri(lu, piv, lwork=lwork, overwrite_lu=True) if info > 0: raise LinAlgError("Singular matrix") if info < 0: raise ValueError('illegal value in %d-th argument of internal ' 'getrf|getri' % -info) return x
[docs]def solve(a, b, overwrite_a=False, overwrite_b=False): """ Solve a linear system ``a x = b`` Parameters ---------- a : (N, N) array_like left-hand-side matrix b : (N, NRHS) array_like right-hand-side matrix overwrite_a : bool, optional whether we are allowed to overwrite the matrix `a` overwrite_b : bool, optional whether we are allowed to overwrite the matrix `b` Returns ------- x : (N, NRHS) ndarray solution matrix """ a1 = atleast_2d(_asarray_validated(a, check_finite=False)) b1 = atleast_1d(_asarray_validated(b, check_finite=False)) n = a1.shape[0] overwrite_a = overwrite_a or _datacopied(a1, a) overwrite_b = overwrite_b or _datacopied(b1, b) if a1.shape[0] != a1.shape[1]: raise ValueError('LHS needs to be a square matrix.') if n != b1.shape[0]: # Last chance to catch 1x1 scalar a and 1D b arrays if not (n == 1 and b1.size != 0): raise ValueError('Input b has to have same number of rows as ' 'input a') # regularize 1D b arrays to 2D if b1.ndim == 1: if n == 1: b1 = b1[None, :] else: b1 = b1[:, None] b_is_1D = True else: b_is_1D = False gesv = get_lapack_funcs('gesv', (a1, b1)) _, _, x, info = gesv(a1, b1, overwrite_a=overwrite_a, overwrite_b=overwrite_b) if info > 0: raise LinAlgError("Singular matrix") if info < 0: raise ValueError('illegal value in %d-th argument of internal ' 'gesv' % -info) if b_is_1D: return x.ravel() return x
def _append(name, suffix): return [name + s for s in suffix] # Solving a linear system solve_destroy = _partial(solve, overwrite_a=True, overwrite_b=True) __all__ += _append('solve', ['', '_destroy']) # Inversion of matrix inv_destroy = _partial(inv, overwrite_a=True) __all__ += _append('inv', ['', '_destroy']) # Solve eigenvalue problem eig = _partial(sl.eig, check_finite=False, overwrite_a=False, overwrite_b=False) eig_left = _partial(sl.eig, check_finite=False, overwrite_a=False, overwrite_b=False, left=True) eig_right = _partial(sl.eig, check_finite=False, overwrite_a=False, overwrite_b=False, right=True) __all__ += _append('eig', ['', '_left', '_right']) # Solve eigenvalue problem eig_destroy = _partial(sl.eig, check_finite=False, overwrite_a=True, overwrite_b=True) eig_left_destroy = _partial(sl.eig, check_finite=False, overwrite_a=True, overwrite_b=True, left=True) eig_right_destroy = _partial(sl.eig, check_finite=False, overwrite_a=True, overwrite_b=True, right=True) __all__ += _append('eig_', ['destroy', 'left_destroy', 'right_destroy']) # Solve symmetric/hermitian eigenvalue problem (generic == no overwrite) eigh = _partial(sl.eigh, check_finite=False, overwrite_a=False, overwrite_b=False, turbo=True) eigh_dc = eigh eigh_qr = _partial(sl.eigh, check_finite=False, overwrite_a=False, overwrite_b=False, turbo=False) __all__ += _append('eigh', ['', '_dc', '_qr']) # Solve symmetric/hermitian eigenvalue problem (allow overwrite) eigh_destroy = _partial(sl.eigh, check_finite=False, overwrite_a=True, overwrite_b=True, turbo=True) eigh_dc_destroy = eigh_destroy eigh_qr_destroy = _partial(sl.eigh, check_finite=False, overwrite_a=True, overwrite_b=True, turbo=False) __all__ += _append('eigh_', ['destroy', 'dc_destroy', 'qr_destroy']) # SVD problem svd = _partial(sl.svd, check_finite=False, overwrite_a=False) svd_destroy = _partial(sl.svd, check_finite=False, overwrite_a=True) __all__ += _append('svd', ['', '_destroy']) # Sparse linalg routines # Solve eigenvalue problem eigs = ssl.eigs __all__ += ['eigs'] # Solve symmetric/hermitian eigenvalue problem (generic == no overwrite) eigsh = ssl.eigsh __all__ += ['eigsh']