Skip to content

Commit 15a67d7

Browse files
GuyAv46alonre24
andauthored
String tensor support (#90)
* added support for string tensors, when the data is given by VALUE (and not with BLOB) * updated test.py for setting and getting string tensors by VALUE * added support for tensorset from numpy string array as blob * updated test.py to test tensorset with numpy string array * linting * small fix * Review fixes: Added a comment. Deleted numpy_string2blob and replaced with a single line using join. Deleted utils.recursive_bytetransform_str and sets 'target' to a decode function Co-authored-by: alonre24 <[email protected]>
1 parent 35b1e3f commit 15a67d7

File tree

5 files changed

+34
-10
lines changed

5 files changed

+34
-10
lines changed

redisai/command_builder.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,12 @@ def tensorset(
176176
args = ["AI.TENSORSET", key, dtype, *shape, "BLOB", blob]
177177
elif isinstance(tensor, (list, tuple)):
178178
try:
179-
dtype = utils.dtype_dict[dtype.lower()]
179+
# Numpy 'str' dtype has many different names regarding maximal length in the tensor and more,
180+
# but the all share the 'num' attribute. This is a way to check if a dtype is a kind of string.
181+
if np.dtype(dtype).num == np.dtype("str").num:
182+
dtype = utils.dtype_dict["str"]
183+
else:
184+
dtype = utils.dtype_dict[dtype.lower()]
180185
except KeyError:
181186
raise TypeError(
182187
f"``{dtype}`` is not supported by RedisAI. Currently "

redisai/postprocessor.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,11 @@ def tensorget(res, as_numpy, as_numpy_mutable, meta_only):
4242
mutable=False,
4343
)
4444
else:
45-
target = float if rai_result["dtype"] in ("FLOAT", "DOUBLE") else int
45+
if rai_result["dtype"] == "STRING":
46+
def target(b):
47+
return b.decode()
48+
else:
49+
target = float if rai_result["dtype"] in ("FLOAT", "DOUBLE") else int
4650
utils.recursive_bytetransform(rai_result["values"], target)
4751
return rai_result
4852

redisai/utils.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"uint32": "UINT32",
1616
"uint64": "UINT64",
1717
"bool": "BOOL",
18+
"str": "STRING",
1819
}
1920

2021
allowed_devices = {"CPU", "GPU"}
@@ -24,11 +25,15 @@
2425
def numpy2blob(tensor: np.ndarray) -> tuple:
2526
"""Convert the numpy input from user to `Tensor`."""
2627
try:
27-
dtype = dtype_dict[str(tensor.dtype)]
28+
if tensor.dtype.num == np.dtype("str").num:
29+
dtype = dtype_dict["str"]
30+
blob = "".join([string + "\0" for string in tensor.flat])
31+
else:
32+
dtype = dtype_dict[str(tensor.dtype)]
33+
blob = tensor.tobytes()
2834
except KeyError:
2935
raise TypeError(f"RedisAI doesn't support tensors of type {tensor.dtype}")
3036
shape = tensor.shape
31-
blob = bytes(tensor.data)
3237
return dtype, shape, blob
3338

3439

@@ -38,7 +43,9 @@ def blob2numpy(
3843
"""Convert `BLOB` result from RedisAI to `np.ndarray`."""
3944
mm = {"FLOAT": "float32", "DOUBLE": "float64"}
4045
dtype = mm.get(dtype, dtype.lower())
41-
if mutable:
46+
if dtype == 'string':
47+
a = np.array(value.decode().split('\0')[:-1], dtype='str')
48+
elif mutable:
4249
a = np.fromstring(value, dtype=dtype)
4350
else:
4451
a = np.frombuffer(value, dtype=dtype)

test/test.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,12 @@ def test_set_non_numpy_tensor(self):
117117
self.assertEqual([2, 2], result["shape"])
118118
self.assertEqual("BOOL", result["dtype"])
119119

120+
con.tensorset("x", (12, 'a', 'G', 'four'), dtype="str", shape=(2, 2))
121+
result = con.tensorget("x", as_numpy=False)
122+
self.assertEqual(['12', 'a', 'G', 'four'], result["values"])
123+
self.assertEqual([2, 2], result["shape"])
124+
self.assertEqual("STRING", result["dtype"])
125+
120126
with self.assertRaises(TypeError):
121127
con.tensorset("x", (2, 3, 4, 5), dtype="wrongtype", shape=(2, 2))
122128
con.tensorset("x", (2, 3, 4, 5), dtype="int8", shape=(2, 2))
@@ -156,6 +162,12 @@ def test_numpy_tensor(self):
156162
self.assertEqual(values.dtype, "bool")
157163
self.assertTrue(np.array_equal(values, [True, False]))
158164

165+
input_array = np.array(["a", "bb", "⚓⚓⚓", "d♻d♻"]).reshape((2, 2))
166+
con.tensorset("x", input_array)
167+
values = con.tensorget("x")
168+
self.assertEqual(values.dtype.num, np.dtype("str").num)
169+
self.assertTrue(np.array_equal(values, [['a', 'bb'], ["⚓⚓⚓", "d♻d♻"]]))
170+
159171
input_array = np.array([2, 3])
160172
con.tensorset("x", input_array)
161173
values = con.tensorget("x")
@@ -174,10 +186,6 @@ def test_numpy_tensor(self):
174186
np.put(ret, 0, 1)
175187
self.assertEqual(ret[0], 1)
176188

177-
stringarr = np.array("dummy")
178-
with self.assertRaises(TypeError):
179-
con.tensorset("trying", stringarr)
180-
181189
# AI.MODELSET is deprecated by AI.MODELSTORE.
182190
def test_deprecated_modelset(self):
183191
model_path = os.path.join(MODEL_DIR, "graph.pb")

tox.ini

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ envlist = linters,tests
66
max-complexity = 10
77
ignore = E501,C901
88
srcdir = ./redisai
9-
exclude =.git,.tox,dist,doc,*/__pycache__/*
9+
exclude =.git,.tox,dist,doc,*/__pycache__/*,venv
1010

1111
[testenv:tests]
1212
whitelist_externals = find

0 commit comments

Comments
 (0)