Skip to content

Commit cd90c60

Browse files
authored
Include labels when training an Avro dataset (#1571)
1 parent e8ed07d commit cd90c60

File tree

1 file changed

+33
-27
lines changed

1 file changed

+33
-27
lines changed

docs/tutorials/avro.ipynb

+33-27
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@
113113
},
114114
{
115115
"cell_type": "code",
116-
"execution_count": 3,
116+
"execution_count": null,
117117
"metadata": {
118118
"id": "m6KXZuTBWgRm"
119119
},
@@ -134,7 +134,7 @@
134134
},
135135
{
136136
"cell_type": "code",
137-
"execution_count": 4,
137+
"execution_count": null,
138138
"metadata": {
139139
"id": "dX74RKfZ_TdF"
140140
},
@@ -188,7 +188,7 @@
188188
{
189189
"cell_type": "markdown",
190190
"metadata": {
191-
"id": "IGnbXuVnSo8T"
191+
"id": "jJzE6lMwhY7l"
192192
},
193193
"source": [
194194
"Download the corresponding schema file of the sample Avro file:"
@@ -198,7 +198,7 @@
198198
"cell_type": "code",
199199
"execution_count": null,
200200
"metadata": {
201-
"id": "Tu01THzWcE-J"
201+
"id": "Cpxa6yhLhY7l"
202202
},
203203
"outputs": [],
204204
"source": [
@@ -238,7 +238,7 @@
238238
{
239239
"cell_type": "markdown",
240240
"metadata": {
241-
"id": "upgCc3gXybsB"
241+
"id": "m7XR0agdhY7n"
242242
},
243243
"source": [
244244
"To read and print an Avro file in a human-readable format:\n"
@@ -276,7 +276,7 @@
276276
{
277277
"cell_type": "markdown",
278278
"metadata": {
279-
"id": "z9GCyPWNuOm7"
279+
"id": "qKgUPm6JhY7n"
280280
},
281281
"source": [
282282
"And the schema of `train.avro` which is represented by `train.avsc` is a JSON-formatted file.\n",
@@ -287,7 +287,7 @@
287287
"cell_type": "code",
288288
"execution_count": null,
289289
"metadata": {
290-
"id": "nS3eTBvjt-O5"
290+
"id": "D-95aom1hY7o"
291291
},
292292
"outputs": [],
293293
"source": [
@@ -302,7 +302,7 @@
302302
{
303303
"cell_type": "markdown",
304304
"metadata": {
305-
"id": "4CfKVmCvwcL7"
305+
"id": "21szKFY1hY7o"
306306
},
307307
"source": [
308308
"### Prepare the dataset\n"
@@ -311,7 +311,7 @@
311311
{
312312
"cell_type": "markdown",
313313
"metadata": {
314-
"id": "z9GCyPWNuOm7"
314+
"id": "hNeBO9m-hY7o"
315315
},
316316
"source": [
317317
"Load `train.avro` as TensorFlow dataset with Avro dataset API: \n"
@@ -321,7 +321,7 @@
321321
"cell_type": "code",
322322
"execution_count": null,
323323
"metadata": {
324-
"id": "nS3eTBvjt-O5"
324+
"id": "v-nbLZHKhY7o"
325325
},
326326
"outputs": [],
327327
"source": [
@@ -363,7 +363,7 @@
363363
"cell_type": "code",
364364
"execution_count": null,
365365
"metadata": {
366-
"id": "nS3eTBvjt-O5"
366+
"id": "bc9vDHyghY7p"
367367
},
368368
"outputs": [],
369369
"source": [
@@ -382,7 +382,7 @@
382382
{
383383
"cell_type": "markdown",
384384
"metadata": {
385-
"id": "IF_kYz_o2DH4"
385+
"id": "x45KolnDhY7p"
386386
},
387387
"source": [
388388
"One can also increase num_parallel_reads to expediate Avro data processing by increasing avro parse/read parallelism.\n"
@@ -392,7 +392,7 @@
392392
"cell_type": "code",
393393
"execution_count": null,
394394
"metadata": {
395-
"id": "nS3eTBvjt-O5"
395+
"id": "Z2x-gPj_hY7p"
396396
},
397397
"outputs": [],
398398
"source": [
@@ -412,7 +412,7 @@
412412
{
413413
"cell_type": "markdown",
414414
"metadata": {
415-
"id": "IF_kYz_o2DH4"
415+
"id": "6V-nwDJGhY7p"
416416
},
417417
"source": [
418418
"For detailed usage of `make_avro_record_dataset`, please refer to <a target=\"_blank\" href=\"https://www.tensorflow.org/io/api_docs/python/tfio/experimental/columnar/make_avro_record_dataset\">API doc</a>.\n"
@@ -421,7 +421,7 @@
421421
{
422422
"cell_type": "markdown",
423423
"metadata": {
424-
"id": "4CfKVmCvwcL7"
424+
"id": "vIOijGlAhY7p"
425425
},
426426
"source": [
427427
"### Train tf.keras models with Avro dataset\n",
@@ -432,7 +432,7 @@
432432
{
433433
"cell_type": "markdown",
434434
"metadata": {
435-
"id": "z9GCyPWNuOm7"
435+
"id": "s7K85D53hY7q"
436436
},
437437
"source": [
438438
"Load `train.avro` as TensorFlow dataset with Avro dataset API: \n"
@@ -442,14 +442,16 @@
442442
"cell_type": "code",
443443
"execution_count": null,
444444
"metadata": {
445-
"id": "nS3eTBvjt-O5"
445+
"id": "VFoeLwIOhY7q"
446446
},
447447
"outputs": [],
448448
"source": [
449449
"features = {\n",
450-
" 'features[*]': tfio.experimental.columnar.VarLenFeatureWithRank(dtype=tf.int32)\n",
450+
" 'features[*]': tfio.experimental.columnar.VarLenFeatureWithRank(dtype=tf.int32),\n",
451+
" 'label': tf.io.FixedLenFeature(shape=[], dtype=tf.int32, default_value=-100),\n",
451452
"}\n",
452453
"\n",
454+
"\n",
453455
"schema = tf.io.gfile.GFile('train.avsc').read()\n",
454456
"\n",
455457
"dataset = tfio.experimental.columnar.make_avro_record_dataset(file_pattern=['train.avro'],\n",
@@ -463,17 +465,17 @@
463465
{
464466
"cell_type": "markdown",
465467
"metadata": {
466-
"id": "z9GCyPWNuOm7"
468+
"id": "hR2FnIIMhY7q"
467469
},
468470
"source": [
469471
"Define a simple keras model: \n"
470472
]
471473
},
472474
{
473475
"cell_type": "code",
474-
"execution_count": 3,
476+
"execution_count": null,
475477
"metadata": {
476-
"id": "m6KXZuTBWgRm"
478+
"id": "hGV5rHfJhY7q"
477479
},
478480
"outputs": [],
479481
"source": [
@@ -488,27 +490,31 @@
488490
{
489491
"cell_type": "markdown",
490492
"metadata": {
491-
"id": "4CfKVmCvwcL7"
493+
"id": "Tuv9n6HshY7q"
492494
},
493495
"source": [
494496
"### Train the keras model with Avro dataset:\n"
495497
]
496498
},
497499
{
498500
"cell_type": "code",
499-
"execution_count": 3,
501+
"execution_count": null,
500502
"metadata": {
501-
"id": "m6KXZuTBWgRm"
503+
"id": "lb44cUuWhY7r"
502504
},
503505
"outputs": [],
504506
"source": [
505-
"model.fit(x=dataset, epochs=1, steps_per_epoch=1, verbose=1)\n"
507+
"def extract_label(feature):\n",
508+
" label = feature.pop('label')\n",
509+
" return tf.sparse.to_dense(feature['features[*]']), label\n",
510+
"\n",
511+
"model.fit(x=dataset.map(extract_label), epochs=1, steps_per_epoch=1, verbose=1)\n"
506512
]
507513
},
508514
{
509515
"cell_type": "markdown",
510516
"metadata": {
511-
"id": "IF_kYz_o2DH4"
517+
"id": "7K6qAv5rhY7r"
512518
},
513519
"source": [
514520
"The avro dataset can parse and coerce any avro data into TensorFlow tensors, including records in records, maps, arrays, branches, and enumerations. The parsing information is passed into the avro dataset implementation as a map where \n",
@@ -541,7 +547,7 @@
541547
{
542548
"cell_type": "markdown",
543549
"metadata": {
544-
"id": "IF_kYz_o2DH4"
550+
"id": "1PFQPuy5hY7r"
545551
},
546552
"source": [
547553
"A comprehensive set of examples of Avro dataset API is provided within <a target=\"_blank\" href=\"https://github.com/tensorflow/io/blob/master/tests/test_parse_avro.py#L437\">the tests</a>.\n"

0 commit comments

Comments
 (0)