Skip to content
This repository was archived by the owner on Feb 2, 2024. It is now read-only.

Commit 81208c7

Browse files
kozlov-alexeyAlexanderKalistratov
authored andcommitted
Add support for Series.operator.add in a new-style (#305)
* Add support for Series.operator.add in a new-style * Applying review comments and addding tests * More comments and refactoring from review * Bugfix in indexes join and minor changes
1 parent ce80fe3 commit 81208c7

File tree

4 files changed

+711
-91
lines changed

4 files changed

+711
-91
lines changed

sdc/datatypes/common_functions.py

Lines changed: 267 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,17 @@
3131
"""
3232

3333
import numpy
34+
import pandas
3435

36+
import numba
3537
from numba import types
3638
from numba.errors import TypingError
3739
from numba.extending import overload
3840
from numba import numpy_support
3941

4042
import sdc
41-
from sdc.str_arr_ext import (string_array_type, num_total_chars, append_string_array_to)
43+
from sdc.str_arr_ext import (string_array_type, num_total_chars, append_string_array_to,
44+
str_arr_is_na, pre_alloc_string_array, str_arr_set_na)
4245

4346

4447
class TypeChecker:
@@ -91,7 +94,7 @@ def check(self, data, accepted_type, name=''):
9194

9295

9396
def has_literal_value(var, value):
94-
'''Used during typing to check that variable var is a Numba literal value equal to value'''
97+
"""Used during typing to check that variable var is a Numba literal value equal to value"""
9598

9699
if not isinstance(var, types.Literal):
97100
return False
@@ -103,7 +106,7 @@ def has_literal_value(var, value):
103106

104107

105108
def has_python_value(var, value):
106-
'''Used during typing to check that variable var was resolved as Python type and has specific value'''
109+
"""Used during typing to check that variable var was resolved as Python type and has specific value"""
107110

108111
if not isinstance(var, type(value)):
109112
return False
@@ -114,13 +117,18 @@ def has_python_value(var, value):
114117
return var == value
115118

116119

120+
def check_index_is_numeric(ty_series):
121+
"""Used during typing to check that series has numeric index"""
122+
return isinstance(ty_series.index, types.Array) and isinstance(ty_series.index.dtype, types.Number)
123+
124+
117125
def hpat_arrays_append(A, B):
118126
pass
119127

120128

121129
@overload(hpat_arrays_append)
122130
def hpat_arrays_append_overload(A, B):
123-
'''Function for appending underlying arrays (A and B) or list/tuple of arrays B to an array A'''
131+
"""Function for appending underlying arrays (A and B) or list/tuple of arrays B to an array A"""
124132

125133
if isinstance(A, types.Array):
126134
if isinstance(B, types.Array):
@@ -131,9 +139,7 @@ def _append_single_numeric_impl(A, B):
131139
elif isinstance(B, (types.UniTuple, types.List)):
132140
# TODO: this heavily relies on B being a homogeneous tuple/list - find a better way
133141
# to resolve common dtype of heterogeneous sequence of arrays
134-
np_dtypes = [numpy_support.as_dtype(A.dtype), numpy_support.as_dtype(B.dtype.dtype)]
135-
np_common_dtype = numpy.find_common_type([], np_dtypes)
136-
numba_common_dtype = numpy_support.from_dtype(np_common_dtype)
142+
numba_common_dtype = find_common_dtype_from_numpy_dtypes([A.dtype, B.dtype.dtype], [])
137143

138144
# TODO: refactor to use numpy.concatenate when Numba supports building a tuple at runtime
139145
def _append_list_numeric_impl(A, B):
@@ -181,3 +187,257 @@ def _append_list_string_array_impl(A, B):
181187
return new_data
182188

183189
return _append_list_string_array_impl
190+
191+
192+
@numba.njit
193+
def _hpat_ensure_array_capacity(new_size, arr):
194+
""" Function ensuring that the size of numpy array is at least as specified
195+
Returns newly allocated array of bigger size with copied elements if existing size is less than requested
196+
"""
197+
198+
k = len(arr)
199+
if k >= new_size:
200+
return arr
201+
202+
n = k
203+
while n < new_size:
204+
n = 2 * n
205+
res = numpy.empty(n, arr.dtype)
206+
res[:k] = arr[:k]
207+
return res
208+
209+
210+
def find_common_dtype_from_numpy_dtypes(array_types, scalar_types):
211+
"""Used to find common numba dtype for a sequences of numba dtypes each representing some numpy dtype"""
212+
np_array_dtypes = [numpy_support.as_dtype(dtype) for dtype in array_types]
213+
np_scalar_dtypes = [numpy_support.as_dtype(dtype) for dtype in scalar_types]
214+
np_common_dtype = numpy.find_common_type(np_array_dtypes, np_scalar_dtypes)
215+
numba_common_dtype = numpy_support.from_dtype(np_common_dtype)
216+
217+
return numba_common_dtype
218+
219+
220+
def hpat_join_series_indexes(left, right):
221+
pass
222+
223+
224+
@overload(hpat_join_series_indexes)
225+
def hpat_join_series_indexes_overload(left, right):
226+
"""Function for joining arrays left and right in a way similar to pandas.join 'outer' algorithm"""
227+
228+
# TODO: eliminate code duplication by merging implementations for numeric and StringArray
229+
# requires equivalents of numpy.arsort and _hpat_ensure_array_capacity for StringArrays
230+
if (isinstance(left, types.Array) and isinstance(right, types.Array)):
231+
232+
numba_common_dtype = find_common_dtype_from_numpy_dtypes([left.dtype, right.dtype], [])
233+
if isinstance(numba_common_dtype, types.Number):
234+
235+
def hpat_join_series_indexes_impl(left, right):
236+
237+
# allocate result arrays
238+
lsize = len(left)
239+
rsize = len(right)
240+
est_total_size = int(1.1 * (lsize + rsize))
241+
242+
lidx = numpy.empty(est_total_size, numpy.int64)
243+
ridx = numpy.empty(est_total_size, numpy.int64)
244+
joined = numpy.empty(est_total_size, numba_common_dtype)
245+
246+
# sort arrays saving the old positions
247+
sorted_left = numpy.argsort(left, kind='mergesort')
248+
sorted_right = numpy.argsort(right, kind='mergesort')
249+
250+
i, j, k = 0, 0, 0
251+
while (i < lsize and j < rsize):
252+
joined = _hpat_ensure_array_capacity(k + 1, joined)
253+
lidx = _hpat_ensure_array_capacity(k + 1, lidx)
254+
ridx = _hpat_ensure_array_capacity(k + 1, ridx)
255+
256+
left_index = left[sorted_left[i]]
257+
right_index = right[sorted_right[j]]
258+
259+
if (left_index < right_index):
260+
joined[k] = left_index
261+
lidx[k] = sorted_left[i]
262+
ridx[k] = -1
263+
i += 1
264+
k += 1
265+
elif (left_index > right_index):
266+
joined[k] = right_index
267+
lidx[k] = -1
268+
ridx[k] = sorted_right[j]
269+
j += 1
270+
k += 1
271+
else:
272+
# find ends of sequences of equal index values in left and right
273+
ni, nj = i, j
274+
while (ni < lsize and left[sorted_left[ni]] == left_index):
275+
ni += 1
276+
while (nj < rsize and right[sorted_right[nj]] == right_index):
277+
nj += 1
278+
279+
# join the blocks found into results
280+
for s in numpy.arange(i, ni, 1):
281+
block_size = nj - j
282+
to_joined = numpy.repeat(left_index, block_size)
283+
to_lidx = numpy.repeat(sorted_left[s], block_size)
284+
to_ridx = numpy.array([sorted_right[k] for k in numpy.arange(j, nj, 1)], numpy.int64)
285+
286+
joined = _hpat_ensure_array_capacity(k + block_size, joined)
287+
lidx = _hpat_ensure_array_capacity(k + block_size, lidx)
288+
ridx = _hpat_ensure_array_capacity(k + block_size, ridx)
289+
290+
joined[k:k + block_size] = to_joined
291+
lidx[k:k + block_size] = to_lidx
292+
ridx[k:k + block_size] = to_ridx
293+
k += block_size
294+
i = ni
295+
j = nj
296+
297+
# fill the end of joined with remaining part of left or right
298+
if i < lsize:
299+
block_size = lsize - i
300+
joined = _hpat_ensure_array_capacity(k + block_size, joined)
301+
lidx = _hpat_ensure_array_capacity(k + block_size, lidx)
302+
ridx = _hpat_ensure_array_capacity(k + block_size, ridx)
303+
ridx[k: k + block_size] = numpy.repeat(-1, block_size)
304+
while i < lsize:
305+
joined[k] = left[sorted_left[i]]
306+
lidx[k] = sorted_left[i]
307+
i += 1
308+
k += 1
309+
310+
elif j < rsize:
311+
block_size = rsize - j
312+
joined = _hpat_ensure_array_capacity(k + block_size, joined)
313+
lidx = _hpat_ensure_array_capacity(k + block_size, lidx)
314+
ridx = _hpat_ensure_array_capacity(k + block_size, ridx)
315+
lidx[k: k + block_size] = numpy.repeat(-1, block_size)
316+
while j < rsize:
317+
joined[k] = right[sorted_right[j]]
318+
ridx[k] = sorted_right[j]
319+
j += 1
320+
k += 1
321+
322+
return joined[:k], lidx[:k], ridx[:k]
323+
324+
return hpat_join_series_indexes_impl
325+
326+
else:
327+
# TODO: support joining indexes with common dtype=object - requires Numba
328+
# support of such numpy arrays in nopython mode, for now just return None
329+
return None
330+
331+
elif (left == string_array_type and right == string_array_type):
332+
333+
def hpat_join_series_indexes_impl(left, right):
334+
335+
# allocate result arrays
336+
lsize = len(left)
337+
rsize = len(right)
338+
est_total_size = int(1.1 * (lsize + rsize))
339+
340+
lidx = numpy.empty(est_total_size, numpy.int64)
341+
ridx = numpy.empty(est_total_size, numpy.int64)
342+
343+
# use Series.sort_values since argsort for StringArrays not implemented
344+
original_left_series = pandas.Series(left)
345+
original_right_series = pandas.Series(right)
346+
347+
# sort arrays saving the old positions
348+
left_series = original_left_series.sort_values(kind='mergesort')
349+
right_series = original_right_series.sort_values(kind='mergesort')
350+
sorted_left = left_series._index
351+
sorted_right = right_series._index
352+
353+
i, j, k = 0, 0, 0
354+
while (i < lsize and j < rsize):
355+
lidx = _hpat_ensure_array_capacity(k + 1, lidx)
356+
ridx = _hpat_ensure_array_capacity(k + 1, ridx)
357+
358+
left_index = left[sorted_left[i]]
359+
right_index = right[sorted_right[j]]
360+
361+
if (left_index < right_index):
362+
lidx[k] = sorted_left[i]
363+
ridx[k] = -1
364+
i += 1
365+
k += 1
366+
elif (left_index > right_index):
367+
lidx[k] = -1
368+
ridx[k] = sorted_right[j]
369+
j += 1
370+
k += 1
371+
else:
372+
# find ends of sequences of equal index values in left and right
373+
ni, nj = i, j
374+
while (ni < lsize and left[sorted_left[ni]] == left_index):
375+
ni += 1
376+
while (nj < rsize and right[sorted_right[nj]] == right_index):
377+
nj += 1
378+
379+
# join the blocks found into results
380+
for s in numpy.arange(i, ni, 1):
381+
block_size = nj - j
382+
to_lidx = numpy.repeat(sorted_left[s], block_size)
383+
to_ridx = numpy.array([sorted_right[k] for k in numpy.arange(j, nj, 1)], numpy.int64)
384+
385+
lidx = _hpat_ensure_array_capacity(k + block_size, lidx)
386+
ridx = _hpat_ensure_array_capacity(k + block_size, ridx)
387+
388+
lidx[k:k + block_size] = to_lidx
389+
ridx[k:k + block_size] = to_ridx
390+
k += block_size
391+
i = ni
392+
j = nj
393+
394+
# fill the end of joined with remaining part of left or right
395+
if i < lsize:
396+
block_size = lsize - i
397+
lidx = _hpat_ensure_array_capacity(k + block_size, lidx)
398+
ridx = _hpat_ensure_array_capacity(k + block_size, ridx)
399+
ridx[k: k + block_size] = numpy.repeat(-1, block_size)
400+
while i < lsize:
401+
lidx[k] = sorted_left[i]
402+
i += 1
403+
k += 1
404+
405+
elif j < rsize:
406+
block_size = rsize - j
407+
lidx = _hpat_ensure_array_capacity(k + block_size, lidx)
408+
ridx = _hpat_ensure_array_capacity(k + block_size, ridx)
409+
lidx[k: k + block_size] = numpy.repeat(-1, block_size)
410+
while j < rsize:
411+
ridx[k] = sorted_right[j]
412+
j += 1
413+
k += 1
414+
415+
# count total number of characters and allocate joined array
416+
total_joined_size = k
417+
num_chars_in_joined = 0
418+
for i in numpy.arange(total_joined_size):
419+
if lidx[i] != -1:
420+
num_chars_in_joined += len(left[lidx[i]])
421+
elif ridx[i] != -1:
422+
num_chars_in_joined += len(right[ridx[i]])
423+
424+
joined = pre_alloc_string_array(total_joined_size, num_chars_in_joined)
425+
426+
# iterate over joined and fill it with indexes using lidx and ridx indexers
427+
for i in numpy.arange(total_joined_size):
428+
if lidx[i] != -1:
429+
joined[i] = left[lidx[i]]
430+
if (str_arr_is_na(left, lidx[i])):
431+
str_arr_set_na(joined, i)
432+
elif ridx[i] != -1:
433+
joined[i] = right[ridx[i]]
434+
if (str_arr_is_na(right, ridx[i])):
435+
str_arr_set_na(joined, i)
436+
else:
437+
str_arr_set_na(joined, i)
438+
439+
return joined, lidx, ridx
440+
441+
return hpat_join_series_indexes_impl
442+
443+
return None

0 commit comments

Comments
 (0)