Skip to content
This repository was archived by the owner on Nov 4, 2024. It is now read-only.

Commit e9a2ed7

Browse files
fix: handle bitstypes and wrapped arrays in isleaf (#88)
* bitstype and wrapped arrays * fixes * fix import * bound * cleanup * chore: fix min version of LinearAlgebra * chore: run formatter --------- Co-authored-by: Avik Pal <[email protected]> Co-authored-by: Avik Pal <[email protected]>
1 parent 7a39f50 commit e9a2ed7

File tree

4 files changed

+51
-16
lines changed

4 files changed

+51
-16
lines changed

Project.toml

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
name = "MLDataDevices"
22
uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
33
authors = ["Avik Pal <[email protected]> and contributors"]
4-
version = "1.4.0"
4+
version = "1.4.1"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
88
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
99
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
10+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1011
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
1112
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1213

@@ -53,6 +54,7 @@ Compat = "4.15"
5354
FillArrays = "1"
5455
Functors = "0.4.8"
5556
GPUArrays = "10, 11"
57+
LinearAlgebra = "1.10"
5658
MLUtils = "0.4.4"
5759
Metal = "1"
5860
Preferences = "1.4"

src/MLDataDevices.jl

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ using Functors: Functors, fleaves
55
using Preferences: @delete_preferences!, @load_preference, @set_preferences!
66
using Random: AbstractRNG, Random
77
using Compat: @compat
8+
using LinearAlgebra: Transpose, Adjoint
89

910
abstract type AbstractDevice <: Function end
1011
abstract type AbstractCPUDevice <: AbstractDevice end

src/public.jl

+3
Original file line numberDiff line numberDiff line change
@@ -397,3 +397,6 @@ data movement if `isleaf(x::T) == true`.
397397
If `MLDataDevices.isleaf(x::T)` is not defined, then it will fall back to `Functors.isleaf(x)`.
398398
"""
399399
isleaf(x) = Functors.isleaf(x)
400+
401+
isleaf(::AbstractArray{T}) where {T} = isbitstype(T)
402+
isleaf(::Union{Transpose, Adjoint, PermutedDimsArray}) = false

test/misc_tests.jl

+44-15
Original file line numberDiff line numberDiff line change
@@ -160,21 +160,50 @@ end
160160
end
161161

162162
@testset "isleaf" begin
163-
# Functors.isleaf fallback
164-
@test MLDataDevices.isleaf(rand(2))
165-
@test !MLDataDevices.isleaf((rand(2),))
163+
@testset "basics" begin
164+
# Functors.isleaf fallback
165+
@test MLDataDevices.isleaf(rand(2))
166+
@test !MLDataDevices.isleaf((rand(2),))
167+
168+
struct Tleaf
169+
x::Any
170+
end
171+
Functors.@functor Tleaf
172+
MLDataDevices.isleaf(::Tleaf) = true
173+
Adapt.adapt_structure(dev::CPUDevice, t::Tleaf) = Tleaf(2 .* dev(t.x))
174+
175+
cpu = cpu_device()
176+
t = Tleaf(ones(2))
177+
y = cpu(t)
178+
@test y.x == 2 .* ones(2)
179+
y = cpu([(t,)])
180+
@test y[1][1].x == 2 .* ones(2)
181+
end
182+
183+
@testset "shared parameters" begin
184+
# from
185+
x = rand(1)
186+
m = (; a=x, b=x')
187+
count = Ref(0)
188+
mcopy = Functors.fmap(m; exclude=MLDataDevices.isleaf) do x
189+
count[] += 1
190+
return copy(x)
191+
end
192+
@test count[] == 1
193+
@test mcopy.a === mcopy.b'
194+
end
166195

167-
struct Tleaf
168-
x::Any
196+
@testset "bitstypes and wrapped types" begin
197+
struct BitsType
198+
x::Int32
199+
y::Float64
200+
end
201+
202+
for x in [1.0, 'a', BitsType(1, 2.0)]
203+
@test MLDataDevices.isleaf([x])
204+
@test !MLDataDevices.isleaf([x]')
205+
@test !MLDataDevices.isleaf(transpose([x]))
206+
@test !MLDataDevices.isleaf(PermutedDimsArray([x;;], (1, 2)))
207+
end
169208
end
170-
Functors.@functor Tleaf
171-
MLDataDevices.isleaf(::Tleaf) = true
172-
Adapt.adapt_structure(dev::CPUDevice, t::Tleaf) = Tleaf(2 .* dev(t.x))
173-
174-
cpu = cpu_device()
175-
t = Tleaf(ones(2))
176-
y = cpu(t)
177-
@test y.x == 2 .* ones(2)
178-
y = cpu([(t,)])
179-
@test y[1][1].x == 2 .* ones(2)
180209
end

0 commit comments

Comments
 (0)