Skip to content

Commit 7c2dba4

Browse files
dkarraschKristofferC
authored andcommitted
Complete size checks in BLAS.[sy/he]mm! (#45605)
(cherry picked from commit da13d78)
1 parent 11e7072 commit 7c2dba4

File tree

3 files changed

+56
-10
lines changed

3 files changed

+56
-10
lines changed

stdlib/LinearAlgebra/src/blas.jl

+42-10
Original file line numberDiff line numberDiff line change
@@ -1566,11 +1566,27 @@ for (mfname, elty) in ((:dsymm_,:Float64),
15661566
require_one_based_indexing(A, B, C)
15671567
m, n = size(C)
15681568
j = checksquare(A)
1569-
if j != (side == 'L' ? m : n)
1570-
throw(DimensionMismatch(lazy"A has size $(size(A)), C has size ($m,$n)"))
1571-
end
1572-
if size(B,2) != n
1573-
throw(DimensionMismatch(lazy"B has second dimension $(size(B,2)) but needs to match second dimension of C, $n"))
1569+
M, N = size(B)
1570+
if side == 'L'
1571+
if j != m
1572+
throw(DimensionMismatch(lazy"A has first dimension $j but needs to match first dimension of C, $m"))
1573+
end
1574+
if N != n
1575+
throw(DimensionMismatch(lazy"B has second dimension $N but needs to match second dimension of C, $n"))
1576+
end
1577+
if j != M
1578+
throw(DimensionMismatch(lazy"A has second dimension $j but needs to match first dimension of B, $M"))
1579+
end
1580+
else
1581+
if j != n
1582+
throw(DimensionMismatch(lazy"B has second dimension $j but needs to match second dimension of C, $n"))
1583+
end
1584+
if N != j
1585+
throw(DimensionMismatch(lazy"A has second dimension $N but needs to match first dimension of B, $j"))
1586+
end
1587+
if M != m
1588+
throw(DimensionMismatch(lazy"A has first dimension $M but needs to match first dimension of C, $m"))
1589+
end
15741590
end
15751591
chkstride1(A)
15761592
chkstride1(B)
@@ -1640,11 +1656,27 @@ for (mfname, elty) in ((:zhemm_,:ComplexF64),
16401656
require_one_based_indexing(A, B, C)
16411657
m, n = size(C)
16421658
j = checksquare(A)
1643-
if j != (side == 'L' ? m : n)
1644-
throw(DimensionMismatch(lazy"A has size $(size(A)), C has size ($m,$n)"))
1645-
end
1646-
if size(B,2) != n
1647-
throw(DimensionMismatch(lazy"B has second dimension $(size(B,2)) but needs to match second dimension of C, $n"))
1659+
M, N = size(B)
1660+
if side == 'L'
1661+
if j != m
1662+
throw(DimensionMismatch(lazy"A has first dimension $j but needs to match first dimension of C, $m"))
1663+
end
1664+
if N != n
1665+
throw(DimensionMismatch(lazy"B has second dimension $N but needs to match second dimension of C, $n"))
1666+
end
1667+
if j != M
1668+
throw(DimensionMismatch(lazy"A has second dimension $j but needs to match first dimension of B, $M"))
1669+
end
1670+
else
1671+
if j != n
1672+
throw(DimensionMismatch(lazy"B has second dimension $j but needs to match second dimension of C, $n"))
1673+
end
1674+
if N != j
1675+
throw(DimensionMismatch(lazy"A has second dimension $N but needs to match first dimension of B, $j"))
1676+
end
1677+
if M != m
1678+
throw(DimensionMismatch(lazy"A has first dimension $M but needs to match first dimension of C, $m"))
1679+
end
16481680
end
16491681
chkstride1(A)
16501682
chkstride1(B)

stdlib/LinearAlgebra/test/blas.jl

+8
Original file line numberDiff line numberDiff line change
@@ -227,11 +227,19 @@ Random.seed!(100)
227227
@test_throws DimensionMismatch BLAS.symm('R','U',Cmn,Cnn)
228228
@test_throws DimensionMismatch BLAS.symm!('L','U',one(elty),Asymm,Cnn,one(elty),Cmn)
229229
@test_throws DimensionMismatch BLAS.symm!('L','U',one(elty),Asymm,Cnn,one(elty),Cnm)
230+
@test_throws DimensionMismatch BLAS.symm!('L','U',one(elty),Asymm,Cmn,one(elty),Cnn)
231+
@test_throws DimensionMismatch BLAS.symm!('R','U',one(elty),Asymm,Cnm,one(elty),Cmn)
232+
@test_throws DimensionMismatch BLAS.symm!('R','U',one(elty),Asymm,Cnn,one(elty),Cnm)
233+
@test_throws DimensionMismatch BLAS.symm!('R','U',one(elty),Asymm,Cmn,one(elty),Cnn)
230234
if elty <: BlasComplex
231235
@test_throws DimensionMismatch BLAS.hemm('L','U',Cnm,Cnn)
232236
@test_throws DimensionMismatch BLAS.hemm('R','U',Cmn,Cnn)
233237
@test_throws DimensionMismatch BLAS.hemm!('L','U',one(elty),Aherm,Cnn,one(elty),Cmn)
234238
@test_throws DimensionMismatch BLAS.hemm!('L','U',one(elty),Aherm,Cnn,one(elty),Cnm)
239+
@test_throws DimensionMismatch BLAS.hemm!('L','U',one(elty),Aherm,Cmn,one(elty),Cnn)
240+
@test_throws DimensionMismatch BLAS.hemm!('R','U',one(elty),Aherm,Cnm,one(elty),Cmn)
241+
@test_throws DimensionMismatch BLAS.hemm!('R','U',one(elty),Aherm,Cnn,one(elty),Cnm)
242+
@test_throws DimensionMismatch BLAS.hemm!('R','U',one(elty),Aherm,Cmn,one(elty),Cnn)
235243
end
236244
end
237245
end

stdlib/LinearAlgebra/test/symmetric.jl

+6
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,9 @@ end
352352
C = zeros(eltya,n,n)
353353
@test Hermitian(aherm) * a aherm * a
354354
@test a * Hermitian(aherm) a * aherm
355+
# rectangular multiplication
356+
@test [a; a] * Hermitian(aherm) [a; a] * aherm
357+
@test Hermitian(aherm) * [a a] aherm * [a a]
355358
@test Hermitian(aherm) * Hermitian(aherm) aherm*aherm
356359
@test_throws DimensionMismatch Hermitian(aherm) * Vector{eltya}(undef, n+1)
357360
LinearAlgebra.mul!(C,a,Hermitian(aherm))
@@ -360,6 +363,9 @@ end
360363
@test Symmetric(asym) * Symmetric(asym) asym*asym
361364
@test Symmetric(asym) * a asym * a
362365
@test a * Symmetric(asym) a * asym
366+
# rectangular multiplication
367+
@test Symmetric(asym) * [a a] asym * [a a]
368+
@test [a; a] * Symmetric(asym) [a; a] * asym
363369
@test_throws DimensionMismatch Symmetric(asym) * Vector{eltya}(undef, n+1)
364370
LinearAlgebra.mul!(C,a,Symmetric(asym))
365371
@test C a*asym

0 commit comments

Comments
 (0)