Skip to content

Commit 95b4ac7

Browse files
authored
Elasticsearch: headers support for retrieving data (#1110)
* headers support for retrieving data * modify docstring * formatted connection error message * lint fix
1 parent d89f6ed commit 95b4ac7

File tree

5 files changed

+114
-35
lines changed

5 files changed

+114
-35
lines changed

tensorflow_io/core/kernels/elasticsearch_kernels.cc

+40-19
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,12 @@ class ElasticsearchReadableResource : public ResourceBase {
3333
Status Init(const std::string& healthcheck_url,
3434
const std::string& healthcheck_field,
3535
const std::string& request_url,
36+
const std::vector<string>& headers,
3637
std::function<Status(const TensorShape& columns_shape,
3738
Tensor** columns, Tensor** dtypes)>
3839
allocate_func) {
3940
// Perform healthcheck before proceeding
40-
Healthcheck(healthcheck_url, healthcheck_field);
41+
Healthcheck(healthcheck_url, healthcheck_field, headers);
4142

4243
// Make the request API call and set the metadata based on a sample of
4344
// data returned. The request_url will have the "scroll" param set with
@@ -46,7 +47,7 @@ class ElasticsearchReadableResource : public ResourceBase {
4647
base_dtypes_.clear();
4748
base_columns_.clear();
4849
rapidjson::Document response_json;
49-
MakeAPICall(request_url, &response_json);
50+
MakeAPICall(request_url, &response_json, headers);
5051

5152
// Validate the presence of the _scroll_id in the response.
5253
// The _scroll_id keeps might change in subsequent calls, thus not
@@ -121,9 +122,9 @@ class ElasticsearchReadableResource : public ResourceBase {
121122
data_allocate_func) {
122123
rapidjson::Document response_json;
123124
if (scroll_id == "") {
124-
MakeAPICall(request_url, &response_json);
125+
MakeAPICall(request_url, &response_json, headers_);
125126
} else {
126-
MakeAPICall(scroll_request_url, &response_json);
127+
MakeAPICall(scroll_request_url, &response_json, headers_);
127128
}
128129

129130
if (response_json.HasMember("_scroll_id")) {
@@ -172,10 +173,11 @@ class ElasticsearchReadableResource : public ResourceBase {
172173

173174
protected:
174175
Status Healthcheck(const std::string& healthcheck_url,
175-
const std::string& healthcheck_field) {
176+
const std::string& healthcheck_field,
177+
const std::vector<string>& headers) {
176178
// Make the healthcheck API call and get the response json
177179
rapidjson::Document response_json;
178-
MakeAPICall(healthcheck_url, &response_json);
180+
MakeAPICall(healthcheck_url, &response_json, headers);
179181

180182
if (response_json.HasMember(healthcheck_field.c_str())) {
181183
// LOG(INFO) << "cluster health: "
@@ -186,8 +188,8 @@ class ElasticsearchReadableResource : public ResourceBase {
186188
return Status::OK();
187189
}
188190

189-
Status MakeAPICall(const std::string& url,
190-
rapidjson::Document* response_json) {
191+
Status MakeAPICall(const std::string& url, rapidjson::Document* response_json,
192+
const std::vector<string>& headers) {
191193
HttpRequest* request = http_request_factory_.Create();
192194

193195
if (scroll_id != "") {
@@ -200,7 +202,15 @@ class ElasticsearchReadableResource : public ResourceBase {
200202
}
201203

202204
// LOG(INFO) << "Setting the headers";
203-
request->AddHeader("Content-Type", "application/json; charset=utf-8");
205+
for (size_t i = 0; i < headers.size(); ++i) {
206+
std::string header = headers[i];
207+
std::vector<string> parts = str_util::Split(header, "=");
208+
if (parts.size() != 2) {
209+
return errors::InvalidArgument("invalid header configuration: ",
210+
header);
211+
}
212+
request->AddHeader(parts[0], parts[1]);
213+
}
204214

205215
// LOG(INFO) << "Setting the response buffer";
206216
std::vector<char> response;
@@ -231,6 +241,9 @@ class ElasticsearchReadableResource : public ResourceBase {
231241
"Invalid JSON response. The response should be an object");
232242
}
233243

244+
// Store the default headers if the response is valid
245+
headers_ = headers;
246+
234247
return Status::OK();
235248
}
236249

@@ -242,6 +255,7 @@ class ElasticsearchReadableResource : public ResourceBase {
242255
std::vector<DataType> base_dtypes_;
243256
std::vector<string> base_columns_;
244257
std::string scroll_id = "";
258+
std::vector<string> headers_;
245259
};
246260

247261
class ElasticsearchReadableInitOp
@@ -271,17 +285,24 @@ class ElasticsearchReadableInitOp
271285
OP_REQUIRES_OK(context, context->input("request_url", &request_url_tensor));
272286
const string& request_url = request_url_tensor->scalar<tstring>()();
273287

288+
const Tensor* headers_tensor;
289+
OP_REQUIRES_OK(context, context->input("headers", &headers_tensor));
290+
std::vector<string> headers;
291+
for (int64 i = 0; i < headers_tensor->NumElements(); i++) {
292+
headers.push_back(headers_tensor->flat<tstring>()(i));
293+
}
294+
274295
OP_REQUIRES_OK(
275-
context,
276-
resource_->Init(healthcheck_url, healthcheck_field, request_url,
277-
[&](const TensorShape& columns_shape, Tensor** columns,
278-
Tensor** dtypes) -> Status {
279-
TF_RETURN_IF_ERROR(context->allocate_output(
280-
1, columns_shape, columns));
281-
TF_RETURN_IF_ERROR(context->allocate_output(
282-
2, columns_shape, dtypes));
283-
return Status::OK();
284-
}));
296+
context, resource_->Init(
297+
healthcheck_url, healthcheck_field, request_url, headers,
298+
[&](const TensorShape& columns_shape, Tensor** columns,
299+
Tensor** dtypes) -> Status {
300+
TF_RETURN_IF_ERROR(
301+
context->allocate_output(1, columns_shape, columns));
302+
TF_RETURN_IF_ERROR(
303+
context->allocate_output(2, columns_shape, dtypes));
304+
return Status::OK();
305+
}));
285306
}
286307

287308
Status CreateResource(ElasticsearchReadableResource** resource)

tensorflow_io/core/ops/elasticsearch_ops.cc

+1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ REGISTER_OP("IO>ElasticsearchReadableInit")
2525
.Input("healthcheck_url: string")
2626
.Input("healthcheck_field: string")
2727
.Input("request_url: string")
28+
.Input("headers: string")
2829
.Output("resource: resource")
2930
.Output("columns: string")
3031
.Output("dtypes: string")

tensorflow_io/core/python/experimental/elasticsearch_dataset_ops.py

+25-9
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,11 @@ class _ElasticsearchHandler:
2525
session data.
2626
"""
2727

28-
def __init__(self, nodes, index, doc_type):
28+
def __init__(self, nodes, index, doc_type, headers_dict):
2929
self.nodes = nodes
3030
self.index = index
3131
self.doc_type = doc_type
32+
self.headers_dict = headers_dict
3233
self.prepare_base_urls()
3334
self.prepare_connection_data()
3435

@@ -57,8 +58,6 @@ def prepare_base_urls(self):
5758
base_url = "{}://{}".format(url_obj.scheme, url_obj.netloc)
5859
self.base_urls.append(base_url)
5960

60-
return self.base_urls
61-
6261
def prepare_connection_data(self):
6362
"""Prepares the healthcheck and resource urls from the base_urls"""
6463

@@ -75,7 +74,17 @@ def prepare_connection_data(self):
7574
)
7675
self.request_urls.append(request_url)
7776

78-
return self.healthcheck_urls, self.request_urls
77+
self.headers = ["Content-Type=application/json"]
78+
if self.headers_dict is not None:
79+
if isinstance(self.headers_dict, dict):
80+
for key, value in self.headers_dict.items():
81+
if key.lower() == "content-type":
82+
continue
83+
self.headers.append("{}={}".format(key, value))
84+
else:
85+
raise ValueError(
86+
"Headers should be a dict of key:value pairs. Got: ", self.headers
87+
)
7988

8089
def get_healthy_resource(self):
8190
"""Retrieve the resource which is connected to a healthy node"""
@@ -88,6 +97,7 @@ def get_healthy_resource(self):
8897
healthcheck_url=healthcheck_url,
8998
healthcheck_field="status",
9099
request_url=request_url,
100+
headers=self.headers,
91101
)
92102
print("Connection successful: {}".format(healthcheck_url))
93103
dtypes = []
@@ -102,11 +112,13 @@ def get_healthy_resource(self):
102112
dtypes.append(tf.string)
103113
return resource, columns.numpy(), dtypes, request_url
104114
except Exception:
105-
print("Skipping host: {}".format(healthcheck_url))
115+
print("Skipping node: {}".format(healthcheck_url))
106116
continue
107117
else:
108118
raise ConnectionError(
109-
"No healthy node available for this index, check the cluster status and index"
119+
"No healthy node available for the index: {}, please check the cluster config".format(
120+
self.index
121+
)
110122
)
111123

112124
def get_next_batch(self, resource, request_url):
@@ -155,21 +167,25 @@ def parse_json(self, raw_item, columns, dtypes):
155167
class ElasticsearchIODataset(tf.compat.v2.data.Dataset):
156168
"""Represents an elasticsearch based tf.data.Dataset"""
157169

158-
def __init__(self, nodes, index, doc_type=None, internal=True):
170+
def __init__(self, nodes, index, doc_type=None, headers=None, internal=True):
159171
"""Prepare the ElasticsearchIODataset.
160172
161173
Args:
162174
nodes: A `tf.string` tensor containing the hostnames of nodes
163175
in [protocol://hostname:port] format.
164176
For example: ["http://localhost:9200"]
165177
index: A `tf.string` representing the elasticsearch index to query.
166-
doc_type: A `tf.string` representing the type of documents in the index
178+
doc_type: (Optional) A `tf.string` representing the type of documents in the index
167179
to query.
180+
headers: (Optional) A dict of headers. For example:
181+
{'Content-Type': 'application/json'}
168182
"""
169183
with tf.name_scope("ElasticsearchIODataset"):
170184
assert internal
171185

172-
handler = _ElasticsearchHandler(nodes=nodes, index=index, doc_type=doc_type)
186+
handler = _ElasticsearchHandler(
187+
nodes=nodes, index=index, doc_type=doc_type, headers_dict=headers
188+
)
173189
resource, columns, dtypes, request_url = handler.get_healthy_resource()
174190

175191
dataset = tf.data.experimental.Counter()

tests/test_elasticsearch/elasticsearch_test.sh

+22-2
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,20 @@ action=$1
2121

2222
if [ "$action" == "start" ]; then
2323

24+
echo ""
2425
echo "Preparing the environment variables file..."
26+
echo ""
2527
cat >> .env-vars << "EOF"
2628
cluster.name=tfio-es-cluster
2729
bootstrap.memory_lock=true
2830
discovery.type=single-node
31+
ELASTIC_PASSWORD=default_password
32+
xpack.security.enabled=true
2933
EOF
3034

35+
echo ""
3136
echo "Starting the tfio elasticsearch docker container..."
37+
echo ""
3238
ELASTICSEARCH_IMAGE="docker.elastic.co/elasticsearch/elasticsearch:7.4.0"
3339

3440
docker run -d --rm --name=tfio-elasticsearch \
@@ -37,22 +43,36 @@ docker run -d --rm --name=tfio-elasticsearch \
3743
--ulimit memlock=-1:-1 \
3844
${ELASTICSEARCH_IMAGE}
3945

46+
echo ""
4047
echo "Waiting for the elasticsearch cluster to be up and running..."
48+
echo ""
4149
sleep 20
4250

51+
echo ""
4352
echo "Checking the base REST-API endpoint..."
44-
curl localhost:9200/
53+
echo ""
54+
# The Authorization header contains the base64 encoded value of "elastic:default_password"
55+
# As per the environment variable set while starting the container.
56+
curl -X GET localhost:9200/ --header 'Authorization: Basic ZWxhc3RpYzpkZWZhdWx0X3Bhc3N3b3Jk'
4557

58+
echo ""
4659
echo "Checking the healthcheck REST-API endpoint..."
47-
curl localhost:9200/_cluster/health
60+
echo ""
61+
curl -X GET localhost:9200/_cluster/health --header 'Authorization: Basic ZWxhc3RpYzpkZWZhdWx0X3Bhc3N3b3Jk'
4862

63+
echo ""
4964
echo "Clean up..."
65+
echo ""
5066
rm -rf ./.env-vars
5167

5268
elif [ "$action" == "stop" ]; then
69+
echo ""
5370
echo "Removing the tfio elasticsearch container..."
71+
echo ""
5472
docker rm -f tfio-elasticsearch
5573

5674
else
75+
echo ""
5776
echo "Invalid value: Use 'start' to run the container and 'stop' to remove it."
77+
echo ""
5878
fi

tests/test_elasticsearch_eager.py

+26-5
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,10 @@
2929
NODE = "http://localhost:9200"
3030
INDEX = "people"
3131
DOC_TYPE = "survivors"
32-
HEADERS = {"Content-Type": "application/json"}
32+
HEADERS = {
33+
"Content-Type": "application/json",
34+
"Authorization": "Basic ZWxhc3RpYzpkZWZhdWx0X3Bhc3N3b3Jk",
35+
}
3336
ATTRS = ["name", "gender", "age", "fare", "survived"]
3437

3538

@@ -51,7 +54,7 @@ def test_create_index():
5154
"""Create an index in the cluster"""
5255

5356
create_index_url = "{}/{}".format(NODE, INDEX)
54-
res = requests.put(create_index_url)
57+
res = requests.put(create_index_url, headers=HEADERS)
5558
assert res.status_code == 200
5659

5760

@@ -87,7 +90,7 @@ def test_elasticsearch_io_dataset():
8790
"""Test the functionality of the ElasticsearchIODataset"""
8891

8992
dataset = tfio.experimental.elasticsearch.ElasticsearchIODataset(
90-
nodes=[NODE], index=INDEX, doc_type=DOC_TYPE
93+
nodes=[NODE], index=INDEX, doc_type=DOC_TYPE, headers=HEADERS
9194
)
9295

9396
assert issubclass(type(dataset), tf.data.Dataset)
@@ -97,13 +100,31 @@ def test_elasticsearch_io_dataset():
97100
assert attr in item
98101

99102

103+
@pytest.mark.skipif(not is_container_running(), reason="The container is not running")
104+
def test_elasticsearch_io_dataset_no_auth():
105+
"""Test the functionality of the ElasticsearchIODataset when basic auth is
106+
required but the associated header is not passed.
107+
"""
108+
109+
try:
110+
dataset = tfio.experimental.elasticsearch.ElasticsearchIODataset(
111+
nodes=[NODE], index=INDEX, doc_type=DOC_TYPE
112+
)
113+
except ConnectionError as e:
114+
assert str(
115+
e
116+
) == "No healthy node available for the index: {}, please check the cluster config".format(
117+
INDEX
118+
)
119+
120+
100121
@pytest.mark.skipif(not is_container_running(), reason="The container is not running")
101122
def test_elasticsearch_io_dataset_batch():
102123
"""Test the functionality of the ElasticsearchIODataset"""
103124

104125
BATCH_SIZE = 2
105126
dataset = tfio.experimental.elasticsearch.ElasticsearchIODataset(
106-
nodes=[NODE], index=INDEX, doc_type=DOC_TYPE
127+
nodes=[NODE], index=INDEX, doc_type=DOC_TYPE, headers=HEADERS
107128
).batch(BATCH_SIZE)
108129

109130
assert issubclass(type(dataset), tf.data.Dataset)
@@ -119,5 +140,5 @@ def test_cleanup():
119140
"""Clean up the index"""
120141

121142
delete_index_url = "{}/{}".format(NODE, INDEX)
122-
res = requests.delete(delete_index_url)
143+
res = requests.delete(delete_index_url, headers=HEADERS)
123144
assert res.status_code == 200

0 commit comments

Comments
 (0)