|
113 | 113 | },
|
114 | 114 | {
|
115 | 115 | "cell_type": "code",
|
116 |
| - "execution_count": 3, |
| 116 | + "execution_count": null, |
117 | 117 | "metadata": {
|
118 | 118 | "id": "m6KXZuTBWgRm"
|
119 | 119 | },
|
|
134 | 134 | },
|
135 | 135 | {
|
136 | 136 | "cell_type": "code",
|
137 |
| - "execution_count": 4, |
| 137 | + "execution_count": null, |
138 | 138 | "metadata": {
|
139 | 139 | "id": "dX74RKfZ_TdF"
|
140 | 140 | },
|
|
188 | 188 | {
|
189 | 189 | "cell_type": "markdown",
|
190 | 190 | "metadata": {
|
191 |
| - "id": "IGnbXuVnSo8T" |
| 191 | + "id": "jJzE6lMwhY7l" |
192 | 192 | },
|
193 | 193 | "source": [
|
194 | 194 | "Download the corresponding schema file of the sample Avro file:"
|
|
198 | 198 | "cell_type": "code",
|
199 | 199 | "execution_count": null,
|
200 | 200 | "metadata": {
|
201 |
| - "id": "Tu01THzWcE-J" |
| 201 | + "id": "Cpxa6yhLhY7l" |
202 | 202 | },
|
203 | 203 | "outputs": [],
|
204 | 204 | "source": [
|
|
238 | 238 | {
|
239 | 239 | "cell_type": "markdown",
|
240 | 240 | "metadata": {
|
241 |
| - "id": "upgCc3gXybsB" |
| 241 | + "id": "m7XR0agdhY7n" |
242 | 242 | },
|
243 | 243 | "source": [
|
244 | 244 | "To read and print an Avro file in a human-readable format:\n"
|
|
276 | 276 | {
|
277 | 277 | "cell_type": "markdown",
|
278 | 278 | "metadata": {
|
279 |
| - "id": "z9GCyPWNuOm7" |
| 279 | + "id": "qKgUPm6JhY7n" |
280 | 280 | },
|
281 | 281 | "source": [
|
282 | 282 | "And the schema of `train.avro` which is represented by `train.avsc` is a JSON-formatted file.\n",
|
|
287 | 287 | "cell_type": "code",
|
288 | 288 | "execution_count": null,
|
289 | 289 | "metadata": {
|
290 |
| - "id": "nS3eTBvjt-O5" |
| 290 | + "id": "D-95aom1hY7o" |
291 | 291 | },
|
292 | 292 | "outputs": [],
|
293 | 293 | "source": [
|
|
302 | 302 | {
|
303 | 303 | "cell_type": "markdown",
|
304 | 304 | "metadata": {
|
305 |
| - "id": "4CfKVmCvwcL7" |
| 305 | + "id": "21szKFY1hY7o" |
306 | 306 | },
|
307 | 307 | "source": [
|
308 | 308 | "### Prepare the dataset\n"
|
|
311 | 311 | {
|
312 | 312 | "cell_type": "markdown",
|
313 | 313 | "metadata": {
|
314 |
| - "id": "z9GCyPWNuOm7" |
| 314 | + "id": "hNeBO9m-hY7o" |
315 | 315 | },
|
316 | 316 | "source": [
|
317 | 317 | "Load `train.avro` as TensorFlow dataset with Avro dataset API: \n"
|
|
321 | 321 | "cell_type": "code",
|
322 | 322 | "execution_count": null,
|
323 | 323 | "metadata": {
|
324 |
| - "id": "nS3eTBvjt-O5" |
| 324 | + "id": "v-nbLZHKhY7o" |
325 | 325 | },
|
326 | 326 | "outputs": [],
|
327 | 327 | "source": [
|
|
363 | 363 | "cell_type": "code",
|
364 | 364 | "execution_count": null,
|
365 | 365 | "metadata": {
|
366 |
| - "id": "nS3eTBvjt-O5" |
| 366 | + "id": "bc9vDHyghY7p" |
367 | 367 | },
|
368 | 368 | "outputs": [],
|
369 | 369 | "source": [
|
|
382 | 382 | {
|
383 | 383 | "cell_type": "markdown",
|
384 | 384 | "metadata": {
|
385 |
| - "id": "IF_kYz_o2DH4" |
| 385 | + "id": "x45KolnDhY7p" |
386 | 386 | },
|
387 | 387 | "source": [
|
388 | 388 | "One can also increase num_parallel_reads to expediate Avro data processing by increasing avro parse/read parallelism.\n"
|
|
392 | 392 | "cell_type": "code",
|
393 | 393 | "execution_count": null,
|
394 | 394 | "metadata": {
|
395 |
| - "id": "nS3eTBvjt-O5" |
| 395 | + "id": "Z2x-gPj_hY7p" |
396 | 396 | },
|
397 | 397 | "outputs": [],
|
398 | 398 | "source": [
|
|
412 | 412 | {
|
413 | 413 | "cell_type": "markdown",
|
414 | 414 | "metadata": {
|
415 |
| - "id": "IF_kYz_o2DH4" |
| 415 | + "id": "6V-nwDJGhY7p" |
416 | 416 | },
|
417 | 417 | "source": [
|
418 | 418 | "For detailed usage of `make_avro_record_dataset`, please refer to <a target=\"_blank\" href=\"https://www.tensorflow.org/io/api_docs/python/tfio/experimental/columnar/make_avro_record_dataset\">API doc</a>.\n"
|
|
421 | 421 | {
|
422 | 422 | "cell_type": "markdown",
|
423 | 423 | "metadata": {
|
424 |
| - "id": "4CfKVmCvwcL7" |
| 424 | + "id": "vIOijGlAhY7p" |
425 | 425 | },
|
426 | 426 | "source": [
|
427 | 427 | "### Train tf.keras models with Avro dataset\n",
|
|
432 | 432 | {
|
433 | 433 | "cell_type": "markdown",
|
434 | 434 | "metadata": {
|
435 |
| - "id": "z9GCyPWNuOm7" |
| 435 | + "id": "s7K85D53hY7q" |
436 | 436 | },
|
437 | 437 | "source": [
|
438 | 438 | "Load `train.avro` as TensorFlow dataset with Avro dataset API: \n"
|
|
442 | 442 | "cell_type": "code",
|
443 | 443 | "execution_count": null,
|
444 | 444 | "metadata": {
|
445 |
| - "id": "nS3eTBvjt-O5" |
| 445 | + "id": "VFoeLwIOhY7q" |
446 | 446 | },
|
447 | 447 | "outputs": [],
|
448 | 448 | "source": [
|
449 | 449 | "features = {\n",
|
450 |
| - " 'features[*]': tfio.experimental.columnar.VarLenFeatureWithRank(dtype=tf.int32)\n", |
| 450 | + " 'features[*]': tfio.experimental.columnar.VarLenFeatureWithRank(dtype=tf.int32),\n", |
| 451 | + " 'label': tf.io.FixedLenFeature(shape=[], dtype=tf.int32, default_value=-100),\n", |
451 | 452 | "}\n",
|
452 | 453 | "\n",
|
| 454 | + "\n", |
453 | 455 | "schema = tf.io.gfile.GFile('train.avsc').read()\n",
|
454 | 456 | "\n",
|
455 | 457 | "dataset = tfio.experimental.columnar.make_avro_record_dataset(file_pattern=['train.avro'],\n",
|
|
463 | 465 | {
|
464 | 466 | "cell_type": "markdown",
|
465 | 467 | "metadata": {
|
466 |
| - "id": "z9GCyPWNuOm7" |
| 468 | + "id": "hR2FnIIMhY7q" |
467 | 469 | },
|
468 | 470 | "source": [
|
469 | 471 | "Define a simple keras model: \n"
|
470 | 472 | ]
|
471 | 473 | },
|
472 | 474 | {
|
473 | 475 | "cell_type": "code",
|
474 |
| - "execution_count": 3, |
| 476 | + "execution_count": null, |
475 | 477 | "metadata": {
|
476 |
| - "id": "m6KXZuTBWgRm" |
| 478 | + "id": "hGV5rHfJhY7q" |
477 | 479 | },
|
478 | 480 | "outputs": [],
|
479 | 481 | "source": [
|
|
488 | 490 | {
|
489 | 491 | "cell_type": "markdown",
|
490 | 492 | "metadata": {
|
491 |
| - "id": "4CfKVmCvwcL7" |
| 493 | + "id": "Tuv9n6HshY7q" |
492 | 494 | },
|
493 | 495 | "source": [
|
494 | 496 | "### Train the keras model with Avro dataset:\n"
|
495 | 497 | ]
|
496 | 498 | },
|
497 | 499 | {
|
498 | 500 | "cell_type": "code",
|
499 |
| - "execution_count": 3, |
| 501 | + "execution_count": null, |
500 | 502 | "metadata": {
|
501 |
| - "id": "m6KXZuTBWgRm" |
| 503 | + "id": "lb44cUuWhY7r" |
502 | 504 | },
|
503 | 505 | "outputs": [],
|
504 | 506 | "source": [
|
505 |
| - "model.fit(x=dataset, epochs=1, steps_per_epoch=1, verbose=1)\n" |
| 507 | + "def extract_label(feature):\n", |
| 508 | + " label = feature.pop('label')\n", |
| 509 | + " return tf.sparse.to_dense(feature['features[*]']), label\n", |
| 510 | + "\n", |
| 511 | + "model.fit(x=dataset.map(extract_label), epochs=1, steps_per_epoch=1, verbose=1)\n" |
506 | 512 | ]
|
507 | 513 | },
|
508 | 514 | {
|
509 | 515 | "cell_type": "markdown",
|
510 | 516 | "metadata": {
|
511 |
| - "id": "IF_kYz_o2DH4" |
| 517 | + "id": "7K6qAv5rhY7r" |
512 | 518 | },
|
513 | 519 | "source": [
|
514 | 520 | "The avro dataset can parse and coerce any avro data into TensorFlow tensors, including records in records, maps, arrays, branches, and enumerations. The parsing information is passed into the avro dataset implementation as a map where \n",
|
|
541 | 547 | {
|
542 | 548 | "cell_type": "markdown",
|
543 | 549 | "metadata": {
|
544 |
| - "id": "IF_kYz_o2DH4" |
| 550 | + "id": "1PFQPuy5hY7r" |
545 | 551 | },
|
546 | 552 | "source": [
|
547 | 553 | "A comprehensive set of examples of Avro dataset API is provided within <a target=\"_blank\" href=\"https://github.com/tensorflow/io/blob/master/tests/test_parse_avro.py#L437\">the tests</a>.\n"
|
|
0 commit comments