Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: suppport types_mapper #2082

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions google/cloud/bigquery/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ def default_converter(value, field):
return converter(resource, field)


def _row_tuple_from_json(row, schema):
def _row_tuple_from_json(row, schema, types_mapper=None):
"""Convert JSON row data to row with appropriate types.

Note: ``row['f']`` and ``schema`` are presumed to be of the same length.
Expand All @@ -406,7 +406,7 @@ def _row_tuple_from_json(row, schema):
"""
from google.cloud.bigquery.schema import _to_schema_fields

schema = _to_schema_fields(schema)
schema = _to_schema_fields(schema, types_mapper)

row_data = []
for field, cell in zip(schema, row["f"]):
Expand Down
4 changes: 2 additions & 2 deletions google/cloud/bigquery/_pandas_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,7 +711,7 @@ def _row_iterator_page_to_arrow(page, column_names, arrow_types):
return pyarrow.RecordBatch.from_arrays(arrays, names=column_names)


def download_arrow_row_iterator(pages, bq_schema):
def download_arrow_row_iterator(pages, bq_schema, types_mapper=None):
"""Use HTTP JSON RowIterator to construct an iterable of RecordBatches.

Args:
Expand All @@ -726,7 +726,7 @@ def download_arrow_row_iterator(pages, bq_schema):
:class:`pyarrow.RecordBatch`
The next page of records as a ``pyarrow`` record batch.
"""
bq_schema = schema._to_schema_fields(bq_schema)
bq_schema = schema._to_schema_fields(bq_schema, types_mapper)
column_names = bq_to_arrow_schema(bq_schema) or [field.name for field in bq_schema]
arrow_types = [bq_to_arrow_data_type(field) for field in bq_schema]

Expand Down
20 changes: 20 additions & 0 deletions google/cloud/bigquery/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import typing
from typing import (
Any,
Callable,
Dict,
IO,
Iterable,
Expand Down Expand Up @@ -219,6 +220,9 @@ class Client(ClientWithProject):
client_options (Optional[Union[google.api_core.client_options.ClientOptions, Dict]]):
Client options used to set user options on the client. API Endpoint
should be set through client_options.
types_mapper (typing.Callable):
Client options used to set user options on the client. API Endpoint
should be set through client_options.

Raises:
google.auth.exceptions.DefaultCredentialsError:
Expand All @@ -239,6 +243,8 @@ def __init__(
default_load_job_config=None,
client_info=None,
client_options=None,
*,
types_mapper=None,
) -> None:
super(Client, self).__init__(
project=project,
Expand Down Expand Up @@ -275,6 +281,9 @@ def __init__(
# Use property setter so validation can run.
self.default_query_job_config = default_query_job_config

# Client level types mapper setting.
self._types_mapper = types_mapper

@property
def location(self):
"""Default location for jobs / datasets / tables."""
Expand Down Expand Up @@ -308,6 +317,17 @@ def default_load_job_config(self):
def default_load_job_config(self, value: LoadJobConfig):
self._default_load_job_config = copy.deepcopy(value)


@property
def types_mapper(self):
"""TODO: add docstring
"""
return self._types_mapper

@types_mapper.setter
def types_mapper(self, value: Optional[Callable]):
self._types_mapper = value

def close(self):
"""Close the underlying transport objects, releasing system resources.

Expand Down
23 changes: 16 additions & 7 deletions google/cloud/bigquery/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Schemas for BigQuery tables / queries."""

import collections
import copy
import enum
from typing import Any, Dict, Iterable, Optional, Union, cast

Expand Down Expand Up @@ -473,7 +474,7 @@ def _build_schema_resource(fields):
return [field.to_api_repr() for field in fields]


def _to_schema_fields(schema):
def _to_schema_fields(schema, types_mapper=None):
"""Coerce `schema` to a list of schema field instances.

Args:
Expand All @@ -493,17 +494,25 @@ def _to_schema_fields(schema):
sequence is not a :class:`~google.cloud.bigquery.schema.SchemaField`
instance or a compatible mapping representation of the field.
"""
schema_fields = []
for field in schema:
if not isinstance(field, (SchemaField, collections.abc.Mapping)):
if isinstance(field, SchemaField):
current_field = copy.deepcopy(field)
field_name = field.name
elif isinstance(field, collections.abc.Mapping):
current_field = SchemaField.from_api_repr(field)
field_name = field["name"]
else:
raise ValueError(
"Schema items must either be fields or compatible "
"mapping representations."
)
)

if types_mapper and types_mapper(field_name):
current_field._properties["type"] = types_mapper(field_name)

return [
field if isinstance(field, SchemaField) else SchemaField.from_api_repr(field)
for field in schema
]
schema_fields.append(current_field)
return schema_fields


class PolicyTagList(object):
Expand Down
24 changes: 19 additions & 5 deletions google/cloud/bigquery/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1594,19 +1594,29 @@ def __init__(
project: Optional[str] = None,
num_dml_affected_rows: Optional[int] = None,
):
if client:
types_mapper = client.types_mapper
else:
types_mapper = None

if types_mapper:
_item_to_row_with_mapper = functools.partial(_item_to_row, types_mapper=types_mapper)
else:
_item_to_row_with_mapper = _item_to_row

super(RowIterator, self).__init__(
client,
api_request,
path,
item_to_value=_item_to_row,
item_to_value=_item_to_row_with_mapper,
items_key="rows",
page_token=page_token,
max_results=max_results,
extra_params=extra_params,
page_start=_rows_page_start,
next_token="pageToken",
)
schema = _to_schema_fields(schema)
schema = _to_schema_fields(schema, types_mapper)
self._field_to_index = _helpers._field_to_index_mapping(schema)
self._page_size = page_size
self._preserve_order = False
Expand Down Expand Up @@ -1871,8 +1881,12 @@ def to_arrow_iterable(
max_queue_size=max_queue_size,
max_stream_count=max_stream_count,
)
if self.client is not None:
types_mapper = self.client.types_mapper
else:
types_mapper = None
tabledata_list_download = functools.partial(
_pandas_helpers.download_arrow_row_iterator, iter(self.pages), self.schema
_pandas_helpers.download_arrow_row_iterator, iter(self.pages), self.schema, types_mapper=types_mapper,
)
return self._to_page_iterable(
bqstorage_download,
Expand Down Expand Up @@ -3262,7 +3276,7 @@ def from_api_repr(cls, resource: Dict[str, Any]) -> "TableConstraints":
return cls(primary_key, foreign_keys)


def _item_to_row(iterator, resource):
def _item_to_row(iterator, resource, types_mapper=None):
"""Convert a JSON row to the native object.

.. note::
Expand All @@ -3279,7 +3293,7 @@ def _item_to_row(iterator, resource):
google.cloud.bigquery.table.Row: The next row in the page.
"""
return Row(
_helpers._row_tuple_from_json(resource, iterator.schema),
_helpers._row_tuple_from_json(resource, iterator.schema, types_mapper),
iterator._field_to_index,
)

Expand Down
6 changes: 4 additions & 2 deletions tests/unit/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,10 @@
def _mock_client():
from google.cloud.bigquery import client

mock_client = mock.create_autospec(client.Client)
mock_client.project = "my-project"
mock_client = client.Client(project="my-project")
mock_client._ensure_bqstorage_client = mock.MagicMock(
mock_client._ensure_bqstorage_client,
)
return mock_client


Expand Down