diff --git a/docs/source/python_tutorials/ragged/basics.rst b/docs/source/python_tutorials/ragged/basics.rst index 68cd67ca9..b39195d18 100644 --- a/docs/source/python_tutorials/ragged/basics.rst +++ b/docs/source/python_tutorials/ragged/basics.rst @@ -13,8 +13,6 @@ In this tutorial, we describe - What is ``RaggedShape``? - What is ``row_splits`` ? - What is ``row_ids`` ? - - What is ``dim0`` ? - - What is ``tot_size`` ? What are ragged tensors? ------------------------ @@ -29,7 +27,7 @@ tensors, i.e., regular tensors, look like. :lines: 8-20 The shape of the 2-D regular tensor ``a`` is ``(3, 4)``, meaning it has 3 - rows and 4 columns. Each row has **exactly** 4 elements, no more, no less. + rows and 4 columns. Each row has **exactly** 4 elements. - 3-D regular tensors @@ -38,8 +36,8 @@ tensors, i.e., regular tensors, look like. :lines: 24-45 The shape of the 3-D regular tensor ``b`` is ``(3, 3, 2)``, meaning it has - 3 planes. Each plane has **exactly** 3 rows, no more, no less. Each row has - **exactly** two entries, no more, no less. + 3 planes. Each plane has **exactly** 3 rows and each row has **exactly** two + entries - N-D regular tensors (N >= 4) @@ -89,7 +87,7 @@ tensors in ``k2``. A ragged tensor in ``k2`` has ``N`` (``N >= 2``) axes. Unlike regular tensors, each axis of a ragged tensor can have different number of elements. -Ragged tensors are **the most important** data structures in ``k2``. FSAs are +Ragged tensors are **the most important** data structure in ``k2``. FSAs are represented as ragged tensors. There are also various operations defined on ragged tensors. @@ -113,7 +111,7 @@ Exercise 1 - Row 1 is empty, i.e., it has no elements. - Row 2 has two elements: ``-1.5, 2`` - (Click ▶ to see it) + (Click ▶ to view the solution) .. literalinclude:: code/basics/ragged-tensors.py :language: python @@ -130,11 +128,34 @@ Exercise 2 How to create a ragged tensor with only 1 axis? - (Click ▶ to see it) + (Click ▶ to view the solution) You **cannot** create a ragged tensor with only 1 axis. Ragged tensors in ``k2`` have at least 2 axes. +dtype and device +^^^^^^^^^^^^^^^^ + +Like tensors in PyTorch. ragged tensors in ``k2`` has attributes ``dtype`` and +``device``. The following code shows that you can specify the ``dtype`` and +``device`` while constructing ragged tensors. + +.. literalinclude:: code/basics/dtype-device.py + :language: python + :lines: 3-23 + +.. container:: toggle + + .. container:: header + + .. Note:: + + (Click ▶ to view the output) + + .. literalinclude:: code/basics/dtype-device.py + :language: python + :lines: 25-50 + Concepts about ragged tensors ----------------------------- @@ -144,18 +165,18 @@ A ragged tensor in ``k2`` consists of two parts: .. Caution:: - It is assumed that a shape within a ragged tensor in ``k2`` is a constant. + It is assumed that a shape within a ragged tensor in ``k2`` is a constant. Once constructed, you are not expected to modify it. Otherwise, unexpected things can happen; you will be SAD. - - ``data``, which is an **array** of type ``T`` + - ``values``, which is an **array** of type ``T`` .. Hint:: - ``data`` is stored ``contiguously`` in memory, whose entries have to be + ``values`` is stored ``contiguously`` in memory, whose entries have to be of the same type ``T``. ``T`` can be either primitive types, such as ``int``, ``float``, and ``double`` or can be user defined types. For instance, - ``data`` in FSAs contains ``arcs``, which is defined in C++ + ``values`` in FSAs contains ``arcs``, which is defined in C++ `as follows `_: .. code-block:: c++ @@ -167,8 +188,83 @@ A ragged tensor in ``k2`` consists of two parts: float score; } -In the following, we describe what is inside a ``shape`` and how to manipulate -``data``. +Before explaining what ``shape`` and ``values`` contain, let us look at an example of +how to use a ragged tensor to represent the following +FSA (see :numref:`ragged_basics_simple_fsa_1`). + +.. _ragged_basics_simple_fsa_1: +.. figure:: code/basics/images/simple-fsa.svg + :alt: A simple FSA + :align: center + :figwidth: 600px + + An simple FSA that is to be represented by a ragged tensor. + +The FSA in :numref:`ragged_basics_simple_fsa_1` has 3 arcs and 3 states. + ++---------+--------------------+--------------------+--------------------+--------------------+ +| | src_state | dst_state | label | score | ++---------+--------------------+--------------------+--------------------+--------------------+ +| Arc 0 | 0 | 1 | 1 | 0.1 | ++---------+--------------------+--------------------+--------------------+--------------------+ +| Arc 1 | 0 | 1 | 2 | 0.2 | ++---------+--------------------+--------------------+--------------------+--------------------+ +| Arc 2 | 1 | 2 | -1 | 0.3 | ++---------+--------------------+--------------------+--------------------+--------------------+ + +When the above FSA is saved in a ragged tensor, its arcs are saved in a 1-D contiguous +``values`` array containing ``[Arc0, Arc1, Arc2]``. +At this point, you might ask: + + - As we can construct the original FSA by using the ``values`` array, + what's the point of saving it in a ragged tensor? + +Using the ``values`` array alone is not possible to answer the following questions in ``O(1)`` +time: + + - How many states does the FSA have ? + - How many arcs does each state have ? + - Where do the arcs belonging to state 0 start in the ``values`` array ? + +To handle the above questions, we introduce another 1-D array, called ``row_splits``. +``row_splits[s] = p`` means for state ``s`` its first outgoing arc starts at position +``p`` in the ``values`` array. As a side effect, it also indicates that the last outgoing +arc for state ``s-1`` ends at position ``p`` (exclusive) in the ``values`` array. + +In our example, ``row_splits`` would be ``[0, 2, 3, 3]``, meaning: + + - The first outgoing arc for state 0 is at position ``row_splits[0] = 0`` + in the ``values`` array + - State 0 has ``row_splits[1] - row_splits[0] = 2 - 0 = 2`` arcs + - The first outgoing arc for state 1 is at position ``row_splits[1] = 2`` + in the ``values`` array + - State 1 has ``row_splits[2] - row_splits[1] = 3 - 2 = 1`` arc + - State 2 has no arcs since ``row_splits[3] - row_splits[2] = 3 - 3 = 0`` + - The FSA has ``len(row_splits) - 1 = 3`` states. + +We can construct a ``RaggedShape`` from a ``row_splits`` array: + +.. literalinclude:: code/basics/ragged_shape_1.py + :language: python + :lines: 3-14 + +Pay attention to the string form of the shape ``[ [x x] [x] [ ] ]``. +``x`` means we don't care about the actual content inside a ragged tensor. +The above shape has 2 axes, 3 rows, and 3 elements. Row 0 has two elements as there +are two ``x`` inside the 0th ``[]``. Row 1 has only one element, while +row 2 has no elements at all. We can assign names to the axes. In our case, +we say the shape has axes ``[state][arc]``. + +Combining the ragged shape and the ``values`` array, the above FSA can +be represented using a ragged tensor ``[ [Arc0 Arc1] [Arc2] [ ] ]``. + +The following code displays the string from of the above FSA when represented +as a ragged tensor in k2: + +.. literalinclude:: code/basics/single-fsa.py + :language: python + :lines: 2-14 + Shape ^^^^^ diff --git a/docs/source/python_tutorials/ragged/code/basics/dtype-device.py b/docs/source/python_tutorials/ragged/code/basics/dtype-device.py new file mode 100755 index 000000000..6124b0aa7 --- /dev/null +++ b/docs/source/python_tutorials/ragged/code/basics/dtype-device.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 + +import k2 +import torch + +a = k2.RaggedTensor([[1, 2], [1]]) +b = k2.RaggedTensor([[1, 2], [1]], dtype=torch.int32) +c = k2.RaggedTensor([[1, 2], [1.5]]) +d = k2.RaggedTensor([[1, 2], [1.5]], dtype=torch.float32) +e = k2.RaggedTensor([[1, 2], [1.5]], dtype=torch.float64) +f = k2.RaggedTensor([[1, 2], [1]], dtype=torch.float32, device=torch.device("cuda", 0)) +g = k2.RaggedTensor([[1, 2], [1]], device="cuda:0", dtype=torch.float64) +print(f"a:\n{a}") +print(f"b:\n{b}") +print(f"c:\n{c}") +print(f"d:\n{d}") +print(f"e:\n{e}") +print(f"f:\n{f}") +print(f"g:\n{g}") +print(f"g.to_str_simple():\n{g.to_str_simple()}") +print(f"a.dtype: {a.dtype}, g.device: {g.device}") +print(f"a.to(g.device).device: {a.to(g.device).device}") +print(f"a.to(g.dtype).dtype: {a.to(g.dtype).dtype}") +""" +a: +RaggedTensor([[1, 2], + [1]], dtype=torch.int32) +b: +RaggedTensor([[1, 2], + [1]], dtype=torch.int32) +c: +RaggedTensor([[1, 2], + [1.5]], dtype=torch.float32) +d: +RaggedTensor([[1, 2], + [1.5]], dtype=torch.float32) +e: +RaggedTensor([[1, 2], + [1.5]], dtype=torch.float64) +f: +RaggedTensor([[1, 2], + [1]], device='cuda:0', dtype=torch.float32) +g: +RaggedTensor([[1, 2], + [1]], device='cuda:0', dtype=torch.float64) +g.to_str_simple(): +RaggedTensor([[1, 2], [1]], device='cuda:0', dtype=torch.float64) +a.dtype: torch.int32, g.device: cuda:0 +a.to(g.device).device: cuda:0 +a.to(g.dtype).dtype: torch.float64 +""" diff --git a/docs/source/python_tutorials/ragged/code/basics/images/simple-fsa.svg b/docs/source/python_tutorials/ragged/code/basics/images/simple-fsa.svg new file mode 100644 index 000000000..264164859 --- /dev/null +++ b/docs/source/python_tutorials/ragged/code/basics/images/simple-fsa.svg @@ -0,0 +1,53 @@ + + + + + + +WFSA + + + +0 + +0 + + + +1 + +1 + + + +0->1 + + +1/0.1 + + + +0->1 + + +2/0.2 + + + +2 + + +2 + + + +1->2 + + +-1/0.3 + + + diff --git a/docs/source/python_tutorials/ragged/code/basics/ragged_shape_1.py b/docs/source/python_tutorials/ragged/code/basics/ragged_shape_1.py new file mode 100755 index 000000000..d695024e3 --- /dev/null +++ b/docs/source/python_tutorials/ragged/code/basics/ragged_shape_1.py @@ -0,0 +1,16 @@ +#!/usr/bin/env python3 + +import k2 +import torch + +shape = k2.ragged.create_ragged_shape2( + row_splits=torch.tensor([0, 2, 3, 3], dtype=torch.int32), +) +print(type(shape)) +print(shape) +""" + +[ [ x x ] [ x ] [ ] ] +""" +print("num_states:", shape.dim0) +print("num_arcs:", shape.numel()) diff --git a/docs/source/python_tutorials/ragged/code/basics/single-fsa.py b/docs/source/python_tutorials/ragged/code/basics/single-fsa.py new file mode 100755 index 000000000..fd99f44ac --- /dev/null +++ b/docs/source/python_tutorials/ragged/code/basics/single-fsa.py @@ -0,0 +1,23 @@ +#!/usr/bin/env python3 +import k2 + +s = """ +0 1 1 0.1 +0 1 2 0.2 +1 2 -1 0.3 +2 +""" +fsa = k2.Fsa.from_str(s) +print(fsa.arcs) +""" +[ [ 0 1 1 0.1 0 1 2 0.2 ] [ 1 2 -1 0.3 ] [ ] ] +""" + +sym_str = """ +a 1 +b 2 +""" + +# fsa.labels_sym = k2.SymbolTable.from_str(sym_str) +# fsa.draw("images/simple-fsa.svg") +# print(k2.to_dot(fsa)) diff --git a/k2/python/csrc/torch/v2/ragged_any.cu b/k2/python/csrc/torch/v2/ragged_any.cu index ad26333e7..111800083 100644 --- a/k2/python/csrc/torch/v2/ragged_any.cu +++ b/k2/python/csrc/torch/v2/ragged_any.cu @@ -335,7 +335,7 @@ std::string RaggedAny::ToString(bool compact /*=false*/, int32_t device_id /*=-1*/) const { ContextPtr context = any.Context(); if (context->GetDeviceType() != kCpu) { - return To("cpu").ToString(context->GetDeviceId()); + return To("cpu").ToString(compact, context->GetDeviceId()); } std::ostringstream os; diff --git a/k2/python/csrc/torch/v2/ragged_shape.cu b/k2/python/csrc/torch/v2/ragged_shape.cu index 81c3118ce..cb3bc8c13 100644 --- a/k2/python/csrc/torch/v2/ragged_shape.cu +++ b/k2/python/csrc/torch/v2/ragged_shape.cu @@ -232,8 +232,8 @@ void PybindRaggedShape(py::module &m) { m.def( "create_ragged_shape2", - [](torch::optional row_splits, - torch::optional row_ids, + [](torch::optional row_splits = torch::nullopt, + torch::optional row_ids = torch::nullopt, int32_t cached_tot_size = -1) -> RaggedShape { if (!row_splits.has_value() && !row_ids.has_value()) K2_LOG(FATAL) << "Both row_splits and row_ids are None"; @@ -257,7 +257,7 @@ void PybindRaggedShape(py::module &m) { row_splits.has_value() ? &array_row_splits : nullptr, row_ids.has_value() ? &array_row_ids : nullptr, cached_tot_size); }, - py::arg("row_splits"), py::arg("row_ids"), + py::arg("row_splits") = py::none(), py::arg("row_ids") = py::none(), py::arg("cached_tot_size") = -1, kCreateRaggedShape2Doc); m.def("random_ragged_shape", &RandomRaggedShape, "RandomRaggedShape",