From cb5b32b7c91acd6769c7cf35b2fb2a23e976535d Mon Sep 17 00:00:00 2001 From: Maxence Gollier Date: Wed, 22 Jan 2025 14:32:09 -0500 Subject: [PATCH 1/2] preallocate get value for qrm_get in qrm_spfct --- src/wrapper/qr_mumps_api.jl | 9 ++++----- src/wrapper/qr_mumps_common.jl | 12 +++++++----- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/src/wrapper/qr_mumps_api.jl b/src/wrapper/qr_mumps_api.jl index 3528235..db7dd72 100644 --- a/src/wrapper/qr_mumps_api.jl +++ b/src/wrapper/qr_mumps_api.jl @@ -861,16 +861,15 @@ for (finame, frname, elty) in ((:sqrm_spfct_get_i8_c, :sqrm_spfct_get_r4_c, :Flo @eval begin function qrm_get(spfct :: qrm_spfct{$elty}, str :: String) if (str ∈ PICNTL) || (str ∈ STATS) - val = Ref{Clonglong}(0) - err = $finame(spfct, str, val) + err = $finame(spfct, str, spfct.get_val_long) elseif str ∈ RCNTL - val = Ref{Float32}(0) - err = $frname(spfct, str, val) + err = $frname(spfct, str, spfct.get_val_float) else err = Int32(23) end qrm_check(err) - return val[] + (str ∈ RCNTL) && return spfct.get_val_float[] + return spfct.get_val_long[] end end end diff --git a/src/wrapper/qr_mumps_common.jl b/src/wrapper/qr_mumps_common.jl index ab79222..5c3da02 100644 --- a/src/wrapper/qr_mumps_common.jl +++ b/src/wrapper/qr_mumps_common.jl @@ -75,13 +75,15 @@ the factorization, namely, the factors with all the symbolic information needed solve phase. """ mutable struct qrm_spfct{T} <: Factorization{T} - cperm_in :: Vector{Cint} - ptr_rp :: Ref{Ptr{Cint}} - ptr_cp :: Ref{Ptr{Cint}} - fct :: c_spfct{T} + cperm_in :: Vector{Cint} + ptr_rp :: Ref{Ptr{Cint}} + ptr_cp :: Ref{Ptr{Cint}} + get_val_long :: Ref{Clonglong} + get_val_float :: Ref{Float32} + fct :: c_spfct{T} function qrm_spfct{T}() where T - spfct = new(Cint[], Ref{Ptr{Cint}}(), Ref{Ptr{Cint}}(), c_spfct{T}()) + spfct = new(Cint[], Ref{Ptr{Cint}}(), Ref{Ptr{Cint}}(), Ref{Clonglong}(0), Ref{Float32}(0), c_spfct{T}()) finalizer(qrm_spfct_destroy!, spfct) return spfct end From e1fc075dc3219e6b9b4264db08249f8a27769298 Mon Sep 17 00:00:00 2001 From: Maxence Gollier Date: Wed, 22 Jan 2025 15:15:16 -0500 Subject: [PATCH 2/2] add unit tests for qrm_get allocs and minimize changes in wrapper --- src/wrapper/qr_mumps_api.jl | 9 +++++---- src/wrapper/qr_mumps_common.jl | 12 ++++++------ test/test_qrm.jl | 24 ++++++++++++++++++++++++ 3 files changed, 35 insertions(+), 10 deletions(-) diff --git a/src/wrapper/qr_mumps_api.jl b/src/wrapper/qr_mumps_api.jl index db7dd72..b09a1c5 100644 --- a/src/wrapper/qr_mumps_api.jl +++ b/src/wrapper/qr_mumps_api.jl @@ -861,15 +861,16 @@ for (finame, frname, elty) in ((:sqrm_spfct_get_i8_c, :sqrm_spfct_get_r4_c, :Flo @eval begin function qrm_get(spfct :: qrm_spfct{$elty}, str :: String) if (str ∈ PICNTL) || (str ∈ STATS) - err = $finame(spfct, str, spfct.get_val_long) + val = spfct.ref_int + err = $finame(spfct, str, val) elseif str ∈ RCNTL - err = $frname(spfct, str, spfct.get_val_float) + val = spfct.ref_float + err = $frname(spfct, str, val) else err = Int32(23) end qrm_check(err) - (str ∈ RCNTL) && return spfct.get_val_float[] - return spfct.get_val_long[] + return val[] end end end diff --git a/src/wrapper/qr_mumps_common.jl b/src/wrapper/qr_mumps_common.jl index 5c3da02..53bc2e7 100644 --- a/src/wrapper/qr_mumps_common.jl +++ b/src/wrapper/qr_mumps_common.jl @@ -75,12 +75,12 @@ the factorization, namely, the factors with all the symbolic information needed solve phase. """ mutable struct qrm_spfct{T} <: Factorization{T} - cperm_in :: Vector{Cint} - ptr_rp :: Ref{Ptr{Cint}} - ptr_cp :: Ref{Ptr{Cint}} - get_val_long :: Ref{Clonglong} - get_val_float :: Ref{Float32} - fct :: c_spfct{T} + cperm_in :: Vector{Cint} + ptr_rp :: Ref{Ptr{Cint}} + ptr_cp :: Ref{Ptr{Cint}} + ref_int :: Ref{Clonglong} + ref_float:: Ref{Float32} + fct :: c_spfct{T} function qrm_spfct{T}() where T spfct = new(Cint[], Ref{Ptr{Cint}}(), Ref{Ptr{Cint}}(), Ref{Clonglong}(0), Ref{Float32}(0), c_spfct{T}()) diff --git a/test/test_qrm.jl b/test/test_qrm.jl index ac3fb6c..c7a42be 100644 --- a/test/test_qrm.jl +++ b/test/test_qrm.jl @@ -546,3 +546,27 @@ end @test norm(b - A'*(A*x)) ≥ norm(b - A'*(A*x_refined)) end end + +@testset "allocations" begin + for T in (Float32, Float64, ComplexF32, ComplexF64) + tol = (real(T) == Float32) ? 1e-3 : 1e-12 + transp = (T <: Real) ? 't' : 'c' + + for I in (Int32 , Int64) + A = sprand(T, m, n, 0.3) + A = convert(SparseMatrixCSC{T,I}, A) + + spmat = qrm_spmat_init(T) + qrm_spmat_init!(spmat, A) + + spfct = qrm_spfct_init(spmat) + + qrm_set(spfct, "qrm_rd_eps", tol) + qrm_analyse!(spmat, spfct) + qrm_factorize!(spmat, spfct) + @test (@allocated qrm_get(spfct, "qrm_rd_num")) == 0 + + end + end + +end