Skip to content

Commit

Permalink
Merge pull request #19 from JuliaMolSim/fixes
Browse files Browse the repository at this point in the history
Simplify grid shape determination
  • Loading branch information
mfherbst authored Sep 29, 2020
2 parents 99b5e3e + b0a8327 commit 14cf3bb
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 25 deletions.
32 changes: 20 additions & 12 deletions src/evaluate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,19 @@ as a named tuple.
function evaluate(func::Functional; derivatives=0:1, rho::AbstractArray, kwargs...)
@assert all(0 .≤ derivatives .≤ 4)

# If we have an n_spin × size array, keep the shape when allocating output arrays
shape = size(rho)
if func.spin_dimensions.rho > 1
if size(rho, 1) == func.spin_dimensions.rho
shape = size(rho)[2:end]
else
shape = (Int(length(rho) / func.spin_dimensions.rho), )
# Determine the gridshape (i.e. the shape of the grid points without the spin components)
if ndims(rho) > 1
if size(rho, 1) != func.spin_dimensions.rho
error("First axis for multidimensional rho array should be equal " *
"to the number of spin components (== $(func.spin_dimensions.rho))")
end
gridshape = size(rho)[2:end]
else
if mod(length(rho), func.spin_dimensions.rho) != 0
error("Length of linear rho array should be divisible by number of spin " *
"components in rho (== $(func.spin_dimensions.rho)).")
end
gridshape = (Int(length(rho) / func.spin_dimensions.rho), )
end

# Output arguments, where memory is already allocated
Expand All @@ -69,13 +74,11 @@ function evaluate(func::Functional; derivatives=0:1, rho::AbstractArray, kwargs.
for symbol in vcat(ARGUMENTS[func.family][1 .+ derivatives]...)
if symbol in keys(kwargs)
outargs_allocated[symbol] = kwargs[symbol]
elseif symbol == :zk # For zk keep just the grid shape
outargs[symbol] = similar(rho, gridshape)
else
n_spin = getfield(func.spin_dimensions, symbol)
if n_spin > 1
outargs[symbol] = similar(rho, n_spin, shape...)
else
outargs[symbol] = similar(rho, shape)
end
outargs[symbol] = similar(rho, n_spin, gridshape...)
end
end

Expand All @@ -91,6 +94,11 @@ depend on the functional type (`rho` for all functionals, `sigma` for GGA and mG
`tau` and `lapl` for mGGA).
"""
function evaluate!(func::Functional; rho::AbstractArray, kwargs...)
mod(length(rho), func.spin_dimensions.rho) != 0 && error(
"Length of rho array should be divisible by number of spin " *
"components in rho (== $(func.spin_dimensions.rho))."
)

n_p = Int(length(rho) / func.spin_dimensions.rho)
kwargs = Dict(kwargs)
for argument in vcat(INPUT[func.family]..., ARGUMENTS[func.family]...)
Expand Down
25 changes: 12 additions & 13 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,45 +98,44 @@ end

@testset "LDA / GGA evaluate with spin" begin
shape = (2, 3, 4)
rho = abs.(randn(shape))
sigma = abs.(randn(shape))
rho = reshape(abs.(randn(shape)), 1, shape...)
sigma = reshape(abs.(randn(shape)), 1, shape...)

# Duplicate rho and sigma for spin = 2 tests
rho2 = 0.5vcat(reshape(rho, 1, shape...), reshape(rho, 1, shape...))
sigma2 = 0.25vcat(reshape(sigma, 1, shape...), reshape(sigma, 1, shape...),
reshape(sigma, 1, shape...))
rho2 = 0.5vcat(rho, rho)
sigma2 = 0.25vcat(sigma, sigma, sigma)

# LSDA
for sym in (:lda_x, :lda_c_vwn)
res = evaluate(Functional(sym, n_spin=1), rho=rho, zk=zeros(shape))
@test size(res.zk) == shape
@test size(res.vrho) == shape
@test size(res.vrho) == (1, shape...)

res2 = evaluate(Functional(sym, n_spin=2), rho=rho2)
@test size(res2.zk) == shape
@test size(res2.vrho) == (2, shape...)

@test res.zk res2.zk
@test res.vrho res2.vrho[1, :, :, :]
@test res.vrho res2.vrho[2, :, :, :]
@test res.vrho[1, :, :, :] res2.vrho[1, :, :, :]
@test res.vrho[1, :, :, :] res2.vrho[2, :, :, :]
end

# GGA
for sym in (:gga_x_pbe, :gga_c_pbe)
res = evaluate(Functional(sym, n_spin=1), rho=rho, sigma=sigma)
@test size(res.zk) == shape
@test size(res.vrho) == shape
@test size(res.vsigma) == shape
@test size(res.vrho) == (1, shape...)
@test size(res.vsigma) == (1, shape...)

res2 = evaluate(Functional(sym, n_spin=2), rho=rho2, sigma=sigma2)
@test size(res2.zk) == shape
@test size(res2.vrho) == (2, shape...)
@test size(res2.vsigma) == (3, shape...)

@test res.zk res2.zk
@test res.vrho res2.vrho[1, :, :, :]
@test res.vrho res2.vrho[2, :, :, :]
@test 4res.vsigma dropdims(sum(res2.vsigma, dims=1), dims=1)
@test res.vrho[1, :, :, :] res2.vrho[1, :, :, :]
@test res.vrho[1, :, :, :] res2.vrho[2, :, :, :]
@test 4res.vsigma sum(res2.vsigma, dims=1)
end
end

Expand Down

0 comments on commit 14cf3bb

Please sign in to comment.