Skip to content

Commit

Permalink
Add sparse LU routines
Browse files Browse the repository at this point in the history
  • Loading branch information
jchristopherson committed Feb 12, 2024
1 parent 740eb9a commit e9983ce
Show file tree
Hide file tree
Showing 4 changed files with 284 additions and 1 deletion.
44 changes: 43 additions & 1 deletion src/linalg.f90
Original file line number Diff line number Diff line change
Expand Up @@ -758,8 +758,32 @@ module linalg
!! - LA_SINGULAR_MATRIX_ERROR: Occurs as a warning if @p a is found to be
!! singular.
!!
!! @par Syntax (Sparse Matrices)
!! @code{.f90}
!! subroutine lu_factor(class(csr_matrix) a, type(msr_matrix) lu, integer(int32) ju(:), optional real(real64) droptol, optional class(errors) err)
!! @endcode
!!
!! @param[in] a The M-by-N sparse matrix to factor.
!! @param[out] lu The factored matrix, stored in MSR format. The diagonal is
!! stored inverted.
!! @param[out] ju An M-element array used to track the starting row index of
!! the U matrix.
!! @param[in] droptol An optional threshold value used to determine when
!! to drop small terms as part of the factorization of matrix A. The
!! default value is set to the square root of machine precision (~1e-8).
!! @param[in,out] err An optional errors-based object that if provided can
!! be used to retrieve information relating to any errors encountered
!! during execution. If not provided, a default implementation of the
!! errors class is used internally to provide error handling. Possible
!! errors and warning messages that may be encountered are as follows.
!! - LA_ARRAY_SIZE_ERROR: Occurs if @p ju is not sized correctly.
!! - LA_OUT_OF_MEMORY_ERROR: Occurs if there is an issue with internal
!! memory allocations.
!! - LA_MATRIX_FORMAT_ERROR: Occurs if @p a is improperly formatted.
!! - LA_SINGULAR_MATRIX_ERROR: Occurs if @p a is singular.
!!
!! @par Notes
!! This routine utilizes the LAPACK routine DGETRF.
!! The dense routine utilizes the LAPACK routine DGETRF.
!!
!! @par See Also
!! - [Wikipedia](https://en.wikipedia.org/wiki/LU_decomposition)
Expand Down Expand Up @@ -818,6 +842,7 @@ module linalg
interface lu_factor
module procedure :: lu_factor_dbl
module procedure :: lu_factor_cmplx
module procedure :: csr_lu_factor
end interface

!> @brief Extracts the L and U matrices from the condensed [L\\U] storage
Expand Down Expand Up @@ -2379,6 +2404,7 @@ module linalg
module procedure :: solve_lu_mtx_cmplx
module procedure :: solve_lu_vec
module procedure :: solve_lu_vec_cmplx
module procedure :: csr_lu_solve
end interface

! ------------------------------------------------------------------------------
Expand Down Expand Up @@ -5616,6 +5642,22 @@ module subroutine msr_assign_to_csr(csr, msr)
type(csr_matrix), intent(out) :: csr
class(msr_matrix), intent(in) :: msr
end subroutine

module subroutine csr_lu_factor(a, lu, ju, droptol, err)
class(csr_matrix), intent(in) :: a
type(msr_matrix), intent(out) :: lu
integer(int32), intent(out), dimension(:) :: ju
real(real64), intent(in), optional :: droptol
class(errors), intent(inout), optional, target :: err
end subroutine

module subroutine csr_lu_solve(lu, b, ju, x, err)
class(msr_matrix), intent(in) :: lu
real(real64), intent(in), dimension(:) :: b
integer(int32), intent(in), dimension(:) :: ju
real(real64), intent(out), dimension(:) :: x
class(errors), intent(inout), optional, target :: err
end subroutine
end interface

! ------------------------------------------------------------------------------
Expand Down
188 changes: 188 additions & 0 deletions src/linalg_sparse.f90
Original file line number Diff line number Diff line change
Expand Up @@ -1138,5 +1138,193 @@ module subroutine msr_assign_to_csr(csr, msr)
csr = msr_to_csr(msr)
end subroutine

! ******************************************************************************
! LU PRECONDITIONER ROUTINES
! ------------------------------------------------------------------------------
module subroutine csr_lu_factor(a, lu, ju, droptol, err)
! Arguments
class(csr_matrix), intent(in) :: a
type(msr_matrix), intent(out) :: lu
integer(int32), intent(out), dimension(:) :: ju
real(real64), intent(in), optional :: droptol
class(errors), intent(inout), optional, target :: err

! Local Variables
integer(int32) :: i, m, n, nn, nnz, lfil, iwk, ierr, flag
integer(int32), allocatable, dimension(:) :: jlu, jw
real(real64), allocatable, dimension(:) :: alu, w
real(real64) :: dt
class(errors), pointer :: errmgr
type(errors), target :: deferr

! Initialization
if (present(err)) then
errmgr => err
else
errmgr => deferr
end if
if (present(droptol)) then
dt = droptol
else
dt = sqrt(epsilon(dt))
end if
m = size(a, 1)
n = size(a, 2)
nnz = nonzero_count(a)

! Input Check
if (size(ju) /= m) then
call errmgr%report_error("csr_lu_factor", &
"U row tracking array is not sized correctly.", LA_ARRAY_SIZE_ERROR)
return
end if

! Parameter Determination
lfil = 1
do i = 1, m
lfil = max(lfil, a%row_indices(i+1) - a%row_indices(i))
end do
iwk = max(lfil * m, nnz) ! somewhat arbitrary - can be adjusted

! Local Memory Allocation
allocate(alu(iwk), w(n+1), jlu(iwk), jw(2 * n), stat = flag)
if (flag /= 0) go to 10

! Factorization
do
! Factor the matrix
call ilut(n, a%values, a%column_indices, a%row_indices, lfil, dt, &
alu, jlu, ju, iwk, w, jw, ierr)

! Check the error flag
if (ierr == 0) then
! Success
exit
else if (ierr > 0) then
! Zero pivot
else if (ierr == -1) then
! The input matrix is not formatted correctly
go to 20
else if (ierr == -2 .or. ierr == -3) then
! ALU and JLU are too small - try something larger
! This is the main reason for the loop - to offload worrying about
! workspace size from the user
iwk = min(iwk + m + n, m * n)
deallocate(alu)
deallocate(jlu)
allocate(alu(iwk), jlu(iwk), stat = flag)
if (flag /= 0) go to 10
else if (ierr == -4) then
! Illegal value for LFIL - reset and try again
lfil = n
else if (ierr == -5) then
! Zero row encountered
go to 30
else
! We should never get here, but just in case
go to 40
end if
end do

! Determine the actual number of non-zero elements
nnz = jlu(m+1) - 1

! Copy the contents to the output arrays
lu%m = m
lu%n = n
lu%nnz = nnz
nn = m + 1 + nnz - min(m, n)
allocate(lu%values(nn), source = alu(:nn), stat = flag)
if (flag /= 0) go to 10
allocate(lu%indices(nn), source = jlu(:nn), stat = flag)

! End
return

! Memory Error
10 continue
call errmgr%report_error("csr_lu_factor", &
"Memory allocation error.", LA_OUT_OF_MEMORY_ERROR)
return

! Matrix Format Error
20 continue
call errmgr%report_error("csr_lu_factor", &
"The input matrix was incorrectly formatted. A row with more " // &
"than N entries was found.", LA_MATRIX_FORMAT_ERROR)
return

! Zero Row Error
30 continue
call errmgr%report_error("csr_lu_factor", &
"A row with all zeros was encountered in the matrix.", &
LA_SINGULAR_MATRIX_ERROR)
return

! Unknown Error
40 continue
call errmgr%report_error("csr_solve_sparse_direct", "ILUT encountered " // &
"an unknown error. The error code from the ILUT routine is " // &
"provided in the output.", ierr)
return

! Zero Pivot Error
50 continue
call errmgr%report_error("csr_lu_factor", &
"A zero pivot was encountered.", LA_SINGULAR_MATRIX_ERROR)
return
end subroutine

! ------------------------------------------------------------------------------
module subroutine csr_lu_solve(lu, b, ju, x, err)
! Arguments
class(msr_matrix), intent(in) :: lu
real(real64), intent(in), dimension(:) :: b
integer(int32), intent(in), dimension(:) :: ju
real(real64), intent(out), dimension(:) :: x
class(errors), intent(inout), optional, target :: err

! Local Variables
integer(int32) :: m, n
class(errors), pointer :: errmgr
type(errors), target :: deferr

! Initialization
if (present(err)) then
errmgr => err
else
errmgr => deferr
end if
m = size(lu, 1)
n = size(lu, 2)

! Input Check
if (m /= n) then
call errmgr%report_error("csr_lu_solve", &
"The input matrix is expected to be square.", LA_ARRAY_SIZE_ERROR)
return
end if
if (size(x) /= m) then
call errmgr%report_error("csr_lu_solve", &
"Inner matrix dimension mismatch.", LA_ARRAY_SIZE_ERROR)
return
end if
if (size(b) /= m) then
call errmgr%report_error("csr_lu_solve", &
"The output array dimension does not match the rest of the problem.", &
LA_ARRAY_SIZE_ERROR)
return
end if
if (size(ju) /= m) then
call errmgr%report_error("csr_lu_solve", &
"The U row tracking array is not sized correctly.", &
LA_ARRAY_SIZE_ERROR)
return
end if

! Process
call lusol(m, b, x, lu%values, lu%indices, ju)
end subroutine

! ------------------------------------------------------------------------------
end submodule
3 changes: 3 additions & 0 deletions tests/linalg_test.f90
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,9 @@ program main
rst = test_msr_1()
if (.not.rst) flag = 104

rst = test_csr_lu_factor_1()
if (.not.rst) flag = 105

! End
if (flag /= 0) stop flag
end program
50 changes: 50 additions & 0 deletions tests/test_sparse.f90
Original file line number Diff line number Diff line change
Expand Up @@ -581,5 +581,55 @@ function test_msr_1() result(rst)
end if
end function

! ------------------------------------------------------------------------------
function test_csr_lu_factor_1() result(rst)
! Arguments
logical :: rst

! Local Variables
integer(int32) :: i, ipiv(4), ju(4)
real(real64) :: dense(4, 4), check(4, 4), x(4), b(4)
type(csr_matrix) :: sparse
type(msr_matrix) :: slu

! Initialization
rst = .true.
dense = reshape([ &
5.0d0, 0.0d0, 0.0d0, 0.0d0, &
0.0d0, 8.0d0, 0.0d0, 6.0d0, &
0.0d0, 0.0d0, 3.0d0, 0.0d0, &
0.0d0, 0.0d0, 0.0d0, 5.0d0], [4, 4])
sparse = dense
call random_number(b)

! Compute the factorization of the sparse matrix
call lu_factor(sparse, slu, ju)

! Compute the factorization of the dense matrix
call lu_factor(dense, ipiv)

! Test - the diagonal must be inverted
check = slu
do i = 1, size(check, 1)
check(i,i) = 1.0d0 / check(i,i)
end do
if (.not.assert(check, dense)) then
rst = .false.
print "(A)", "Test Failed: test_csr_lu_factor_1 -1"
end if

! Solve the sparse system
call solve_lu(slu, b, ju, x)

! Now solve the dense system for comparison
call solve_lu(dense, ipiv, b)

! Test
if (.not.assert(x, b)) then
rst = .false.
print "(A)", "Test Failed: test_csr_lu_factor_1 -2"
end if
end function

! ------------------------------------------------------------------------------
end module

0 comments on commit e9983ce

Please sign in to comment.