@@ -5,8 +5,8 @@ using Preferences: load_preference
5
5
using Random: AbstractRNG
6
6
7
7
using .. MLDataDevices: MLDataDevices, AbstractDevice, CPUDevice, CUDADevice, AMDGPUDevice,
8
- MetalDevice, oneAPIDevice, XLADevice, supported_gpu_backends ,
9
- GPU_DEVICES, loaded, functional
8
+ MetalDevice, oneAPIDevice, XLADevice, UnknownDevice ,
9
+ supported_gpu_backends, GPU_DEVICES, loaded, functional
10
10
11
11
for dev in (CPUDevice, MetalDevice, oneAPIDevice)
12
12
msg = " `device_id` is not applicable for `$dev `."
@@ -107,31 +107,38 @@ special_aos(::AbstractArray) = false
107
107
recursive_array_eltype (:: Type{T} ) where {T} = ! isbitstype (T) && ! (T <: Number )
108
108
109
109
combine_devices (:: Nothing , :: Nothing ) = nothing
110
- combine_devices (:: Type{Nothing} , :: Type{Nothing} ) = Nothing
111
110
combine_devices (:: Nothing , dev:: AbstractDevice ) = dev
112
- combine_devices (:: Type{Nothing} , :: Type{T} ) where {T <: AbstractDevice } = T
113
111
combine_devices (dev:: AbstractDevice , :: Nothing ) = dev
114
- combine_devices (:: Type{T} , :: Type{Nothing} ) where {T <: AbstractDevice } = T
115
112
function combine_devices (dev1:: AbstractDevice , dev2:: AbstractDevice )
116
113
dev1 == dev2 && return dev1
114
+ dev1 isa UnknownDevice && return dev2
115
+ dev2 isa UnknownDevice && return dev1
117
116
throw (ArgumentError (" Objects are on different devices: $(dev1) and $(dev2) ." ))
118
117
end
118
+
119
+ combine_devices (:: Type{Nothing} , :: Type{Nothing} ) = Nothing
119
120
combine_devices (:: Type{T} , :: Type{T} ) where {T <: AbstractDevice } = T
121
+ combine_devices (:: Type{T} , :: Type{Nothing} ) where {T <: AbstractDevice } = T
122
+ combine_devices (:: Type{T} , :: Type{UnknownDevice} ) where {T <: AbstractDevice } = T
123
+ combine_devices (:: Type{Nothing} , :: Type{T} ) where {T <: AbstractDevice } = T
124
+ combine_devices (:: Type{UnknownDevice} , :: Type{T} ) where {T <: AbstractDevice } = T
125
+ combine_devices (:: Type{UnknownDevice} , :: Type{UnknownDevice} ) = UnknownDevice
120
126
function combine_devices (T1:: Type{<:AbstractDevice} , T2:: Type{<:AbstractDevice} )
121
127
throw (ArgumentError (" Objects are on devices with different types: $(T1) and $(T2) ." ))
122
128
end
123
129
124
130
for op in (:get_device , :get_device_type )
125
131
cpu_ret_val = op == :get_device ? CPUDevice () : CPUDevice
132
+ unknown_ret_val = op == :get_device ? UnknownDevice () : UnknownDevice
126
133
not_assigned_msg = " AbstractArray has some undefined references. Giving up, returning \
127
- $(cpu_ret_val ) ..."
134
+ $(unknown_ret_val ) ..."
128
135
129
136
@eval begin
130
137
function $ (op)(x:: AbstractArray{T} ) where {T}
131
138
if recursive_array_eltype (T)
132
139
if any (! isassigned (x, i) for i in eachindex (x))
133
140
@warn $ (not_assigned_msg)
134
- return $ (cpu_ret_val )
141
+ return $ (unknown_ret_val )
135
142
end
136
143
return mapreduce (MLDataDevices.$ (op), combine_devices, x)
137
144
end
@@ -147,13 +154,31 @@ for op in (:get_device, :get_device_type)
147
154
length (x) == 0 && return $ (op == :get_device ? nothing : Nothing)
148
155
return unrolled_mapreduce (MLDataDevices.$ (op), combine_devices, values (x))
149
156
end
157
+
158
+ function $ (op)(f:: F ) where {F <: Function }
159
+ Base. issingletontype (F) &&
160
+ return $ (op == :get_device ? UnknownDevice () : UnknownDevice)
161
+ return unrolled_mapreduce (MLDataDevices.$ (op), combine_devices,
162
+ map (Base. Fix1 (getfield, f), fieldnames (F)))
163
+ end
150
164
end
151
165
152
166
for T in (Number, AbstractRNG, Val, Symbol, String, Nothing, AbstractRange)
153
167
@eval $ (op)(:: $ (T)) = $ (op == :get_device ? nothing : Nothing)
154
168
end
155
169
end
156
170
171
+ get_device (_) = UnknownDevice ()
172
+ get_device_type (_) = UnknownDevice
173
+
174
+ fast_structure (:: AbstractArray ) = true
175
+ fast_structure (:: Union{Tuple, NamedTuple} ) = true
176
+ for T in (Number, AbstractRNG, Val, Symbol, String, Nothing, AbstractRange)
177
+ @eval fast_structure (:: $ (T)) = true
178
+ end
179
+ fast_structure (:: Function ) = true
180
+ fast_structure (_) = false
181
+
157
182
function unrolled_mapreduce (f:: F , op:: O , itr) where {F, O}
158
183
return unrolled_mapreduce (f, op, itr, static_length (itr))
159
184
end
0 commit comments