Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/source/locale/en/LC_MESSAGES/df.po
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ msgid ""
"PyODPS DataFrame 未来将停止维护。对于新项目,建议使用 `MaxFrame <https://"
"maxframe.readthedocs.io/en/latest/index.html>`_\\ 。"
msgstr ""
"Maintenance of PyODPS DataFrame is discontinued. For new projects, please"
" do not use this feature."
"**Maintenance of PyODPS DataFrame is discontinued. Please do not use this"
" feature in new projects.**"

#: ../../source/df.rst:16
msgid ""
Expand Down
39 changes: 34 additions & 5 deletions odps/src/types_c.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ cdef class StringValidator(TypeValidator):
if s_size <= max_field_size:
return u_val
raise ValueError(
"InvalidData: Length of string(%s) is more than %sM.'" %
(val, max_field_size / (1024 ** 2))
"InvalidData: Byte length of string(%s) is more than %sM.'" %
(s_size, max_field_size / (1024 ** 2))
)


Expand All @@ -118,8 +118,33 @@ cdef class BinaryValidator(TypeValidator):
if s_size <= max_field_size:
return bytes_val
raise ValueError(
"InvalidData: Length of string(%s) is more than %sM.'" %
(val, max_field_size / (1024 ** 2)))
"InvalidData: Byte length of string(%s) is more than %sM.'" %
(s_size, max_field_size / (1024 ** 2)))


cdef class SizeLimitedStringValidator(TypeValidator):
cdef int _size_limit

def __init__(self, int size_limit):
self._size_limit = size_limit

cdef object validate(self, object val, int64_t max_field_size):
cdef:
unicode u_val

if type(val) is bytes or isinstance(val, bytes):
u_val = (<bytes> val).decode("utf-8")
elif type(val) is unicode or isinstance(val, unicode):
u_val = <unicode> val
else:
raise TypeError("Invalid data type: expect bytes or unicode, got %s" % type(val))

if len(u_val) <= self._size_limit:
return u_val
raise ValueError(
"InvalidData: Length of string(%s) is more than %s.'" %
(val, self._size_limit)
)


py_strptime = datetime.strptime
Expand Down Expand Up @@ -431,7 +456,9 @@ cdef object _build_type_validator(int type_id, object data_type):
elif type_id == DOUBLE_TYPE_ID:
return DoubleValidator()
elif type_id == STRING_TYPE_ID:
if options.tunnel.string_as_binary:
if isinstance(data_type, types.SizeLimitedString):
return SizeLimitedStringValidator(data_type.size_limit)
elif options.tunnel.string_as_binary:
return BinaryValidator()
else:
return StringValidator()
Expand Down Expand Up @@ -502,6 +529,8 @@ cdef class SchemaSnapshot:

cdef class BaseRecord:
def __cinit__(self, columns=None, schema=None, values=None, max_field_size=None):
if isinstance(columns, types.Schema):
schema, columns = columns, None
self._c_schema_snapshot = getattr(schema, "_snapshot", None)
if columns is not None:
self._c_columns = columns
Expand Down
5 changes: 4 additions & 1 deletion odps/src/utils_c.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -308,4 +308,7 @@ cpdef inline unicode to_text(s, encoding="utf-8"):
cpdef str to_lower_str(s, encoding="utf-8"):
if s is None:
return None
return to_str(s, encoding).lower()
if _is_py3:
return <str>(to_text(s, encoding).lower())
else:
return <str>(to_binary(s, encoding).lower())
54 changes: 52 additions & 2 deletions odps/tests/test_types.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 1999-2025 Alibaba Group Holding Ltd.
#
Expand Down Expand Up @@ -28,6 +27,7 @@

from .. import options
from .. import types as odps_types
from .. import utils
from ..tests.core import pandas_case, py_and_c


Expand Down Expand Up @@ -274,19 +274,23 @@ def test_composite_types():
comp_type = odps_types.validate_data_type("decimal(10)")
assert isinstance(comp_type, odps_types.Decimal)
assert comp_type.precision == 10
assert comp_type == "decimal(10)"

comp_type = odps_types.validate_data_type("decimal(10, 2)")
assert isinstance(comp_type, odps_types.Decimal)
assert comp_type.precision == 10
assert comp_type.scale == 2
assert comp_type == "decimal(10,2)"

comp_type = odps_types.validate_data_type("varchar(10)")
assert isinstance(comp_type, odps_types.Varchar)
assert comp_type.size_limit == 10
assert comp_type == "varchar(10)"

comp_type = odps_types.validate_data_type("char(20)")
assert isinstance(comp_type, odps_types.Char)
assert comp_type.size_limit == 20
assert comp_type == "char(20)"

with pytest.raises(ValueError) as ex_info:
odps_types.validate_data_type("array")
Expand All @@ -295,6 +299,7 @@ def test_composite_types():
comp_type = odps_types.validate_data_type("array<bigint>")
assert isinstance(comp_type, odps_types.Array)
assert isinstance(comp_type.value_type, odps_types.Bigint)
assert comp_type == "array<bigint>"

with pytest.raises(ValueError) as ex_info:
odps_types.validate_data_type("map")
Expand All @@ -304,12 +309,16 @@ def test_composite_types():
assert isinstance(comp_type, odps_types.Map)
assert isinstance(comp_type.key_type, odps_types.Bigint)
assert isinstance(comp_type.value_type, odps_types.String)
assert comp_type == "map<bigint, string>"

comp_type = odps_types.validate_data_type("struct<abc:int, def:string>")
assert isinstance(comp_type, odps_types.Struct)
assert len(comp_type.field_types) == 2
assert isinstance(comp_type.field_types["abc"], odps_types.Int)
assert isinstance(comp_type.field_types["def"], odps_types.String)
assert comp_type == "struct<abc:int, def:string>"
assert comp_type != "struct<abc:int>"
assert comp_type != "struct<abc:int, uvw:string>"

comp_type = odps_types.validate_data_type(
"struct<abc:int, def:map<bigint, string>, ghi:string>"
Expand Down Expand Up @@ -347,10 +356,15 @@ def test_set_with_cast():
@py_and_c_deco
def test_record_copy():
s = TableSchema.from_lists(["col1"], ["string"])
r = Record(schema=s)
r = Record(s)
r.col1 = "a"

cr = copy.copy(r)
assert cr == r
assert cr.col1 == r.col1

cr = copy.deepcopy(r)
assert cr == r
assert cr.col1 == r.col1


Expand Down Expand Up @@ -515,6 +529,42 @@ def test_validate_struct():
options.struct_as_dict = False


@py_and_c_deco
@pytest.mark.parametrize("use_binary", [False, True])
def test_varchar_size_limit(use_binary):
def _c(s):
if use_binary:
return utils.to_binary(s)
return utils.to_text(s)

s = TableSchema.from_lists(["col1"], ["varchar(3)"])
r = Record(schema=s)
r[0] = _c("123")
r[0] = _c("测试字")
with pytest.raises(ValueError):
r[0] = _c("1234")
with pytest.raises(ValueError):
r[0] = _c("测试字符")


@py_and_c_deco
@pytest.mark.parametrize("use_binary", [False, True])
def test_field_size_limit(use_binary):
def _c(s):
if use_binary:
return utils.to_binary(s)
return utils.to_text(s)

s = TableSchema.from_lists(["str_col", "bin_col"], ["string", "binary"])
r = Record(schema=s, max_field_size=1024)
r[0] = _c("1" * 1024)
r[1] = _c("1" * 1024)
with pytest.raises(ValueError):
r[0] = _c("1" * 1023 + "测")
with pytest.raises(ValueError):
r[1] = _c("1" * 1023 + "测")


def test_validate_decimal():
with pytest.raises(ValueError):
odps_types.Decimal(32, 60)
Expand Down
61 changes: 52 additions & 9 deletions odps/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,30 @@ def __repr__(self):
)

def __hash__(self):
return hash((type(self), self.name, self.type, self.comment, self.label))
return hash(
(
type(self),
self.name,
self.type,
self.comment,
self.label,
self.nullable,
self.generate_expression,
)
)

def __eq__(self, other):
return self is other or all(
getattr(self, attr, None) == getattr(other, attr, None)
for attr in (
"name",
"type",
"comment",
"label",
"nullable",
"generate_expression",
)
)

def to_sql_clause(self, with_column_comments=True):
from .expressions import parse as parse_expression
Expand Down Expand Up @@ -604,6 +627,8 @@ class BaseRecord(object):
__slots__ = "_values", "_columns", "_name_indexes", "_max_field_size"

def __init__(self, columns=None, schema=None, values=None, max_field_size=None):
if isinstance(columns, Schema):
schema, columns = columns, None
if columns is not None:
self._columns = columns
self._name_indexes = {
Expand Down Expand Up @@ -978,6 +1003,17 @@ class Double(BaseFloat):
_type_id = 1


def _check_string_byte_size(val, max_size):
if isinstance(val, six.binary_type):
byt_len = len(val)
else:
byt_len = 4 * len(val)
if byt_len > max_size:
# encode only when necessary
byt_len = len(utils.to_binary(val))
return byt_len <= max_size, byt_len


@_primitive_doc(is_odps2=False)
class String(OdpsPrimitive):
__slots__ = ()
Expand All @@ -997,11 +1033,12 @@ def validate_value(self, val, max_field_size=None):
if val is None and self.nullable:
return True
max_field_size = max_field_size or self._max_length
if len(val) <= max_field_size:
valid, byt_len = _check_string_byte_size(val, max_field_size)
if valid:
return True
raise ValueError(
"InvalidData: Length of string(%s) is more than %sM.'"
% (val, max_field_size / (1024**2))
"InvalidData: Byte length of string(%s) is more than %sM.'"
% (byt_len, max_field_size / (1024**2))
)

def cast_value(self, value, data_type):
Expand Down Expand Up @@ -1104,11 +1141,12 @@ def validate_value(self, val, max_field_size=None):
if val is None and self.nullable:
return True
max_field_size = max_field_size or self._max_length
if len(val) <= max_field_size:
valid, byt_len = _check_string_byte_size(val, max_field_size)
if valid:
return True
raise ValueError(
"InvalidData: Length of binary(%s) is more than %sM.'"
% (val, max_field_size / (1024**2))
"InvalidData: Byte length of binary(%s) is more than %sM.'"
% (byt_len, max_field_size / (1024**2))
)

def cast_value(self, value, data_type):
Expand Down Expand Up @@ -1261,9 +1299,14 @@ def validate_value(self, val, max_field_size=None):
if val is None and self.nullable:
return True
if len(val) <= self.size_limit:
# binary size >= unicode size
return True
elif isinstance(val, six.binary_type):
val = val.decode("utf-8")
if len(val) <= self.size_limit:
return True
raise ValueError(
"InvalidData: Length of string(%d) is more than %sM.'"
"InvalidData: Length of string(%d) is more than %s.'"
% (len(val), self.size_limit)
)

Expand Down Expand Up @@ -1734,7 +1777,7 @@ def _equals(self, other):
isinstance(other, Struct)
and len(self.field_types) == len(other.field_types)
and all(
self.field_types[k] == other.field_types[k]
self.field_types[k] == other.field_types.get(k)
for k in six.iterkeys(self.field_types)
)
)
Expand Down