Skip to content

Commit 8575786

Browse files
committed
increase delaydiffeq tol, shorten adtype names
1 parent aa54f88 commit 8575786

File tree

1 file changed

+22
-21
lines changed

1 file changed

+22
-21
lines changed

main.jl

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,19 @@ import Zygote
1212

1313
# AD backends to test.
1414
ADTYPES = Dict(
15-
"FiniteDifferences" => AutoFiniteDifferences(; fdm = central_fdm(5, 1)),
15+
"FiniteDifferences" => AutoFiniteDifferences(; fdm=central_fdm(5, 1)),
1616
"ForwardDiff" => AutoForwardDiff(),
17-
"ReverseDiff" => AutoReverseDiff(; compile = false),
18-
"ReverseDiffCompiled" => AutoReverseDiff(; compile = true),
19-
"MooncakeReverse" => AutoMooncake(),
20-
"MooncakeForward" => AutoMooncakeForward(),
21-
"EnzymeForward" => AutoEnzyme(;
22-
mode = set_runtime_activity(Forward, true),
23-
function_annotation = Const,
17+
"ReverseDiff" => AutoReverseDiff(; compile=false),
18+
"ReverseDiffCompiled" => AutoReverseDiff(; compile=true),
19+
"MooncakeRvs" => AutoMooncake(),
20+
"MooncakeFwd" => AutoMooncakeForward(),
21+
"EnzymeFwd" => AutoEnzyme(;
22+
mode=set_runtime_activity(Forward, true),
23+
function_annotation=Const,
2424
),
25-
"EnzymeReverse" => AutoEnzyme(;
26-
mode = set_runtime_activity(Reverse, true),
27-
function_annotation = Const,
25+
"EnzymeRvs" => AutoEnzyme(;
26+
mode=set_runtime_activity(Reverse, true),
27+
function_annotation=Const,
2828
),
2929
"Zygote" => AutoZygote(),
3030
)
@@ -162,35 +162,36 @@ elseif length(ARGS) == 3 && ARGS[1] == "--run"
162162
result = run_ad(
163163
model,
164164
adtype;
165-
varinfo = vi,
166-
params = params,
167-
test = WithBackend(ref_backend),
168-
benchmark = true,
165+
varinfo=vi,
166+
params=params,
167+
test=WithBackend(ref_backend),
168+
benchmark=true,
169169
)
170170
else
171171
# Some models are more numerically sensitive
172172
rtol = if model_name == "dppl_logistic_regression"
173173
1e-1
174174
elseif model_name == "lux_nn"
175175
1e-2
176-
elseif model_name == "ordinarydiffeq"
176+
elseif model_name == "ordinarydiffeq" || model_name == "delaydiffeq"
177177
1e-3
178178
else
179179
sqrt(eps())
180180
end
181181
result = run_ad(
182182
model,
183183
adtype;
184-
rng = Xoshiro(468),
185-
test = WithBackend(ref_backend),
186-
benchmark = true,
187-
rtol = rtol,
184+
rng=Xoshiro(468),
185+
test=WithBackend(ref_backend),
186+
benchmark=true,
187+
rtol=rtol,
188188
)
189189
end
190190
# If reached here - nothing went wrong
191191
println(result.grad_time / result.primal_time)
192192
catch e
193-
@show e
193+
showerror(stderr, e)
194+
println()
194195
if e isa ADIncorrectException
195196
# If not, check for NaN's and report those
196197
if any(isnan, e.grad_expected) || any(isnan, e.grad_actual)

0 commit comments

Comments
 (0)