@@ -33,11 +33,12 @@ class ElasticsearchReadableResource : public ResourceBase {
33
33
Status Init (const std::string& healthcheck_url,
34
34
const std::string& healthcheck_field,
35
35
const std::string& request_url,
36
+ const std::vector<string>& headers,
36
37
std::function<Status(const TensorShape& columns_shape,
37
38
Tensor** columns, Tensor** dtypes)>
38
39
allocate_func) {
39
40
// Perform healthcheck before proceeding
40
- Healthcheck (healthcheck_url, healthcheck_field);
41
+ Healthcheck (healthcheck_url, healthcheck_field, headers );
41
42
42
43
// Make the request API call and set the metadata based on a sample of
43
44
// data returned. The request_url will have the "scroll" param set with
@@ -46,7 +47,7 @@ class ElasticsearchReadableResource : public ResourceBase {
46
47
base_dtypes_.clear ();
47
48
base_columns_.clear ();
48
49
rapidjson::Document response_json;
49
- MakeAPICall (request_url, &response_json);
50
+ MakeAPICall (request_url, &response_json, headers );
50
51
51
52
// Validate the presence of the _scroll_id in the response.
52
53
// The _scroll_id keeps might change in subsequent calls, thus not
@@ -121,9 +122,9 @@ class ElasticsearchReadableResource : public ResourceBase {
121
122
data_allocate_func) {
122
123
rapidjson::Document response_json;
123
124
if (scroll_id == " " ) {
124
- MakeAPICall (request_url, &response_json);
125
+ MakeAPICall (request_url, &response_json, headers_ );
125
126
} else {
126
- MakeAPICall (scroll_request_url, &response_json);
127
+ MakeAPICall (scroll_request_url, &response_json, headers_ );
127
128
}
128
129
129
130
if (response_json.HasMember (" _scroll_id" )) {
@@ -172,10 +173,11 @@ class ElasticsearchReadableResource : public ResourceBase {
172
173
173
174
protected:
174
175
Status Healthcheck (const std::string& healthcheck_url,
175
- const std::string& healthcheck_field) {
176
+ const std::string& healthcheck_field,
177
+ const std::vector<string>& headers) {
176
178
// Make the healthcheck API call and get the response json
177
179
rapidjson::Document response_json;
178
- MakeAPICall (healthcheck_url, &response_json);
180
+ MakeAPICall (healthcheck_url, &response_json, headers );
179
181
180
182
if (response_json.HasMember (healthcheck_field.c_str ())) {
181
183
// LOG(INFO) << "cluster health: "
@@ -186,8 +188,8 @@ class ElasticsearchReadableResource : public ResourceBase {
186
188
return Status::OK ();
187
189
}
188
190
189
- Status MakeAPICall (const std::string& url,
190
- rapidjson::Document* response_json ) {
191
+ Status MakeAPICall (const std::string& url, rapidjson::Document* response_json,
192
+ const std::vector<string>& headers ) {
191
193
HttpRequest* request = http_request_factory_.Create ();
192
194
193
195
if (scroll_id != " " ) {
@@ -200,7 +202,15 @@ class ElasticsearchReadableResource : public ResourceBase {
200
202
}
201
203
202
204
// LOG(INFO) << "Setting the headers";
203
- request->AddHeader (" Content-Type" , " application/json; charset=utf-8" );
205
+ for (size_t i = 0 ; i < headers.size (); ++i) {
206
+ std::string header = headers[i];
207
+ std::vector<string> parts = str_util::Split (header, " =" );
208
+ if (parts.size () != 2 ) {
209
+ return errors::InvalidArgument (" invalid header configuration: " ,
210
+ header);
211
+ }
212
+ request->AddHeader (parts[0 ], parts[1 ]);
213
+ }
204
214
205
215
// LOG(INFO) << "Setting the response buffer";
206
216
std::vector<char > response;
@@ -231,6 +241,9 @@ class ElasticsearchReadableResource : public ResourceBase {
231
241
" Invalid JSON response. The response should be an object" );
232
242
}
233
243
244
+ // Store the default headers if the response is valid
245
+ headers_ = headers;
246
+
234
247
return Status::OK ();
235
248
}
236
249
@@ -242,6 +255,7 @@ class ElasticsearchReadableResource : public ResourceBase {
242
255
std::vector<DataType> base_dtypes_;
243
256
std::vector<string> base_columns_;
244
257
std::string scroll_id = " " ;
258
+ std::vector<string> headers_;
245
259
};
246
260
247
261
class ElasticsearchReadableInitOp
@@ -271,17 +285,24 @@ class ElasticsearchReadableInitOp
271
285
OP_REQUIRES_OK (context, context->input (" request_url" , &request_url_tensor));
272
286
const string& request_url = request_url_tensor->scalar <tstring>()();
273
287
288
+ const Tensor* headers_tensor;
289
+ OP_REQUIRES_OK (context, context->input (" headers" , &headers_tensor));
290
+ std::vector<string> headers;
291
+ for (int64 i = 0 ; i < headers_tensor->NumElements (); i++) {
292
+ headers.push_back (headers_tensor->flat <tstring>()(i));
293
+ }
294
+
274
295
OP_REQUIRES_OK (
275
- context,
276
- resource_-> Init ( healthcheck_url, healthcheck_field, request_url,
277
- [&](const TensorShape& columns_shape, Tensor** columns,
278
- Tensor** dtypes) -> Status {
279
- TF_RETURN_IF_ERROR (context-> allocate_output (
280
- 1 , columns_shape, columns));
281
- TF_RETURN_IF_ERROR (context-> allocate_output (
282
- 2 , columns_shape, dtypes));
283
- return Status::OK ();
284
- }));
296
+ context, resource_-> Init (
297
+ healthcheck_url, healthcheck_field, request_url, headers ,
298
+ [&](const TensorShape& columns_shape, Tensor** columns,
299
+ Tensor** dtypes) -> Status {
300
+ TF_RETURN_IF_ERROR (
301
+ context-> allocate_output ( 1 , columns_shape, columns));
302
+ TF_RETURN_IF_ERROR (
303
+ context-> allocate_output ( 2 , columns_shape, dtypes));
304
+ return Status::OK ();
305
+ }));
285
306
}
286
307
287
308
Status CreateResource (ElasticsearchReadableResource** resource)
0 commit comments