Skip to content

Commit

Permalink
Wrap lattice to python (#50)
Browse files Browse the repository at this point in the history
* Wrap lattice to Python

* Fix style issues

* release v1.7.5
  • Loading branch information
csukuangfj authored Sep 6, 2023
1 parent 81adab6 commit 7e63f1d
Show file tree
Hide file tree
Showing 13 changed files with 161 additions and 23 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ cmake_minimum_required(VERSION 3.13 FATAL_ERROR)

project(kaldifst CXX)

set(KALDIFST_VERSION "1.7.4")
set(KALDIFST_VERSION "1.7.5")

if(NOT CMAKE_BUILD_TYPE)
message(STATUS "No CMAKE_BUILD_TYPE given, default to Release")
Expand Down
7 changes: 4 additions & 3 deletions kaldifst/csrc/lattice-weight.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,19 @@ class LatticeWeightTpl {
public:
typedef FloatType T; // normally float.
typedef LatticeWeightTpl ReverseWeight;
using ValueType = FloatType;

inline T Value1() const { return value1_; }
inline T Value1() const { return value1_; } // usually graph cost

inline T Value2() const { return value2_; }
inline T Value2() const { return value2_; } // usually acoustic cost

inline void SetValue1(T f) { value1_ = f; }

inline void SetValue2(T f) { value2_ = f; }

LatticeWeightTpl() : value1_{}, value2_{} {}

LatticeWeightTpl(T a, T b) : value1_(a), value2_(b) {}
LatticeWeightTpl(T a, T b = 0) : value1_(a), value2_(b) {} // NOLINT

LatticeWeightTpl(const LatticeWeightTpl &other)
: value1_(other.value1_), value2_(other.value2_) {}
Expand Down
1 change: 1 addition & 0 deletions kaldifst/python/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ pybind11_add_module(_kaldifst
fstrmepsilon.cc
kaldi-table.cc
kaldifst.cc
lattice-weight.cc
mutable-fst.cc
pre-determinize.cc
symbol-table.cc
Expand Down
2 changes: 2 additions & 0 deletions kaldifst/python/csrc/arc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <string>

#include "fst/arc.h"
#include "kaldifst/csrc/lattice-weight.h"

namespace kaldifst {

Expand Down Expand Up @@ -59,6 +60,7 @@ static void PybindArcImpl(py::module &m, // NOLINT

void PybindArc(py::module &m) { // NOLINT
PybindArcImpl<fst::TropicalWeight>(m, "StdArc");
PybindArcImpl<fst::LatticeWeight>(m, "LatticeArc");
}

} // namespace kaldifst
2 changes: 2 additions & 0 deletions kaldifst/python/csrc/expanded-fst.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@

#include "fst/expanded-fst.h"

#include "kaldifst/csrc/lattice-weight.h"
#include "kaldifst/python/csrc/expanded-fst.h"

namespace kaldifst {

void PybindExpandedFst(py::module &m) { // NOLINT
PybindExpandedFst<fst::StdArc>(m, "StdExpandedFst");
PybindExpandedFst<fst::LatticeArc>(m, "LatticeExpandedFst");
//
}

Expand Down
11 changes: 11 additions & 0 deletions kaldifst/python/csrc/fst.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <string>

#include "fst/fst.h"
#include "kaldifst/csrc/lattice-weight.h"

namespace kaldifst {

Expand Down Expand Up @@ -139,6 +140,16 @@ void PybindFst(py::module &m) { // NOLINT
// PybindArcIteratorImpl<fst::StdFst>(m, "_StdFstArcIterator");

PybindFstImpl<fst::StdArc>(m, "_StdFstImpl");

PybindFst<fst::LatticeArc>(
m, "LatticeFst",
"A generic FST, templated on the arc definition, with \n"
"common-demoninator methods (use StateIterator and \n"
"ArcIterator to iterate over its states and arcs).");
// PybindStateIteratorImpl<fst::StdFst>(m, "_StdFstStateIterator");
// PybindArcIteratorImpl<fst::StdFst>(m, "_StdFstArcIterator");

PybindFstImpl<fst::LatticeArc>(m, "_LatticeFstImpl");
}

} // namespace kaldifst
39 changes: 22 additions & 17 deletions kaldifst/python/csrc/fstext-utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <utility>
#include <vector>

#include "kaldifst/csrc/lattice-weight.h"
#include "kaldifst/python/csrc/fstarcsort.h"
#include "kaldifst/python/csrc/fstext-utils.h"

Expand Down Expand Up @@ -184,6 +185,19 @@ case it outputs the symbol

namespace kaldifst {

template <class Arc, class I = int32_t>
std::tuple<bool, std::vector<I>, std::vector<I>, typename Arc::Weight>
GetLinearSymbolSequenceWrapper(const fst::Fst<Arc> &fst) {
std::vector<I> isymbols_out;
std::vector<I> osymbols_out;
typename Arc::Weight w;

bool succeeded =
GetLinearSymbolSequence(fst, &isymbols_out, &osymbols_out, &w);

return std::make_tuple(succeeded, isymbols_out, osymbols_out, w);
}

void PybindFstExtUtils(py::module &m) { // NOLINT
m.def(
"minimize_encoded",
Expand Down Expand Up @@ -212,23 +226,14 @@ void PybindFstExtUtils(py::module &m) { // NOLINT
},
py::arg("ifst"), py::arg("length"), py::arg("rand_seed"),
py::arg("num_retries") = 10, kEqualAlignDoc);
m.def(
"get_linear_symbol_sequence",
[](const fst::StdFst &fst) -> std::tuple<bool, std::vector<int32_t>,
std::vector<int32_t>, float> {
std::vector<int32_t> isymbols_out;
std::vector<int32_t> osymbols_out;
float total_weight_out;
fst::TropicalWeight w;

bool succeeded =
GetLinearSymbolSequence(fst, &isymbols_out, &osymbols_out, &w);
total_weight_out = w.Value();

return std::make_tuple(succeeded, isymbols_out, osymbols_out,
total_weight_out);
},
py::arg("fst"), kGetLinearSymbolSequenceDoc);

m.def("get_linear_symbol_sequence",
&(GetLinearSymbolSequenceWrapper<fst::StdArc>), py::arg("fst"),
kGetLinearSymbolSequenceDoc);

m.def("get_linear_symbol_sequence",
&(GetLinearSymbolSequenceWrapper<fst::LatticeArc>), py::arg("fst"),
kGetLinearSymbolSequenceDoc);
}

} // namespace kaldifst
2 changes: 2 additions & 0 deletions kaldifst/python/csrc/kaldifst.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "kaldifst/python/csrc/fstreverse.h"
#include "kaldifst/python/csrc/fstrmepsilon.h"
#include "kaldifst/python/csrc/kaldi-table.h"
#include "kaldifst/python/csrc/lattice-weight.h"
#include "kaldifst/python/csrc/mutable-fst.h"
#include "kaldifst/python/csrc/pre-determinize.h"
#include "kaldifst/python/csrc/symbol-table.h"
Expand All @@ -35,6 +36,7 @@ PYBIND11_MODULE(_kaldifst, m) {
m.doc() = "Python wrapper for kaldifst";

PybindFloatWeight(m);
PybindLatticeWeight(&m);
PybindArc(m);
PybindSymbolTable(m);
PybindFst(m);
Expand Down
58 changes: 58 additions & 0 deletions kaldifst/python/csrc/lattice-weight.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// kaldifst/python/csrc/lattice-weight.cc
//
// Copyright (c) 2021-2023 Xiaomi Corporation (authors: Fangjun Kuang)

#include "kaldifst/python/csrc/lattice-weight.h"

#include "kaldifst/csrc/lattice-weight.h"

namespace kaldifst {

void PybindLatticeWeight(py::module *m) {
using PyClass = fst::LatticeWeight;
py::class_<PyClass>(*m, "LatticeWeight")
.def(py::init<>())
.def(py::init<float, float>(), py::arg("graph_cost"),
py::arg("acoustic_cost"))
.def_property_readonly_static("zero",
[](py::object) { return PyClass::Zero(); })
.def_property_readonly_static("one",
[](py::object) { return PyClass::One(); })
.def_property_readonly_static(
"no_weight", [](py::object) { return PyClass::NoWeight(); })
.def_property_readonly_static(
"type", [](py::object) { return PyClass::Type(); },
py::return_value_policy::reference)
.def("member", &PyClass::Member)
.def("quantize", &PyClass::Quantize, py::arg("delta") = fst::kDelta)
.def_property_readonly_static(
"properties", [](py::object) { return PyClass::Properties(); })
.def("hash", &PyClass::Hash)
.def("__eq__",
[](const PyClass &w1, const PyClass &w2) { return w1 == w2; })
.def("__ne__",
[](const PyClass &w1, const PyClass &w2) { return w1 != w2; })
.def("__str__", [](const PyClass &w) {
std::ostringstream os;
os << w.Value1() << ", " << w.Value2();
return os.str();
});

m->def("plus", [](const PyClass &w1, const PyClass &w2) {
return fst::Plus(w1, w2);
});

m->def("times", [](const PyClass &w1, const PyClass &w2) {
return fst::Times(w1, w2);
});

m->def("divide", [](const PyClass &w1, const PyClass &w2) {
return fst::Divide(w1, w2);
});

m->def("approx_equal", [](const PyClass &w1, const PyClass &w2) {
return fst::ApproxEqual(w1, w2);
});
}

} // namespace kaldifst
16 changes: 16 additions & 0 deletions kaldifst/python/csrc/lattice-weight.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// kaldifst/python/csrc/lattice-weight.h
//
// Copyright (c) 2021-2023 Xiaomi Corporation (authors: Fangjun Kuang)

#ifndef KALDIFST_PYTHON_CSRC_LATTICE_WEIGHT_H_
#define KALDIFST_PYTHON_CSRC_LATTICE_WEIGHT_H_

#include "kaldifst/python/csrc/kaldifst.h"

namespace kaldifst {

void PybindLatticeWeight(py::module *m);

} // namespace kaldifst

#endif // KALDIFST_PYTHON_CSRC_LATTICE_WEIGHT_H_
3 changes: 3 additions & 0 deletions kaldifst/python/csrc/mutable-fst.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@

#include "kaldifst/python/csrc/mutable-fst.h"

#include "kaldifst/csrc/lattice-weight.h"

namespace kaldifst {

void PybindMutableFst(py::module &m) { // NOLINT
PybindMutableFst<fst::StdArc>(m, "StdMutableFst");
PybindMutableFst<fst::LatticeArc>(m, "LatticeMutableFst");
}

} // namespace kaldifst
36 changes: 35 additions & 1 deletion kaldifst/python/csrc/vector-fst.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@

#include "kaldifst/python/csrc/vector-fst.h"

#include "kaldifst/csrc/lattice-weight.h"
#include "kaldifst/python/csrc/expanded-fst.h"
#include "kaldifst/python/csrc/fst.h"
#include "kaldifst/python/csrc/mutable-fst.h"

namespace kaldifst {

void PybindVectorFst(py::module &m) { // NOLINT
static void PybindStdVectorFst(py::module &m) { // NOLINT
PybindVectorState<fst::StdArc>(m, "_StdVectorState");
PybindVectorFstBaseImpl<fst::VectorState<fst::StdArc>>(
m, "_StdVectorFstBaseImpl");
Expand All @@ -34,4 +35,37 @@ void PybindVectorFst(py::module &m) { // NOLINT
m, "_ArcIteratorStdVectorFst");
}

static void PybindLattice(py::module &m) { // NOLINT
PybindVectorState<fst::LatticeArc>(m, "_LatticeVectorState");
PybindVectorFstBaseImpl<fst::VectorState<fst::LatticeArc>>(
m, "_LatticeVectorFstBaseImpl");
PybindVectorFstImpl<fst::VectorState<fst::LatticeArc>>(
m, "_LatticeVectorFstImpl");

PybindImplToFst<
fst::internal::VectorFstImpl<fst::VectorState<fst::LatticeArc>>,
fst::MutableFst<fst::LatticeArc>>(m, "_LatticeImplToFst");

PybindImplToExpandedFst<
fst::internal::VectorFstImpl<fst::VectorState<fst::LatticeArc>>,
fst::MutableFst<fst::LatticeArc>>(m, "_LatticeImplToExpandedFst");

PybindImplToMutableFst<
fst::internal::VectorFstImpl<fst::VectorState<fst::LatticeArc>>>(
m, "_LatticeImplToMutableFst");

PybindVectorFst<fst::LatticeArc>(m, "Lattice");
PybindStateIteratorVectorFst<fst::LatticeArc,
fst::VectorState<fst::LatticeArc>>(
m, "_StateIteratorLattice");
PybindArcIteratorVectorFst<fst::LatticeArc,
fst::VectorState<fst::LatticeArc>>(
m, "_ArcIteratorLattice");
}

void PybindVectorFst(py::module &m) { // NOLINT
PybindStdVectorFst(m);
PybindLattice(m);
}

} // namespace kaldifst
5 changes: 4 additions & 1 deletion kaldifst/python/kaldifst/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from _kaldifst import (
StdFst,
FloatWeight,
Lattice,
LatticeArc,
LatticeWeight,
StdArc,
StdConstFst,
StdFst,
StdVectorFst,
SymbolTable,
TropicalWeight,
Expand Down

0 comments on commit 7e63f1d

Please sign in to comment.