Skip to content

Commit

Permalink
add a throttle macro
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Nov 5, 2024
1 parent 7be1ca7 commit e8572ab
Showing 1 changed file with 55 additions and 7 deletions.
62 changes: 55 additions & 7 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -551,8 +551,9 @@ _eltype(::AbstractArray{T}) where T = T
"""
throttle(f, timeout; leading=true, trailing=false)
Return a function that when invoked, will only be triggered at most once
during `timeout` seconds.
Return a function that when called, will only call the given `f` at most
once during `timeout` seconds. Any arguments passed to this new function
are passed to `f`.
Normally, the throttled function will run as much as it can, without ever
going more than once per `wait` duration; but if you'd like to disable the
Expand All @@ -561,17 +562,27 @@ the trailing edge, pass `trailing=true`.
# Examples
```jldoctest
julia> a = Flux.throttle(() -> println("Flux"), 2);
julia> noarg = Flux.throttle(() -> println("Flux"), 2);
julia> for i = 1:4 # a called in alternate iterations
a()
julia> for i in 1:4
noarg() # println called in alternate iterations
sleep(1)
end
Flux
Flux
julia> onearg = Flux.throttle(i -> println("step = ", i), 1);
julia> for i in 1:10
onearg(i)
sleep(0.3)
end
step = 1
step = 5
step = 9
```
"""
function throttle(f, timeout; leading=true, trailing=false)
function throttle(f, timeout::Real; leading=true, trailing=false)
cooldown = true
later = nothing
result = nothing
Expand Down Expand Up @@ -603,6 +614,44 @@ function throttle(f, timeout; leading=true, trailing=false)
end
end

"""
@throttle timeout expr
Evaluates the given expression at most once every `timeout` seconds.
Internally, it uses [`throttle`](@ref Flux.throttle). But instead of
defining a function outside the loop, it lets you place the code inside
the loop.
# Example
```jldoctest
julia> for i in 1:20
j = 100i
sleep(0.2)
Flux.@throttle 0.9 if iseven(i)
println("i = ", i, ", and j = ", j)
else
println("i = ", i)
end
end
i = 1
i = 6, and j = 600
i = 11
i = 16, and j = 1600
```
"""
macro throttle(timeout::Real, ex)
expr = macroexpand(__module__, ex)
vars = unique(_allsymbols(expr))
@gensym fast slow
Base.eval(__module__, :($fast($(vars...)) = $expr))

Check warning on line 647 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L647

Added line #L647 was not covered by tests
Base.eval(__module__, :(const $slow = $throttle($fast, $timeout)))
:($slow($(vars...))) |> esc
end

_allsymbols(s::Symbol) =[s]
_allsymbols(other) = Symbol[]
_allsymbols(ex::Expr) = vcat(_allsymbols.(ex.args)...)

Check warning on line 654 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L652-L654

Added lines #L652 - L654 were not covered by tests

"""
modules(m)
Expand Down Expand Up @@ -675,7 +724,6 @@ julia> loss() = rand();
julia> trigger = Flux.patience(() -> loss() < 1, 3);
julia> for i in 1:10
@info "Epoch \$i"
trigger() && break
Expand Down

0 comments on commit e8572ab

Please sign in to comment.