Skip to content

Commit fc3ef02

Browse files
committed
Add pad_sequences util.
1 parent e8c0f87 commit fc3ef02

File tree

3 files changed

+249
-0
lines changed

3 files changed

+249
-0
lines changed

keras_core/testing/test_case.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,14 @@ def assertNotAllClose(self, x1, x2, atol=1e-6, rtol=1e-6, msg=None):
4141
def assertAlmostEqual(self, x1, x2, decimal=3, msg=None):
4242
np.testing.assert_almost_equal(x1, x2, decimal=decimal)
4343

44+
def assertAllEqual(self, x1, x2, msg=None):
45+
self.assertEqual(len(x1), len(x2), msg=msg)
46+
for e1, e2 in zip(x1, x2):
47+
if isinstance(e1, (list, tuple)) or isinstance(e2, (list, tuple)):
48+
self.assertAllEqual(e1, e2, msg=msg)
49+
else:
50+
self.assertEqual(e1, e2, msg=msg)
51+
4452
def assertLen(self, iterable, expected_len, msg=None):
4553
self.assertEqual(len(iterable), expected_len, msg=msg)
4654

keras_core/utils/sequence_utils.py

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
import numpy as np
2+
3+
from keras_core.api_export import keras_core_export
4+
5+
6+
@keras_core_export(
7+
[
8+
"keras_core.utils.pad_sequences",
9+
"keras_core.preprocessing.sequence.pad_sequences",
10+
]
11+
)
12+
def pad_sequences(
13+
sequences,
14+
maxlen=None,
15+
dtype="int32",
16+
padding="pre",
17+
truncating="pre",
18+
value=0.0,
19+
):
20+
"""Pads sequences to the same length.
21+
22+
This function transforms a list (of length `num_samples`)
23+
of sequences (lists of integers)
24+
into a 2D NumPy array of shape `(num_samples, num_timesteps)`.
25+
`num_timesteps` is either the `maxlen` argument if provided,
26+
or the length of the longest sequence in the list.
27+
28+
Sequences that are shorter than `num_timesteps`
29+
are padded with `value` until they are `num_timesteps` long.
30+
31+
Sequences longer than `num_timesteps` are truncated
32+
so that they fit the desired length.
33+
34+
The position where padding or truncation happens is determined by
35+
the arguments `padding` and `truncating`, respectively.
36+
Pre-padding or removing values from the beginning of the sequence is the
37+
default.
38+
39+
>>> sequence = [[1], [2, 3], [4, 5, 6]]
40+
>>> keras_core.utils.pad_sequences(sequence)
41+
array([[0, 0, 1],
42+
[0, 2, 3],
43+
[4, 5, 6]], dtype=int32)
44+
45+
>>> keras_core.utils.pad_sequences(sequence, value=-1)
46+
array([[-1, -1, 1],
47+
[-1, 2, 3],
48+
[ 4, 5, 6]], dtype=int32)
49+
50+
>>> keras_core.utils.pad_sequences(sequence, padding='post')
51+
array([[1, 0, 0],
52+
[2, 3, 0],
53+
[4, 5, 6]], dtype=int32)
54+
55+
>>> keras_core.utils.pad_sequences(sequence, maxlen=2)
56+
array([[0, 1],
57+
[2, 3],
58+
[5, 6]], dtype=int32)
59+
60+
Args:
61+
sequences: List of sequences (each sequence is a list of integers).
62+
maxlen: Optional Int, maximum length of all sequences. If not provided,
63+
sequences will be padded to the length of the longest individual
64+
sequence.
65+
dtype: (Optional, defaults to `"int32"`). Type of the output sequences.
66+
To pad sequences with variable length strings, you can use `object`.
67+
padding: String, "pre" or "post" (optional, defaults to `"pre"`):
68+
pad either before or after each sequence.
69+
truncating: String, "pre" or "post" (optional, defaults to `"pre"`):
70+
remove values from sequences larger than
71+
`maxlen`, either at the beginning or at the end of the sequences.
72+
value: Float or String, padding value. (Optional, defaults to 0.)
73+
74+
Returns:
75+
NumPy array with shape `(len(sequences), maxlen)`
76+
"""
77+
if not hasattr(sequences, "__len__"):
78+
raise ValueError("`sequences` must be iterable.")
79+
num_samples = len(sequences)
80+
81+
lengths = []
82+
sample_shape = ()
83+
flag = True
84+
85+
# take the sample shape from the first non empty sequence
86+
# checking for consistency in the main loop below.
87+
88+
for x in sequences:
89+
try:
90+
lengths.append(len(x))
91+
if flag and len(x):
92+
sample_shape = np.asarray(x).shape[1:]
93+
flag = False
94+
except TypeError as e:
95+
raise ValueError(
96+
"`sequences` must be a list of iterables. "
97+
f"Found non-iterable: {str(x)}"
98+
) from e
99+
100+
if maxlen is None:
101+
maxlen = np.max(lengths)
102+
103+
is_dtype_str = np.issubdtype(dtype, np.str_) or np.issubdtype(
104+
dtype, np.unicode_
105+
)
106+
if isinstance(value, str) and dtype != object and not is_dtype_str:
107+
raise ValueError(
108+
f"`dtype` {dtype} is not compatible with `value`'s type: "
109+
f"{type(value)}\nYou should set `dtype=object` for variable length "
110+
"strings."
111+
)
112+
113+
x = np.full((num_samples, maxlen) + sample_shape, value, dtype=dtype)
114+
for idx, s in enumerate(sequences):
115+
if not len(s):
116+
continue # empty list/array was found
117+
if truncating == "pre":
118+
trunc = s[-maxlen:]
119+
elif truncating == "post":
120+
trunc = s[:maxlen]
121+
else:
122+
raise ValueError(f'Truncating type "{truncating}" not understood')
123+
124+
# check `trunc` has expected shape
125+
trunc = np.asarray(trunc, dtype=dtype)
126+
if trunc.shape[1:] != sample_shape:
127+
raise ValueError(
128+
f"Shape of sample {trunc.shape[1:]} of sequence at "
129+
f"position {idx} is different from expected shape "
130+
f"{sample_shape}"
131+
)
132+
133+
if padding == "post":
134+
x[idx, : len(trunc)] = trunc
135+
elif padding == "pre":
136+
x[idx, -len(trunc) :] = trunc
137+
else:
138+
raise ValueError(f'Padding type "{padding}" not understood')
139+
return x
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
from keras_core import testing
2+
from keras_core.utils import sequence_utils
3+
4+
5+
class PadSequencesTest(testing.TestCase):
6+
def test_pad_sequences(self):
7+
a = [[1], [1, 2], [1, 2, 3]]
8+
9+
# test padding
10+
b = sequence_utils.pad_sequences(a, maxlen=3, padding="pre")
11+
self.assertAllClose(b, [[0, 0, 1], [0, 1, 2], [1, 2, 3]])
12+
b = sequence_utils.pad_sequences(a, maxlen=3, padding="post")
13+
self.assertAllClose(b, [[1, 0, 0], [1, 2, 0], [1, 2, 3]])
14+
15+
# test truncating
16+
b = sequence_utils.pad_sequences(a, maxlen=2, truncating="pre")
17+
self.assertAllClose(b, [[0, 1], [1, 2], [2, 3]])
18+
b = sequence_utils.pad_sequences(a, maxlen=2, truncating="post")
19+
self.assertAllClose(b, [[0, 1], [1, 2], [1, 2]])
20+
21+
# test value
22+
b = sequence_utils.pad_sequences(a, maxlen=3, value=1)
23+
self.assertAllClose(b, [[1, 1, 1], [1, 1, 2], [1, 2, 3]])
24+
25+
def test_pad_sequences_str(self):
26+
a = [["1"], ["1", "2"], ["1", "2", "3"]]
27+
28+
# test padding
29+
b = sequence_utils.pad_sequences(
30+
a, maxlen=3, padding="pre", value="pad", dtype=object
31+
)
32+
self.assertAllEqual(
33+
b, [["pad", "pad", "1"], ["pad", "1", "2"], ["1", "2", "3"]]
34+
)
35+
b = sequence_utils.pad_sequences(
36+
a, maxlen=3, padding="post", value="pad", dtype="<U3"
37+
)
38+
self.assertAllEqual(
39+
b, [["1", "pad", "pad"], ["1", "2", "pad"], ["1", "2", "3"]]
40+
)
41+
42+
# test truncating
43+
b = sequence_utils.pad_sequences(
44+
a, maxlen=2, truncating="pre", value="pad", dtype=object
45+
)
46+
self.assertAllEqual(b, [["pad", "1"], ["1", "2"], ["2", "3"]])
47+
b = sequence_utils.pad_sequences(
48+
a, maxlen=2, truncating="post", value="pad", dtype="<U3"
49+
)
50+
self.assertAllEqual(b, [["pad", "1"], ["1", "2"], ["1", "2"]])
51+
52+
with self.assertRaisesRegex(
53+
ValueError, "`dtype` int32 is not compatible with "
54+
):
55+
sequence_utils.pad_sequences(
56+
a, maxlen=2, truncating="post", value="pad"
57+
)
58+
59+
def test_pad_sequences_vector(self):
60+
a = [[[1, 1]], [[2, 1], [2, 2]], [[3, 1], [3, 2], [3, 3]]]
61+
62+
# test padding
63+
b = sequence_utils.pad_sequences(a, maxlen=3, padding="pre")
64+
self.assertAllClose(
65+
b,
66+
[
67+
[[0, 0], [0, 0], [1, 1]],
68+
[[0, 0], [2, 1], [2, 2]],
69+
[[3, 1], [3, 2], [3, 3]],
70+
],
71+
)
72+
b = sequence_utils.pad_sequences(a, maxlen=3, padding="post")
73+
self.assertAllClose(
74+
b,
75+
[
76+
[[1, 1], [0, 0], [0, 0]],
77+
[[2, 1], [2, 2], [0, 0]],
78+
[[3, 1], [3, 2], [3, 3]],
79+
],
80+
)
81+
82+
# test truncating
83+
b = sequence_utils.pad_sequences(a, maxlen=2, truncating="pre")
84+
self.assertAllClose(
85+
b, [[[0, 0], [1, 1]], [[2, 1], [2, 2]], [[3, 2], [3, 3]]]
86+
)
87+
88+
b = sequence_utils.pad_sequences(a, maxlen=2, truncating="post")
89+
self.assertAllClose(
90+
b, [[[0, 0], [1, 1]], [[2, 1], [2, 2]], [[3, 1], [3, 2]]]
91+
)
92+
93+
# test value
94+
b = sequence_utils.pad_sequences(a, maxlen=3, value=1)
95+
self.assertAllClose(
96+
b,
97+
[
98+
[[1, 1], [1, 1], [1, 1]],
99+
[[1, 1], [2, 1], [2, 2]],
100+
[[3, 1], [3, 2], [3, 3]],
101+
],
102+
)

0 commit comments

Comments
 (0)