|
| 1 | +import math |
| 2 | +import ast |
| 3 | +from marklogic.vector_util import VectorUtil |
| 4 | +from marklogic import Client |
| 5 | + |
| 6 | +VECTOR = [3.14, 1.59, 2.65] |
| 7 | +EXPECTED_BASE64 = "AAAAAAMAAADD9UhAH4XLP5qZKUA=" |
| 8 | +ACCEPTABLE_DELTA = 0.0001 |
| 9 | + |
| 10 | + |
| 11 | +def test_encode_and_decode_with_python(): |
| 12 | + encoded = VectorUtil.base64_encode(VECTOR) |
| 13 | + assert encoded == EXPECTED_BASE64 |
| 14 | + |
| 15 | + decoded = VectorUtil.base64_decode(encoded) |
| 16 | + assert len(decoded) == len(VECTOR) |
| 17 | + for a, b in zip(decoded, VECTOR): |
| 18 | + assert abs(a - b) < ACCEPTABLE_DELTA |
| 19 | + |
| 20 | + |
| 21 | +def test_decode_known_base64(): |
| 22 | + decoded = VectorUtil.base64_decode(EXPECTED_BASE64) |
| 23 | + assert len(decoded) == len(VECTOR) |
| 24 | + for a, b in zip(decoded, VECTOR): |
| 25 | + assert abs(a - b) < ACCEPTABLE_DELTA |
| 26 | + |
| 27 | + |
| 28 | +def test_encode_and_decode_with_server(client: Client): |
| 29 | + """ |
| 30 | + Encode a vector in Python, decode it on the MarkLogic server, and check the result. |
| 31 | + """ |
| 32 | + encoded = VectorUtil.base64_encode(VECTOR) |
| 33 | + assert encoded == EXPECTED_BASE64 |
| 34 | + |
| 35 | + # Use MarkLogic's eval endpoint to decode the vector on the server |
| 36 | + xquery = f"vec:base64-decode('{encoded}')" |
| 37 | + binary_result = client.eval(xquery=xquery) |
| 38 | + float_list = ast.literal_eval(binary_result[0].decode("utf-8")) |
| 39 | + assert len(float_list) == len(VECTOR) |
| 40 | + for a, b in zip(float_list, VECTOR): |
| 41 | + assert math.isclose(a, b, abs_tol=ACCEPTABLE_DELTA) |
| 42 | + |
| 43 | + |
| 44 | +def test_encode_with_server_and_decode_with_python(client: Client): |
| 45 | + """ |
| 46 | + Encode a vector on the MarkLogic server, decode it in Python, and check the result. |
| 47 | + """ |
| 48 | + xquery = "vec:base64-encode(vec:vector((3.14, 1.59, 2.65)))" |
| 49 | + encoded = client.eval(xquery=xquery)[0] |
| 50 | + assert encoded == EXPECTED_BASE64 |
| 51 | + |
| 52 | + decoded = VectorUtil.base64_decode(encoded) |
| 53 | + assert len(decoded) == len(VECTOR) |
| 54 | + for a, b in zip(decoded, VECTOR): |
| 55 | + assert math.isclose(a, b, abs_tol=ACCEPTABLE_DELTA) |
0 commit comments