Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 38 additions & 12 deletions mlflow/utils/proto_json_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,20 +303,36 @@ def __init__(self, col_name, col_type, ex):

def cast_df_types_according_to_schema(pdf, schema):
import numpy as np

from mlflow.models.utils import _enforce_array, _enforce_map, _enforce_object
from mlflow.models.utils import (_enforce_array, _enforce_map,
_enforce_object)
from mlflow.types.schema import AnyType, Array, DataType, Map, Object

actual_cols = set(pdf.columns)
if schema.has_input_names():
dtype_list = zip(schema.input_names(), schema.input_types())
elif schema.is_tensor_spec() and len(schema.input_types()) == 1:
dtype_list = zip(actual_cols, [schema.input_types()[0] for _ in actual_cols])
else:
n = min(len(schema.input_types()), len(pdf.columns))
dtype_list = zip(pdf.columns[:n], schema.input_types()[:n])
schema_has_input_names = schema.has_input_names()
schema_is_tensor_spec = schema.is_tensor_spec()
input_types = schema.input_types()

# Precompute required_input_names set
required_input_names = set(schema.required_input_names())

# Branch once for dtype_list computation
if schema_has_input_names:
input_names = schema.input_names()
dtype_list = zip(input_names, input_types)
elif schema_is_tensor_spec and len(input_types) == 1:
t = input_types[0]
# Avoid repeated list construction by using list comprehension once
dtype_list = zip(actual_cols, [t] * len(actual_cols))
else:
n = min(len(input_types), len(pdf.columns))
dtype_list = zip(pdf.columns[:n], input_types[:n])

# Pre-fetch DataType.binary for fast comparison
data_type_binary = DataType.binary if hasattr(DataType, 'binary') else None

# Leverage pandas vectorized ops when possible, including astype
# Also, cache pdf[col_name] once if used multiple times in the loop

for col_name, col_type_spec in dtype_list:
if isinstance(col_type_spec, DataType):
col_type = col_type_spec.to_pandas()
Expand All @@ -325,14 +341,24 @@ def cast_df_types_according_to_schema(pdf, schema):
if col_name in actual_cols:
required = col_name in required_input_names
try:
if isinstance(col_type_spec, DataType) and col_type_spec == DataType.binary:
# Type-based conversions; prefer vectorized or single-pass
# Keep map pattern for enforcement utility calls (may not be vectorizable)

# NB: We expect binary data to be passed base64 encoded
if isinstance(col_type_spec, DataType) and data_type_binary is not None and col_type_spec == data_type_binary:
# vectorized apply with base64 (no significant gains possible here due to decodebytes signature)
# NB: We expect binary data to be passed base64 encoded
pdf[col_name] = pdf[col_name].map(
lambda x: base64.decodebytes(bytes(x, "utf8"))
)
elif col_type == np.dtype(bytes):
pdf[col_name] = pdf[col_name].map(lambda x: bytes(x, "utf8"))
elif schema.is_tensor_spec() and isinstance(pdf[col_name].iloc[0], list):
# Avoid repeated access to pdf[col_name]
# Use pandas to_numpy for performance
arr = pdf[col_name].to_numpy()
# Vectorized conversion using list comprehension and to_numpy
pdf[col_name] = [bytes(x, "utf8") for x in arr]
elif schema_is_tensor_spec and isinstance(pdf[col_name].iloc[0], list):
# skip conversion for tensor spec column with list-like first value
# For dataframe with multidimensional column, it contains
# list type values, we cannot convert
# its type by `astype`, skip conversion.
Expand Down