diff --git a/tensorflow_io/core/BUILD b/tensorflow_io/core/BUILD
index f2902c8df..763840091 100644
--- a/tensorflow_io/core/BUILD
+++ b/tensorflow_io/core/BUILD
@@ -368,6 +368,23 @@ cc_library(
     alwayslink = 1,
 )
 
+cc_library(
+    name = "orc_ops",
+    srcs = [
+        "kernels/orc/orc_kernels.cc",
+        "ops/orc_ops.cc",
+    ],
+    copts = tf_io_copts(),
+    linkstatic = True,
+    deps = [
+        "//tensorflow_io/core:dataset_ops",
+        "@liborc",
+        "@local_config_tf//:libtensorflow_framework",
+        "@local_config_tf//:tf_header_lib",
+    ],
+    alwayslink = 1,
+)
+
 cc_library(
     name = "text_ops",
     srcs = [
@@ -531,19 +548,6 @@ cc_library(
     alwayslink = 1,
 )
 
-cc_library(
-    name = "orc_ops",
-    srcs = [
-    ],
-    copts = tf_io_copts(),
-    linkstatic = True,
-    deps = [
-        "//tensorflow_io/core:dataset_ops",
-        "@liborc",
-    ],
-    alwayslink = 1,
-)
-
 cc_library(
     name = "numpy_ops",
     srcs = [
diff --git a/tensorflow_io/core/kernels/orc/orc_kernels.cc b/tensorflow_io/core/kernels/orc/orc_kernels.cc
new file mode 100644
index 000000000..c3b86ce9c
--- /dev/null
+++ b/tensorflow_io/core/kernels/orc/orc_kernels.cc
@@ -0,0 +1,246 @@
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <ctime>
+#include <iostream>
+#include <orc/Exceptions.hh>
+#include <orc/OrcFile.hh>
+#include <orc/Reader.hh>
+#include <orc/Type.hh>
+
+#include "orc/orc-config.hh"
+#include "tensorflow/core/lib/io/buffered_inputstream.h"
+#include "tensorflow_io/core/kernels/io_interface.h"
+#include "tensorflow_io/core/kernels/io_stream.h"
+
+namespace tensorflow {
+namespace data {
+
+class ORCReadable : public IOReadableInterface {
+ public:
+  ORCReadable(Env* env) : env_(env) {}
+  ~ORCReadable() {}
+  Status Init(const std::vector<string>& input,
+              const std::vector<string>& metadata, const void* memory_data,
+              const int64 memory_size) override {
+    if (input.size() > 1) {
+      return errors::InvalidArgument("more than 1 filename is not supported");
+    }
+    const string& filename = input[0];
+    // read packet data
+    orc::RowReaderOptions row_reader_opts;
+    orc::ReaderOptions reader_opts;
+    std::unique_ptr<orc::Reader> reader =
+        orc::createReader(orc::readFile(filename), reader_opts);
+
+    row_reader_ = reader->createRowReader(row_reader_opts);
+    LOG(INFO) << "ORC file schema:" << reader->getType().toString();
+
+    // Parse columns. We assume the orc record file is a flat array
+    auto row_count = reader->getNumberOfRows();
+    for (uint64_t i = 0; i < reader->getType().getSubtypeCount(); ++i) {
+      auto field_name = reader->getType().getFieldName(i);
+      auto subtype = reader->getType().getSubtype(i);
+      DataType dtype;
+      switch (static_cast<int64_t>(subtype->getKind())) {
+        case orc::SHORT:
+          dtype = DT_INT16;
+          break;
+        case orc::INT:
+          dtype = DT_INT32;
+          break;
+        case orc::LONG:
+          dtype = DT_INT64;
+          break;
+        case orc::STRING:
+          dtype = DT_STRING;
+          break;
+        case orc::DOUBLE:
+          dtype = DT_DOUBLE;
+          break;
+        case orc::FLOAT:
+          dtype = DT_FLOAT;
+          break;
+        default:
+          return errors::InvalidArgument("data type is not supported: ",
+                                         subtype->toString());
+      }
+      columns_.push_back(field_name);
+      shapes_.push_back(TensorShape({static_cast<int64>(row_count)}));
+      dtypes_.push_back(dtype);
+      columns_index_[field_name] = i;
+      tensors_.emplace_back(
+          Tensor(dtype, TensorShape({static_cast<int64>(row_count)})));
+    }
+    // Fill in the values
+    std::unique_ptr<orc::ColumnVectorBatch> batch =
+        row_reader_->createRowBatch(10);
+    auto* fields = dynamic_cast<orc::StructVectorBatch*>(batch.get());
+    int64_t record_index = 0;
+// Template type conversions between ORC and TensorFlow DT
+#define PROCESS_TYPE(VTYPE, VDTYPE, TDTYPE)                                   \
+  {                                                                           \
+    auto* col = dynamic_cast<VTYPE>(fields->fields[column_index]);            \
+    VDTYPE* buffer1 = col->data.data();                                       \
+    tensors_[column_index].flat<TDTYPE>()(record_index) = (TDTYPE)buffer1[r]; \
+  }
+    while (row_reader_->next(*batch)) {
+      for (uint32_t r = 0; r < batch->numElements; ++r) {
+        for (size_t column_index = 0; column_index < columns_.size();
+             column_index++) {
+          switch (dtypes_[column_index]) {
+            case DT_DOUBLE:
+              PROCESS_TYPE(orc::DoubleVectorBatch*, double, double);
+              break;
+            case DT_FLOAT:
+              PROCESS_TYPE(orc::DoubleVectorBatch*, double, float);
+              break;
+            case DT_INT16:
+              PROCESS_TYPE(orc::LongVectorBatch*, int64, int16);
+              break;
+            case DT_INT32:
+              PROCESS_TYPE(orc::LongVectorBatch*, int64, int32);
+              break;
+            case DT_INT64:
+              PROCESS_TYPE(orc::LongVectorBatch*, int64, int64);
+              break;
+            case DT_STRING: {
+              auto* string_col = dynamic_cast<orc::StringVectorBatch*>(
+                  fields->fields[column_index]);
+              char** buffer = string_col->data.data();
+              int64_t* lengths = string_col->length.data();
+              tensors_[column_index].flat<tstring>()(record_index) =
+                  std::string(buffer[r], lengths[r]);
+              break;
+            }
+            default:
+              return errors::InvalidArgument(
+                  "data type is not supported: ",
+                  DataTypeString(dtypes_[column_index]));
+          }
+        }
+        record_index++;
+      }
+    }
+
+    return Status::OK();
+  }
+
+  Status Read(const int64 start, const int64 stop, const string& component,
+              int64* record_read, Tensor* value, Tensor* label) override {
+    if (columns_index_.find(component) == columns_index_.end()) {
+      return errors::InvalidArgument("component ", component, " is invalid");
+    }
+    int64 column_index = columns_index_[component];
+
+    (*record_read) = 0;
+    if (start >= shapes_[column_index].dim_size(0)) {
+      return Status::OK();
+    }
+    const string& column = component;
+    int64 element_start = start < shapes_[column_index].dim_size(0)
+                              ? start
+                              : shapes_[column_index].dim_size(0);
+    int64 element_stop = stop < shapes_[column_index].dim_size(0)
+                             ? stop
+                             : shapes_[column_index].dim_size(0);
+    if (element_start > element_stop) {
+      return errors::InvalidArgument("dataset ", column,
+                                     " selection is out of boundary");
+    }
+    if (element_start == element_stop) {
+      return Status::OK();
+    }
+
+#define PROCESS_VALUE(VTYPE)                            \
+  {                                                     \
+    value->flat<VTYPE>().data()[i] =                    \
+        tensors_[column_index].flat<VTYPE>().data()[i]; \
+  }
+    for (int i = element_start; i < element_stop; i++) {
+      switch (dtypes_[column_index]) {
+        case DT_DOUBLE:
+          PROCESS_VALUE(double);
+          break;
+        case DT_FLOAT:
+          PROCESS_VALUE(float);
+          break;
+        case DT_INT16:
+          PROCESS_VALUE(int16);
+          break;
+        case DT_INT32:
+          PROCESS_VALUE(int32);
+          break;
+        case DT_INT64:
+          PROCESS_VALUE(int64);
+          break;
+        case DT_STRING: {
+          PROCESS_VALUE(tstring);
+          break;
+        }
+        default:
+          return errors::InvalidArgument("data type is not supported: ",
+                                         DataTypeString(dtypes_[column_index]));
+      }
+    }
+    (*record_read) = element_stop - element_start;
+
+    return Status::OK();
+  }
+
+  Status Components(std::vector<string>* components) override {
+    components->clear();
+    for (size_t i = 0; i < columns_.size(); i++) {
+      components->push_back(columns_[i]);
+    }
+    return Status::OK();
+  }
+
+  Status Spec(const string& component, PartialTensorShape* shape,
+              DataType* dtype, bool label) override {
+    if (columns_index_.find(component) == columns_index_.end()) {
+      return errors::InvalidArgument("component ", component, " is invalid");
+    }
+    int64 column_index = columns_index_[component];
+    *shape = shapes_[column_index];
+    *dtype = dtypes_[column_index];
+    return Status::OK();
+  }
+
+  string DebugString() const override {
+    mutex_lock l(mu_);
+    return strings::StrCat("ORCReadable");
+  }
+
+ private:
+  mutable mutex mu_;
+  Env* env_ TF_GUARDED_BY(mu_);
+  std::unique_ptr<SizedRandomAccessFile> file_ TF_GUARDED_BY(mu_);
+  std::unique_ptr<orc::RowReader> row_reader_ TF_GUARDED_BY(mu_);
+  std::vector<Tensor> tensors_;
+
+  std::vector<DataType> dtypes_;
+  std::vector<TensorShape> shapes_;
+  std::vector<string> columns_;
+  std::unordered_map<string, int64> columns_index_;
+};
+REGISTER_KERNEL_BUILDER(Name("IO>ORCReadableInit").Device(DEVICE_CPU),
+                        IOInterfaceInitOp<ORCReadable>);
+REGISTER_KERNEL_BUILDER(Name("IO>ORCReadableSpec").Device(DEVICE_CPU),
+                        IOInterfaceSpecOp<ORCReadable>);
+REGISTER_KERNEL_BUILDER(Name("IO>ORCReadableRead").Device(DEVICE_CPU),
+                        IOReadableReadOp<ORCReadable>);
+}  // namespace data
+}  // namespace tensorflow
\ No newline at end of file
diff --git a/tensorflow_io/core/ops/orc_ops.cc b/tensorflow_io/core/ops/orc_ops.cc
new file mode 100644
index 000000000..9bc292459
--- /dev/null
+++ b/tensorflow_io/core/ops/orc_ops.cc
@@ -0,0 +1,60 @@
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+namespace tensorflow {
+REGISTER_OP("IO>ORCReadableInit")
+    .Input("input: string")
+    .Output("resource: resource")
+    .Output("components: string")
+    .Attr("container: string = ''")
+    .Attr("shared_name: string = ''")
+    .SetShapeFn([](shape_inference::InferenceContext* c) {
+      c->set_output(0, c->Scalar());
+      c->set_output(1, c->MakeShape({}));
+      return Status::OK();
+    });
+
+REGISTER_OP("IO>ORCReadableSpec")
+    .Input("input: resource")
+    .Output("shape: int64")
+    .Output("dtype: int64")
+    .Attr("component: string")
+    .SetShapeFn([](shape_inference::InferenceContext* c) {
+      c->set_output(0, c->MakeShape({c->UnknownDim()}));
+      c->set_output(1, c->MakeShape({}));
+      return Status::OK();
+    });
+
+REGISTER_OP("IO>ORCReadableRead")
+    .Input("input: resource")
+    .Input("start: int64")
+    .Input("stop: int64")
+    .Output("value: dtype")
+    .Attr("component: string")
+    .Attr("shape: shape")
+    .Attr("dtype: type")
+    .SetShapeFn([](shape_inference::InferenceContext* c) {
+      PartialTensorShape shape;
+      TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape));
+      shape_inference::ShapeHandle entry;
+      TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape, &entry));
+      c->set_output(0, entry);
+      return Status::OK();
+    });
+}  // namespace tensorflow
\ No newline at end of file
diff --git a/tensorflow_io/core/python/ops/io_dataset.py b/tensorflow_io/core/python/ops/io_dataset.py
index 84720994d..de5ed0f0d 100644
--- a/tensorflow_io/core/python/ops/io_dataset.py
+++ b/tensorflow_io/core/python/ops/io_dataset.py
@@ -26,6 +26,7 @@
 from tensorflow_io.core.python.ops import parquet_dataset_ops
 from tensorflow_io.core.python.ops import pcap_dataset_ops
 from tensorflow_io.core.python.ops import mnist_dataset_ops
+from tensorflow_io.core.python.ops import orc_dataset_ops
 
 
 class IODataset(io_dataset_ops._IODataset):  # pylint: disable=protected-access
@@ -308,6 +309,21 @@ def from_pcap(cls, filename, **kwargs):
         with tf.name_scope(kwargs.get("name", "IOFromPcap")):
             return pcap_dataset_ops.PcapIODataset(filename, internal=True, **kwargs)
 
+    @classmethod
+    def from_orc(cls, filename, **kwargs):
+        """Creates an `IODataset` from an ORC file.
+
+        Args:
+          filename: A string, the filename of an ORC file.
+          name: A name prefix for the IOTensor (optional).
+
+        Returns:
+          A `IODataset`.
+
+        """
+        with tf.name_scope(kwargs.get("name", "IOFromORC")):
+            return orc_dataset_ops.ORCIODataset(filename, internal=True, **kwargs)
+
 
 class StreamIODataset(
     io_dataset_ops._StreamIODataset
diff --git a/tensorflow_io/core/python/ops/orc_dataset_ops.py b/tensorflow_io/core/python/ops/orc_dataset_ops.py
new file mode 100644
index 000000000..05425f3de
--- /dev/null
+++ b/tensorflow_io/core/python/ops/orc_dataset_ops.py
@@ -0,0 +1,102 @@
+# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""ORCDataset"""
+
+import sys
+import uuid
+
+import tensorflow as tf
+from tensorflow_io.core.python.ops import core_ops
+
+
+class _ORCIODatasetFunction:
+    def __init__(self, function, resource, component, shape, dtype):
+        self._function = function
+        self._resource = resource
+        self._component = component
+        self._shape = tf.TensorShape([None]).concatenate(shape[1:])
+        self._dtype = dtype
+
+    def __call__(self, start, stop):
+        return self._function(
+            self._resource,
+            start=start,
+            stop=stop,
+            component=self._component,
+            shape=self._shape,
+            dtype=self._dtype,
+        )
+
+
+class ORCIODataset(tf.data.Dataset):
+    """ORCIODataset"""
+
+    def __init__(self, filename, columns=None, internal=True, **kwargs):
+        if not internal:
+            raise ValueError(
+                "ORCIODataset constructor is private; please use one "
+                "of the factory methods instead (e.g., "
+                "IODataset.from_orc())"
+            )
+        with tf.name_scope("ORCIODataset") as scope:
+            capacity = 4096
+            resource, columns_v = core_ops.io_orc_readable_init(
+                filename,
+                container=scope,
+                shared_name="{}/{}".format(filename, uuid.uuid4().hex),
+            )
+            columns = columns if columns is not None else columns_v.numpy()
+            columns_dataset = []
+            columns_function = []
+            for column in columns:
+                shape, dtype = core_ops.io_orc_readable_spec(resource, column)
+                shape = tf.TensorShape([None if e < 0 else e for e in shape.numpy()])
+                dtype = tf.as_dtype(dtype.numpy())
+                function = _ORCIODatasetFunction(
+                    core_ops.io_orc_readable_read, resource, column, shape, dtype
+                )
+                columns_function.append(function)
+
+            for (column, function) in zip(columns, columns_function):
+                column_dataset = tf.compat.v2.data.Dataset.range(
+                    0, sys.maxsize, capacity
+                )
+                column_dataset = column_dataset.map(
+                    lambda index: function(index, index + capacity)
+                )
+                column_dataset = column_dataset.apply(
+                    tf.data.experimental.take_while(
+                        lambda v: tf.greater(tf.shape(v)[0], 0)
+                    )
+                )
+                columns_dataset.append(column_dataset)
+            if len(columns_dataset) == 1:
+                dataset = columns_dataset[0]
+            else:
+                dataset = tf.compat.v2.data.Dataset.zip(tuple(columns_dataset))
+            dataset = dataset.unbatch()
+
+            self._function = columns_function
+            self._dataset = dataset
+            super().__init__(
+                self._dataset._variant_tensor
+            )  # pylint: disable=protected-access
+
+    def _inputs(self):
+        return []
+
+    @property
+    def element_spec(self):
+        return self._dataset.element_spec
diff --git a/tests/test_orc.py b/tests/test_orc.py
new file mode 100644
index 000000000..d4cbd5f10
--- /dev/null
+++ b/tests/test_orc.py
@@ -0,0 +1,97 @@
+# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may not
+# use this file except in compliance with the License.  You may obtain a copy of
+# the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  See the
+# License for the specific language governing permissions and limitations under
+# the License.
+# ==============================================================================
+"""
+Test ORCDataset
+"""
+
+import os
+import numpy as np
+
+import tensorflow as tf
+import tensorflow_io as tfio
+
+
+def test_orc_input():
+    """test_pcap_input
+    """
+    print("Testing ORCDataset")
+    orc_filename = os.path.join(
+        os.path.dirname(os.path.abspath(__file__)), "test_orc", "iris.orc"
+    )
+
+    dataset = tfio.IODataset.from_orc(orc_filename, capacity=15).batch(1)
+    packets_total = 0
+    for v in dataset:
+        if packets_total == 0:
+            sepal_length, _, _, _, species = v
+            assert sepal_length.dtype == tf.float32
+            assert species.dtype == tf.string
+            assert tf.math.less(tf.math.abs(sepal_length - 5.0999999), 0.0001)
+            assert tf.math.equal(species, "setosa")
+        packets_total += 1
+
+    assert packets_total == 150
+
+
+def test_orc_keras():
+    """Test case for ORCDataset with Keras"""
+    orc_filename = os.path.join(
+        os.path.dirname(os.path.abspath(__file__)), "test_orc", "iris.orc"
+    )
+
+    feature_cols = ["sepal_length", "sepal_width", "petal_length", "petal_width"]
+    label_cols = ["species"]
+
+    feature_dataset = tfio.IODataset.from_orc(orc_filename, columns=feature_cols)
+
+    label_dataset = tfio.IODataset.from_orc(orc_filename, columns=label_cols)
+
+    @tf.function
+    def species_float_conversion(x):
+        if x == "virginica":
+            return 1.0
+        if x == "versicolor":
+            return 2.0
+        if x == "setosa":
+            return 3.0
+        return 4.0
+
+    label_dataset = label_dataset.map(species_float_conversion)
+    dataset = tf.data.Dataset.zip((feature_dataset, label_dataset))
+    dataset = dataset.batch(1)
+
+    def pack_features_vector(features, labels):
+        """Pack the features into a single array."""
+        features = tf.stack(list(features), axis=1)
+        return features, labels
+
+    dataset = dataset.map(pack_features_vector)
+
+    model = tf.keras.Sequential(
+        [
+            tf.keras.layers.Dense(
+                10, activation=tf.nn.relu, input_shape=(4,)
+            ),  # input shape required
+            tf.keras.layers.Dense(10, activation=tf.nn.relu),
+            tf.keras.layers.Dense(3),
+        ]
+    )
+
+    model.compile(optimizer="adam", loss="binary_crossentropy", metrics=["accuracy"])
+    model.fit(dataset, epochs=5)
+
+
+if __name__ == "__main__":
+    test.main()
diff --git a/tests/test_orc/iris.orc b/tests/test_orc/iris.orc
new file mode 100644
index 000000000..717948c05
Binary files /dev/null and b/tests/test_orc/iris.orc differ