11# zero-preserving Traits
22# ----------------------
33"""
4- abstract type ZeroPreserving end
4+ abstract type ZeroPreserving <: Function end
55
66Holy Trait to indicate how a function interacts with abstract zero values:
77
@@ -15,10 +15,17 @@ To attempt to automatically determine this, either `ZeroPreserving(f, A::Abstrac
1515!!! warning
1616 incorrectly registering a function to be zero-preserving will lead to silently wrong results.
1717"""
18- abstract type ZeroPreserving end
19- struct StrongPreserving <: ZeroPreserving end
20- struct WeakPreserving <: ZeroPreserving end
21- struct NonPreserving <: ZeroPreserving end
18+ abstract type ZeroPreserving <: Function end
19+
20+ struct StrongPreserving{F} <: ZeroPreserving
21+ f:: F
22+ end
23+ struct WeakPreserving{F} <: ZeroPreserving
24+ f:: F
25+ end
26+ struct NonPreserving{F} <: ZeroPreserving
27+ f:: F
28+ end
2229
2330# Backport: remove in 1.12
2431@static if ! isdefined (Base, :haszero )
3643# TODO : non-concrete element types
3744function ZeroPreserving (f, T:: Type , Ts:: Type... )
3845 if all (_haszero, (T, Ts... ))
39- return iszero (f (zero (T), zero .(Ts)... )) ? WeakPreserving () : NonPreserving ()
46+ return iszero (f (zero (T), zero .(Ts)... )) ? WeakPreserving (f ) : NonPreserving (f )
4047 else
41- return NonPreserving ()
48+ return NonPreserving (f )
4249 end
4350end
4451
4552const _WEAK_FUNCTIONS = (:+ , :- )
4653for f in _WEAK_FUNCTIONS
4754 @eval begin
48- ZeroPreserving (:: typeof ($ f), :: Type{<:Number} , :: Type{<:Number} ...) = WeakPreserving ()
55+ ZeroPreserving (:: typeof ($ f), :: Type{<:Number} , :: Type{<:Number} ...) = WeakPreserving ($ f )
4956 end
5057end
5158
5259const _STRONG_FUNCTIONS = (:* ,)
5360for f in _STRONG_FUNCTIONS
5461 @eval begin
55- ZeroPreserving (:: typeof ($ f), :: Type{<:Number} , :: Type{<:Number} ...) = StrongPreserving ()
62+ ZeroPreserving (:: typeof ($ f), :: Type{<:Number} , :: Type{<:Number} ...) = StrongPreserving (
63+ $ f
64+ )
5665 end
5766end
5867
6170@interface I:: AbstractSparseArrayInterface function Base. map (
6271 f, A:: AbstractArray , Bs:: AbstractArray...
6372)
64- T = Base. Broadcast. combine_eltypes (f, (A, Bs... ))
73+ f_pres = ZeroPreserving (f, A, Bs... )
74+ return @interface I map (f_pres, A, Bs... )
75+ end
76+ @interface I:: AbstractSparseArrayInterface function Base. map (
77+ f:: ZeroPreserving , A:: AbstractArray , Bs:: AbstractArray...
78+ )
79+ T = Base. Broadcast. combine_eltypes (f. f, (A, Bs... ))
6580 C = similar (I, T, size (A))
6681 return @interface I map! (f, C, A, Bs... )
6782end
6883
69- @interface :: AbstractSparseArrayInterface function Base. map! (
84+ @interface I :: AbstractSparseArrayInterface function Base. map! (
7085 f, C:: AbstractArray , A:: AbstractArray , Bs:: AbstractArray...
7186)
72- return _map! (f, ZeroPreserving (f, A, Bs... ), C, A, Bs... )
87+ f_pres = ZeroPreserving (f, A, Bs... )
88+ return @interface I map! (f_pres, C, A, Bs... )
7389end
7490
75- function _map! (
76- f, :: StrongPreserving , C:: AbstractArray , A:: AbstractArray , Bs:: AbstractArray...
77- )
78- checkshape (C, A, Bs... )
79- style = IndexStyle (C, A, Bs... )
80- unaliased = map (Base. Fix1 (Base. unalias, C), (A, Bs... ))
81- zero! (C)
82- for I in intersect (eachstoredindex .(Ref (style), unaliased)... )
83- @inbounds C[I] = f (ith_all (I, unaliased)... )
84- end
85- return C
86- end
87- function _map! (
88- f, :: WeakPreserving , C:: AbstractArray , A:: AbstractArray , Bs:: AbstractArray...
91+ @interface :: AbstractSparseArrayInterface function Base. map! (
92+ f:: ZeroPreserving , C:: AbstractArray , A:: AbstractArray , Bs:: AbstractArray...
8993)
9094 checkshape (C, A, Bs... )
91- style = IndexStyle (C, A, Bs... )
9295 unaliased = map (Base. Fix1 (Base. unalias, C), (A, Bs... ))
93- zero! (C)
94- for I in union (eachstoredindex .(Ref (style), unaliased)... )
95- @inbounds C[I] = f (ith_all (I, unaliased)... )
96+
97+ if f isa StrongPreserving
98+ style = IndexStyle (C, unaliased... )
99+ inds = intersect (eachstoredindex .(Ref (style), unaliased)... )
100+ zero! (C)
101+ elseif f isa WeakPreserving
102+ style = IndexStyle (C, unaliased... )
103+ inds = union (eachstoredindex .(Ref (style), unaliased)... )
104+ zero! (C)
105+ elseif f isa NonPreserving
106+ inds = eachindex (C, unaliased... )
107+ else
108+ error (lazy " unknown zero-preserving type $(typeof(f))" )
96109 end
97- return C
98- end
99- function _map! (f, :: NonPreserving , C:: AbstractArray , A:: AbstractArray , Bs:: AbstractArray... )
100- checkshape (C, A, Bs... )
101- unaliased = map (Base. Fix1 (Base. unalias, C), (A, Bs... ))
102- for I in eachindex (C, A, Bs... )
103- @inbounds C[I] = f (ith_all (I, unaliased)... )
110+
111+ @inbounds for I in inds
112+ C[I] = f. f (ith_all (I, unaliased)... )
104113 end
114+
105115 return C
106116end
107117
0 commit comments