Skip to content

implementation of parquet reader based on arrow #2136

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions WORKSPACE
Original file line number Diff line number Diff line change
@@ -54,7 +54,6 @@ http_archive(
sha256 = "14bf9bf97431b890e0ae5dca8f8904841d4883b8596a7108a42f5700ae58d711",
strip_prefix = "google-cloud-cpp-1.21.0",
urls = [
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/googleapis/google-cloud-cpp/archive/v1.21.0.tar.gz",
"https://github.com/googleapis/google-cloud-cpp/archive/v1.21.0.tar.gz",
],
)
@@ -174,11 +173,12 @@ http_archive(
http_archive(
name = "arrow",
build_file = "//third_party:arrow.BUILD",
sha256 = "57e13c62f27b710e1de54fd30faed612aefa22aa41fa2c0c3bacd204dd18a8f3",
strip_prefix = "arrow-apache-arrow-7.0.0",
patch_cmds = ["""sed -i.bak '24i\\'$'\\n#undef ARROW_WITH_OPENTELEMETRY\\n' cpp/src/arrow/util/tracing_internal.h"""],
sha256 = "19ece12de48e51ce4287d2dee00dc358fbc5ff02f41629d16076f77b8579e272",
strip_prefix = "arrow-apache-arrow-8.0.0",
urls = [
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/apache/arrow/archive/apache-arrow-7.0.0.tar.gz",
"https://github.com/apache/arrow/archive/apache-arrow-7.0.0.tar.gz",
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/apache/arrow/archive/apache-arrow-8.0.0.tar.gz",
"https://github.com/apache/arrow/archive/apache-arrow-8.0.0.tar.gz",
],
)

3 changes: 3 additions & 0 deletions tensorflow_io/arrow.py
Original file line number Diff line number Diff line change
@@ -17,6 +17,7 @@
@@ArrowDataset
@@ArrowFeatherDataset
@@ArrowStreamDataset
@@ArrowParquetDataset
@@list_feather_columns
"""

@@ -26,13 +27,15 @@
from tensorflow_io.python.ops.arrow_dataset_ops import ArrowDataset
from tensorflow_io.python.ops.arrow_dataset_ops import ArrowFeatherDataset
from tensorflow_io.python.ops.arrow_dataset_ops import ArrowStreamDataset
from tensorflow_io.python.ops.arrow_dataset_ops import ArrowParquetDataset
from tensorflow_io.python.ops.arrow_dataset_ops import list_feather_columns


_allowed_symbols = [
"ArrowDataset",
"ArrowFeatherDataset",
"ArrowStreamDataset",
"ArrowParquetDataset",
"list_feather_columns",
]

483 changes: 360 additions & 123 deletions tensorflow_io/core/kernels/arrow/arrow_dataset_ops.cc

Large diffs are not rendered by default.

29 changes: 22 additions & 7 deletions tensorflow_io/core/ops/arrow_ops.cc
Original file line number Diff line number Diff line change
@@ -38,6 +38,21 @@ in file format.
buffer_size: Buffer size in bytes
)doc");

REGISTER_OP("IO>ArrowParquetDataset")
.Input("file_paths: string")
.Input("column_names: string")
.Input("columns: int32")
.Input("batch_size: int64")
.Input("batch_mode: string")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetIsStateful()
.SetShapeFn(shape_inference::ScalarShape)
.Doc(R"doc(
Creates a dataset from parquet files.
)doc");

REGISTER_OP("IO>ArrowSerializedDataset")
.Input("serialized_batches: string")
.Input("columns: int32")
@@ -92,7 +107,7 @@ REGISTER_OP("IO>ListFeatherColumns")
.Output("columns: string")
.Output("dtypes: string")
.Output("shapes: int64")
.SetShapeFn([](shape_inference::InferenceContext* c) {
.SetShapeFn([](shape_inference::InferenceContext *c) {
c->set_output(0, c->MakeShape({c->UnknownDim()}));
c->set_output(1, c->MakeShape({c->UnknownDim()}));
c->set_output(2, c->MakeShape({c->UnknownDim(), c->UnknownDim()}));
@@ -105,7 +120,7 @@ REGISTER_OP("IO>FeatherReadableInit")
.Output("components: string")
.Attr("container: string = ''")
.Attr("shared_name: string = ''")
.SetShapeFn([](shape_inference::InferenceContext* c) {
.SetShapeFn([](shape_inference::InferenceContext *c) {
c->set_output(0, c->Scalar());
c->set_output(1, c->MakeShape({}));
return OkStatus();
@@ -116,7 +131,7 @@ REGISTER_OP("IO>FeatherReadableSpec")
.Output("shape: int64")
.Output("dtype: int64")
.Attr("component: string")
.SetShapeFn([](shape_inference::InferenceContext* c) {
.SetShapeFn([](shape_inference::InferenceContext *c) {
c->set_output(0, c->MakeShape({c->UnknownDim()}));
c->set_output(1, c->MakeShape({}));
return OkStatus();
@@ -130,7 +145,7 @@ REGISTER_OP("IO>FeatherReadableRead")
.Attr("component: string")
.Attr("shape: shape")
.Attr("dtype: type")
.SetShapeFn([](shape_inference::InferenceContext* c) {
.SetShapeFn([](shape_inference::InferenceContext *c) {
PartialTensorShape shape;
TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape));
shape_inference::ShapeHandle entry;
@@ -148,7 +163,7 @@ REGISTER_OP("IO>ArrowReadableFromMemoryInit")
.Output("resource: resource")
.Attr("container: string = ''")
.Attr("shared_name: string = ''")
.SetShapeFn([](shape_inference::InferenceContext* c) {
.SetShapeFn([](shape_inference::InferenceContext *c) {
c->set_output(0, c->Scalar());
return OkStatus();
});
@@ -159,7 +174,7 @@ REGISTER_OP("IO>ArrowReadableSpec")
.Input("column_name: string")
.Output("shape: int64")
.Output("dtype: int64")
.SetShapeFn([](shape_inference::InferenceContext* c) {
.SetShapeFn([](shape_inference::InferenceContext *c) {
c->set_output(0, c->MakeShape({c->UnknownDim()}));
c->set_output(1, c->MakeShape({}));
return OkStatus();
@@ -174,7 +189,7 @@ REGISTER_OP("IO>ArrowReadableRead")
.Input("stop: int64")
.Output("value: dtype")
.Attr("dtype: type")
.SetShapeFn([](shape_inference::InferenceContext* c) {
.SetShapeFn([](shape_inference::InferenceContext *c) {
shape_inference::ShapeHandle full;
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(3, &full));
if (!(c->RankKnown(full) && c->Rank(full) > 0)) {
33 changes: 33 additions & 0 deletions tensorflow_io/python/ops/arrow_dataset_ops.py
Original file line number Diff line number Diff line change
@@ -651,6 +651,39 @@ def gen_record_batches():
)


class ArrowParquetDataset(ArrowBaseDataset):
"""An Arrow Dataset for reading record batches from parquet files.
"""

def __init__(
self,
file_paths,
column_names,
columns,
output_types,
output_shapes=None,
batch_size=None,
batch_mode="keep_remainder",
):
file_paths = tf.convert_to_tensor(
file_paths, dtype=dtypes.string, name="file_paths"
)
column_names = tf.convert_to_tensor(
column_names, dtype=dtypes.string, name="column_names"
)
super().__init__(
partial(
core_ops.io_arrow_parquet_dataset,
file_paths,
column_names,
),
columns,
output_types,
output_shapes,
batch_size,
batch_mode,
)

def list_feather_columns(filename, **kwargs):
"""list_feather_columns"""
if not tf.executing_eagerly():
15 changes: 15 additions & 0 deletions test_parquet_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import tensorflow as tf
import tensorflow_io.arrow as arrow_io

dataset = arrow_io.ArrowParquetDataset(
file_paths = ['/home/yye/training-platform/training-platform/bento/apps/demos/chicago_taxi/data/test.parquet'],
column_names=('tips'),
columns=(),
output_types=(tf.float32),
output_shapes=([]),
batch_size=4,
batch_mode='keep_remainder')

# This will iterate over each row of each file provided
for row in dataset:
print(row)
18 changes: 17 additions & 1 deletion third_party/arrow.BUILD
Original file line number Diff line number Diff line change
@@ -37,7 +37,12 @@ cc_library(
[
"cpp/src/arrow/*.cc",
"cpp/src/arrow/array/*.cc",
"cpp/src/arrow/compute/*.cc",
"cpp/src/arrow/compute/exec/*.cc",
"cpp/src/arrow/compute/kernels/*.cc",
"cpp/src/arrow/csv/*.cc",
"cpp/src/arrow/dataset/*.cc",
"cpp/src/arrow/filesystem/*.cc",
"cpp/src/arrow/io/*.cc",
"cpp/src/arrow/ipc/*.cc",
"cpp/src/arrow/json/*.cc",
@@ -46,6 +51,10 @@ cc_library(
"cpp/src/arrow/vendored/optional.hpp",
"cpp/src/arrow/vendored/string_view.hpp",
"cpp/src/arrow/vendored/variant.hpp",
"cpp/src/arrow/vendored/base64.cpp",
"cpp/src/arrow/vendored/datetime/tz.cpp",
"cpp/src/arrow/vendored/uriparser/*.c",
"cpp/src/arrow/vendored/pcg/*.hpp",
"cpp/src/arrow/**/*.h",
"cpp/src/parquet/**/*.h",
"cpp/src/parquet/**/*.cc",
@@ -58,9 +67,11 @@ cc_library(
"cpp/src/**/*_main.cc",
"cpp/src/**/*_nossl.cc",
"cpp/src/**/*_test.cc",
"cpp/src/**/test_*.cc",
"cpp/src/**/*test*.h",
"cpp/src/**/*test*.cc",
"cpp/src/**/*hdfs*.cc",
"cpp/src/**/*fuzz*.cc",
"cpp/src/**/*gcsfs*.cc",
"cpp/src/**/file_to_stream.cc",
"cpp/src/**/stream_to_file.cc",
"cpp/src/arrow/util/bpacking_avx2.cc",
@@ -99,16 +110,21 @@ cc_library(
"PARQUET_STATIC",
"PARQUET_EXPORT=",
"WIN32_LEAN_AND_MEAN",
"ARROW_DS_STATIC",
"URI_STATIC_BUILD",
],
includes = [
"cpp/src",
"cpp/src/generated",
"cpp/src/arrow/vendored/xxhash",
"cpp/thirdparty/flatbuffers/include",
],
textual_hdrs = [
"cpp/src/arrow/vendored/xxhash/xxhash.c",
],
deps = [
"@aws-sdk-cpp//:identity-management",
"@aws-sdk-cpp//:s3",
"@boringssl//:crypto",
"@brotli",
"@bzip2",
55 changes: 55 additions & 0 deletions third_party/aws-sdk-cpp.BUILD
Original file line number Diff line number Diff line change
@@ -163,6 +163,61 @@ cc_library(
],
)

cc_library(
name = "cognito-identity",
srcs = glob([
"aws-cpp-sdk-cognito-identity/source/*.cpp",
"aws-cpp-sdk-cognito-identity/source/model/*.cpp",
]),
hdrs = glob([
"aws-cpp-sdk-cognito-identity/include/aws/cognito-identity/*.h",
"aws-cpp-sdk-cognito-identity/include/aws/cognito-identity/model/*.h",
]),
includes = [
"aws-cpp-sdk-cognito-identity/include",
],
deps = [
":core",
],
)

cc_library(
name = "sts",
srcs = glob([
"aws-cpp-sdk-sts/source/*.cpp",
"aws-cpp-sdk-sts/source/model/*.cpp",
]),
hdrs = glob([
"aws-cpp-sdk-sts/include/aws/sts/*.h",
"aws-cpp-sdk-sts/include/aws/sts/model/*.h",
]),
includes = [
"aws-cpp-sdk-sts/include",
],
deps = [
":core",
],
)

cc_library(
name = "identity-management",
srcs = glob([
"aws-cpp-sdk-identity-management/source/auth/*.cpp",
]),
hdrs = glob([
"aws-cpp-sdk-identity-management/include/aws/identity-management/*.h",
"aws-cpp-sdk-identity-management/include/aws/identity-management/auth/*.h",
]),
includes = [
"aws-cpp-sdk-identity-management/include",
],
deps = [
":cognito-identity",
":core",
":sts",
],
)

genrule(
name = "SDKConfig_h",
outs = [