Skip to content

Commit

Permalink
refactor to separate logic specific for GeneralMatrix
Browse files Browse the repository at this point in the history
  • Loading branch information
konovod committed Jan 11, 2024
1 parent acd837b commit 866630d
Show file tree
Hide file tree
Showing 8 changed files with 178 additions and 93 deletions.
20 changes: 19 additions & 1 deletion src/linalg/eig.cr
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,24 @@ module LA
a.eigs(b: b, need_left: need_left, need_right: need_right, overwrite_a: overwrite_a, overwrite_b: overwrite_b)
end

class Matrix(T)
def eigvals
to_general.eigvals(overwrite_a: true)
end

def eigs(*, left = false)
to_general.eigs(overwrite_a: true, left: left)
end

def eigs(*, need_left : Bool, need_right : Bool)
to_general.eigs(overwrite_a: true, need_left: need_left, need_right: need_right)
end

def eigs(*, b : self, need_left = false, need_right = false)
to_general.eigs(b: b.to_general, overwrite_a: true, overwrite_b: true, need_left: need_left, need_right: need_right)
end
end

class GeneralMatrix(T) < Matrix(T)
def eigvals(*, overwrite_a = false)
vals, aleft, aright = eigs(overwrite_a: overwrite_a, need_left: false, need_right: false)
Expand Down Expand Up @@ -140,7 +158,7 @@ module LA
end

# generalized eigenvalues problem
def eigs(*, b : Matrix(T), need_left : Bool, need_right : Bool, overwrite_a = false, overwrite_b = false)
def eigs(*, b : GeneralMatrix(T), need_left : Bool, need_right : Bool, overwrite_a = false, overwrite_b = false)
raise ArgumentError.new("a matrix must be square") unless square?
raise ArgumentError.new("b matrix must have same size as a") unless b.size == self.size
{% if T == Complex %}
Expand Down
2 changes: 1 addition & 1 deletion src/linalg/expm.cr
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ module LA
4.250000000000000e+000,
} # m_vals = 13

abstract class Matrix(T)
class GeneralMatrix(T) < Matrix(T)
# Computes matrix exponential
#
# It exploits triangularity (if any) of A.
Expand Down
162 changes: 104 additions & 58 deletions src/linalg/linalg.cr
Original file line number Diff line number Diff line change
Expand Up @@ -85,51 +85,45 @@ module LA
tril! if flags.lower_triangular?
end

# Calculate matrix inversion inplace
# Method selects optimal algorithm depending on `MatrixFlags`
# `transpose` returned if matrix is orthogonal
# `trtri` is used if matrix is triangular
# `potrf`, `potri` are used if matrix is positive definite
# `hetrf`, `hetri` are used if matrix is hermitian
# `sytrf`, `sytri` are used if matrix is symmetric
# `getrf`, `getri` are used otherwise
def inv!
raise ArgumentError.new("can't invert nonsquare matrix") unless square?
return transpose! if flags.orthogonal?
n = self.nrows
if flags.triangular?
lapack(trtri, uplo, 'N'.ord.to_u8, n, self, n)
adjust_triangular
elsif flags.positive_definite?
lapack(potrf, uplo, n, self, n)
lapack(potri, uplo, n, self, n)
adjust_symmetric
elsif {{T == Complex}} && flags.hermitian?
{% if T == Complex %}
ipiv = Slice(Int32).new(n)
lapack(hetrf, uplo, n, self, n, ipiv)
lapack(hetri, uplo, n, self, n, ipiv, worksize: [n])
adjust_symmetric
{% else %}
raise "error" # to prevent type inference of nil
{% end %}
elsif flags.symmetric?
ipiv = Slice(Int32).new(n)
lapack(sytrf, uplo, n, self, n, ipiv)
lapack(sytri, uplo, n, self, n, ipiv, worksize: [2*n])
adjust_symmetric
else
ipiv = Slice(Int32).new(n)
lapack(getrf, n, n, self, n, ipiv)
lapack(getri, n, self, n, ipiv)
end
self
end

# Calculate matrix inversion
# See `#inv!` for details on algorithm
def inv
clone.inv!
to_general.inv!
end

def svd
to_general.svd(overwrite_a: true)
end

def svdvals
to_general.svdvals(overwrite_a: true)
end

def solve(b : GeneralMatrix(T), *, overwrite_b = false)
to_general.solve(b, overwrite_b: overwrite_b)
end

def solve(b : self)
to_general.solve(b.to_general, overwrite_a: true, overwrite_b: true)
end

def det
to_general.det(overwrite_a: true)
end

def balance(*, permute = true, scale = true, separate = false)
a = to_general
s = a.balance!(permute: permute, scale: scale, separate: separate)
{a, s}
end

def hessenberg(*, calc_q = false)
x = to_general
x.hessenberg!(calc_q: calc_q)
end

def hessenberg
to_general.hessenberg!
end

# Calculate Moore–Penrose inverse of a matrix
Expand All @@ -149,7 +143,7 @@ module LA
e
end
}
s_dash : GeneralMatrix(T) = GeneralMatrix(T).diag(s_inverse)
s_dash = self.class.diag(s_inverse)

append_rows = 0
append_cols = 0
Expand All @@ -174,6 +168,73 @@ module LA
return v * s_dash * u
end

def norm(kind : MatrixNorm = MatrixNorm::Frobenius)
result = of_real_type(0)
case kind
in .frobenius?
each { |v| result += of_real_type(v*v) }
Math.sqrt(result)
in .one?
sums = of_real_type(Slice, ncolumns)
each_with_index do |v, row, col|
sums[col] += v.abs
end
sums.max
in .inf?
sums = of_real_type(Slice, nrows)
each_with_index do |v, row, col|
sums[row] += v.abs
end
sums.max
in .max_abs?
each { |v| result = v.abs if result < v.abs }
result
end
end
end

class GeneralMatrix(T) < Matrix(T)
# Calculate matrix inversion inplace
# Method selects optimal algorithm depending on `MatrixFlags`
# `transpose` returned if matrix is orthogonal
# `trtri` is used if matrix is triangular
# `potrf`, `potri` are used if matrix is positive definite
# `hetrf`, `hetri` are used if matrix is hermitian
# `sytrf`, `sytri` are used if matrix is symmetric
# `getrf`, `getri` are used otherwise
def inv!
raise ArgumentError.new("can't invert nonsquare matrix") unless square?
return transpose! if flags.orthogonal?
n = self.nrows
if flags.triangular?
lapack(trtri, uplo, 'N'.ord.to_u8, n, self, n)
adjust_triangular
elsif flags.positive_definite?
lapack(potrf, uplo, n, self, n)
lapack(potri, uplo, n, self, n)
adjust_symmetric
elsif {{T == Complex}} && flags.hermitian?
{% if T == Complex %}
ipiv = Slice(Int32).new(n)
lapack(hetrf, uplo, n, self, n, ipiv)
lapack(hetri, uplo, n, self, n, ipiv, worksize: [n])
adjust_symmetric
{% else %}
raise "error" # to prevent type inference of nil
{% end %}
elsif flags.symmetric?
ipiv = Slice(Int32).new(n)
lapack(sytrf, uplo, n, self, n, ipiv)
lapack(sytri, uplo, n, self, n, ipiv, worksize: [2*n])
adjust_symmetric
else
ipiv = Slice(Int32).new(n)
lapack(getrf, n, n, self, n, ipiv)
lapack(getri, n, self, n, ipiv)
end
self
end

# Solves matrix equation `self*x = b` and returns x
#
# Method returns matrix of same size as `b`
Expand Down Expand Up @@ -342,12 +403,6 @@ module LA
separate ? s : Matrix(T).diag(s.raw)
end

def balance(*, permute = true, scale = true, separate = false)
a = clone
s = a.balance!(permute: permute, scale: scale, separate: separate)
{a, s}
end

def hessenberg!(*, calc_q = false)
raise ArgumentError.new("matrix must be square") unless square?
# idea from scipy.
Expand Down Expand Up @@ -381,15 +436,6 @@ module LA
self
end

def hessenberg(*, calc_q = false)
x = self.clone
x.hessenberg!(calc_q: calc_q)
end

def hessenberg
clone.hessenberg!
end

# returns matrix norm
#
# `kind` defines type of norm
Expand Down
22 changes: 14 additions & 8 deletions src/linalg/lu.cr
Original file line number Diff line number Diff line change
@@ -1,4 +1,18 @@
module LA
class Matrix(T)
def lu
to_general.lu(overwrite_a: true)
end

# Compute pivoted LU decomposition of a matrix and returns it in a compact form,
# useful for solving linear equation systems.
#
# Uses `getrf` LAPACK routine
def lu_factor : LUMatrix(T)
to_general.lu_factor!
end
end

class GeneralMatrix(T) < Matrix(T)
# Compute pivoted LU decomposition of a matrix
#
Expand Down Expand Up @@ -64,14 +78,6 @@ module LA
clear_flags
LUMatrix(T).new(self, ipiv)
end

# Compute pivoted LU decomposition of a matrix and returns it in a compact form,
# useful for solving linear equation systems.
#
# Uses `getrf` LAPACK routine
def lu_factor : LUMatrix(T)
clone.lu_factor!
end
end

enum Enums::LUTranspose
Expand Down
14 changes: 14 additions & 0 deletions src/linalg/qr.cr
Original file line number Diff line number Diff line change
@@ -1,6 +1,20 @@
# TODO - inline docs

module LA
class Matrix(T)
def qr_raw(*, pivoting = false)
to_general.qr_raw(overwrite_a: true)
end

def qr_r(*, pivoting = false)
to_general.qr_r(overwrite_a: true)
end

def qr(*, pivoting = false)
to_general.qr(overwrite_a: true)
end
end

class GeneralMatrix(T) < Matrix(T)
private def qr_initial(a, pivoting)
m = a.nrows
Expand Down
15 changes: 15 additions & 0 deletions src/linalg/rq_lq_ql.cr
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,22 @@ end

end

private macro decomposition_wrap(letters)
def {{letters}}
to_general.{{letters}}(overwrite_a: true)
end
def {{letters}}_r
to_general.{{letters}}_r(overwrite_a: true)
end
end

module LA
class Matrix(T)
decomposition_wrap(rq)
decomposition_wrap(lq)
decomposition_wrap(ql)
end

class GeneralMatrix(T) < Matrix(T)
decomposition(rq, triu)
decomposition(lq, tril)
Expand Down
12 changes: 11 additions & 1 deletion src/linalg/schur.cr
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,16 @@ module LA
a.qz(b, overwrite_a: overwrite_a, overwrite_b: overwrite_b)
end

class Matrix(T)
def schur
to_general.schur(overwrite_a: true)
end

def qz(b : self)
to_general.qz(b, overwrite_a: true, overwrite_b: true)
end
end

class GeneralMatrix(T) < Matrix(T)
def schur(*, overwrite_a = false)
raise ArgumentError.new("matrix must be square") unless square?
Expand All @@ -32,7 +42,7 @@ module LA
{% end %}
end

def qz(b, overwrite_a = false, overwrite_b = false)
def qz(b : self, overwrite_a = false, overwrite_b = false)
raise ArgumentError.new("matrix must be square") unless square?
a = overwrite_a ? self : clone
bb = overwrite_b ? b : b.clone
Expand Down
24 changes: 0 additions & 24 deletions src/matrix/sparse/sparse_matrix.cr
Original file line number Diff line number Diff line change
Expand Up @@ -87,29 +87,5 @@ module LA::Sparse
def select_index(& : (Int32, Int32) -> Bool)
select_with_index { |v, i, j| yield(i, j) }
end

def norm(kind : MatrixNorm = MatrixNorm::Frobenius)
result = T.additive_identity
case kind
in .frobenius?
each { |v| result += v*v }
Math.sqrt(result)
in .one?
sums = Slice(T).new(ncolumns, T.additive_identity)
each_with_index do |v, row, col|
sums[col] += v.abs
end
sums.max
in .inf?
sums = Slice(T).new(nrows, T.additive_identity)
each_with_index do |v, row, col|
sums[row] += v.abs
end
sums.max
in .max_abs?
each { |v| result = v.abs if result < v }
result
end
end
end
end

0 comments on commit 866630d

Please sign in to comment.