Skip to content

Commit 8628729

Browse files
authored
Add atomic float support (#544)
1 parent c6c0ec1 commit 8628729

File tree

4 files changed

+176
-104
lines changed

4 files changed

+176
-104
lines changed

src/context.jl

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
# XXX: rework this -- it doesn't work well when altering the state
88

9-
export driver, driver!, device, device!, context, context!, global_queue, synchronize
9+
export driver, driver!, device, device!, context, context!, global_queue, synchronize, is_integrated
1010

1111
"""
1212
driver() -> ZeDriver
@@ -111,6 +111,40 @@ function device!(i::Int)
111111
return device!(devs[i])
112112
end
113113

114+
"""
115+
is_integrated(dev::ZeDevice=device()) -> Bool
116+
117+
Check if the given device is an integrated GPU (i.e., integrated with the host processor).
118+
119+
Integrated GPUs share memory with the CPU and are typically found in laptop and desktop
120+
processors with integrated graphics.
121+
122+
# Arguments
123+
- `dev::ZeDevice`: The device to check. Defaults to the current device.
124+
125+
# Returns
126+
- `true` if the device is integrated, `false` otherwise (e.g., discrete GPU).
127+
128+
# Examples
129+
```julia
130+
if is_integrated()
131+
println("Running on integrated graphics")
132+
else
133+
println("Running on discrete GPU")
134+
end
135+
136+
# Check a specific device
137+
dev = devices()[1]
138+
is_integrated(dev)
139+
```
140+
141+
See also: [`device`](@ref), [`devices`](@ref)
142+
"""
143+
function is_integrated(dev::ZeDevice=device())
144+
props = oneL0.properties(dev)
145+
return (props.flags & oneL0.ZE_DEVICE_PROPERTY_FLAG_INTEGRATED) != 0
146+
end
147+
114148
const global_contexts = Dict{ZeDriver,ZeContext}()
115149

116150
"""

src/device/atomics.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Atomic operation device overrides and fallbacks
2+
3+
# Fallback wrappers for Float32 atomic_inc!/atomic_dec!
4+
# Intel Level Zero doesn't support these directly for floating-point types,
5+
# so we implement them using atomic_add!/atomic_sub!
6+
7+
@device_override @inline function SPIRVIntrinsics.atomic_inc!(p::LLVMPtr{Float32, AS}) where {AS}
8+
SPIRVIntrinsics.atomic_add!(p, Float32(1))
9+
end
10+
11+
@device_override @inline function SPIRVIntrinsics.atomic_dec!(p::LLVMPtr{Float32, AS}) where {AS}
12+
SPIRVIntrinsics.atomic_sub!(p, Float32(1))
13+
end
14+
15+
# Float64 fallbacks (if Float64 is supported on device)
16+
@device_override @inline function SPIRVIntrinsics.atomic_inc!(p::LLVMPtr{Float64, AS}) where {AS}
17+
SPIRVIntrinsics.atomic_add!(p, Float64(1))
18+
end
19+
20+
@device_override @inline function SPIRVIntrinsics.atomic_dec!(p::LLVMPtr{Float64, AS}) where {AS}
21+
SPIRVIntrinsics.atomic_sub!(p, Float64(1))
22+
end

src/oneAPI.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ Base.Experimental.@MethodTable(method_table)
3434
include("device/runtime.jl")
3535
include("device/array.jl")
3636
include("device/quirks.jl")
37+
include("device/atomics.jl")
3738

3839
# essential stuff
3940
include("context.jl")

test/device/intrinsics.jl

Lines changed: 118 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -274,149 +274,164 @@ end
274274

275275
############################################################################################
276276

277-
@testset "atomics (low level)" begin
277+
# @testset "atomics (low level)" begin
278278

279-
@testset "atomic_add($T)" for T in [Int32, UInt32]
280-
a = oneArray([zero(T)])
279+
@testset "atomic_add($T)" for T in [Int32, UInt32, Float32]
280+
if oneAPI.is_integrated() && T == Float32
281+
continue
282+
end
283+
a = oneArray([zero(T)])
281284

282-
function kernel(a, b)
283-
oneAPI.atomic_add!(pointer(a), b)
284-
return
285+
function kernel(a, b)
286+
oneAPI.atomic_add!(pointer(a), b)
287+
return
288+
end
289+
290+
@oneapi items=256 kernel(a, one(T))
291+
@test Array(a)[1] == T(256)
285292
end
286293

287-
@oneapi items=256 kernel(a, one(T))
288-
@test Array(a)[1] == T(256)
289-
end
294+
@testset "atomic_sub($T)" for T in [Int32, UInt32, Float32]
295+
if oneAPI.is_integrated() && T == Float32
296+
continue
297+
end
298+
a = oneArray([T(256)])
290299

291-
@testset "atomic_sub($T)" for T in [Int32, UInt32]
292-
a = oneArray([T(256)])
300+
function kernel(a, b)
301+
oneAPI.atomic_sub!(pointer(a), b)
302+
return
303+
end
293304

294-
function kernel(a, b)
295-
oneAPI.atomic_sub!(pointer(a), b)
296-
return
305+
@oneapi items=256 kernel(a, one(T))
306+
@test Array(a)[1] == T(0)
297307
end
298308

299-
@oneapi items=256 kernel(a, one(T))
300-
@test Array(a)[1] == T(0)
301-
end
309+
@testset "atomic_inc($T)" for T in [Int32, UInt32]
310+
a = oneArray([zero(T)])
302311

303-
@testset "atomic_inc($T)" for T in [Int32, UInt32]
304-
a = oneArray([zero(T)])
312+
function kernel(a)
313+
oneAPI.atomic_inc!(pointer(a))
314+
return
315+
end
305316

306-
function kernel(a)
307-
oneAPI.atomic_inc!(pointer(a))
308-
return
317+
@oneapi items=256 kernel(a)
318+
@test Array(a)[1] == T(256)
309319
end
310320

311-
@oneapi items=256 kernel(a)
312-
@test Array(a)[1] == T(256)
313-
end
321+
@testset "atomic_dec($T)" for T in [Int32, UInt32]
322+
a = oneArray([T(256)])
314323

315-
@testset "atomic_dec($T)" for T in [Int32, UInt32]
316-
a = oneArray([T(256)])
324+
function kernel(a)
325+
oneAPI.atomic_dec!(pointer(a))
326+
return
327+
end
317328

318-
function kernel(a)
319-
oneAPI.atomic_dec!(pointer(a))
320-
return
329+
@oneapi items=256 kernel(a)
330+
@test Array(a)[1] == T(0)
321331
end
322332

323-
@oneapi items=256 kernel(a)
324-
@test Array(a)[1] == T(0)
325-
end
333+
@testset "atomic_min($T)" for T in [Int32, UInt32, Float32]
334+
if oneAPI.is_integrated() && T == Float32
335+
continue
336+
end
337+
a = oneArray([T(256)])
326338

327-
@testset "atomic_min($T)" for T in [Int32, UInt32]
328-
a = oneArray([T(256)])
339+
function kernel(a, T)
340+
i = get_global_id()
341+
oneAPI.atomic_min!(pointer(a), T(i))
342+
return
343+
end
329344

330-
function kernel(a, T)
331-
i = get_global_id()
332-
oneAPI.atomic_min!(pointer(a), i%T)
333-
return
345+
@oneapi items=256 kernel(a, T)
346+
@test Array(a)[1] == one(T)
334347
end
335348

336-
@oneapi items=256 kernel(a, T)
337-
@test Array(a)[1] == one(T)
338-
end
349+
@testset "atomic_max($T)" for T in [Int32, UInt32, Float32]
350+
if oneAPI.is_integrated() && T == Float32
351+
continue
352+
end
353+
a = oneArray([zero(T)])
339354

340-
@testset "atomic_max($T)" for T in [Int32, UInt32]
341-
a = oneArray([zero(T)])
355+
function kernel(a, T)
356+
i = get_global_id()
357+
oneAPI.atomic_max!(pointer(a), T(i))
358+
return
359+
end
342360

343-
function kernel(a, T)
344-
i = get_global_id()
345-
oneAPI.atomic_max!(pointer(a), i%T)
346-
return
361+
@oneapi items=256 kernel(a, T)
362+
@test Array(a)[1] == T(256)
347363
end
348364

349-
@oneapi items=256 kernel(a, T)
350-
@test Array(a)[1] == T(256)
351-
end
352-
353-
@testset "atomic_and($T)" for T in [Int32, UInt32]
354-
a = oneArray([T(1023)])
365+
@testset "atomic_and($T)" for T in [Int32, UInt32]
366+
a = oneArray([T(1023)])
355367

356-
function kernel(a, T)
357-
i = get_global_id() - 1
358-
k = 1
359-
for i = 1:i
360-
k *= 2
368+
function kernel(a, T)
369+
i = get_global_id() - 1
370+
k = 1
371+
for i = 1:i
372+
k *= 2
373+
end
374+
b = 1023 - k # 1023 - 2^i
375+
oneAPI.atomic_and!(pointer(a), T(b))
376+
return
361377
end
362-
b = 1023 - k # 1023 - 2^i
363-
oneAPI.atomic_and!(pointer(a), T(b))
364-
return
365-
end
366378

367-
@oneapi items=10 kernel(a, T)
368-
@test Array(a)[1] == zero(T)
369-
end
379+
@oneapi items=10 kernel(a, T)
380+
@test Array(a)[1] == zero(T)
381+
end
370382

371-
@testset "atomic_or($T)" for T in [Int32, UInt32]
372-
a = oneArray([zero(T)])
383+
@testset "atomic_or($T)" for T in [Int32, UInt32]
384+
a = oneArray([zero(T)])
373385

374-
function kernel(a, T)
375-
i = get_global_id()
376-
b = 1 # 2^(i-1)
377-
for i = 1:i
378-
b *= 2
386+
function kernel(a, T)
387+
i = get_global_id()
388+
b = 1 # 2^(i-1)
389+
for i = 1:i
390+
b *= 2
391+
end
392+
b ÷= 2
393+
oneAPI.atomic_or!(pointer(a), T(b))
394+
return
379395
end
380-
b ÷= 2
381-
oneAPI.atomic_or!(pointer(a), T(b))
382-
return
383-
end
384396

385-
@oneapi items=10 kernel(a, T)
386-
@test Array(a)[1] == T(1023)
387-
end
397+
@oneapi items=10 kernel(a, T)
398+
@test Array(a)[1] == T(1023)
399+
end
388400

389-
@testset "atomic_xor($T)" for T in [Int32, UInt32]
390-
a = oneArray([T(1023)])
401+
@testset "atomic_xor($T)" for T in [Int32, UInt32]
402+
a = oneArray([T(1023)])
391403

392-
function kernel(a, T)
393-
i = get_global_id()
394-
b = 1 # 2^(i-1)
395-
for i = 1:i
396-
b *= 2
404+
function kernel(a, T)
405+
i = get_global_id()
406+
b = 1 # 2^(i-1)
407+
for i = 1:i
408+
b *= 2
409+
end
410+
b ÷= 2
411+
oneAPI.atomic_xor!(pointer(a), T(b))
412+
return
397413
end
398-
b ÷= 2
399-
oneAPI.atomic_xor!(pointer(a), T(b))
400-
return
414+
415+
@oneapi items=10 kernel(a, T)
416+
@test Array(a)[1] == zero(T)
401417
end
402418

403-
@oneapi items=10 kernel(a, T)
404-
@test Array(a)[1] == zero(T)
405-
end
419+
@testset "atomic_xchg($T)" for T in [Int32, UInt32, Float32]
420+
if oneAPI.is_integrated() && T == Float32
421+
continue
422+
end
423+
a = oneArray([zero(T)])
406424

407-
@testset "atomic_xchg($T)" for T in [Int32, UInt32, Float32]
408-
a = oneArray([zero(T)])
425+
function kernel(a, b)
426+
oneAPI.atomic_xchg!(pointer(a), b)
427+
return
428+
end
409429

410-
function kernel(a, b)
411-
oneAPI.atomic_xchg!(pointer(a), b)
412-
return
430+
@oneapi items=256 kernel(a, one(T))
431+
@test Array(a)[1] == one(T)
413432
end
414433

415-
@oneapi items=256 kernel(a, one(T))
416-
@test Array(a)[1] == one(T)
417-
end
418-
419-
end
434+
# end
420435

421436

422437

0 commit comments

Comments
 (0)