@@ -11,68 +11,156 @@ void THNN_(RReLU_updateOutput)(
1111          real  upper ,
1212          bool  train ,
1313          bool  inplace ,
14+           bool  channelwise ,
1415          THGenerator  * generator )
1516{
16-   if  (train )
17+   if  (channelwise   &&   train )
1718  {
18-     // get default random generator 
19-     THTensor_ (resizeAs )(noise , input );
20-     if  (inplace )
19+     long  bs , ks ;
20+     long  nOutputPlane ;
2121    {
22-       TH_TENSOR_APPLY2 (real , input , real , noise ,
23-         if  (* input_data  <= 0 )
24-         {
25-           const  real  r  =  (real )THRandom_uniform (generator , lower , upper );
26-           * input_data  =  (* input_data ) *  r ;
27-           * noise_data  =  r ;
28-         }
29-         else 
30-         {
31-           * noise_data  =  1 ;
32-         }
33-       );
34-       THTensor_ (set )(output , input );
22+       long  input_ndim  =  THTensor_ (nDimension )(input );
23+       switch  (input_ndim )
24+       {
25+         case  1 :
26+           bs  =  1 ;
27+           ks  =  1 ;
28+           break ;
29+         case  2 :
30+           bs  =  input -> size [0 ];
31+           ks  =  1 ;
32+           break ;
33+         case  3 :
34+           bs  =  1 ;
35+           ks  =  input -> size [1 ] *  input -> size [2 ];
36+           break ;
37+         case  4 :
38+           bs  =  input -> size [0 ];
39+           ks  =  input -> size [2 ] *  input -> size [3 ];
40+           break ;
41+       }
42+       nOutputPlane  =  input -> size [(input_ndim  +  1 ) % 2 ];
3543    }
44+     // get default random generator 
45+     if  (inplace )
46+       THTensor_ (resizeAs )(noise , input );
3647    else 
48+       THTensor_ (resize1d )(noise , nOutputPlane );
49+ 
50+     real  * output_data  =  NULL ;
51+     real  * input_data  =  THTensor_ (data )(input );
52+     real  * noise_data  =  THTensor_ (data )(noise );
53+     if  (!inplace )
3754    {
3855      THTensor_ (resizeAs )(output , input );
39-       TH_TENSOR_APPLY3 (real , input , real , output , real , noise ,
40-         if  (* input_data  <= 0 )
41-         {
42-           const  real  r  =  (real )THRandom_uniform (generator , lower , upper );
43-           * output_data  =  (* input_data ) *  r ;
44-           * noise_data  =  r ;
45-         }
56+       output_data  =  THTensor_ (data )(output );
57+     }
58+     THTensor  * channel_noise  =  THTensor_ (newWithSize1d )(nOutputPlane );
59+     real  * channel_noise_data  =  THTensor_ (data )(channel_noise );
60+ 
61+     THIndex_t  i , j , k ;
62+ #pragma  omp parallel for private(j)
63+     for  (j  =  0 ; j  <  nOutputPlane ; ++ j )
64+       channel_noise_data [j ] =  (real )THRandom_uniform (generator , lower , upper );
65+ #pragma  omp parallel for private(j,k)
66+     for  (i  =  0 ; i  <  bs ; ++ i )
67+     {
68+       real *  n_input_data  =  input_data  +  i * nOutputPlane * ks ;
69+       real *  n_output_data  =  NULL ;
70+       real *  n_noise_data  =  NULL ;
71+       if  (inplace )
72+         n_noise_data  =  noise_data  +  i * nOutputPlane * ks ;
73+       else 
74+         n_output_data  =  output_data  +  i * nOutputPlane * ks ;
75+       for  (j  =  0 ; j  <  nOutputPlane ; ++ j )
76+       {
77+         const  real  r  =  channel_noise_data [j ];
78+         for  (k  =  0 ; k  <  ks ; ++ k )
79+           if  (inplace )
80+             if  (n_input_data [k ] <= 0 )
81+             {
82+               n_input_data [k ] =  r  *  n_input_data [k ];
83+               n_noise_data [k ] =  r ;
84+             }
85+             else 
86+               n_noise_data [k ] =  1 ;
87+           else 
88+             n_output_data [k ] =  (n_input_data [k ] >  0 ) ? n_input_data [k ] : r  *  n_input_data [k ];
89+         n_input_data  +=  ks ;
90+         if  (inplace )
91+           n_noise_data  +=  ks ;
4692        else 
47-         {
48-           * output_data  =  * input_data ;
49-           * noise_data  =  1 ;
50-         }
51-       );
93+           n_output_data  +=  ks ;
94+       }
5295    }
96+     if  (inplace )
97+       THTensor_ (set )(output , input );
98+     else 
99+       THTensor_ (set )(noise , channel_noise );
53100  }
54101  else 
55102  {
56-     const  real  negSlope  =  (lower  +  upper ) / 2 ;
57-     if  (inplace )
103+     if  (train )
58104    {
59-       TH_TENSOR_APPLY (real , input ,
60-         if  (* input_data  <= 0 )
61-         {
62-           * input_data  =  * input_data  *  negSlope ;
63-         }
64-       );
65-       THTensor_ (set )(output , input );
105+       // get default random generator 
106+       THTensor_ (resizeAs )(noise , input );
107+       if  (inplace )
108+       {
109+         TH_TENSOR_APPLY2 (real , input , real , noise ,
110+           if  (* input_data  <= 0 )
111+           {
112+             const  real  r  =  (real )THRandom_uniform (generator , lower , upper );
113+             * input_data  =  (* input_data ) *  r ;
114+             * noise_data  =  r ;
115+           }
116+           else 
117+           {
118+             * noise_data  =  1 ;
119+           }
120+         );
121+         THTensor_ (set )(output , input );
122+       }
123+       else 
124+       {
125+         THTensor_ (resizeAs )(output , input );
126+         TH_TENSOR_APPLY3 (real , input , real , output , real , noise ,
127+           if  (* input_data  <= 0 )
128+           {
129+             const  real  r  =  (real )THRandom_uniform (generator , lower , upper );
130+             * output_data  =  (* input_data ) *  r ;
131+             * noise_data  =  r ;
132+           }
133+           else 
134+           {
135+             * output_data  =  * input_data ;
136+             * noise_data  =  1 ;
137+           }
138+         );
139+       }
66140    }
67141    else 
68142    {
69-       THTensor_ (resizeAs )(output , input );
70-       TH_TENSOR_APPLY2 (real , input , real , output ,
71-         const  real  r  =  (* input_data ) <= 0  ? negSlope  : 1 ;
72-         * output_data  =  * input_data  *  r ;
73-       );
143+       const  real  negSlope  =  (lower  +  upper ) / 2 ;
144+       if  (inplace )
145+       {
146+         TH_TENSOR_APPLY (real , input ,
147+           if  (* input_data  <= 0 )
148+           {
149+             * input_data  =  * input_data  *  negSlope ;
150+           }
151+         );
152+         THTensor_ (set )(output , input );
153+       }
154+       else 
155+       {
156+         THTensor_ (resizeAs )(output , input );
157+         TH_TENSOR_APPLY2 (real , input , real , output ,
158+           const  real  r  =  (* input_data ) <= 0  ? negSlope  : 1 ;
159+           * output_data  =  * input_data  *  r ;
160+         );
161+       }
74162    }
75-   }   
163+   }
76164}
77165
78166void  THNN_ (RReLU_updateGradInput )(
@@ -84,24 +172,84 @@ void THNN_(RReLU_updateGradInput)(
84172          real  lower ,
85173          real  upper ,
86174          bool  train ,
87-           bool  inplace )
175+           bool  inplace ,
176+           bool  channelwise )
88177{
89178  if  (train  &&  upper  -  lower  >  1E-6 )    // e.g. if upper == lower, RReLU behaves like LeakyReLU 
90179  {
91-     // multiply the gradient by the noise tensor 
92-     if  (inplace )
180+     if  (channelwise  &&  !inplace )
93181    {
94-       THTensor_ (cmul )(gradOutput , gradOutput , noise );
95-       THTensor_ (set )(gradInput , gradOutput );
182+       long  bs , ks ;
183+       long  nOutputPlane ;
184+       {
185+         long  input_ndim  =  THTensor_ (nDimension )(input );
186+         switch  (input_ndim )
187+         {
188+           case  1 :
189+             bs  =  1 ;
190+             ks  =  1 ;
191+             break ;
192+           case  2 :
193+             bs  =  input -> size [0 ];
194+             ks  =  1 ;
195+             break ;
196+           case  3 :
197+             bs  =  1 ;
198+             ks  =  input -> size [1 ] *  input -> size [2 ];
199+             break ;
200+           case  4 :
201+             bs  =  input -> size [0 ];
202+             ks  =  input -> size [2 ] *  input -> size [3 ];
203+             break ;
204+         }
205+         nOutputPlane  =  input -> size [(input_ndim  +  1 ) % 2 ];
206+       }
207+ 
208+       const  real  * input_data  =  THTensor_ (data )(input );
209+       const  real  * gradOutput_data  =  THTensor_ (data )(gradOutput );
210+       THTensor_ (resizeAs )(gradInput , input );
211+       real  * gradInput_data  =  THTensor_ (data )(gradInput );
212+       const  real  * noise_data  =  THTensor_ (data )(noise );
213+ 
214+       THIndex_t  i , j , k ;
215+ #pragma  omp parallel for private(j,k)
216+       for  (i  =  0 ; i  <  bs ; ++ i )
217+       {
218+         const  real  * n_input_data  =  input_data  +  i * nOutputPlane * ks ;
219+         const  real  * n_gradOutput_data  =  gradOutput_data  +  i * nOutputPlane * ks ;
220+         real  * n_gradInput_data  =  gradInput_data  +  i * nOutputPlane * ks ;
221+ 
222+         for  (j  =  0 ; j  <  nOutputPlane ; ++ j )
223+         {
224+           const  real  r  =  noise_data [j ];
225+           for  (k  =  0 ; k  <  ks ; ++ k )
226+             if  (n_input_data [k ] >  0 )
227+               n_gradInput_data [k ] =  n_gradOutput_data [k ];
228+             else 
229+               n_gradInput_data [k ] =  n_gradOutput_data [k ] *  r ;
230+           n_input_data  +=  ks ;
231+           n_gradInput_data  +=  ks ;
232+           n_gradOutput_data  +=  ks ;
233+         }
234+       }
96235    }
97236    else 
98237    {
99-       THTensor_ (resizeAs )(gradInput , input );
100-       THTensor_ (cmul )(gradInput , gradOutput , noise );
101-     }    
238+       // multiply the gradient by the noise tensor 
239+       if  (inplace )
240+       {
241+         THTensor_ (cmul )(gradOutput , gradOutput , noise );
242+         THTensor_ (set )(gradInput , gradOutput );
243+       }
244+       else 
245+       {
246+         THTensor_ (resizeAs )(gradInput , input );
247+         THTensor_ (cmul )(gradInput , gradOutput , noise );
248+       }
249+     }
102250  }
103251  else 
104-   {  
252+   {
105253    // use constant factor for negative input values 
106254    const  real  negSlope  =  (lower  +  upper ) / 2 ;
107255    if  (inplace )
0 commit comments