Skip to content
This repository was archived by the owner on Nov 4, 2024. It is now read-only.

Commit 17bc9aa

Browse files
authored
feat: add fallbacks for unknown objects (#87)
* feat: add fallbacks for unknown objects * feat: handle RNGs and undef arrays gracefully * test: RNG movement * test: functions and closures
1 parent 0d6c6a8 commit 17bc9aa

14 files changed

+215
-21
lines changed

.buildkite/pipeline.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
steps:
22
- label: "Triggering Pipelines (Pull Request)"
3-
if: "build.pull_request.base_branch == 'main'"
3+
if: build.branch != "main" && build.tag == null
44
agents:
55
queue: "juliagpu"
66
plugins:

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MLDataDevices"
22
uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
33
authors = ["Avik Pal <[email protected]> and contributors"]
4-
version = "1.2.1"
4+
version = "1.3.0"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

ext/MLDataDevicesAMDGPUExt.jl

+2
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,10 @@ function Internal.get_device(x::AMDGPU.AnyROCArray)
4949
parent_x === x && return AMDGPUDevice(AMDGPU.device(x))
5050
return Internal.get_device(parent_x)
5151
end
52+
Internal.get_device(::AMDGPU.rocRAND.RNG) = AMDGPUDevice(AMDGPU.device())
5253

5354
Internal.get_device_type(::AMDGPU.AnyROCArray) = AMDGPUDevice
55+
Internal.get_device_type(::AMDGPU.rocRAND.RNG) = AMDGPUDevice
5456

5557
# Set Device
5658
function MLDataDevices.set_device!(::Type{AMDGPUDevice}, dev::AMDGPU.HIPDevice)

ext/MLDataDevicesCUDAExt.jl

+4
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,12 @@ function Internal.get_device(x::CUDA.AnyCuArray)
2929
return MLDataDevices.get_device(parent_x)
3030
end
3131
Internal.get_device(x::AbstractCuSparseArray) = CUDADevice(CUDA.device(x.nzVal))
32+
Internal.get_device(::CUDA.RNG) = CUDADevice(CUDA.device())
33+
Internal.get_device(::CUDA.CURAND.RNG) = CUDADevice(CUDA.device())
3234

3335
Internal.get_device_type(::Union{<:CUDA.AnyCuArray, <:AbstractCuSparseArray}) = CUDADevice
36+
Internal.get_device_type(::CUDA.RNG) = CUDADevice
37+
Internal.get_device_type(::CUDA.CURAND.RNG) = CUDADevice
3438

3539
# Set Device
3640
MLDataDevices.set_device!(::Type{CUDADevice}, dev::CUDA.CuDevice) = CUDA.device!(dev)

ext/MLDataDevicesChainRulesCoreExt.jl

+8-3
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,20 @@ module MLDataDevicesChainRulesCoreExt
33
using Adapt: Adapt
44
using ChainRulesCore: ChainRulesCore, NoTangent, @non_differentiable
55

6-
using MLDataDevices: AbstractDevice, get_device, get_device_type
6+
using MLDataDevices: AbstractDevice, UnknownDevice, get_device, get_device_type
77

88
@non_differentiable get_device(::Any)
99
@non_differentiable get_device_type(::Any)
1010

1111
function ChainRulesCore.rrule(
1212
::typeof(Adapt.adapt_storage), to::AbstractDevice, x::AbstractArray)
13-
∇adapt_storage = let x = x
14-
Δ -> (NoTangent(), NoTangent(), (get_device(x))(Δ))
13+
∇adapt_storage = let dev = get_device(x)
14+
if dev === nothing || dev isa UnknownDevice
15+
@warn "`get_device(::$(typeof(x)))` returned `$(dev)`." maxlog=1
16+
Δ -> (NoTangent(), NoTangent(), Δ)
17+
else
18+
Δ -> (NoTangent(), NoTangent(), dev(Δ))
19+
end
1520
end
1621
return Adapt.adapt_storage(to, x), ∇adapt_storage
1722
end

ext/MLDataDevicesGPUArraysExt.jl

+4-1
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,12 @@ module MLDataDevicesGPUArraysExt
22

33
using Adapt: Adapt
44
using GPUArrays: GPUArrays
5-
using MLDataDevices: CPUDevice
5+
using MLDataDevices: Internal, CPUDevice
66
using Random: Random
77

88
Adapt.adapt_storage(::CPUDevice, rng::GPUArrays.RNG) = Random.default_rng()
99

10+
Internal.get_device(rng::GPUArrays.RNG) = Internal.get_device(rng.state)
11+
Internal.get_device_type(rng::GPUArrays.RNG) = Internal.get_device_type(rng.state)
12+
1013
end

src/internal.jl

+32-7
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ using Preferences: load_preference
55
using Random: AbstractRNG
66

77
using ..MLDataDevices: MLDataDevices, AbstractDevice, CPUDevice, CUDADevice, AMDGPUDevice,
8-
MetalDevice, oneAPIDevice, XLADevice, supported_gpu_backends,
9-
GPU_DEVICES, loaded, functional
8+
MetalDevice, oneAPIDevice, XLADevice, UnknownDevice,
9+
supported_gpu_backends, GPU_DEVICES, loaded, functional
1010

1111
for dev in (CPUDevice, MetalDevice, oneAPIDevice)
1212
msg = "`device_id` is not applicable for `$dev`."
@@ -107,31 +107,38 @@ special_aos(::AbstractArray) = false
107107
recursive_array_eltype(::Type{T}) where {T} = !isbitstype(T) && !(T <: Number)
108108

109109
combine_devices(::Nothing, ::Nothing) = nothing
110-
combine_devices(::Type{Nothing}, ::Type{Nothing}) = Nothing
111110
combine_devices(::Nothing, dev::AbstractDevice) = dev
112-
combine_devices(::Type{Nothing}, ::Type{T}) where {T <: AbstractDevice} = T
113111
combine_devices(dev::AbstractDevice, ::Nothing) = dev
114-
combine_devices(::Type{T}, ::Type{Nothing}) where {T <: AbstractDevice} = T
115112
function combine_devices(dev1::AbstractDevice, dev2::AbstractDevice)
116113
dev1 == dev2 && return dev1
114+
dev1 isa UnknownDevice && return dev2
115+
dev2 isa UnknownDevice && return dev1
117116
throw(ArgumentError("Objects are on different devices: $(dev1) and $(dev2)."))
118117
end
118+
119+
combine_devices(::Type{Nothing}, ::Type{Nothing}) = Nothing
119120
combine_devices(::Type{T}, ::Type{T}) where {T <: AbstractDevice} = T
121+
combine_devices(::Type{T}, ::Type{Nothing}) where {T <: AbstractDevice} = T
122+
combine_devices(::Type{T}, ::Type{UnknownDevice}) where {T <: AbstractDevice} = T
123+
combine_devices(::Type{Nothing}, ::Type{T}) where {T <: AbstractDevice} = T
124+
combine_devices(::Type{UnknownDevice}, ::Type{T}) where {T <: AbstractDevice} = T
125+
combine_devices(::Type{UnknownDevice}, ::Type{UnknownDevice}) = UnknownDevice
120126
function combine_devices(T1::Type{<:AbstractDevice}, T2::Type{<:AbstractDevice})
121127
throw(ArgumentError("Objects are on devices with different types: $(T1) and $(T2)."))
122128
end
123129

124130
for op in (:get_device, :get_device_type)
125131
cpu_ret_val = op == :get_device ? CPUDevice() : CPUDevice
132+
unknown_ret_val = op == :get_device ? UnknownDevice() : UnknownDevice
126133
not_assigned_msg = "AbstractArray has some undefined references. Giving up, returning \
127-
$(cpu_ret_val)..."
134+
$(unknown_ret_val)..."
128135

129136
@eval begin
130137
function $(op)(x::AbstractArray{T}) where {T}
131138
if recursive_array_eltype(T)
132139
if any(!isassigned(x, i) for i in eachindex(x))
133140
@warn $(not_assigned_msg)
134-
return $(cpu_ret_val)
141+
return $(unknown_ret_val)
135142
end
136143
return mapreduce(MLDataDevices.$(op), combine_devices, x)
137144
end
@@ -147,13 +154,31 @@ for op in (:get_device, :get_device_type)
147154
length(x) == 0 && return $(op == :get_device ? nothing : Nothing)
148155
return unrolled_mapreduce(MLDataDevices.$(op), combine_devices, values(x))
149156
end
157+
158+
function $(op)(f::F) where {F <: Function}
159+
Base.issingletontype(F) &&
160+
return $(op == :get_device ? UnknownDevice() : UnknownDevice)
161+
return unrolled_mapreduce(MLDataDevices.$(op), combine_devices,
162+
map(Base.Fix1(getfield, f), fieldnames(F)))
163+
end
150164
end
151165

152166
for T in (Number, AbstractRNG, Val, Symbol, String, Nothing, AbstractRange)
153167
@eval $(op)(::$(T)) = $(op == :get_device ? nothing : Nothing)
154168
end
155169
end
156170

171+
get_device(_) = UnknownDevice()
172+
get_device_type(_) = UnknownDevice
173+
174+
fast_structure(::AbstractArray) = true
175+
fast_structure(::Union{Tuple, NamedTuple}) = true
176+
for T in (Number, AbstractRNG, Val, Symbol, String, Nothing, AbstractRange)
177+
@eval fast_structure(::$(T)) = true
178+
end
179+
fast_structure(::Function) = true
180+
fast_structure(_) = false
181+
157182
function unrolled_mapreduce(f::F, op::O, itr) where {F, O}
158183
return unrolled_mapreduce(f, op, itr, static_length(itr))
159184
end

src/public.jl

+16-6
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ struct oneAPIDevice <: AbstractGPUDevice end
1212
# TODO: Later we might want to add the client field here?
1313
struct XLADevice <: AbstractAcceleratorDevice end
1414

15+
# Fallback for when we don't know the device type
16+
struct UnknownDevice <: AbstractDevice end
17+
1518
"""
1619
functional(x::AbstractDevice) -> Bool
1720
functional(::Type{<:AbstractDevice}) -> Bool
@@ -229,11 +232,6 @@ const GET_DEVICE_ADMONITIONS = """
229232
!!! note
230233
231234
Trigger Packages must be loaded for this to return the correct device.
232-
233-
!!! warning
234-
235-
RNG types currently don't participate in device determination. We will remove this
236-
restriction in the future.
237235
"""
238236

239237
# Query Device from Array
@@ -245,6 +243,12 @@ device. Otherwise, we throw an error. If the object is device agnostic, we retur
245243
246244
$(GET_DEVICE_ADMONITIONS)
247245
246+
## Special Retuened Values
247+
248+
- `nothing` -- denotes that the object is device agnostic. For example, scalar, abstract
249+
range, etc.
250+
- `UnknownDevice()` -- denotes that the device type is unknown
251+
248252
See also [`get_device_type`](@ref) for a faster alternative that can be used for dispatch
249253
based on device type.
250254
"""
@@ -258,6 +262,12 @@ itself. This value is often a compile time constant and is recommended to be use
258262
of [`get_device`](@ref) where ever defining dispatches based on the device type.
259263
260264
$(GET_DEVICE_ADMONITIONS)
265+
266+
## Special Retuened Values
267+
268+
- `Nothing` -- denotes that the object is device agnostic. For example, scalar, abstract
269+
range, etc.
270+
- `UnknownDevice` -- denotes that the device type is unknown
261271
"""
262272
function get_device_type end
263273

@@ -345,7 +355,7 @@ end
345355

346356
for op in (:get_device, :get_device_type)
347357
@eval function $(op)(x)
348-
hasmethod(Internal.$(op), Tuple{typeof(x)}) && return Internal.$(op)(x)
358+
Internal.fast_structure(x) && return Internal.$(op)(x)
349359
return mapreduce(Internal.$(op), Internal.combine_devices, fleaves(x))
350360
end
351361
end

test/amdgpu_tests.jl

+29
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,11 @@ using FillArrays, Zygote # Extensions
5757
@test ps_xpu.e == ps.e
5858
@test ps_xpu.d == ps.d
5959
@test ps_xpu.rng_default isa rngType
60+
@test get_device(ps_xpu.rng_default) isa AMDGPUDevice
61+
@test get_device_type(ps_xpu.rng_default) <: AMDGPUDevice
6062
@test ps_xpu.rng == ps.rng
63+
@test get_device(ps_xpu.rng) === nothing
64+
@test get_device_type(ps_xpu.rng) <: Nothing
6165

6266
if MLDataDevices.functional(AMDGPUDevice)
6367
@test ps_xpu.one_elem isa ROCArray
@@ -83,7 +87,11 @@ using FillArrays, Zygote # Extensions
8387
@test ps_cpu.e == ps.e
8488
@test ps_cpu.d == ps.d
8589
@test ps_cpu.rng_default isa Random.TaskLocalRNG
90+
@test get_device(ps_cpu.rng_default) === nothing
91+
@test get_device_type(ps_cpu.rng_default) <: Nothing
8692
@test ps_cpu.rng == ps.rng
93+
@test get_device(ps_cpu.rng) === nothing
94+
@test get_device_type(ps_cpu.rng) <: Nothing
8795

8896
if MLDataDevices.functional(AMDGPUDevice)
8997
@test ps_cpu.one_elem isa Array
@@ -118,6 +126,27 @@ using FillArrays, Zygote # Extensions
118126
end
119127
end
120128

129+
@testset "Functions" begin
130+
if MLDataDevices.functional(AMDGPUDevice)
131+
@test get_device(tanh) isa MLDataDevices.UnknownDevice
132+
@test get_device_type(tanh) <: MLDataDevices.UnknownDevice
133+
134+
f(x, y) = () -> (x, x .^ 2, y)
135+
136+
ff = f([1, 2, 3], 1)
137+
@test get_device(ff) isa CPUDevice
138+
@test get_device_type(ff) <: CPUDevice
139+
140+
ff_xpu = ff |> AMDGPUDevice()
141+
@test get_device(ff_xpu) isa AMDGPUDevice
142+
@test get_device_type(ff_xpu) <: AMDGPUDevice
143+
144+
ff_cpu = ff_xpu |> cpu_device()
145+
@test get_device(ff_cpu) isa CPUDevice
146+
@test get_device_type(ff_cpu) <: CPUDevice
147+
end
148+
end
149+
121150
@testset "Wrapped Arrays" begin
122151
if MLDataDevices.functional(AMDGPUDevice)
123152
x = rand(10, 10) |> AMDGPUDevice()

test/cuda_tests.jl

+29
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,11 @@ using FillArrays, Zygote # Extensions
5656
@test ps_xpu.e == ps.e
5757
@test ps_xpu.d == ps.d
5858
@test ps_xpu.rng_default isa rngType
59+
@test get_device(ps_xpu.rng_default) isa CUDADevice
60+
@test get_device_type(ps_xpu.rng_default) <: CUDADevice
5961
@test ps_xpu.rng == ps.rng
62+
@test get_device(ps_xpu.rng) === nothing
63+
@test get_device_type(ps_xpu.rng) <: Nothing
6064

6165
if MLDataDevices.functional(CUDADevice)
6266
@test ps_xpu.one_elem isa CuArray
@@ -82,7 +86,11 @@ using FillArrays, Zygote # Extensions
8286
@test ps_cpu.e == ps.e
8387
@test ps_cpu.d == ps.d
8488
@test ps_cpu.rng_default isa Random.TaskLocalRNG
89+
@test get_device(ps_cpu.rng_default) === nothing
90+
@test get_device_type(ps_cpu.rng_default) <: Nothing
8591
@test ps_cpu.rng == ps.rng
92+
@test get_device(ps_cpu.rng) === nothing
93+
@test get_device_type(ps_cpu.rng) <: Nothing
8694

8795
if MLDataDevices.functional(CUDADevice)
8896
@test ps_cpu.one_elem isa Array
@@ -143,6 +151,27 @@ using FillArrays, Zygote # Extensions
143151
end
144152
end
145153

154+
@testset "Functions" begin
155+
if MLDataDevices.functional(CUDADevice)
156+
@test get_device(tanh) isa MLDataDevices.UnknownDevice
157+
@test get_device_type(tanh) <: MLDataDevices.UnknownDevice
158+
159+
f(x, y) = () -> (x, x .^ 2, y)
160+
161+
ff = f([1, 2, 3], 1)
162+
@test get_device(ff) isa CPUDevice
163+
@test get_device_type(ff) <: CPUDevice
164+
165+
ff_xpu = ff |> CUDADevice()
166+
@test get_device(ff_xpu) isa CUDADevice
167+
@test get_device_type(ff_xpu) <: CUDADevice
168+
169+
ff_cpu = ff_xpu |> cpu_device()
170+
@test get_device(ff_cpu) isa CPUDevice
171+
@test get_device_type(ff_cpu) <: CPUDevice
172+
end
173+
end
174+
146175
@testset "Wrapped Arrays" begin
147176
if MLDataDevices.functional(CUDADevice)
148177
x = rand(10, 10) |> CUDADevice()

test/metal_tests.jl

+29
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,11 @@ using FillArrays, Zygote # Extensions
5555
@test ps_xpu.e == ps.e
5656
@test ps_xpu.d == ps.d
5757
@test ps_xpu.rng_default isa rngType
58+
@test get_device(ps_xpu.rng_default) isa MetalDevice
59+
@test get_device_type(ps_xpu.rng_default) <: MetalDevice
5860
@test ps_xpu.rng == ps.rng
61+
@test get_device(ps_xpu.rng) === nothing
62+
@test get_device_type(ps_xpu.rng) <: Nothing
5963

6064
if MLDataDevices.functional(MetalDevice)
6165
@test ps_xpu.one_elem isa MtlArray
@@ -81,7 +85,11 @@ using FillArrays, Zygote # Extensions
8185
@test ps_cpu.e == ps.e
8286
@test ps_cpu.d == ps.d
8387
@test ps_cpu.rng_default isa Random.TaskLocalRNG
88+
@test get_device(ps_cpu.rng_default) === nothing
89+
@test get_device_type(ps_cpu.rng_default) <: Nothing
8490
@test ps_cpu.rng == ps.rng
91+
@test get_device(ps_cpu.rng) === nothing
92+
@test get_device_type(ps_cpu.rng) <: Nothing
8593

8694
if MLDataDevices.functional(MetalDevice)
8795
@test ps_cpu.one_elem isa Array
@@ -107,6 +115,27 @@ using FillArrays, Zygote # Extensions
107115
end
108116
end
109117

118+
@testset "Functions" begin
119+
if MLDataDevices.functional(MetalDevice)
120+
@test get_device(tanh) isa MLDataDevices.UnknownDevice
121+
@test get_device_type(tanh) <: MLDataDevices.UnknownDevice
122+
123+
f(x, y) = () -> (x, x .^ 2, y)
124+
125+
ff = f([1, 2, 3], 1)
126+
@test get_device(ff) isa CPUDevice
127+
@test get_device_type(ff) <: CPUDevice
128+
129+
ff_xpu = ff |> MetalDevice()
130+
@test get_device(ff_xpu) isa MetalDevice
131+
@test get_device_type(ff_xpu) <: MetalDevice
132+
133+
ff_cpu = ff_xpu |> cpu_device()
134+
@test get_device(ff_cpu) isa CPUDevice
135+
@test get_device_type(ff_cpu) <: CPUDevice
136+
end
137+
end
138+
110139
@testset "Wrapper Arrays" begin
111140
if MLDataDevices.functional(MetalDevice)
112141
x = rand(Float32, 10, 10) |> MetalDevice()

test/misc_tests.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,6 @@ end
154154
@testset "undefined references array" begin
155155
x = Matrix{Any}(undef, 10, 10)
156156

157-
@test get_device(x) isa CPUDevice
158-
@test get_device_type(x) <: CPUDevice
157+
@test get_device(x) isa MLDataDevices.UnknownDevice
158+
@test get_device_type(x) <: MLDataDevices.UnknownDevice
159159
end

0 commit comments

Comments
 (0)