From b9ffae3bde5d3d40e3c85187871b1c11c4fa7048 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 5 May 2020 17:01:11 +0800 Subject: [PATCH 01/14] Added interface for auxiliary labels --- k2/csrc/aux_labels.h | 104 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 104 insertions(+) create mode 100644 k2/csrc/aux_labels.h diff --git a/k2/csrc/aux_labels.h b/k2/csrc/aux_labels.h new file mode 100644 index 000000000..e6ab199c8 --- /dev/null +++ b/k2/csrc/aux_labels.h @@ -0,0 +1,104 @@ +// k2/csrc/aux_labels.h + +// Copyright (c) 2020 Xiaomi Corporation (author: Daniel Povey) + +// See ../../LICENSE for clarification regarding multiple authors + +#ifndef K2_CSRC_AUX_LABELS_H_ +#define K2_CSRC_AUX_LABELS_H_ + +#include + +#include "k2/csrc/fsa.h" +#include "k2/csrc/fsa_util.h" +#include "k2/csrc/properties.h" + +namespace k2 { + +/* + This header contains utilities for dealing with auxiliary labels on FSAs. + + These auxiliary labels can be used where you really have a transducer (i.e. to + store the ilabels or olabels, whichever of the two is not participating + directly in the operation you are doing). + + We deal with two formats of labels: a vector of int32_t, one per arc, + for cases where we know we have at most one label per arc; and + struct AuxLabels for cases where there may in general be a sequence + of labels (one per arc). + +*/ + + +/* + This allows you to store auxiliary labels (e.g. olabels or ilabels) + on each arc of an Fsa. + */ +struct AuxLabels { + /* Suppose this is associated with an Fsa f. start_pos will be of + size f.arcs.size() + 1; start_pos[i] is the start position in + `labels` of the label sequence on arc i. start_pos.end() + equals labels.size(). */ + std::vector start_pos; + /* For arc i, (labels[start_pos[i] ], labels[start_pos[i]+1], ... labels[start_pos[i+1]-1]) + are the list of labels on that arc. None of the elements of `labels` are + expected to be zero (epsilon). */ + std::vector labels; +}; + + +/* + Maps auxiliary labels after an FSA operation where each arc in the output + FSA corresponds to exactly one arc in the input FSA. + @param [in] labels_in Labels on the arcs of the input FSA + @param [in] arc_map Vector of size (output_fsa.arcs.size()), + saying which arc of the input FSA it + corresponds to. + @param [in] labels_out Labels on the arcs of the output FSA + */ +void MapAuxLabels1(const AuxLabels &labels_in, + const std::vector &arc_map, + AuxLabels *labels_out); + +/* + Maps auxiliary labels after an FSA operation where each arc in the output + FSA can correspond to a sequence of arcs in the input FSA. + @param [in] labels_in Labels on the arcs of the input FSA + @param [in] arc_map Vector of size (output_fsa.arcs.size()), + giving the sequence of arc-indexes in the input + FSA that it corresponds to. + @param [in] labels_out Labels on the arcs of the output FSA + */ +void MapAuxLabels2(const AuxLabels &labels_in, + const std::vector > &arc_map, + AuxLabels *labels_out); + + +/* + Invert an FST, swapping the symbols in the FSA with the auxiliary labels. + (e.g. swap input and output symbols in FST, but you decide which is which). + Because each arc may have more than one auxiliary label, in general + the output FSA may have more states than the input FSA. + + @param [in] fsa_in Input FSA + @param [in] labels_in Input aux-label sequences, one for each arc in + fsa_in + @param [out] fsa_out Output FSA. Will have a number of states + >= that in fsa_in. If fsa_in was top-sorted it + will be top-sorted. Labels in the FSA will + correspond to those in `labels_in`. + @param [out] aux_labels_out Auxiliary labels on the arcs of + fsa_out. Will be the same as the labels on + `fsa_in`, although epsilons (kEpsilon, zeros) will be + removed. + */ +void InvertFst(const Fsa &fsa_in, + const AuxLabels &labels_in, + Fsa *fsa_out, + AuxLabels *aux_labels_out); + + + +} // namespace k2 + +#endif // K2_CSRC_AUX_LABELS_H_ From a9e17687d61139d97d1f023d35b12731a2a61f3d Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 5 May 2020 20:42:00 +0800 Subject: [PATCH 02/14] Add notes on Python interface --- k2/csrc/fsa.h | 28 +++++++++++++++ k2/csrc/fsa_util.h | 30 ++++++++++++++++ k2/csrc/weights.h | 2 ++ notes/python.txt | 87 ++++++++++++++++++++++++++++++++++++++++++++++ notes/training.txt | 2 -- 5 files changed, 147 insertions(+), 2 deletions(-) create mode 100644 notes/python.txt diff --git a/k2/csrc/fsa.h b/k2/csrc/fsa.h index 93c0d5cdb..b19fbbba4 100644 --- a/k2/csrc/fsa.h +++ b/k2/csrc/fsa.h @@ -123,6 +123,34 @@ struct Fst { std::vector aux_label; }; +/* + This demonstrates an interface for a deterministic FSA or FST; it's similar + to Kaldi's DeterministicOnDemandFst class. It can be used for things like + language models. Actually we'll template on types like this. There is no + need to actually inherit from this class. */ +class DeterministicGenericFsa { + + int32_t Start(); + + + bool LookupArc(int32_t cur_state, + int32_t label, + int32_t *arc_index); + + + float GetWeightForArc(int32_t arc_index); + + int32_t GetLabelForArc(int32_t arc_index); + + int32_t GetPrevStateForArc(int32_t arc_index); + + int32_t GetNextStateForArc(int32_t arc_index); + + // Specific subclasses of this may have additional functions, e.g. + int32_t GetOlabelForArc(int32_t arc_index); + +}; + using FsaVec = std::vector; using FstVec = std::vector; using DenseFsaVec = std::vector; diff --git a/k2/csrc/fsa_util.h b/k2/csrc/fsa_util.h index 879535e8b..9e6c8cc9d 100644 --- a/k2/csrc/fsa_util.h +++ b/k2/csrc/fsa_util.h @@ -35,6 +35,36 @@ namespace k2 { void GetEnteringArcs(const Fsa &fsa, std::vector *arc_index, std::vector *end_index); +/* + Convert indexes (typically arc-mapping indexes, e.g. as output by Compose()) + from int32 to long int; this will be needed for conversion to LongTensor. + */ +void ConvertIndexes1(const std::vector &arc_map, + long int *indexes_out); + +/* + Convert indexes (typically arc-mapping indexes, e.g. as output by RmEpsilonPruned()) + from int32 to long int; this will be needed for conversion to LongTensor. + + This version is for when each arc of the output FSA may correspond to a + sequence of arcs in the input FSA. + + @param [in] arc_map Indexed by arc-index in the output FSA, the + sequence of arc-indexes in the input FSA that + it corresponds to + @param [out] indexes1 This vector, of length equal to the + total number of int32's in arc_map, will contain + arc-indexes in the input FSA + @param [out] indexes2 This vector, also of length equal to the + total number of int32's in arc_map, will contain + arc-indexes in the output FSA + */ +void GetArcIndexes2(const std::vector > &arc_map, + std::vector *indexes1, + std::vector *indexes2); + + + } // namespace k2 #endif // K2_CSRC_FSA_UTIL_H_ diff --git a/k2/csrc/weights.h b/k2/csrc/weights.h index a7ca6e83c..94d91c4f5 100644 --- a/k2/csrc/weights.h +++ b/k2/csrc/weights.h @@ -83,6 +83,8 @@ struct WfsaWithFbWeights { */ WfsaWithFbWeights(const Fsa *fsa, const float *arc_weights, FbWeightType t); + + private: std::vector mem_; }; diff --git a/notes/python.txt b/notes/python.txt new file mode 100644 index 000000000..8c5a6b95b --- /dev/null +++ b/notes/python.txt @@ -0,0 +1,87 @@ + + These are some notes regarding how we'll interact with Python and PyTorch. + + + + # For composition we can rely on PyTorch's inbuilt autograd + def Compose(a: FsaVec, a_weights: Tensor, + a: FsaVec, b_weights: Tensor): + c, input_indexes1, input_indexes2 = fsa.FsaVecCompose(a, b) + + c_weights = a_weights[input_indexes1] + b_weights[input_indexes2] + + # Handle transducers: + if a.aux_symbol != None: + c.aux_symbol = a.aux_symbol[input_indexes1] + if b.aux_symbol != None: + c.aux_symbol = b.aux_symbol[input_indexes2] + + return c, c_weights + + def RmEpsilon(a: FsaVec, a_weights, a_aux_symbols = None): + # At the C++ level, RmEpsilon outputs a `vector > indexes` that + # says, for each arc in the output, what list of arcs in the input it corresponds + # to. For exposition purposes, imagine we're dealing with a single FSA, not + # a vector of FSAs. Suppose indexes == [ [ 1, 2 ], [ 6, 8, 9 ] ], we'd form + # indexes1 = [ 1, 2, 6, 8, 9 ], and indexes2 = [ 0, 0, 1, 1, 1 ]. + + b, indexes1, indexes2 = fsa.FsaVecRmEpsilon(a) + + # Note: the 1st dim of a_weights must equal a.num_arcs, but it is allowed + # to have other dims (e.g. for non-scalar weights). + # In the normal case, a_weights and b_weights will have just one axis. + + b_weights = torch.zeros( (b.num_arcs, *a_weights.shape[1:]) ) + b_weights._index_add(0, a_weights[indexes1], indexes2) + + # If we later need access to indexes1 and indexes2, we can + # create a different version of this function or extend its interface. + + if a_aux_symbols is None: + return b, b_weights + else: + return b, b_weights, fsa.MapAuxSymbols(a_aux_symbols, indexes1, indexes2) + + + + + # Composing with transducers. Assumes that A is an acceptor but B may + # have auxiliary symbols. + def TransducerCompose(a: FsaVec, a_weights: Tensor, + a: FsaVec, b_weights: Tensor, + b_aux_symbols = None): + c, indexes_a, indexes_b = fsa.FsaVecCompose(a, b) + + c_weights = a_weights[indexes_a] + b_weights[indexes_b] + if b_aux_symbols is None: + return c, c_weights + else: + return c, c_weights, fsa.MapAuxSymbols(b_aux_symbols, indexes_b) + + + class TotalWeight(Function): + """ + Returns the total weight of FSAs (i.e. the log-sum-exp across + paths) as a Tensor with shapee (num_fsas,) + """ + + @staticmethod + def forward(ctx, a:FsaVec, a_weights: Tensor): + ctx.a = a + ctx.fb = fsa.FsaVecWithFbWeights(a, a_weights, fsa.kLogSumWeight) + ans = fb.GetTotalWeights() # a Tensor + return ans + + def backward(ctx, grad_out): + # `indexes` is a Tensor that would contain, for each arc in a, + # the index of the FSA in the FSAVec it belongs to. + # It would be of the form [ 0, 0, 0 .., 0, 1, 1, 1, .. 1, 2 ... ] + indexes = fsa.GetFsaVecIndexes(ctx.a) + + # GetArcProbs would return the probability of traversing each arc of + # each FSA, as a single Tensor. + return ctx.fb.GetArcProbs() * grad_out[indexes] + + # TODO: handle transfers to/from GPU in case grad_out was on GPU. + # Maybe mark this only once differentiable (it's twice differentiable, + # I think, but this code doesn't currently support that). diff --git a/notes/training.txt b/notes/training.txt index 0b87213d4..4588e9718 100644 --- a/notes/training.txt +++ b/notes/training.txt @@ -67,8 +67,6 @@ Raw training data is a list of tuples automatically backprop derivatives w.r.t. the output k2.wsfsavec back to the source weights. - - Then we compute the per-sequence objective functions from the ctc lattices via: ctc_objfs = k2.ForwardProb(ctc_lattices) From 9f764588e5e4569187f01677cde9d9a8daa6a3c3 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 6 May 2020 12:07:46 +0800 Subject: [PATCH 03/14] Fix typos --- notes/python.txt | 35 ++++++++++++++++++----------------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/notes/python.txt b/notes/python.txt index 8c5a6b95b..274e0ef87 100644 --- a/notes/python.txt +++ b/notes/python.txt @@ -3,9 +3,24 @@ - # For composition we can rely on PyTorch's inbuilt autograd + + + # Assumes that A is an acceptor but B may + # have auxiliary symbols (i.e. may be a transducer). + def TransducerCompose(a: FsaVec, a_weights: Tensor, + b: FsaVec, b_weights: Tensor, + b_aux_symbols = None): + c, indexes_a, indexes_b = fsa.FsaVecCompose(a, b) + + c_weights = a_weights[indexes_a] + b_weights[indexes_b] + if b_aux_symbols is None: + return c, c_weights + else: + return c, c_weights, fsa.MapAuxSymbols(b_aux_symbols, indexes_b) + + def Compose(a: FsaVec, a_weights: Tensor, - a: FsaVec, b_weights: Tensor): + b: FsaVec, b_weights: Tensor): c, input_indexes1, input_indexes2 = fsa.FsaVecCompose(a, b) c_weights = a_weights[input_indexes1] + b_weights[input_indexes2] @@ -45,24 +60,10 @@ - # Composing with transducers. Assumes that A is an acceptor but B may - # have auxiliary symbols. - def TransducerCompose(a: FsaVec, a_weights: Tensor, - a: FsaVec, b_weights: Tensor, - b_aux_symbols = None): - c, indexes_a, indexes_b = fsa.FsaVecCompose(a, b) - - c_weights = a_weights[indexes_a] + b_weights[indexes_b] - if b_aux_symbols is None: - return c, c_weights - else: - return c, c_weights, fsa.MapAuxSymbols(b_aux_symbols, indexes_b) - - class TotalWeight(Function): """ Returns the total weight of FSAs (i.e. the log-sum-exp across - paths) as a Tensor with shapee (num_fsas,) + paths) as a Tensor with shape (num_fsas,) """ @staticmethod From 10e6712a4710d228c5a16d1f0ec4e4dd1623c1da Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 6 May 2020 12:22:25 +0800 Subject: [PATCH 04/14] Fix conflicts, remove some typedefs --- k2/csrc/fsa.h | 5 ++--- k2/csrc/fsa_algo.cc | 8 +++++--- k2/csrc/fsa_util.h | 1 - 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/k2/csrc/fsa.h b/k2/csrc/fsa.h index f0338771f..b50a09b5d 100644 --- a/k2/csrc/fsa.h +++ b/k2/csrc/fsa.h @@ -130,7 +130,7 @@ struct Fsa { weights[t,n]. */ struct DenseFsa { - Weight *weights; // Would typically be a log-prob or unnormalized log-prob + float *weights; // Would typically be a log-prob or unnormalized log-prob int32_t T; // The number of time steps == rows in the matrix `weights`; // this FSA has T + 2 states, see explanation above. int32_t num_symbols; // The number of symbols == columns in the matrix @@ -144,7 +144,7 @@ struct DenseFsa { CAUTION: we may later enforce that stride == num_symbols, in order to be able to know the layout of a phantom matrix of arcs. (?) */ - DenseFsa(Weight *data, int32_t T, int32_t num_symbols, int32_t stride); + DenseFsa(float *data, int32_t T, int32_t num_symbols, int32_t stride); }; struct Fst { @@ -152,7 +152,6 @@ struct Fst { std::vector aux_label; }; -<<<<<<< HEAD /* This demonstrates an interface for a deterministic FSA or FST; it's similar to Kaldi's DeterministicOnDemandFst class. It can be used for things like diff --git a/k2/csrc/fsa_algo.cc b/k2/csrc/fsa_algo.cc index aaa458bd0..a20e5938a 100644 --- a/k2/csrc/fsa_algo.cc +++ b/k2/csrc/fsa_algo.cc @@ -31,10 +31,12 @@ struct DfsState { int32_t arc_end; // end of the arc index of the visiting node }; +using StatePair = std::pair; + inline int32_t InsertIntersectionState( - const k2::StatePair &new_state, int32_t *state_index_c, - std::queue *qstates, - std::unordered_map *state_pair_map) { + const StatePair &new_state, int32_t *state_index_c, + std::queue *qstates, + std::unordered_map *state_pair_map) { auto result = state_pair_map->insert({new_state, *state_index_c + 1}); if (result.second) { // we have not visited `new_state` before. diff --git a/k2/csrc/fsa_util.h b/k2/csrc/fsa_util.h index f0522fe64..9719921fe 100644 --- a/k2/csrc/fsa_util.h +++ b/k2/csrc/fsa_util.h @@ -37,7 +37,6 @@ namespace k2 { void GetEnteringArcs(const Fsa &fsa, std::vector *arc_index, std::vector *end_index); -<<<<<<< HEAD /* Convert indexes (typically arc-mapping indexes, e.g. as output by Compose()) from int32 to long int; this will be needed for conversion to LongTensor. From 735f83df567447ccb94b8141994cdbac570164dc Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 6 May 2020 12:45:37 +0800 Subject: [PATCH 05/14] Fixes from review --- k2/csrc/aux_labels.h | 26 +++++++++++++------------- k2/csrc/determinize.cc | 2 +- k2/csrc/fsa.h | 2 +- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/k2/csrc/aux_labels.h b/k2/csrc/aux_labels.h index 104167424..e6ab199c8 100644 --- a/k2/csrc/aux_labels.h +++ b/k2/csrc/aux_labels.h @@ -24,7 +24,7 @@ namespace k2 { We deal with two formats of labels: a vector of int32_t, one per arc, for cases where we know we have at most one label per arc; and - struct Auxint32_ts for cases where there may in general be a sequence + struct AuxLabels for cases where there may in general be a sequence of labels (one per arc). */ @@ -34,7 +34,7 @@ namespace k2 { This allows you to store auxiliary labels (e.g. olabels or ilabels) on each arc of an Fsa. */ -struct Auxint32_ts { +struct AuxLabels { /* Suppose this is associated with an Fsa f. start_pos will be of size f.arcs.size() + 1; start_pos[i] is the start position in `labels` of the label sequence on arc i. start_pos.end() @@ -50,28 +50,28 @@ struct Auxint32_ts { /* Maps auxiliary labels after an FSA operation where each arc in the output FSA corresponds to exactly one arc in the input FSA. - @param [in] labels_in int32_ts on the arcs of the input FSA + @param [in] labels_in Labels on the arcs of the input FSA @param [in] arc_map Vector of size (output_fsa.arcs.size()), saying which arc of the input FSA it corresponds to. - @param [in] labels_out int32_ts on the arcs of the output FSA + @param [in] labels_out Labels on the arcs of the output FSA */ -void MapAuxint32_ts1(const Auxint32_ts &labels_in, +void MapAuxLabels1(const AuxLabels &labels_in, const std::vector &arc_map, - Auxint32_ts *labels_out); + AuxLabels *labels_out); /* Maps auxiliary labels after an FSA operation where each arc in the output FSA can correspond to a sequence of arcs in the input FSA. - @param [in] labels_in int32_ts on the arcs of the input FSA + @param [in] labels_in Labels on the arcs of the input FSA @param [in] arc_map Vector of size (output_fsa.arcs.size()), giving the sequence of arc-indexes in the input FSA that it corresponds to. - @param [in] labels_out int32_ts on the arcs of the output FSA + @param [in] labels_out Labels on the arcs of the output FSA */ -void MapAuxint32_ts2(const Auxint32_ts &labels_in, +void MapAuxLabels2(const AuxLabels &labels_in, const std::vector > &arc_map, - Auxint32_ts *labels_out); + AuxLabels *labels_out); /* @@ -85,7 +85,7 @@ void MapAuxint32_ts2(const Auxint32_ts &labels_in, fsa_in @param [out] fsa_out Output FSA. Will have a number of states >= that in fsa_in. If fsa_in was top-sorted it - will be top-sorted. int32_ts in the FSA will + will be top-sorted. Labels in the FSA will correspond to those in `labels_in`. @param [out] aux_labels_out Auxiliary labels on the arcs of fsa_out. Will be the same as the labels on @@ -93,9 +93,9 @@ void MapAuxint32_ts2(const Auxint32_ts &labels_in, removed. */ void InvertFst(const Fsa &fsa_in, - const Auxint32_ts &labels_in, + const AuxLabels &labels_in, Fsa *fsa_out, - Auxint32_ts *aux_labels_out); + AuxLabels *aux_labels_out); diff --git a/k2/csrc/determinize.cc b/k2/csrc/determinize.cc index 86fbc4600..e7a0b05ff 100644 --- a/k2/csrc/determinize.cc +++ b/k2/csrc/determinize.cc @@ -168,7 +168,7 @@ class DetStateMap { } } - size_t size() const { return cur_output_state_; } + int32_t size() const { return cur_output_state_; } private: diff --git a/k2/csrc/fsa.h b/k2/csrc/fsa.h index b50a09b5d..3aec4f33d 100644 --- a/k2/csrc/fsa.h +++ b/k2/csrc/fsa.h @@ -158,7 +158,7 @@ struct Fst { language models. Actually we'll template on types like this. There is no need to actually inherit from this class. */ class DeterministicGenericFsa { - + public: int32_t Start(); From 80339a50bee9d665548e06a212f3b8026816b415 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 6 May 2020 12:58:18 +0800 Subject: [PATCH 06/14] Small fixes in determinization code --- k2/csrc/determinize.cc | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/k2/csrc/determinize.cc b/k2/csrc/determinize.cc index e7a0b05ff..7c74c53e7 100644 --- a/k2/csrc/determinize.cc +++ b/k2/csrc/determinize.cc @@ -1,4 +1,4 @@ -// k2/csrc/fsa_algo.cc +// k2/csrc/determinize.cc // Copyright (c) 2020 Xiaomi Corporation (authors: Daniel Povey dpove@gmail.com, Haowen Qiu qindazhu@gmail.com) @@ -39,8 +39,7 @@ struct DetStateElement { bool operator < (const DetStateElement &other) const { if (weight < other.weight) return true; else if (weight > other.weight) return false; - - + // TODO. } }; @@ -178,7 +177,11 @@ class DetStateMap { /* Turns DetState into a compact form of 128 bits. Technically there could be collisions, which would be fatal for the algorithm, but this is one of those lifetime-of-the-universe type of things (kind of like - the theoretical potential for git hash collision) that we ignore. */ + the theoretical potential for git hash collision) that we ignore. + + The normalized form + + */ void DetStateToCompact(const DetState &d, std::pair *vec) { assert(d.normalized); @@ -191,9 +194,10 @@ class DetStateMap { // matter which element we choose to trace back. DetStateElement *elem = d.head; int32_t seq_len = d.seq_len; - for (int32_t i = 0; i < seq_len; i++) { + for (int32_t i = 0; i < seq_len; ++i) { a = elem->symbol + 102299 * a; b = elem->symbol + 102983 * b; + elem = elem->parent } vec->first = a; vec->second = b; @@ -222,7 +226,7 @@ void DeterminizeMax(const WfsaWithFbWeights &a, return; } float cutoff = a.backward_state_weights[0] - beam; - + // TODO. } From 3afad81a46f9f96e44693372bdea9853ea8cb6e6 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 8 May 2020 12:08:51 +0800 Subject: [PATCH 07/14] Progress on determinization code; add new declarations of un-pruned functions --- k2/csrc/determinize.cc | 335 +++++++++++++++++++++++++++++++++++++---- k2/csrc/fsa.h | 2 +- k2/csrc/fsa_algo.h | 65 ++++++-- k2/csrc/weights.h | 9 +- 4 files changed, 368 insertions(+), 43 deletions(-) diff --git a/k2/csrc/determinize.cc b/k2/csrc/determinize.cc index 7c74c53e7..a2737fb98 100644 --- a/k2/csrc/determinize.cc +++ b/k2/csrc/determinize.cc @@ -8,44 +8,133 @@ #include #include +#include namespace k2 { +using std::shared_ptr; +using std::vector; +using std::priority_queue; +using std::pair -struct DetStateElement { - // Element of the doubly linked list whose start/end are - // members 'head' and 'tail' of DetState. +struct MaxTracebackState { + // Element of a path from the start state to some state in an FSA // We can trace back the `parent` links, which will take // us backward along a path in the original FSA. - DetStateElement *parent = nullptr; - int32_t arc_index; // Index of most recent arc in path to the dest-state. - // This data-structure represents a path through the FSA, - // with this arc being the most recent arc on that path. + std::shared_ptr prev; + + int32_t arc_index; // Index of most recent arc in path from start-state to + // the dest-state, or -1 if the path is empty (only + // possible if this element belongs to the start-state). + + int32_t symbol; // Symbol on the arc numbered `arc_index` of the input FSA + // (copied here for convenience), or 0 if arc_index == -1. + + MaxTracebackState(std::shared_ptr prev, + int32_t arc_index, int32_t symbol): + prev(prev), arc_index(arc_index), symbol(symbol) { } + +}; + + +class LogSumTracebackState; + +// This struct is used inside LogSumTracebackState; it represents an +// arc that traces back to a previous LogSumTracebackState. +// A LogSumTracebackState represents a weighted colletion of paths +// terminating in a specific state. +struct LogSumTracebackLink { + + int32_t arc_index; // Index of most recent arc in path from start-state to + // the dest-state, or -1 if the path is empty (only + // possible if this element belongs to the start-state). + int32_t symbol; // Symbol on the arc numbered `arc_index` of the input FSA - // (copied here for convenience). + // (copied here for convenience), or 0 if arc_index == -1. + + double prob; // The probability mass associated with this incoming + // arc in the LogSumTracebackState to which this belongs. + + std::shared_ptr prev_state; +}; + +struct LogSumTracebackState { + // LogSumTracebackState can be thought of as as a weighted set of paths from the + // start state to a particular state. (It will be limited to the subset of + // paths that have a specific symbol sequence). + + // `prev_elements` is, conceptually, a list of pairs (incoming arc-index, + // traceback link); we will keep it free of duplicates of the same incoming + // arc. + vector prev_elements; + + + int32_t arc_index; // Index of most recent arc in path from start-state to + // the dest-state, or -1 if the path is empty (only + // possible if this element belongs to the start-state). + + int32_t symbol; // Symbol on the arc numbered `arc_index` of the input FSA + // (copied here for convenience), or 0 if arc_index == -1. + + MaxTracebackState(std::shared_ptr prev, + int32_t arc_index, int32_t symbol): + prev(prev), arc_index(arc_index), symbol(symbol) { } + +}; + + +struct DetStateElement { double weight; // Weight from reference state to this state, along - // the path taken by following the 'parent' links + // the path taken by following the 'prev' links // (the path would have `seq_len` arcs in it). // Note: by "this state" we mean the destination-state of // the arc at `arc_index`. + // Interpret this with caution, because the + // base state, and the length of the sequence arcs from the + // base state to here, are known only in the DetState + // that owns this DetStateElement. - // `prev` and `next` form the doubly linked list of DetStateElement - DetStateElement *prev = nullptr; - DetStateElement *next = nullptr; + std::shared_ptr path; + // The path from the start state to here (actually we will + // only follow back `seq_len` links. Will be nullptr if + // seq_len == 0 and this belongs to the initial determinized + // state. - // This comparator function compares the weights, but is careful in case of - // ties to ensure deterministic behavior. - bool operator < (const DetStateElement &other) const { - if (weight < other.weight) return true; - else if (weight > other.weight) return false; - // TODO. + + DetStateElement &&Advance(float arc_weight, int32_t arc_index, int32_t arc_symbol) { + return DetStateElement(weight + arc_weight, + std::make_shared(path, arc_index, arc_symbol)); } + DetStateElement(double weight, std::shared_ptr &&path): + weight(weight), path(path) { } + +}; + +class DetState; + + +struct DetStateCompare { + // Comparator for priority queue. Less-than operator that compares + // forward_backward_weight for best-first processing. + bool operator()(const shared_ptr &a, + const shared_ptr &b); }; +class Determinizer { + public: + private: + + using DetStatePriorityQueue = priority_queue, + vector >, + DetStateCompare>; + + +}; + /* @@ -64,10 +153,23 @@ struct DetStateElement { it will just give us an output that's less minimal than it could be). + Not really following the Google guidelines by not having _ at the end of class + members, but this is more struct-like (members are public). + */ -struct DetState { - // `base_state` is a state in the input FSA. +class DetState { + public: + // `output_state` is the state in the output FSA that this determinized + // state corresponds to. + int32_t output_state; + + // `base_state` is the state in the input FSA from which the sequence of + // `seq_len` symbols starts. The weighted set of states that this DetState + // represents is the set of states reachable by following that symbol sequence + // from state `base_state`, with the best weights (per reachable state) along + // those paths. When Normalize() is called we may advance int32_t base_state; + // seq_len is the length of symbol sequence that we follow from state `base_state`. // The sequence of symbols can be found by tracing back one of the DetStateElements // in the doubly linked list (it doesn't matter which you pick, the result will be the @@ -76,13 +178,42 @@ struct DetState { bool normalized { false }; - DetState *parent; // Maybe not needed! - - DetStateElement *head; - DetStateElement *tail; + std::list elements; + // This is the weight on the best path that includes this determinized state. + // It's needed to form a priority queue on DetStates, so we can process them + // best-first. It is computed as: the forward-weight on `base_state`, + // plus the best/most-positive of: (the weight in a DetStateElement plus + // the backward-weight of the state associated with that DetStateElement). double forward_backward_weight; + + /* + Process arcs leaving this determinized state, possibly creating new determinized + states in the process. + @param [in] wfsa_in The input FSA that we are determinizing, along + with forward-backward weights. + The input FSA should normally be epsilon-free as + epsilons are treated as a normal symbol; and require + wfsa_in.weight_tpe == kMaxWeight, for + now (might later create a version of this code + that works + @param [in] prune_cutoff Cutoff on forward-backward likelihood + that we use for pruning; will equal + wfsa_in.backward_state_weights[0] - prune_beam. + Will be -infinity if we're not doing pruning. + @param [in,out] state_map Map from DetState to state-index in + + + + */ + void ProcessArcs(const WfsaWithFbWeights &wfsa_in, + Fsa *wfsa_out, + float prune_cutoff, + DetStateMap *state_map, + DetStatePriorityQueue *queue); + + /* Normalizes this DetState and sets forward_backward_weight. @@ -123,22 +254,170 @@ struct DetState { plus (the greatest over the DetStateElements, of its `weight` element, plus the backward weight in the input FSA of the state that corresponds to it). + */ + void Normalize(const Fsa &input_fsa, + const float *input_fsa_weights, + float *removed_weight, + std::vector *leftover_arcs) { +#ifndef NDEBUG + CheckElementOrder(); +#endif + RemoveDuplicatesOfStates(input_fsa); + RemoveCommonPrefix(input_fsa, input_fsa_weights, removed_weight, leftover_arcs); + } + private: + /* + Called from Normalize(), this function removes duplicates in + `elements`: that is, if two elements represent paths that terminate at + the same state in `input_fsa`, we choose the one with the better + weight (or the first one in case of a tie). + */ + void RemoveDuplicatesOfStates(const Fsa &input_fsa, + const float *input_fsa_weights); - worked outobtained from - + /* + Called from Normalize(), this function removes any common prefix that the + paths in `elements` possess. If there is a common prefix it will reduce + `seq_len`, subtract the weights associated with the removed arcs from the + weights in `elements`, and set `input_arcs` to the sequence of arcs that + were removed from */ - void Normalize(std::vector *leftover_arcs); + RemoveCommonPrefix(const Fsa &input_fsa, + const float *input_fsa_weights, + std::vector *input_arcs); + /* + This function just does some checking on the `elements` list that + they are in the correct order, which is a lexicographical + order (by state-id) on the paths of length `seq_len` starting from + `base_state`. The label sequences don't come into it because + they are all the same. + */ + void CheckElementOrder() const; + }; +bool DetStateCompare::operator()(const shared_ptr &a, + const shared_ptr &b) { + return a->forward_backward_weight < b->forward_backward_weight; +} + + + +void DetState::RemoveDuplicatesOfStates(const Fsa &input_fsa) { + + /* + `state_to_elem` maps from int32_t state-id to the DetStateElement + associated with it (there can be only one, we choose the one with + the best weight). + */ + std::unordered_map::iterator> state_to_elem; + -void DetState::Normalize(std::vector *input_arcs) { + for (auto iter = elements.begin(); iter != elements.end(); ++iter) { + int32_t state = input_fsa.arcs[elem.arc_index].nextstate; + auto p = state_to_elem.insert({state, elem}); + bool inserted = p.second; + if (!inserted) { + DetStateElement *old_elem = p.first->second; + if (old_elem->weight > elem->weight) { // old weight is better + this->RemoveElement(elem); + } else { + p.first->second = elem; + this->RemoveElement(old_elem); + } + } + } } +void DetState::RemoveCommonPrefix(const Fsa &input_fsa, + const float *input_fsa_weights, + float *removed_weight_out, + std::vector *input_arcs) { + + CHECK_GE(seq_len, 0); + int32_t len; + auto first_path = elements.front().path, + last_path = elements.back().path; + + for (len = 1; len < seq_len; ++len) { + first_path = first_path->prev; + last_path = last_path->prev; + if (first_path == last_path) { + // Note: we are comparing pointers here. We reached the same PathElement, + // which means we reached the same state. + break; + } + } + input_arcs->clear(); + if (len < seq_len) { + /* We reach a common state after traversing fewer than `seq_len` arcs, + so we can remove a shared prefix. */ + double removed_weight = 0.0; + int32_t new_seq_len = len, + removed_seq_len = seq_len - len; + input_arcs->resize(removed_seq_len); + // Advance base_state + int32_t new_base_state = input_fsa.arcs[first_path->arc_index].src_state; + for (; len < seq_len; ++len) { + auto arc = input_fsa.arcs[first_path->arc_index]; + input_arcs[seq_len - 1 - len] = first_path->arc_index; + removed_weight += input_fsa_weights[first_path->arc_index]; + first_path = first_path->prev; + } + // Check that we got to base_state. + CHECK((self->base_state == 0 && first_path == nullptr) || + fsa.arcs[first_path->arc_index].dest_state == this->base_state); + this->base_state = new_base_state; + if (removed_weight != 0) { + for (DetStateElement &det_state_elem: elements) { + det_state_elem.weight -= removed_weight; + } + } + *removed_weight_out = removed_weight; + } else { + *removed_weight_out = 0; + input_arcs->clear(); + } +} + +void DetState::CheckElementOrder(const Fsa &input_fsa) const { + // Checks that the DetStateElements are in a lexicographical order on the + // lists of states in their paths. This will be true becase of how we + // construct them (it requires on the IsArcSorted() property, whereby arcs + // leaving each state in the FSA are sorted first on label and then on + // dest_state. + if (seq_len == 0) { + CHECK(elements.size() == 1); + CHECK(elements.front().weight == 0.0); + } + + std::vector prev_seq; + for (auto iter = elements.begin(); iter != elements.end(); ++iter) { + auto path = iter->path; + std::vector cur_seq; + for (int32_t i = 0; i < seq_len; i++) { + cur_seq.push_back(input_fsa.arcs[path->arc_index].prev_state); + path = path->prev; + } + std::reverse(cur_seq.begin(), cur_seq.end()); + if (iter != elements.begin()) { + CHECK(cur_seq > prev_seq); + } + prev_seq.swap(cur_seq); + } +} + + +/* + This class maps from determinized states (DetState) to integer state-id + in the determinized output. + */ class DetStateMap { public: + /* Outputs the output state-id corresponding to a specific DetState structure. This does not store any pointers to the DetState or its contents, so diff --git a/k2/csrc/fsa.h b/k2/csrc/fsa.h index 3aec4f33d..65f2f1005 100644 --- a/k2/csrc/fsa.h +++ b/k2/csrc/fsa.h @@ -169,7 +169,7 @@ class DeterministicGenericFsa { float GetWeightForArc(int32_t arc_index); - int32_t Getint32_tForArc(int32_t arc_index); + int32_t GetLabelForArc(int32_t arc_index); int32_t GetPrevStateForArc(int32_t arc_index); diff --git a/k2/csrc/fsa_algo.h b/k2/csrc/fsa_algo.h index 6ed53060b..dce853892 100644 --- a/k2/csrc/fsa_algo.h +++ b/k2/csrc/fsa_algo.h @@ -39,11 +39,9 @@ bool ConnectCore(const Fsa &fsa, std::vector *state_map); Removes states that are not accessible (from the start state) or are not co-accessible (i.e. that cannot reach the final state), and ensures that if the FSA admits a topological sorting (i.e. it contains no cycles except - self-loops), the version that is output is topologically sorted (states may - be renumbered). - - Whenever the output fsa is acyclic or contains only self-loops, it is - topsorted. + self-loops), the version that is output is topologically sorted. This + is not a stable sort, i.e. states may be renumbered even for top-sorted + input. @param [in] a Input FSA @param [out] b Output FSA, that will be trim / connected (there are @@ -53,10 +51,13 @@ bool ConnectCore(const Fsa &fsa, std::vector *state_map); output a map from the arc-index in `b` to the corresponding arc-index in `a`. - Returns true on success (i.e. the output is topsorted). - The only failure condition is when the input had cycles that were not self loops. + @return The return status indicates whether topological sorting + was successful; if true, the result is top-sorted. The only situation + it might return false is when the input had cycles that were not self + loops; such FSAs do not admit a topological sorting. - Caution: true return status does not imply that the returned FSA is nonempty. + Caution: true return status does not imply that the returned FSA is + nonempty. Notes: - If `a` admitted a topological sorting, b will be topologically @@ -81,6 +82,7 @@ bool Connect(const Fsa &a, Fsa *b, std::vector *arc_map = nullptr); as required by this computation. For now we assume that `a` is topologically sorted, as required by the current constructor of WfsaWithFbWeights. + a.weight_type must be kMaxWeight. @param [in] beam beam > 0 that affects pruning; this algorithm will keep paths that are within `beam` of the best path. Just make this very large if you don't want pruning. @@ -130,11 +132,11 @@ bool Connect(const Fsa &a, Fsa *b, std::vector *arc_map = nullptr); // best cost they are going to get. If // FSA was top-sorted at the start, which we assume, we could perhaps // process them in numerical order, e.g. using a heap. - queue.pop_front() + queue.pop_front() for each arc leaving state ji: next_weight = local_forward_state_weights[ji] + arc_weights[this_arc_index] if next_weight + backward_state_weights[arc_dest_state] < best_path_weight - beam: - if arc label is epsilon: + if arc label is epsilon: if next_weight < local_forward_state_weight[next_state]: local_forward_state_weight[next_state] = next_weight local_backpointers[next_state] = ji @@ -144,9 +146,50 @@ bool Connect(const Fsa &a, Fsa *b, std::vector *arc_map = nullptr); figure out the details). Note: the output FSA's weights can be computed later on, by calling code, using the info in arc_map. */ -void RmEpsilonsPruned(const WfsaWithFbWeights &a, float beam, Fsa *b, +void RmEpsilonsPrunedMax(const WfsaWithFbWeights &a, float beam, Fsa *b, + std::vector> *arc_map); + +/* + Version of RmEpsilonsPrunedMax that doesn't support pruning; see its + documentation. + */ +void RmEpsilonsMax(const Fsa &a, float *a_weights, Fsa *b, + std::vector> *arc_map); + + +/** + This version of RmEpsilonsPruned does log-sum on weights along alternative + epsilon paths rather than taking the max. + + @param [in] a The input, with weights and forward-backward weights + as required by this computation. For now we assume + that `a` is topologically sorted, as required by + the current constructor of WfsaWithFbWeights. + a.weight_type may be kMaxWeight or kLogSumWeight; + the difference will affect pruning slightly. + @param [in] beam Beam for pruning, must be > 0. + @param [out] b The output FSA + @param [out] arc_derivs Indexed by arc-index in b, it is an list of + (input-arc, deriv), where 0 < deriv <= 1, where the + lists are ordered by input-arc (unlike + RmEpsilonsPrunedMax, they should not be interpreted + as a sequence). arc_derivs may be interpreted as + a CSR-format matrix of dimension num_arcs_out by + num_arcs in; it gives the derivatives of output-arcs + weights w.r.t. input-arc weights. + */ +void RmEpsilonsPrunedLogSum(const WfsaWithFbWeights &a, float beam, Fsa *b, + std::vector> *arc_derivs); + + +/* + Version of RmEpsilonsLogSum that doesn't support pruning; see its + documentation. + */ +void RmEpsilonsLogSum(const Fsa &a, float *a_weights, Fsa *b, std::vector> *arc_map); + /* Compute the intersection of two FSAs; this is the equivalent of composition for automata rather than transducers, and can be used as the core of diff --git a/k2/csrc/weights.h b/k2/csrc/weights.h index 94d91c4f5..971ccbe7a 100644 --- a/k2/csrc/weights.h +++ b/k2/csrc/weights.h @@ -61,14 +61,17 @@ enum FbWeightType { kMaxWeight, kLogSumWeight }; struct WfsaWithFbWeights { const Fsa *fsa; const float *arc_weights; - // forward_state_weights are the sum of weights along the best path from the + // forward_state_weights are the log-sum or max of weights along all paths from the // start-state to each state. We use double because for long FSAs roundoff // effects can cause nasty errors in pruning. const double *forward_state_weights; - // backward_state_weights are the sum of weights along the best path - // from each state to the final state. + // backward_state_weights are the log-sum or max of weights along all paths from + // each state to the final state. const double *backward_state_weights; + // Records whether we use max or log-sum. + FbWeightType weight_type; + /* Constructor. @param [in] fsa Pointer to an FSA; must satisfy From 14a7ccace5d0e37fe80c52f169b9ac466841212b Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 8 May 2020 12:16:59 +0800 Subject: [PATCH 08/14] Fix compile error --- k2/csrc/fsa_algo.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/k2/csrc/fsa_algo.h b/k2/csrc/fsa_algo.h index bcf3153c8..341842cc8 100644 --- a/k2/csrc/fsa_algo.h +++ b/k2/csrc/fsa_algo.h @@ -179,7 +179,7 @@ void RmEpsilonsMax(const Fsa &a, float *a_weights, Fsa *b, weights w.r.t. input-arc weights. */ void RmEpsilonsPrunedLogSum(const WfsaWithFbWeights &a, float beam, Fsa *b, - std::vector> *arc_derivs); + std::vector>> *arc_derivs); /* From 316786485249919ea742d7a6a73ba20586fc2e6d Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 8 May 2020 12:30:44 +0800 Subject: [PATCH 09/14] Resolve conflicts --- k2/csrc/determinize.cc | 53 ++++-------------------------------------- 1 file changed, 5 insertions(+), 48 deletions(-) diff --git a/k2/csrc/determinize.cc b/k2/csrc/determinize.cc index 340b6ed3b..547234bbd 100644 --- a/k2/csrc/determinize.cc +++ b/k2/csrc/determinize.cc @@ -87,7 +87,6 @@ struct LogSumTracebackState { struct DetStateElement { -<<<<<<< HEAD double weight; // Weight from reference state to this state, along // the path taken by following the 'prev' links // (the path would have `seq_len` arcs in it). @@ -97,13 +96,6 @@ struct DetStateElement { // base state, and the length of the sequence arcs from the // base state to here, are known only in the DetState // that owns this DetStateElement. -======= - double weight; // Weight from reference state to this state, along - // the path taken by following the 'parent' links - // (the path would have `seq_len` arcs in it). - // Note: by "this state" we mean the destination-state of - // the arc at `arc_index`. ->>>>>>> upstream/master std::shared_ptr path; // The path from the start state to here (actually we will @@ -111,8 +103,6 @@ struct DetStateElement { // seq_len == 0 and this belongs to the initial determinized // state. -<<<<<<< HEAD - DetStateElement &&Advance(float arc_weight, int32_t arc_index, int32_t arc_symbol) { return DetStateElement(weight + arc_weight, std::make_shared(path, arc_index, arc_symbol)); @@ -147,20 +137,6 @@ class Determinizer { }; - -======= - // This comparator function compares the weights, but is careful in case of - // ties to ensure deterministic behavior. - bool operator<(const DetStateElement &other) const { - if (weight < other.weight) - return true; - else if (weight > other.weight) - return false; - // TODO(dpovey) - } -}; - ->>>>>>> upstream/master /* Conceptually a determinized state in weighted FSA determinization would normally @@ -194,32 +170,18 @@ class DetState { // from state `base_state`, with the best weights (per reachable state) along // those paths. When Normalize() is called we may advance int32_t base_state; -<<<<<<< HEAD - // seq_len is the length of symbol sequence that we follow from state `base_state`. - // The sequence of symbols can be found by tracing back one of the DetStateElements - // in the doubly linked list (it doesn't matter which you pick, the result will be the -======= + // seq_len is the length of symbol sequence that we follow from state - // `base_state`. - // The sequence of symbols can be found by tracing back one of the - // DetStateElements - // in the doubly linked list (it doesn't matter which you pick, the result - // will be the ->>>>>>> upstream/master - // same. + // `base_state`. The sequence of symbols can be found by tracing back one of + // the DetStateElements in the doubly linked list (it doesn't matter which you + // pick, the result will be the same. int32_t seq_len; bool normalized{false}; -<<<<<<< HEAD - std::list elements; -======= - DetState *parent; // Maybe not needed! - DetStateElement *head; - DetStateElement *tail; ->>>>>>> upstream/master + std::list elements; // This is the weight on the best path that includes this determinized state. // It's needed to form a priority queue on DetStates, so we can process them @@ -339,7 +301,6 @@ class DetState { }; -<<<<<<< HEAD bool DetStateCompare::operator()(const shared_ptr &a, const shared_ptr &b) { return a->forward_backward_weight < b->forward_backward_weight; @@ -374,10 +335,6 @@ void DetState::RemoveDuplicatesOfStates(const Fsa &input_fsa) { } } -======= -void DetState::Normalize(std::vector *input_arcs) {} ->>>>>>> upstream/master - void DetState::RemoveCommonPrefix(const Fsa &input_fsa, const float *input_fsa_weights, float *removed_weight_out, From c8378d459d82528a3db4a30a0187f5a1f3a30b76 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 9 May 2020 11:15:28 +0800 Subject: [PATCH 10/14] Add LogAdd --- k2/csrc/util.h | 48 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/k2/csrc/util.h b/k2/csrc/util.h index db42237d4..559b398c2 100644 --- a/k2/csrc/util.h +++ b/k2/csrc/util.h @@ -31,5 +31,53 @@ struct PairHash { } }; +static const double kMinLogDiffDouble = Log(DBL_EPSILON); // negative! +static const float kMinLogDiffFloat = Log(FLT_EPSILON); // negative! + +// returns log(exp(x) + exp(y)). +inline double LogAdd(double x, double y) { + double diff; + + if (x < y) { + diff = x - y; + x = y; + } else { + diff = y - x; + } + // diff is negative. x is now the larger one. + + if (diff >= kMinLogDiffDouble) { + double res; + res = x + Log1p(Exp(diff)); + return res; + } else { + return x; // return the larger one. + } +} + + +// returns log(exp(x) + exp(y)). +inline float LogAdd(float x, float y) { + float diff; + + if (x < y) { + diff = x - y; + x = y; + } else { + diff = y - x; + } + // diff is negative. x is now the larger one. + + if (diff >= kMinLogDiffFloat) { + float res; + res = x + Log1p(Exp(diff)); + return res; + } else { + return x; // return the larger one. + } +} + + + } // namespace k2 #endif // K2_CSRC_UTIL_H_ From 04fafe3c41d905b357f4ab6c865929159b0f03a6 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 9 May 2020 11:34:38 +0800 Subject: [PATCH 11/14] Fix compile errors in util.h --- k2/csrc/util.h | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/k2/csrc/util.h b/k2/csrc/util.h index 559b398c2..9422bed5a 100644 --- a/k2/csrc/util.h +++ b/k2/csrc/util.h @@ -7,9 +7,11 @@ #ifndef K2_CSRC_UTIL_H_ #define K2_CSRC_UTIL_H_ +#include #include #include + #include "k2/csrc/fsa.h" namespace k2 { @@ -31,8 +33,15 @@ struct PairHash { } }; -static const double kMinLogDiffDouble = Log(DBL_EPSILON); // negative! -static const float kMinLogDiffFloat = Log(FLT_EPSILON); // negative! +#ifndef DBL_EPSILON +#define DBL_EPSILON 2.2204460492503131e-16 +#endif +#ifndef FLT_EPSILON +#define FLT_EPSILON 1.19209290e-7f +#endif + +static const double kMinLogDiffDouble = log(DBL_EPSILON); // negative! +static const float kMinLogDiffFloat = log(FLT_EPSILON); // negative! // returns log(exp(x) + exp(y)). inline double LogAdd(double x, double y) { @@ -48,7 +57,7 @@ inline double LogAdd(double x, double y) { if (diff >= kMinLogDiffDouble) { double res; - res = x + Log1p(Exp(diff)); + res = x + log1p(exp(diff)); return res; } else { return x; // return the larger one. @@ -70,7 +79,7 @@ inline float LogAdd(float x, float y) { if (diff >= kMinLogDiffFloat) { float res; - res = x + Log1p(Exp(diff)); + res = x + log1pf(expf(diff)); return res; } else { return x; // return the larger one. From 908457b5b2591f8d94c5ad0560c504bd9824a9b6 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 11 May 2020 20:13:28 +0800 Subject: [PATCH 12/14] More progress on determinization code --- k2/csrc/determinize.cc | 688 ++++++++++++++++++++++++++++++----------- k2/csrc/fsa.h | 21 +- 2 files changed, 516 insertions(+), 193 deletions(-) diff --git a/k2/csrc/determinize.cc b/k2/csrc/determinize.cc index 547234bbd..506dc810f 100644 --- a/k2/csrc/determinize.cc +++ b/k2/csrc/determinize.cc @@ -16,26 +16,62 @@ namespace k2 { using std::shared_ptr; using std::vector; using std::priority_queue; -using std::pair +using std::pair; + + struct MaxTracebackState { - // Element of a path from the start state to some state in an FSA - // We can trace back the `parent` links, which will take - // us backward along a path in the original FSA. - std::shared_ptr prev; + using DerivOutputType = int32_t; + - int32_t arc_index; // Index of most recent arc in path from start-state to - // the dest-state, or -1 if the path is empty (only - // possible if this element belongs to the start-state). + int32_t state_id; // state-id in the input FSA - int32_t symbol; // Symbol on the arc numbered `arc_index` of the input FSA - // (copied here for convenience), or 0 if arc_index == -1. + int32_t arc_id; // arc-id in input FSA of the arc that enters + // state `state_id` (or -1 if this is the start state). - MaxTracebackState(std::shared_ptr prev, - int32_t arc_index, int32_t symbol): - prev(prev), arc_index(arc_index), symbol(symbol) { } + // prev_state is the state we trace back to (the previous state), + // which is the src_state of the arc numbered arc_id. + // It will be nullptr if state_id == 0 (start state). + shared_ptr prev_state; + double forward_prob; // The total forward log-probability from the start + // state to this state (along whichever specific + // sequence of symbols we took to get here; it's not + // necessarily the best forward in the lattice). + + // This constructor is for the start-state of the input FSA. + MaxTracebackState(): state_id(0), arc_id(-1), arc_symbol(-1), + prev_state(nullptr), forward_prob(0.0) { } + + /** + @param [in] state_id State in input FSA that this corresponds to + @param [in] src Previous LogSumTracebackState that we'll point back + to, or NULL + @param [in] incoming_arc_index. Its src_state will equal src->state_id, + its dest_state will equal state_id. + @param [in] src_symbol Symbol on the input arc + @param [in] arc_weight Weight on the input arc + */ + MaxTracebackState(int32_t state_id, + const std::shared_ptr &src, + int32_t incoming_arc_index, + int32_t arc_weight): + state_id(state_id), + arc_id(incoming_arc_index), + prev_state(src), + forward_prob(element->forward_prob + arc_weight) { } + + void Accept(const std::shared_ptr &src, + int32_t arc_index, int32_t _symbol, float arc_weight) { + double new_forward_prob = src->forward_prob + arc_weight; + if (new_forward_prob > forward_prob) { + forward_prob = new_forward_prob; + arc_id = arc_index; + prev_state = src; + // state_id doesn't change, nor does _symbol. + } + } }; @@ -47,101 +83,280 @@ class LogSumTracebackState; // terminating in a specific state. struct LogSumTracebackLink { - int32_t arc_index; // Index of most recent arc in path from start-state to - // the dest-state, or -1 if the path is empty (only - // possible if this element belongs to the start-state). + shared_ptr prev_state; + // `prev_state` is the state that this points back to. + + + int32_t arc_index; // Index (in input FSA) of this arc from prev_state to the + // destination state (in whose LogSumTracebackState this + // LogSumTracebackLink will be located). + + double forward_prob; // The total forward log-probability from the start + // state to the end of this arc just before it joins the + // node (conceptually). Note: this is only the total + // forward log-prob limited to whatever symbol-sequence + // got us to this point... there may be other + // symbol-sequences terminating in this state. + // (That symbol-sequence can be obtained by tracing back + // this data structure till you hit a LogSumTracebackState + // with prev_state == nullptr (and state_id == 0). - int32_t symbol; // Symbol on the arc numbered `arc_index` of the input FSA - // (copied here for convenience), or 0 if arc_index == -1. - double prob; // The probability mass associated with this incoming - // arc in the LogSumTracebackState to which this belongs. + LogSumTracebackLink(const std::shared_ptr &src, + int32_t arc_index, float arc_weight): + src(src), arc_index(arc_index), + forward_prob(arc_weight + src->forward_prob) { } + - std::shared_ptr prev_state; }; + struct LogSumTracebackState { - // LogSumTracebackState can be thought of as as a weighted set of paths from the - // start state to a particular state. (It will be limited to the subset of - // paths that have a specific symbol sequence). + using DerivOutputType = pair; + + // LogSumTracebackState can be thought of as as a weighted set of paths from + // the start state to a particular state. (It will be limited to the subset + // of paths that have a specific symbol sequence). - // `prev_elements` is, conceptually, a list of pairs (incoming arc-index, - // traceback link); we will keep it free of duplicates of the same incoming - // arc. + // `prev_elements` is, conceptually, a list of incoming arcs with associated + // weights. vector prev_elements; + int32_t state_id; // The state-id in the input FSA that this + // LogSumTracebackState corresponds to. (Unique to + // this determinized state; the same state-id may + // appear in multiple determinized states, in general. + + double forward_prob; // The total forward log-probability from the start + // state to this state (along whichever specific + // sequence of symbols we took to get here; it's not + // necessarily the best forward-prob in the lattice + // that would take us to this state). Will equal the + // log-sum of the forward_probs of the prev_elements. + + double backward_prob; // Used temporarily in algorithms as a backward prob. + // Undefined most of the time. + + // This constructor is to be used only for the start-state (of both the + // input FSA and the determinized FSA). + LogSumTracebackState(): state_id(0), forward_prob(0.0) { } + + /** + @param [in] state_id State in input FSA that this corresponds to + @param [in] src Previous LogSumTracebackState that we'll point back + to, or nullptr if this belongs to the initial + determinized-state. + @param [in] incoming_arc_index. Arc-index in input FSA. + Its src_state will equal src->state_id, its dest_state + will equal state_id. + @param [in] arc_weight Weight on the arc + */ + LogSumTracebackState(int32_t state_id, + const std::shared_ptr &src, + int32_t incoming_arc_index, + int32_t arc_weight): + state_id(state_id), + forward_prob(element->forward_prob + arc_weight) { + prev_elements.emplace_back(src, incoming_arc_index, forward_prob); + } + + /* + Accept a new incoming link. The args are the same as for + the constructor just above; see documentation there. + */ + void Accept(const std::shared_ptr &src, + int32_t arc_index, float arc_weight) { + double link_forward_prob = src.forward_prob + arc_weight; + prev_elements.emplace_back(src, arc_index, link_forward_prob); + this->forward_prob = LogAdd(this->forward_prob, link_forward_prob); + } +}; - int32_t arc_index; // Index of most recent arc in path from start-state to - // the dest-state, or -1 if the path is empty (only - // possible if this element belongs to the start-state). - int32_t symbol; // Symbol on the arc numbered `arc_index` of the input FSA - // (copied here for convenience), or 0 if arc_index == -1. - MaxTracebackState(std::shared_ptr prev, - int32_t arc_index, int32_t symbol): - prev(prev), arc_index(arc_index), symbol(symbol) { } -}; +/* Given a set of traceback states (e.g. as present in one determinized state or + after tracing back a certain number of times), trace back along one arc to + get the set of traceback states reachable from this. These will correspond + to a symbol sequence from the start-state that's shorter by one. -struct DetStateElement { + @param [in] cur_states A set of traceback states that we want to + trace back from + @param [out] prev_states The set of states reachable by one + link of traceback from any state in `cur_states` + is written to here. + TODO(dpovey): remove +*/ +void GetPreviousTracebackStates( + const std::vector &cur_states, + std::vector *prev_states); - double weight; // Weight from reference state to this state, along - // the path taken by following the 'prev' links - // (the path would have `seq_len` arcs in it). - // Note: by "this state" we mean the destination-state of - // the arc at `arc_index`. - // Interpret this with caution, because the - // base state, and the length of the sequence arcs from the - // base state to here, are known only in the DetState - // that owns this DetStateElement. +} - std::shared_ptr path; - // The path from the start state to here (actually we will - // only follow back `seq_len` links. Will be nullptr if - // seq_len == 0 and this belongs to the initial determinized - // state. - DetStateElement &&Advance(float arc_weight, int32_t arc_index, int32_t arc_symbol) { - return DetStateElement(weight + arc_weight, - std::make_shared(path, arc_index, arc_symbol)); +/* + Find the most recent common ancestor LogSumTracebackState of a set of + LogSumTracebackStates, and return the number of links we had to follow to get + there (i.e. the length of symbol sequence we had to remove from + each path). + + @param [in,out] cur_states A set of TracebackStates that we'll + trace back from. Must be nonempty. Equality + here is simply pointer identity. + + At exit it will contain a single member which will + be the most recent common ancestor. + @return Returns the number of links we had to follow (>=0). If + `cur_states.size() == 1` this will be zero. + */ +int32_t GetMostRecentCommonAncestor( + std::unordered_set *cur_states) { + int32_t ans = 0; + std::unordered_set prev_states; + for (; cur_states->size() != 1; ans++) { + CHECK(!cur_states->empty()); + for (LogSumTracebackState *s: cur_states) { + for (const LogSumTracebackLink &l: s->prev_elements) { + prev_states.insert(l.prev_state.get()); + } + } + cur_states->clear(); + cur_states->swap(prev_states); } + return ans; +} - DetStateElement(double weight, std::shared_ptr &&path): - weight(weight), path(path) { } - -}; -class DetState; +// Version of GetMostRecentCommonAncestor() for MaxTracebackState; +// see documentation for the other version. +int32_t GetMostRecentCommonAncestor( + std::unordered_set *cur_states) { + int32_t ans = 0; + std::unordered_set prev_states; + for (; cur_states->size() != 1; ans++) { + CHECK(!cur_states->empty()); + for (MaxTracebackState *s: cur_states) { + prev_states.insert(s->prev_state.get()); + } + cur_states->clear(); + cur_states->swap(prev_states); + } + return ans; +} -struct DetStateCompare { - // Comparator for priority queue. Less-than operator that compares - // forward_backward_weight for best-first processing. - bool operator()(const shared_ptr &a, - const shared_ptr &b); -}; +/** + A TraceBack() function exists for LogSumTracebackState and MaxTracebackState; + it's used in DetState::Normalize(). It finds the cost and derivative + information from getting rid of `num_steps` symbols from a symbol sequence. + + @param [in] cur_states (This is consumed destructively, i.e. don't + expect it to contain the same set on exit). + A set of states; we'll iteratively trace back this + set one step at a time. At entry it must have + size() == 1; it will also have size() == 1 after + `num_steps` steps. + @param [in] num_steps The number of steps to trace back + @param [in] arc_weights_in Weights on the arcs of the input FSA + @param [out] weight_out The output weight; will be the forward-backward + weight of the sub-graph whose final-state is + (*cur_states).front() and whose start-state is + the result of following that back for `num_steps` steps + (which will also be a single state, by virtue of how + the whole determinization algorithm works). Will be + zero if num_steps == 0. + @param [out] deriv_out Some derivative information at the output + will be written to here, which tells us how the weight + `weight_out` varies as a function of the weights + on the arcs of the input FSA; it's a list + (input_arc_id, deriv) where, mathematically, 0 < deriv <= 1 + (but we might still get exact zeros due to limitations + of floating point representation). + Note: the sum of the float values in this vector should + be equal to `num_steps`. + */ +void TraceBack(std::unordered_set *cur_states, + int32_t num_steps, + const float *arc_weights_in, + float *weight_out, + std::vector > *deriv_out) { + std::unordered_set prev_states; + assert(cur_states.size() == 1); + // In the standard forward-backward algorithm for HMMs this backward_prob + // would, mathematically, be 0.0, but if we set it to the negative of the + // forward prob we can avoid having to subtract the total log-prob + // when we compute posterior/occupation probabilities for arcs. + double cur_forward_prob = cur_states.front()->forward_prob; + cur_states.front()->backward_prob = cur_forward_prob; + deriv_out->clear(); + for (int32_t i = 0; i < num_steps; i++) { + for (LogSumTracebackState *state_ptr: *cur_states) { + double backward_prob = state_ptr->backward_prob; + for (auto link: state_tr->prev_elements) { + float arc_log_posterior = link.forward_prob + backward_prob; + deriv_out->push_back(std::pair(link.arc_index, expf(log_posterior))); + LogSumTracebackState *prev_state = link.prev_state.get(); + double new_backward_prob = backward_prob + arc_weights_in[link.arc_index]; + if (prev_states.insert(prev_state).second) { // newly inserted + prev_state->backward_prob = new_backward_prob; + } else { + prev_state->backward_prob = LogAdd(new_backward_prob, + prev_state->backward_prob); + } + } + } + cur_states->clear(); + cur_states->swap(prev_states); + } + // failure of the next assertion may indicate many kinds of bugs in the + // algorithm. + CHECK_EQ(cur_states.size(), 1); + double prev_forward_prob = cur_states.front()->forward_prob; + *weight_out = cur_forward_prob - prev_forward_prob; + // The following is mostly for ease of interpretability of the output; + // conceptually the order makes no difference. + std::reverse(deriv_out->begin(), deriv_out->end()); +} -class Determinizer { - public: - private: +// See documentation of TraceBack for LogSumTracebackState, above. +// This version is simpler. +void TraceBack(std::unordered_set *cur_states, + int32_t num_steps, + const float *, // arc_weights_in, unused. + float *weight_out, + std::vector *deriv_out) { + // we recompute the arc weight sum from arc_weights_in, which should + // hopefully give + float arc_weight_sum = 0.0; + CHECK_EQ(cur_states.size(), 1); + MaxTracebackState *state = cur_states->front(); + double cur_forward_prob = state->forward_prob; + deriv_out->resize(num_steps); + for (int32_t i = num_steps - 1; i >= 0; i--) { + (*deriv_out)[i] = state->arc_id; + } + double prev_forward_prob = state->forward_prob; + *weight_out = cur_forward_prob - prev_forward_prob; +} - using DetStatePriorityQueue = priority_queue, - vector >, - DetStateCompare>; +// Priority queue templated on: +// item queued = unique_ptr (using pointer equality as comparison) +// container type = vector > +// less-than operator = DetStateCompare (which compares the forward_backward_prob). - -}; +template +using DetStatePriorityQueue = priority_queue >, + vector > >, + DetStateCompare >; /* Conceptually a determinized state in weighted FSA determinization would - normally - be a weighted subset of states in the input FSA, with the weights normalized - somehow (e.g. subtracting the sum of the weights). + normally be a weighted subset of states in the input FSA, with the weights + normalized somehow (e.g. subtracting the sum of the weights). Two determinized states are equal if the states and weights are the same. To ensure differentiability, our assumption is that in general no two arcs in the @@ -154,34 +369,71 @@ class Determinizer { it will just give us an output that's less minimal than it could be). - Not really following the Google guidelines by not having _ at the end of class - members, but this is more struct-like (members are public). - + We're not really following the Google guidelines by not having _ at the end of + class members, but this is more struct-like (members are public). */ +template // TracebackState == MaxTracebackState or LogSumTracebackState class DetState { + + public: + using DerivOutputType = typename TracebackState::DerivOutputType; + // .. and DerivOutputType == int32_t or pair + // respectively. + + + DetState(int32_t seq_len): + seq_len(seq_len), + output_state(-1), // Not known yet + normalized(false) { } // .. and forward_backward_weight undefined + + /** + Process incoming arc to this DetState. See documentation for + constructor of TracebackState (which will be MaxTracebackState or + LogSumTracebackState). + @param [in] state_id State-id, in input FSA, into which this arc enters. + [Note: a DetState is a weighted subset of state-ids.] + @param [in] incoming_arc_index Arc in input FSA that enters state + `state_id`. + @param [in] arc_symbol The symbol on + */ + void AcceptIncomingArc(int32_t state_id, + const std::shared_ptr &src, + int32_t incoming_arc_index, + int32_t arc_weight) { + auto iter = elements.find(state_id); + if (iter == elements.end()) { + elements[state_id] = std::make_shared( + state_id, src, incoming_arc_index, arc_weight); + } else { + iter.second->Accept( + src, incoming_arc_index, arc_symbol, arc_weight); + } + } + + // Length of sequence of symbols leading to this DetState. + int32_t seq_len; + // `output_state` is the state in the output FSA that this determinized // state corresponds to. int32_t output_state; - // `base_state` is the state in the input FSA from which the sequence of - // `seq_len` symbols starts. The weighted set of states that this DetState - // represents is the set of states reachable by following that symbol sequence - // from state `base_state`, with the best weights (per reachable state) along - // those paths. When Normalize() is called we may advance - int32_t base_state; - - // seq_len is the length of symbol sequence that we follow from state - // `base_state`. The sequence of symbols can be found by tracing back one of - // the DetStateElements in the doubly linked list (it doesn't matter which you - // pick, the result will be the same. + // `base_state`. (Note: we don't store base_state as a member any more, it + // can be worked out by tracing back in the TracebackState data structure). + // The sequence of symbols can be found by tracing back one of the + // DetStateElements in the doubly linked list (it doesn't matter which you + // pick, the result will be the same). int32_t seq_len; - bool normalized{false}; + bool normalized; - std::list elements; + // `elements` can be thought of as weighted subsets of states in the input + // FSA, that also stores some traceback information that lets us compute + // derivatives. + // It's a map from (state-id in input FSA) -> its corresponding TracebackState. + std::unordered_map > elements; // This is the weight on the best path that includes this determinized state. // It's needed to form a priority queue on DetStates, so we can process them @@ -193,7 +445,9 @@ class DetState { /* Process arcs leaving this determinized state, possibly creating new determinized - states in the process. + states in the process. Note: Normalize() should already have been called on + *this. + @param [in] wfsa_in The input FSA that we are determinizing, along with forward-backward weights. The input FSA should normally be epsilon-free as @@ -205,70 +459,60 @@ class DetState { that we use for pruning; will equal wfsa_in.backward_state_weights[0] - prune_beam. Will be -infinity if we're not doing pruning. - @param [in,out] state_map Map from DetState to state-index in - - - + @param [out] arcs_out Output-FSA arcs-- those leaving this + determinized state-- will be appended to here. + @param [out] arc_weights_out Weights for the output arcs will + be appended to here. + @param [out] derivs_per_arc Derivative information for the output + arcs will be appended to here: either sequences + of int32_t (corresponding to input arcs), or + lists of pair, corresponding + to weighted input arcs. + @param [in,out] state_map Maps from DetState to int32_t state-id + in the output FSA. */ void ProcessArcs(const WfsaWithFbWeights &wfsa_in, - Fsa *wfsa_out, float prune_cutoff, - DetStateMap *state_map, - DetStatePriorityQueue *queue); - + vector *arcs_out, + vector *arc_weights_out, + vector *derivs_per_arc, + DetStateMap *state_map, + DetStatePriorityQueue *queue); + + // Computes the forward-backward weight of this DetState. This is + // related to the best cost of any path through the output FSA + // that included this determinized state. I say "related to" + // because while it should be exact in the Max case, in the + // LogSum case the relationship is a bit more complicated; + // maybe just best to say that this is a weight that we use + // for pruning. + // @param [in] backward_state_weight Array, indexed by + // state in input WFSA, of the weight from this state + // to the end. (Of the best path or the sum of paths, + // depending how it was computed; this will of + // course affect the pruning). + void ComputeFbWeight(const float *backward_state_weights); /* - Normalizes this DetState and sets forward_backward_weight. - - By 'normalize' what we mean is the following: - - - Remove duplicates. - - If the DLL of DetStateElements contains duplicate elements (i.e. - elements whose paths end in the same state) it removes whichever has - the - smallest weight. (Remember, a determinized state is, conceptually, a - weighted subset of elements; we are implementing determinization in a - tropical-like semiring where we take the best weight. - - In case of ties on the weights, we carefully re-examine the paths to - make sure that the tie was not due to numerical roundoffi; and if it - was still a tie, we disambiguate using a lexical order on state - sequences. The reason it's important to have deterministic behavior in - case of ties on weights, is that a failure here could lead to - situations where we didn't advance the base state where we could, - leading the number of determinized states to be larger than it could - be. - - - Advance the base state if possible. Each DetState can be represented - as a base state and a sequence of symbols from that base state, but - if some initial subsequence of that symbol sequence takes us to - a unique state then we say the DetState is not normalized. In that - case we need to advance the base state and reduced `seq_len`. - If this happens, then the arc sequence which takes us to the new - base state will be output to `leftover_arcs`. When this is done, - the 'weight' components of the DetStateElement members also need - to be adjusted to remove the weight contribution from those arcs. - - The forward_backward_weight is the weight on the best path through the - output determinized FSA that will include this DetState. It will determine - the order of expansion of DetStates and also whether the states are - expanded at all (if the pruning beam `beam` is finite). - forward_backward_weight is the sum of the forward weight of the base state, - plus (the greatest over the DetStateElements, of its `weight` element, - plus the backward weight in the input FSA of the state that corresponds - to it). + Normalizes this DetState by reducing seq_len to the extent possible + and outputting the weight and derivative info corresponding to this + reduction of sequence length. Recall that these determinized states + are represented as a (base state, and a sequence of symbols that we + followed from the base state). This allows a smaller set of + input FSAs to be determinized than the normal weighted-subet-of-states + formulation, equivalent (I believe) to the assumption that all + the weights are distinct and have no 'special relationships' between + them, i.e. no equalities like a + b = c. This kind of requirement is + necessarly for differentiablity. + + @param [in] wfsa_in The weighted FSA we are determinizing + @param [out] removed_weight The part of the weight that was + removed when we reduced `seq_len`, if any, will + be written to here (else 0.0). */ - void Normalize(const Fsa &input_fsa, - const float *input_fsa_weights, + void Normalize(const WfsaWithFbWeights &wfsa_in, float *removed_weight, - std::vector *leftover_arcs) { -#ifndef NDEBUG - CheckElementOrder(); -#endif - RemoveDuplicatesOfStates(input_fsa); - RemoveCommonPrefix(input_fsa, input_fsa_weights, removed_weight, leftover_arcs); - } + std::vector *deriv_info); private: /* @@ -289,52 +533,116 @@ class DetState { */ RemoveCommonPrefix(const Fsa &input_fsa, const float *input_fsa_weights, - std::vector *input_arcs); - /* - This function just does some checking on the `elements` list that - they are in the correct order, which is a lexicographical - order (by state-id) on the paths of length `seq_len` starting from - `base_state`. The label sequences don't come into it because - they are all the same. - */ - void CheckElementOrder() const; + float *weight_out, + std::vector *input_arcs); }; -bool DetStateCompare::operator()(const shared_ptr &a, - const shared_ptr &b) { +template +bool DetStateCompare::operator()( + const shared_ptr > &a, + const shared_ptr > &b) { return a->forward_backward_weight < b->forward_backward_weight; } -void DetState::RemoveDuplicatesOfStates(const Fsa &input_fsa) { - /* - `state_to_elem` maps from int32_t state-id to the DetStateElement - associated with it (there can be only one, we choose the one with - the best weight). - */ - std::unordered_map::iterator> state_to_elem; +template +void DetState::ProcessArcs( + const WfsaWithFbWeights &wfsa_in, + float prune_cutoff, + vector *arcs_out, + vector *arc_weights_out, + vector > *derivs_per_arc, + DetStateMap *state_map, + DetStatePriorityQueue *queue) { + std::unordered_map > > label_to_state; - for (auto iter = elements.begin(); iter != elements.end(); ++iter) { - int32_t state = input_fsa.arcs[elem.arc_index].nextstate; - auto p = state_to_elem.insert({state, elem}); - bool inserted = p.second; - if (!inserted) { - DetStateElement *old_elem = p.first->second; - if (old_elem->weight > elem->weight) { // old weight is better - this->RemoveElement(elem); + Fsa *fsa = wsfa_in.fsa; + const float *arc_weights = wfsa_in.arc_weights; + for (const std::shared_ptr &state_ptr: elements) { + int32_t state_id = state_ptr->state_id, + begin_arc = fsa->arc_indexes[state_id], + end_arc = fsa->arc_indexes[state_id + 1]; + for (int32_t a = begin_arc; a < end_arc; ++a) { + const Arc &arc = fsa->arcs[a]; + float weight = arc_weights[a]; + int32_t label = arc.label; + TracebackState *state; + if (iter != label_to_state.end()) { + state = iter->second.get(); } else { - p.first->second = elem; - this->RemoveElement(old_elem); + auto new_state = std::make_shared >(seq_len + 1); + state = new_state.get(); + label_to_state[label] = std::move(new_state); + } else { + state->Accept(state_ptr, a, arc.label, weight); + + state_to_weight[arc.dest_state] = std::make_shared( + state_ptr, a, arc.label, weight); } } } + CHECK(!label_to_state.empty() || + elements[0]->state_id == fsa->FinalState()); // I'm assuming the input + // FSA is connected. + + for (auto iter = state_to_weight.begin(); + iter != state_to_weight.end(); ++iter) { + int32_t label = iter->first; + std::shared_ptr &next_det_state = iter->second; + float arc_weight; + std::vector deriv_info; + next_det_state->Normalize(wfsa_in, &arc_weight, &deriv_info); + int32_t next_state_id; + if (state_map->GetOutputState(&next_state_id)) { + // State was newly created. + queue->push(iter->second); + } + arcs_out->push_back({this->state_id, next_state_id, label}); + arc_weights_out->push_back(arc_weight); + derivs_per_arc->push_back(std::move(deriv_info)); + } } +template +void DetState::ComputeFbWeight( + const float *backward_state_weights) { + forward_backward_weight = -std::numeric_limits::infinity(); + for (auto p: elements) { + TracebackState *state = p.second.get(); + forward_backward_weight = max(forward_backward_weight, + state->forward_prob + + backward_state_weights[state->state_id]); + } +} + +template +void DetState::Normalize(const WfsaWithFbWeights &wfsa_in, + float *removed_weight, + std::vector *deriv_info) { + std::unordered_set cur_states; + for (auto p: elements) { + TracebackState *state = p.second.get(); + cur_states.insert(state); + } + int32_t new_seq_len = GetMostRecentCommonAncestor(&cur_states); + // now cur_states.size() == 1. + CHECK_LE(new_seq_len, seq_len); + + TraceBack(&cur_states, seq_len - new_seq_len, + wfsa_in.arc_weights, + removed_weight, deriv_info); + + seq_len = new_seq_len; + normalized = true; +} + + + void DetState::RemoveCommonPrefix(const Fsa &input_fsa, const float *input_fsa_weights, float *removed_weight_out, @@ -397,12 +705,12 @@ void DetState::CheckElementOrder(const Fsa &input_fsa) const { CHECK(elements.front().weight == 0.0); } - std::vector prev_seq; + std::vector prev_seq; for (auto iter = elements.begin(); iter != elements.end(); ++iter) { auto path = iter->path; - std::vector cur_seq; + std::vector cur_seq; for (int32_t i = 0; i < seq_len; i++) { - cur_seq.push_back(input_fsa.arcs[path->arc_index].prev_state); + cur_seq.push_back(input_fsa.arcs[path->arc_index].src_state); path = path->prev; } std::reverse(cur_seq.begin(), cur_seq.end()); @@ -415,9 +723,12 @@ void DetState::CheckElementOrder(const Fsa &input_fsa) const { /* - This class maps from determinized states (DetState) to integer state-id - in the determinized output. + This class maps from determinized states (DetState) to integer state-ids + in the determinized output. Caution: it uses a randomized algorithm that + could in principle produce collisions that would generate wrong output. + We don't think this will ever happen though (128-bit representations). */ +template class DetStateMap { public: @@ -435,7 +746,7 @@ class DetStateMap { @return Returns true if this was a NEWLY CREATED state, false otherwise. */ - bool GetOutputState(const DetState &a, int32_t *state_id) { + bool GetOutputState(const DetState &a, int32_t *state_id) { std::pair compact; DetStateToCompact(a, &compact); auto p = map_.insert({compact, cur_output_state}); @@ -452,9 +763,16 @@ class DetStateMap { int32_t size() const { return cur_output_state_; } private: + struct PairHasher { + size_t operator () (const std::pair &p) const { + return static_cast(p.first); + } + } + + int32_t cur_output_state_{0}; std::unordered_map, int32_t, - DetStateVectorHasher> map_; + PairHasher> map_; /* Turns DetState into a compact form of 128 bits. Technically there could be collisions, which would be fatal for the algorithm, but this diff --git a/k2/csrc/fsa.h b/k2/csrc/fsa.h index b682d59e3..6df24214b 100644 --- a/k2/csrc/fsa.h +++ b/k2/csrc/fsa.h @@ -71,14 +71,14 @@ struct ArcHash { accepts no strings) by having no states at all, so `arcs` would be empty. */ struct Fsa { - // `arc_indexes` is indexed by state-index, is of length num-states, - // contains the first arc-index leaving this state (index into `arcs`). - // The next element of this array gives the end of that range. - // Note: the final-state is numbered last, and implicitly has no - // arcs leaving it. For non-empty FSA, we put a duplicate of the final state - // at the end of `arc_indexes` to avoid boundary check for some FSA - // operations. Caution: users should never call `arc_indexes.size()` to get - // the number of states, they should call `NumStates()` to get the number. + // `arc_indexes` is indexed by state-index, is of length num-states + 1; it + // contains the first arc-index leaving this state (index into `arcs`). The + // next element of this array gives the end of that range. Note: the + // final-state is numbered last, and implicitly has no arcs leaving it. For + // non-empty FSA, we put a duplicate of the final state at the end of + // `arc_indexes` to avoid boundary check for some FSA operations. Caution: + // users should never call `arc_indexes.size()` to get the number of states, + // they should call `NumStates()` to get the number. std::vector arc_indexes; // Note: an index into the `arcs` array is called an arc-index. @@ -112,6 +112,11 @@ struct Fsa { return !arc_indexes.empty() ? (static_cast(arc_indexes.size()) - 1) : 0; } + int32_t FinalState() const { + // It's not valid to call this if the FSA is empty. + CHECK(!arc_indexes.empty()); + return arc_indexes.size() - 2; + } }; /* From 3ed45faea0d4403c774f3853161ee18c0062f308 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 12 May 2020 21:13:50 +0800 Subject: [PATCH 13/14] More progress on determinizaton draft. --- k2/csrc/determinize.cc | 540 +++++++++++++++++++++-------------------- k2/csrc/fsa_algo.h | 66 ++++- 2 files changed, 340 insertions(+), 266 deletions(-) diff --git a/k2/csrc/determinize.cc b/k2/csrc/determinize.cc index 506dc810f..4d23b8cc6 100644 --- a/k2/csrc/determinize.cc +++ b/k2/csrc/determinize.cc @@ -14,12 +14,18 @@ namespace k2 { using std::shared_ptr; +using std::weak_ptr; using std::vector; using std::priority_queue; using std::pair; +// setting this to true will reduce memory consumption, but could increase +// compute time in cases where we finish before the beam. +constexpr bool PROCESS_ARCS_IMMEDIATELY = false; + + struct MaxTracebackState { using DerivOutputType = int32_t; @@ -173,28 +179,6 @@ struct LogSumTracebackState { }; - - - -/* Given a set of traceback states (e.g. as present in one determinized state or - after tracing back a certain number of times), trace back along one arc to - get the set of traceback states reachable from this. These will correspond - to a symbol sequence from the start-state that's shorter by one. - - @param [in] cur_states A set of traceback states that we want to - trace back from - @param [out] prev_states The set of states reachable by one - link of traceback from any state in `cur_states` - is written to here. - TODO(dpovey): remove -*/ -void GetPreviousTracebackStates( - const std::vector &cur_states, - std::vector *prev_states); - -} - - /* Find the most recent common ancestor LogSumTracebackState of a set of LogSumTracebackStates, and return the number of links we had to follow to get @@ -353,8 +337,11 @@ using DetStatePriorityQueue = priority_queue DetStateCompare >; +template +class DetStateMap; + /* - Conceptually a determinized state in weighted FSA determinization would + Conceptually, a determinized state in weighted FSA determinization would normally be a weighted subset of states in the input FSA, with the weights normalized somehow (e.g. subtracting the sum of the weights). @@ -381,11 +368,22 @@ class DetState { // .. and DerivOutputType == int32_t or pair // respectively. + // Constructor for the initial state of the determinized FSA + DetState(): seq_len(0), output_state(-1), normalized(true) { + // the constructor that takes no args gives us what we need for the + // start-state. + elements[0] = std::make_shared(); + } + + // TODO: constructor for start-state. - DetState(int32_t seq_len): + DetState(int32_t seq_len, int32_t src_output_state, + int32_t pending_symbol): seq_len(seq_len), output_state(-1), // Not known yet - normalized(false) { } // .. and forward_backward_weight undefined + normalized(false), + src_output_state(src_output_state), + pending_symbol(pending_symbol) { } // .. and forward_backward_weight undefined /** Process incoming arc to this DetState. See documentation for @@ -411,23 +409,41 @@ class DetState { } } - // Length of sequence of symbols leading to this DetState. + // Length of sequence of symbols (from the base state) leading to this DetState. + // each DetState can be described by: start from base_state, follow paths + // with a specific symbol sequence, and the weighted set of states that you + // reach corresponds to this DetState. + // + // The sequence of symbols, and the base_state, can be found by tracing back + // one of the DetStateElements in the doubly linked list (it doesn't matter + // which you pick, the result will be the same). int32_t seq_len; // `output_state` is the state in the output FSA that this determinized - // state corresponds to. + // state corresponds to. (Only known if `normalized` is true; otherwise, -1). int32_t output_state; - // seq_len is the length of symbol sequence that we follow from state - // `base_state`. (Note: we don't store base_state as a member any more, it - // can be worked out by tracing back in the TracebackState data structure). - // The sequence of symbols can be found by tracing back one of the - // DetStateElements in the doubly linked list (it doesn't matter which you - // pick, the result will be the same). - int32_t seq_len; - + // `normalized` is true if this DetState is known to be normalized (meaning: + // we have reduced seq_len as much as possible). DetStates that are not + // normalized will not yet have an `output_state`. bool normalized; + // The following two elements are only relevant if `normalized` is false. + // The purpose is so that we can delay outputting the arc leading to this + // DetState, until we know that this DetState will end up having arcs + // leaving it processed (see ProcessArcs()). For pruned determinization, + // this avoids unnecessary work. + // + // src_output_state is the state, in the output FSA, of the preceding state + // from which we generated this state. In the end any given state in the output FSA may + // have many preceding-states; this is just the one from which this + // particular DetState structure was generated. It's remembered so that + // we can correctly output the arc in the output FSA from `src_output_state`. + int32_t src_output_state; + // pending_symbol is the symbol, in the output FSA, on the arc from + // src_output_state. + int32_t pending_symbol; + // `elements` can be thought of as weighted subsets of states in the input // FSA, that also stores some traceback information that lets us compute @@ -470,14 +486,16 @@ class DetState { to weighted input arcs. @param [in,out] state_map Maps from DetState to int32_t state-id in the output FSA. + @return Returns a number that approximately indicates how much + computation was done (so we can avoid it taking too long). */ - void ProcessArcs(const WfsaWithFbWeights &wfsa_in, - float prune_cutoff, - vector *arcs_out, - vector *arc_weights_out, - vector *derivs_per_arc, - DetStateMap *state_map, - DetStatePriorityQueue *queue); + int32_t ProcessArcs(const WfsaWithFbWeights &wfsa_in, + float prune_cutoff, + vector *arcs_out, + vector *arc_weights_out, + vector *derivs_per_arc, + DetStateMap *state_map, + DetStatePriorityQueue *queue); // Computes the forward-backward weight of this DetState. This is // related to the best cost of any path through the output FSA @@ -505,6 +523,8 @@ class DetState { them, i.e. no equalities like a + b = c. This kind of requirement is necessarly for differentiablity. + This function also sets the forward_backward_weight field. + @param [in] wfsa_in The weighted FSA we are determinizing @param [out] removed_weight The part of the weight that was removed when we reduced `seq_len`, if any, will @@ -513,29 +533,6 @@ class DetState { void Normalize(const WfsaWithFbWeights &wfsa_in, float *removed_weight, std::vector *deriv_info); - private: - - /* - Called from Normalize(), this function removes duplicates in - `elements`: that is, if two elements represent paths that terminate at - the same state in `input_fsa`, we choose the one with the better - weight (or the first one in case of a tie). - */ - void RemoveDuplicatesOfStates(const Fsa &input_fsa, - const float *input_fsa_weights); - - /* - Called from Normalize(), this function removes any common prefix that the - paths in `elements` possess. If there is a common prefix it will reduce - `seq_len`, subtract the weights associated with the removed arcs from the - weights in `elements`, and set `input_arcs` to the sequence of arcs that - were removed from - */ - RemoveCommonPrefix(const Fsa &input_fsa, - const float *input_fsa_weights, - float *weight_out, - std::vector *input_arcs); - }; template @@ -546,182 +543,6 @@ bool DetStateCompare::operator()( } - - -template -void DetState::ProcessArcs( - const WfsaWithFbWeights &wfsa_in, - float prune_cutoff, - vector *arcs_out, - vector *arc_weights_out, - vector > *derivs_per_arc, - DetStateMap *state_map, - DetStatePriorityQueue *queue) { - - std::unordered_map > > label_to_state; - - - Fsa *fsa = wsfa_in.fsa; - const float *arc_weights = wfsa_in.arc_weights; - for (const std::shared_ptr &state_ptr: elements) { - int32_t state_id = state_ptr->state_id, - begin_arc = fsa->arc_indexes[state_id], - end_arc = fsa->arc_indexes[state_id + 1]; - for (int32_t a = begin_arc; a < end_arc; ++a) { - const Arc &arc = fsa->arcs[a]; - float weight = arc_weights[a]; - int32_t label = arc.label; - TracebackState *state; - if (iter != label_to_state.end()) { - state = iter->second.get(); - } else { - auto new_state = std::make_shared >(seq_len + 1); - state = new_state.get(); - label_to_state[label] = std::move(new_state); - } else { - state->Accept(state_ptr, a, arc.label, weight); - - state_to_weight[arc.dest_state] = std::make_shared( - state_ptr, a, arc.label, weight); - } - } - } - CHECK(!label_to_state.empty() || - elements[0]->state_id == fsa->FinalState()); // I'm assuming the input - // FSA is connected. - - for (auto iter = state_to_weight.begin(); - iter != state_to_weight.end(); ++iter) { - int32_t label = iter->first; - std::shared_ptr &next_det_state = iter->second; - float arc_weight; - std::vector deriv_info; - next_det_state->Normalize(wfsa_in, &arc_weight, &deriv_info); - int32_t next_state_id; - if (state_map->GetOutputState(&next_state_id)) { - // State was newly created. - queue->push(iter->second); - } - arcs_out->push_back({this->state_id, next_state_id, label}); - arc_weights_out->push_back(arc_weight); - derivs_per_arc->push_back(std::move(deriv_info)); - } -} - -template -void DetState::ComputeFbWeight( - const float *backward_state_weights) { - forward_backward_weight = -std::numeric_limits::infinity(); - for (auto p: elements) { - TracebackState *state = p.second.get(); - forward_backward_weight = max(forward_backward_weight, - state->forward_prob + - backward_state_weights[state->state_id]); - } -} - -template -void DetState::Normalize(const WfsaWithFbWeights &wfsa_in, - float *removed_weight, - std::vector *deriv_info) { - std::unordered_set cur_states; - for (auto p: elements) { - TracebackState *state = p.second.get(); - cur_states.insert(state); - } - int32_t new_seq_len = GetMostRecentCommonAncestor(&cur_states); - // now cur_states.size() == 1. - CHECK_LE(new_seq_len, seq_len); - - TraceBack(&cur_states, seq_len - new_seq_len, - wfsa_in.arc_weights, - removed_weight, deriv_info); - - seq_len = new_seq_len; - normalized = true; -} - - - -void DetState::RemoveCommonPrefix(const Fsa &input_fsa, - const float *input_fsa_weights, - float *removed_weight_out, - std::vector *input_arcs) { - - CHECK_GE(seq_len, 0); - int32_t len; - auto first_path = elements.front().path, - last_path = elements.back().path; - - for (len = 1; len < seq_len; ++len) { - first_path = first_path->prev; - last_path = last_path->prev; - if (first_path == last_path) { - // Note: we are comparing pointers here. We reached the same PathElement, - // which means we reached the same state. - break; - } - } - input_arcs->clear(); - if (len < seq_len) { - /* We reach a common state after traversing fewer than `seq_len` arcs, - so we can remove a shared prefix. */ - double removed_weight = 0.0; - int32_t new_seq_len = len, - removed_seq_len = seq_len - len; - input_arcs->resize(removed_seq_len); - // Advance base_state - int32_t new_base_state = input_fsa.arcs[first_path->arc_index].src_state; - for (; len < seq_len; ++len) { - auto arc = input_fsa.arcs[first_path->arc_index]; - input_arcs[seq_len - 1 - len] = first_path->arc_index; - removed_weight += input_fsa_weights[first_path->arc_index]; - first_path = first_path->prev; - } - // Check that we got to base_state. - CHECK((self->base_state == 0 && first_path == nullptr) || - fsa.arcs[first_path->arc_index].dest_state == this->base_state); - this->base_state = new_base_state; - if (removed_weight != 0) { - for (DetStateElement &det_state_elem: elements) { - det_state_elem.weight -= removed_weight; - } - } - *removed_weight_out = removed_weight; - } else { - *removed_weight_out = 0; - input_arcs->clear(); - } -} - -void DetState::CheckElementOrder(const Fsa &input_fsa) const { - // Checks that the DetStateElements are in a lexicographical order on the - // lists of states in their paths. This will be true becase of how we - // construct them (it requires on the IsArcSorted() property, whereby arcs - // leaving each state in the FSA are sorted first on label and then on - // dest_state. - if (seq_len == 0) { - CHECK(elements.size() == 1); - CHECK(elements.front().weight == 0.0); - } - - std::vector prev_seq; - for (auto iter = elements.begin(); iter != elements.end(); ++iter) { - auto path = iter->path; - std::vector cur_seq; - for (int32_t i = 0; i < seq_len; i++) { - cur_seq.push_back(input_fsa.arcs[path->arc_index].src_state); - path = path->prev; - } - std::reverse(cur_seq.begin(), cur_seq.end()); - if (iter != elements.begin()) { - CHECK(cur_seq > prev_seq); - } - prev_seq.swap(cur_seq); - } -} - - /* This class maps from determinized states (DetState) to integer state-ids in the determinized output. Caution: it uses a randomized algorithm that @@ -733,29 +554,29 @@ class DetStateMap { public: /* - Outputs the output state-id corresponding to a specific DetState structure. - This does not store any pointers to the DetState or its contents, so - you can delete the DetState without affecting this object's ability to map - an equivalent DetState to the same state-id. - - @param [in] a The DetState that we're looking up - @param [out] state_id The state-index in the output FSA - corresponding to this DetState (will - be freshly allocated if an equivalent of - this DetState did not already exist. + Looks up the output state-id corresponding to a specific DetState structure, + creating a new output-state if necessary. This does not store any pointers + to the DetState or its contents, so you can delete the DetState without + affecting this object's ability to map an equivalent DetState to the same + state-id. + + @param [in,out] a The DetState whose state-id we are looking up. + The integer id of the output-FSA state will be written + to its `output_state` field, which at entry is assumed + to be unset. @return Returns true if this was a NEWLY CREATED state, false otherwise. */ - bool GetOutputState(const DetState &a, int32_t *state_id) { + bool GetOutputState(DetState *a) { std::pair compact; DetStateToCompact(a, &compact); auto p = map_.insert({compact, cur_output_state}); bool inserted = p.second; if (inserted) { - *state_id = cur_output_state_++; + a->state_id = cur_output_state_++; return true; } else { - *state_id = p.first->second; + a->state_id = p.first->second; return false; } } @@ -771,7 +592,8 @@ class DetStateMap { int32_t cur_output_state_{0}; - std::unordered_map, int32_t, + /* maps from 128-bit key (stored as a pair of uint64_t's) to the int32_t state-id. */ + std::unordered_map, uint32_t, PairHasher> map_; /* Turns DetState into a compact form of 128 bits. Technically there @@ -810,15 +632,205 @@ class DetStateMap { }; }; -void DeterminizeMax(const WfsaWithFbWeights &a, float beam, Fsa *b, - std::vector > *arc_map) { - // TODO(dpovey): use glog stuff. - assert(IsValid(a) && IsEpsilonFree(a) && IsTopSortedAndAcyclic(a)); - if (a.arc_indexes.empty()) { - b->Clear(); - return; +/* + Convenience function that normalizes the state and outputs the arc for it. + + Returns true if the state was newly added (not already present in + `state_map`). + */ +template +bool NormalizeStateAndOutputArc( + DetState *state, + const WfsaWithFbWeights &wfsa_in, + float prune_cutoff, + vector *arcs_out, + vector *arc_weights_out, + vector > *derivs_per_arc, + DetStateMap *state_map) { + float arc_weight; + std::vector deriv_info; + state->Normalize(wfsa_in, &arc_weight, &deriv_info); + int32_t next_state_id; + bool is_new_state = state_map->GetOutputState(state); + arcs_out->push_back({this->state_id, next_state_id, label}); + arc_weights_out->push_back(arc_weight); + derivs_per_arc->push_back(std::move(deriv_info)); + return is_new_state; +} + + +template +int32_t DetState::ProcessArcs( + const WfsaWithFbWeights &wfsa_in, + double prune_cutoff, + vector *arcs_out, + vector *arc_weights_out, + vector > *derivs_per_arc, + DetStateMap *state_map, + DetStatePriorityQueue *queue) { + int32_t num_steps = 0; + + std::unordered_map > > label_to_state; + + Fsa *fsa = wsfa_in.fsa; + const float *arc_weights = wfsa_in.arc_weights; + for (const std::shared_ptr &state_ptr: elements) { + int32_t state_id = state_ptr->state_id, + begin_arc = fsa->arc_indexes[state_id], + end_arc = fsa->arc_indexes[state_id + 1]; + num_steps += end_arc - begin_arc; + for (int32_t a = begin_arc; a < end_arc; ++a) { + const Arc &arc = fsa->arcs[a]; + float weight = arc_weights[a]; + int32_t label = arc.label; + + + auto ret = label_to_state.insert({label, nullptr}); + auto iter = ret.first; + if (ret.second) { // Inserted -> this label was not a key in this map. + // Allocate new DetState. + iter->second = std::make_shared >(seq_len + 1, + this->output_state, + label); + } + TracebackState *state = iter->second.get(); + state->Accept(state_ptr, a, arc.label, weight); + } + } + CHECK(!label_to_state.empty() || + elements[0]->state_id == fsa->FinalState()); // I'm assuming the input + // FSA is connected. + + + for (auto iter = label_to_state.begin(); + iter != label_to_state.end(); ++iter) { + std::shared_ptr &det_state = iter->second; + + float arc_weight; + std::vector deriv_info; + det_state->Normalize(wfsa_in, &arc_weight, &deriv_info); + if (det_state->forward_backward_weight >= prune_cutoff) { + bool is_new_state = state_map->GetOutputState(state); + arcs_out->push_back({this->state_id, next_state_id, label}); + arc_weights_out->push_back(arc_weight); + derivs_per_arc->push_back(std::move(deriv_info)); + if (is_new_state) + queue->push(det_state); + } + } + return num_steps; +} + +template +void DetState::ComputeFbWeight( + const float *backward_state_weights) { + forward_backward_weight = -std::numeric_limits::infinity(); + for (auto p: elements) { + TracebackState *state = p.second.get(); + forward_backward_weight = max(forward_backward_weight, + state->forward_prob + + backward_state_weights[state->state_id]); } - float cutoff = a.backward_state_weights[0] - beam; - // TODO(dpovey) } + +template +double LogSumOrMax(double, double); + +template <> +double LogSumOrMax(double a, double b) { + return max(a, b); +} +template <> +double LogSumOrMax(double a, double b) { + return LogSum(a, b); +} + + +template +void DetState::Normalize(const WfsaWithFbWeights &wfsa_in, + float *removed_weight, + std::vector *deriv_info) { + std::unordered_set cur_states; + + double fb_prob = -std::numeric_limits::infinity(); + for (auto p: elements) { + TracebackState *state = p.second.get(); + fb_prob = LogSumOrMax( + fb_prob, + state->forward_prob + wfsa_in.backward_state_weights[state->state_d]); + cur_states.insert(state); + } + + int32_t new_seq_len = GetMostRecentCommonAncestor(&cur_states); + // now cur_states.size() == 1. + CHECK_EQ(cur_states.size(), 1); + CHECK_LE(new_seq_len, seq_len); + + const TracebackState *base_state = cur_states.front().get(); + // The following statement is a correction term that we add to + // forward_backward_prob, in which we replace the forward_prob in the DetState + // (which will have been computed in a path-dependent way) with the + // forward_prob in wfsa_in. Note: the values of state->forward_prob above can + // be thought of as base_state->forward_prob plus some value that only depends + // on the symbol sequence. The point of this is to ensure that + // this->forward_backward_prob (which is used for pruning) depends only on the + // base_state and the symbol sequence, and not on "how we got here", i.e. the + // history of DetStates from which this one is derived via ProcessArcs(). + fb_prob += wfsa_in.forward_state_weights[base_state->state_id] - + base_state->forward_prob; + // set thi->forward_backward_prob; it will affect pruning. + this->forward_backward_prob = fb_prob; + this->seq_len = new_seq_len; + + // the following will set removed_weight and deriv_info. + TraceBack(&cur_states, seq_len - new_seq_len, + wfsa_in.arc_weights, + removed_weight, deriv_info); + + normalized = true; +} + + + +void DeterminizePrunedLogSum( + const WfsaWithFbWeights &wfsa_in, + float beam, + int64_t max_step, + Fsa *fsa_out, + std::vector *arc_weights_out, + std::vector > > *arc_derivs_out) { + CHECK_GT(beam, 0); + + DetStatePriorityQueue queue; + DetStateMap map; + + std::shared_ptr start_state = std::make_shared(); + + std::vector arcs_out; + arc_weights_out->clear(); + arc_derivs_out->clear(); + + bool ans = map.GetOutputState(start_state.get()); + CHECK(ans && ans->state_id == 0); + + if (max_step <= 0) + max_step = std::numeric_limits::max(); + int64_t num_steps = 0; + int32_t block_size = 32; // process a number of queue elements at a time + // between certain checks.. + + double total_prob = wfsa_in.backward_state_weights[0], + prune_cutoff = total_prob - beam; + while (num_steps < max_step && !queue.empty()) { + std::shared_ptr state = queue.top(); + queue.pop(); + num_steps += state->ProcessArcs(wfsa_in, prune_cutoff, arcs_out, + arc_weights_out, arc_derivs_out, + &map, &queue); + } +} + +// TODO: do the max version of Determinize(), which is much the same as the +// LogSum version. + } // namespace k2 diff --git a/k2/csrc/fsa_algo.h b/k2/csrc/fsa_algo.h index 9e664cd76..0aad41a47 100644 --- a/k2/csrc/fsa_algo.h +++ b/k2/csrc/fsa_algo.h @@ -298,10 +298,72 @@ void ArcSort(const Fsa &a, Fsa *b, std::vector *arc_map = nullptr); bool TopSort(const Fsa& a, Fsa* b, std::vector* state_map = nullptr); /** + Pruned determinization with log-sum on weights (interpret them as log-probs), + equivalent to log semiring + @param [in] a Input FSA `a` to be determinized. Expected to be epsilon free, but this + is not checked; in any case, epsilon will be treated as a normal symbol. + Forward-backward weights must be provided for pruning purposes; + a.weight_type must be kLogSumWeight. + @param [in] beam Pruning beam; should be greater than 0. + @param [in] max_step Maximum number of computation steps before we return + (or if <= 0, there is no limit); provided so users can limit the time + taken in pathological cases. + @param [out] b Output FSA; will be deterministic. For a symbol sequence S accepted by a, + the total (log-sum) weight of S in a should equal the total (log-sum) weight + of S in b (as discoverable by composition then finding the total + weight of the result), except as affected by pruning of course. + @param [out] b_arc_weights Weights per arc of b. + @param [out] arc_derivs Indexed by arc in b, this is a list of pairs (arc_in_a, x) + where 0 < x <= 1 is the derivative of that arc's weight w.r.t. the + weight of `arc_in_a` in a. Note: the x values may actually be zero + if the pruning beam is very large, due to limited floating point range. + @return Returns the effective pruning beam, a value >= 0 which is the difference + between the total weight of the output FSA and the cost of the last + arc expanded. +*/ +float DeterminizePrunedLogSum( + const WfsaWithFbWeights &a, + float beam, + int64_t max_step, + Fsa *b, + std::vector *b_arc_weights, + std::vector > > *arc_derivs); +/** + Pruned determinization with max on weights, equivalent to the tropical semiring. + + @param [in] a Input FSA `a` to be determinized. Expected to be epsilon free, but this + is not checked; in any case, epsilon will be treated as a normal symbol. + Forward-backward weights must be provided for pruning purposes; + a.weight_type must be kMaxWeight. + @param [in] beam Pruning beam; should be greater than 0. + @param [in] max_step Maximum number of computation steps before we return + (or if <= 0, there is no limit); provided so users can limit + the time taken in pathological cases. + @param [out] b Output FSA; will be deterministic For a symbol sequence + S accepted by a, the best weight of symbol-sequence S in + a should equal the best weight of S in b (as discoverable + by composition then finding the total weight of the + result), except as affected by pruning of course. + @param [out] b_arc_weights Weights per arc of b. Note: these can be + computed from arc_derivs and the weights of a, so this + output is not strictly necessary; it's provided mostly due + to sharing the internal code with the log-sum version. + @param [out] arc_derivs Indexed by arc in `b`, this is the sequence of + arcs in `a` that this arc in `b` corresponds to; the + weight of the arc in b will equal the sum of those input + arcs' weights. + @return Returns the effective pruning beam, a value >= 0 which is the difference + between the total weight of the output FSA and the cost of the last + arc expanded. */ -void Determinize(const Fsa &a, Fsa *b, - std::vector> *state_map); +float DeterminizePrunedMax(const WfsaWithFbWeights &a, + float beam, + int64_t max_step, + Fsa *b, + std::vector *b_arc_weights, + std::vector *arc_derivs); + /* Create an acyclic FSA from a list of arcs. From a79f527e32d6fe064f704e0987c263bbf7894e16 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 13 May 2020 23:43:54 +0800 Subject: [PATCH 14/14] More work on Determinize code. --- k2/csrc/determinize.cc | 468 ++++++++++++++++++++++++++--------------- 1 file changed, 301 insertions(+), 167 deletions(-) diff --git a/k2/csrc/determinize.cc b/k2/csrc/determinize.cc index 4d23b8cc6..e1ebcb8a5 100644 --- a/k2/csrc/determinize.cc +++ b/k2/csrc/determinize.cc @@ -20,47 +20,166 @@ using std::priority_queue; using std::pair; +/* + HOW THIS WORKS + + This is FSA determinization that also outputs derivative information that says how + the weights on the arcs in the output FSA vary with the weights on the arcs in the + input FSA. + + INTRO TO DETERMINIZATION. + + The problem in determinization of a weighted FSA is to find a deterministic + FSA (i.e. one that has no two arcs leaving any given state with the same symbol + on), which is equivalent to the input FSA (meaning: the weight it assigns + to any given symbol-sequence is the same as the input FSA). In this explanation, + assume epsilons don't exist, if the input FSA had epsilons we'd get rid of + them prior to determinization. -// setting this to true will reduce memory consumption, but could increase -// compute time in cases where we finish before the beam. -constexpr bool PROCESS_ARCS_IMMEDIATELY = false; + SUBSET-BASED ALGORITHMS + In general, in determinization algorithms, states in the output FSA correspond + to weighted subsets of states in the input FSA. The overall structure of these + algorithms will be: -struct MaxTracebackState { - using DerivOutputType = int32_t; + - Let state 0 in the output FSA correspond to the weighted subset { 0, 0.0 } + in the input FSA, where the 0 is the start state-id in the input FSA + and the 0.0 is the weight (interpret this as a log-prob). + Put that in the queue. + + - While (queue is not empty) + Pop output-state from the queue + Process successor states of this output-state. + + Obviously most of the detail in the outline above resides in "Process + successor states of this output-state." Let's discuss the unweighted + case first. + + *Unweighted case + + Each output-state corresponds to some subset of input-states, say, { s1, s2, + .. }. The set of labels on arcs leaving the output-state will correspond to + the set of labels on all the arcs leaving s1, s2 and so on; and the + destination-states of those arcs will correspond to the sets of + destination-states of those arcs. The algorithm requires us to store a map + from (subset of input-state-ids) to (output-state-id). + + *Weighted case + + In the weighted case, the difference is that instead of a set of + input-state-ids we have a weighted subset of them, and the map is from these + weighted subsets to state-ids in the output FSA. The weights in the weighted + subsets have to be normalized somehow. The natural normalization is in the + "max/tropical-semiring" case to have the most negative weight be 0.0, and in + the "log-sum/log-semiring" case to have the log-sum be 0.0. Imagine + we have a function: + Normalize (unnormalized-weighted-subset) -> normalized-weighted-subset, leftover-weight + E.g., in the Max case: + Normalize( { (6, 1.0), (7, 5.0) } ) -> { (6, 0.0), (7, 4.0) }, 1.0 + The "leftover-weights" become the weights on the arcs in the output FSA. + + + *The problem with differentability + + Consider how to differentiate the weights of the output weighted FSA + w.r.t. those of the input. The problem with differentiability if we use the + algorithm above is the case of special symmetries. What if two weighted + subsets happen to coincide because there was an exact relationship between the + values of the weights in the input FSA, but there was no *structural* reason + in the input FSA why those weighted subsets have to be the same? Then we + have a problem with how to differentiate, because any small change in the + input weights would lead to a structural change in the output FSA. + + + OUR ALGORITHM + *Different representation of subsets + + Our algorithm is still very similar to the subset-based algorithms mentioned + above, and it still involves weighted subsets, but we use a different + representation of them. Our representation (think of this as the key in + the map) is: ( base_state, symbol_sequence ). The relationship with + the weighted subset is: start from state `base_state` in the input FSA, + traverse all sequences of arcs that have sequence `symbol_sequence` on them, + and the weighted set of states you end up with is the weighted subset + in the algorithm above. + + *Different normalization + + Our form of "normalization" of this representation is differen too. The + normalization is to make `symbol_sequence` as short as possible, and advance + `base_state` to compensate. For instance, if `symbol_sequence` is `a b c + d`, but the weighted subset of states we can reach by this symbol sequence + is the same as what we'd get by moving `base_state` forward two steps and + letting `symbol_sequence` be just `c d`, then the latter representation is + the canonical one (assuming that was the largest prefix we could remove). + Algorithmically, finding the "most recent base_state" involves finding the + most recent common ancestor in a directed tree of paths through the input + FSA (in the max case) or a directed lattice of paths through the input FSA + (in the log-sum case). + + The weights on arcs are related to the total weight of paths from `original + base_state` to `new base_state`, (counting only paths that have the removed + symbol sequence `a b`). Just as with the subset algorithm, these + weights are what gets "spit out" by the normalization process; we simply + have a different normalization process. + + + PRUNED DETERMINIZATION + + We support pruned determinization. We won't describe all of the details + here, but it involves a priority queue of determinized states, so we always + process the "best" queue element first, and we may terminate before the + queue is empty. + + + IMPLEMENTATION DETAILS + + A few details on the implementation: + + - To save memory space, the process of hashing from `base_state, symbol_seq` + to output state-id maps them to a fixed-size 128-bit value. This could + in principle generate collisions which would generate incorrect output, + but we consider that vanishingly improbable. + + */ + +struct MaxTracebackState { + using DerivType = int32_t; int32_t state_id; // state-id in the input FSA - int32_t arc_id; // arc-id in input FSA of the arc that enters - // state `state_id` (or -1 if this is the start state). + int32_t arc_id; // arc-id in input FSA of the arc that enters state + // `state_id` (or -1 if this is the start state). It will + // be the best arc if there were multiple possible arcs + // from the base_state to this state with the same symbol + // sequence. - // prev_state is the state we trace back to (the previous state), - // which is the src_state of the arc numbered arc_id. - // It will be nullptr if state_id == 0 (start state). + // prev_state is the state we trace back to (the previous state), which is the + // src_state of the arc numbered arc_id. It will be nullptr if state_id == 0 + // (start state). shared_ptr prev_state; - double forward_prob; // The total forward log-probability from the start + double forward_prob; // The best forward log-probability from the start // state to this state (along whichever specific - // sequence of symbols we took to get here; it's not - // necessarily the best forward in the lattice). + // sequence of symbols we took to get here) - // This constructor is for the start-state of the input FSA. - MaxTracebackState(): state_id(0), arc_id(-1), arc_symbol(-1), + // This constructor is for the start-state (state zero) of the input FSA. + MaxTracebackState(): state_id(0), arc_id(-1), prev_state(nullptr), forward_prob(0.0) { } /** @param [in] state_id State in input FSA that this corresponds to @param [in] src Previous LogSumTracebackState that we'll point back to, or NULL - @param [in] incoming_arc_index. Its src_state will equal src->state_id, + @param [in] incoming_arc_index Arc-index in input FSA. + Its src_state will equal src->state_id, its dest_state will equal state_id. - @param [in] src_symbol Symbol on the input arc @param [in] arc_weight Weight on the input arc */ MaxTracebackState(int32_t state_id, - const std::shared_ptr &src, + const shared_ptr &src, int32_t incoming_arc_index, int32_t arc_weight): state_id(state_id), @@ -68,14 +187,18 @@ struct MaxTracebackState { prev_state(src), forward_prob(element->forward_prob + arc_weight) { } - void Accept(const std::shared_ptr &src, - int32_t arc_index, int32_t _symbol, float arc_weight) { + /* + This takes the same args as the constructor. It will update the traceback + info if this incoming arc had higher weight. + */ + void Accept(const shared_ptr &src, + int32_t arc_index, float arc_weight) { double new_forward_prob = src->forward_prob + arc_weight; if (new_forward_prob > forward_prob) { forward_prob = new_forward_prob; arc_id = arc_index; prev_state = src; - // state_id doesn't change, nor does _symbol. + // state_id doesn't change. } } }; @@ -83,10 +206,13 @@ struct MaxTracebackState { class LogSumTracebackState; -// This struct is used inside LogSumTracebackState; it represents an -// arc that traces back to a previous LogSumTracebackState. -// A LogSumTracebackState represents a weighted colletion of paths -// terminating in a specific state. + +/* + This struct is used inside LogSumTracebackState; it represents an + arc that traces back to a previous LogSumTracebackState. + A LogSumTracebackState represents a weighted colletion of paths + terminating in a specific state. +*/ struct LogSumTracebackLink { shared_ptr prev_state; @@ -108,7 +234,7 @@ struct LogSumTracebackLink { // with prev_state == nullptr (and state_id == 0). - LogSumTracebackLink(const std::shared_ptr &src, + LogSumTracebackLink(const shared_ptr &src, int32_t arc_index, float arc_weight): src(src), arc_index(arc_index), forward_prob(arc_weight + src->forward_prob) { } @@ -117,21 +243,23 @@ struct LogSumTracebackLink { }; -struct LogSumTracebackState { - using DerivOutputType = pair; +/* + This stores traceback information for the log-sum case. Rather than a tree + structure, the LogSumTracebackStates interconnect with a lattice structure. - // LogSumTracebackState can be thought of as as a weighted set of paths from - // the start state to a particular state. (It will be limited to the subset - // of paths that have a specific symbol sequence). + It can be thought of as as a weighted set of paths from the start state to a + particular state. It will be limited to the subset of paths that have a + specific symbol sequence. +*/ +struct LogSumTracebackState { + using DerivType = pair; // `prev_elements` is, conceptually, a list of incoming arcs with associated // weights. vector prev_elements; int32_t state_id; // The state-id in the input FSA that this - // LogSumTracebackState corresponds to. (Unique to - // this determinized state; the same state-id may - // appear in multiple determinized states, in general. + // LogSumTracebackState corresponds to. double forward_prob; // The total forward log-probability from the start // state to this state (along whichever specific @@ -147,18 +275,16 @@ struct LogSumTracebackState { // input FSA and the determinized FSA). LogSumTracebackState(): state_id(0), forward_prob(0.0) { } - /** - @param [in] state_id State in input FSA that this corresponds to - @param [in] src Previous LogSumTracebackState that we'll point back - to, or nullptr if this belongs to the initial - determinized-state. + /* + @param [in] state_id State in input FSA + @param [in] src Previous LogSumTracebackState that we'll point back to @param [in] incoming_arc_index. Arc-index in input FSA. Its src_state will equal src->state_id, its dest_state will equal state_id. @param [in] arc_weight Weight on the arc */ LogSumTracebackState(int32_t state_id, - const std::shared_ptr &src, + const shared_ptr &src, int32_t incoming_arc_index, int32_t arc_weight): state_id(state_id), @@ -167,10 +293,17 @@ struct LogSumTracebackState { } /* - Accept a new incoming link. The args are the same as for - the constructor just above; see documentation there. + Accept a new incoming link. The args are the same as for the constructor + just above; see documentation there. + + @param [in] src Previous LogSumTracebackState that we'll point back to + @param [in] incoming_arc_index. Arc-index in input FSA. + Its src_state will equal src->state_id, its dest_state + will equal this->state_id. + @param [in] arc_weight Weight on the incoming arc + */ - void Accept(const std::shared_ptr &src, + void Accept(const shared_ptr &src, int32_t arc_index, float arc_weight) { double link_forward_prob = src.forward_prob + arc_weight; prev_elements.emplace_back(src, arc_index, link_forward_prob); @@ -178,12 +311,10 @@ struct LogSumTracebackState { } }; - /* Find the most recent common ancestor LogSumTracebackState of a set of LogSumTracebackStates, and return the number of links we had to follow to get - there (i.e. the length of symbol sequence we had to remove from - each path). + there (i.e. the length of symbol sequence). @param [in,out] cur_states A set of TracebackStates that we'll trace back from. Must be nonempty. Equality @@ -191,6 +322,7 @@ struct LogSumTracebackState { At exit it will contain a single member which will be the most recent common ancestor. + @return Returns the number of links we had to follow (>=0). If `cur_states.size() == 1` this will be zero. */ @@ -200,8 +332,8 @@ int32_t GetMostRecentCommonAncestor( std::unordered_set prev_states; for (; cur_states->size() != 1; ans++) { CHECK(!cur_states->empty()); - for (LogSumTracebackState *s: cur_states) { - for (const LogSumTracebackLink &l: s->prev_elements) { + for (LogSumTracebackState *s: *cur_states) { + for (LogSumTracebackLink &l: s->prev_elements) { prev_states.insert(l.prev_state.get()); } } @@ -220,7 +352,7 @@ int32_t GetMostRecentCommonAncestor( std::unordered_set prev_states; for (; cur_states->size() != 1; ans++) { CHECK(!cur_states->empty()); - for (MaxTracebackState *s: cur_states) { + for (MaxTracebackState *s: *cur_states) { prev_states.insert(s->prev_state.get()); } cur_states->clear(); @@ -239,8 +371,7 @@ int32_t GetMostRecentCommonAncestor( expect it to contain the same set on exit). A set of states; we'll iteratively trace back this set one step at a time. At entry it must have - size() == 1; it will also have size() == 1 after - `num_steps` steps. + size() == 1; it will also have size() == 1 at exit. @param [in] num_steps The number of steps to trace back @param [in] arc_weights_in Weights on the arcs of the input FSA @param [out] weight_out The output weight; will be the forward-backward @@ -257,8 +388,8 @@ int32_t GetMostRecentCommonAncestor( (input_arc_id, deriv) where, mathematically, 0 < deriv <= 1 (but we might still get exact zeros due to limitations of floating point representation). - Note: the sum of the float values in this vector should - be equal to `num_steps`. + Note: the sum of the float values in this vector at + exit should be equal to `num_steps`. */ void TraceBack(std::unordered_set *cur_states, @@ -280,7 +411,8 @@ void TraceBack(std::unordered_set *cur_states, double backward_prob = state_ptr->backward_prob; for (auto link: state_tr->prev_elements) { float arc_log_posterior = link.forward_prob + backward_prob; - deriv_out->push_back(std::pair(link.arc_index, expf(log_posterior))); + deriv_out->push_back(std::pair(link.arc_index, + expf(arc_log_posterior))); LogSumTracebackState *prev_state = link.prev_state.get(); double new_backward_prob = backward_prob + arc_weights_in[link.arc_index]; if (prev_states.insert(prev_state).second) { // newly inserted @@ -301,36 +433,34 @@ void TraceBack(std::unordered_set *cur_states, *weight_out = cur_forward_prob - prev_forward_prob; // The following is mostly for ease of interpretability of the output; // conceptually the order makes no difference. + // TODO(dpovey): maybe remove this, for efficiency? std::reverse(deriv_out->begin(), deriv_out->end()); } - -// See documentation of TraceBack for LogSumTracebackState, above. -// This version is simpler. +// The TraceBack function for MaxTracebackState. See documentation of TraceBack +// for LogSumTracebackState, above. This version is simpler. void TraceBack(std::unordered_set *cur_states, int32_t num_steps, const float *, // arc_weights_in, unused. float *weight_out, std::vector *deriv_out) { - // we recompute the arc weight sum from arc_weights_in, which should - // hopefully give - float arc_weight_sum = 0.0; CHECK_EQ(cur_states.size(), 1); MaxTracebackState *state = cur_states->front(); double cur_forward_prob = state->forward_prob; deriv_out->resize(num_steps); for (int32_t i = num_steps - 1; i >= 0; i--) { + // `deriv_out` is just a list of arc indexes in the input FSA + // that this output arc depends on (it's their sum). (*deriv_out)[i] = state->arc_id; } double prev_forward_prob = state->forward_prob; *weight_out = cur_forward_prob - prev_forward_prob; } -// Priority queue templated on: +// Priority queue template arguments: // item queued = unique_ptr (using pointer equality as comparison) // container type = vector > // less-than operator = DetStateCompare (which compares the forward_backward_prob). - template using DetStatePriorityQueue = priority_queue >, vector > >, @@ -340,50 +470,47 @@ using DetStatePriorityQueue = priority_queue template class DetStateMap; + /* - Conceptually, a determinized state in weighted FSA determinization would - normally be a weighted subset of states in the input FSA, with the weights - normalized somehow (e.g. subtracting the sum of the weights). - - Two determinized states are equal if the states and weights are the same. To - ensure differentiability, our assumption is that in general no two arcs in the - input FSA have identical weights. We argue that two determinized states can - always be represented as a base-state and a symbol sequence. Imagine that we - follow arcs with that symbol sequence from the base-state, and then in case we - reach the same states in the different ways we always select the best path - from the base-state. That process gives us a set of states and weights. We - argue that this representation is unique. (If not, it won't matter actually; - it will just give us an output that's less minimal than it could be). - - - We're not really following the Google guidelines by not having _ at the end of - class members, but this is more struct-like (members are public). + This represents a determinized state. Initially it has normalized == false + and it represents an un-normalized determinized state (see intro at the top + of this file), or an un-normalized determinized state under construction + (we add to it using AcceptIncomingArc()). + + After we call Normalize() on it, it is a normalized determinized-state (this + also outputs the weight you need for the incoming arc). + + After that + */ + template // TracebackState == MaxTracebackState or LogSumTracebackState class DetState { - - public: - using DerivOutputType = typename TracebackState::DerivOutputType; - // .. and DerivOutputType == int32_t or pair - // respectively. + using DerivType = typename TracebackState::DerivType; + // DerivType == int32_t for MaxTracbackState, or + // pair for LogSumTracebackState. // Constructor for the initial state of the determinized FSA DetState(): seq_len(0), output_state(-1), normalized(true) { - // the constructor that takes no args gives us what we need for the - // start-state. + // the constructor of TracebackState that takes no args gives us what we + // need for the start-state. elements[0] = std::make_shared(); } - // TODO: constructor for start-state. - DetState(int32_t seq_len, int32_t src_output_state, - int32_t pending_symbol): + /* + Constructor (this is the one that's normally used). + @param [in] seq_len Length of symbol sequence from its + base_state (this is before normalization). + Will be the seq_len of the source det_state plus + one. This seq_len may end up getting reduced + when Normalize() is called (reducing seq_len + implicitly advances the base_state). + */ + DetState(int32_t seq_len): seq_len(seq_len), - output_state(-1), // Not known yet - normalized(false), - src_output_state(src_output_state), - pending_symbol(pending_symbol) { } // .. and forward_backward_weight undefined + normalized(false) { } // .. and forward_backward_prob undefined /** Process incoming arc to this DetState. See documentation for @@ -391,21 +518,24 @@ class DetState { LogSumTracebackState). @param [in] state_id State-id, in input FSA, into which this arc enters. [Note: a DetState is a weighted subset of state-ids.] + @param [in] src The preceding state (from which the arc leaves). + This will be a member of the `elements` of the "parent" + DetState, i.e. the DetState in whose ProcessArcs() function + this DetState was created. @param [in] incoming_arc_index Arc in input FSA that enters state `state_id`. - @param [in] arc_symbol The symbol on + @param [in] arc_weight The weight on this arc */ void AcceptIncomingArc(int32_t state_id, - const std::shared_ptr &src, + const shared_ptr &src, int32_t incoming_arc_index, int32_t arc_weight) { - auto iter = elements.find(state_id); - if (iter == elements.end()) { - elements[state_id] = std::make_shared( + auto ret = elements.insert({state_id, nullptr}); + if (!ret.second) { // No such state existed in `elements` + ret.first->second = std::make_shared( state_id, src, incoming_arc_index, arc_weight); - } else { - iter.second->Accept( - src, incoming_arc_index, arc_symbol, arc_weight); + } else { // A state with this staste_id existed in `elements`. + ret.first->second->Accept(src, incoming_arc_index, arc_weight); } } @@ -419,44 +549,23 @@ class DetState { // which you pick, the result will be the same). int32_t seq_len; - // `output_state` is the state in the output FSA that this determinized - // state corresponds to. (Only known if `normalized` is true; otherwise, -1). - int32_t output_state; - // `normalized` is true if this DetState is known to be normalized (meaning: // we have reduced seq_len as much as possible). DetStates that are not // normalized will not yet have an `output_state`. bool normalized; - // The following two elements are only relevant if `normalized` is false. - // The purpose is so that we can delay outputting the arc leading to this - // DetState, until we know that this DetState will end up having arcs - // leaving it processed (see ProcessArcs()). For pruned determinization, - // this avoids unnecessary work. - // - // src_output_state is the state, in the output FSA, of the preceding state - // from which we generated this state. In the end any given state in the output FSA may - // have many preceding-states; this is just the one from which this - // particular DetState structure was generated. It's remembered so that - // we can correctly output the arc in the output FSA from `src_output_state`. - int32_t src_output_state; - // pending_symbol is the symbol, in the output FSA, on the arc from - // src_output_state. - int32_t pending_symbol; - - // `elements` can be thought of as weighted subsets of states in the input // FSA, that also stores some traceback information that lets us compute // derivatives. // It's a map from (state-id in input FSA) -> its corresponding TracebackState. - std::unordered_map > elements; + std::unordered_map > elements; // This is the weight on the best path that includes this determinized state. // It's needed to form a priority queue on DetStates, so we can process them // best-first. It is computed as: the forward-weight on `base_state`, // plus the best/most-positive of: (the weight in a DetStateElement plus // the backward-weight of the state associated with that DetStateElement). - double forward_backward_weight; + double forward_backward_prob; /* @@ -493,7 +602,7 @@ class DetState { float prune_cutoff, vector *arcs_out, vector *arc_weights_out, - vector *derivs_per_arc, + vector *derivs_per_arc, DetStateMap *state_map, DetStatePriorityQueue *queue); @@ -523,7 +632,7 @@ class DetState { them, i.e. no equalities like a + b = c. This kind of requirement is necessarly for differentiablity. - This function also sets the forward_backward_weight field. + This function also sets the forward_backward_prob field. @param [in] wfsa_in The weighted FSA we are determinizing @param [out] removed_weight The part of the weight that was @@ -532,14 +641,14 @@ class DetState { */ void Normalize(const WfsaWithFbWeights &wfsa_in, float *removed_weight, - std::vector *deriv_info); + std::vector *deriv_info); }; template bool DetStateCompare::operator()( const shared_ptr > &a, const shared_ptr > &b) { - return a->forward_backward_weight < b->forward_backward_weight; + return a->forward_backward_prob < b->forward_backward_prob; } @@ -584,6 +693,7 @@ class DetStateMap { int32_t size() const { return cur_output_state_; } private: + // simple hashing function that just takes the first element of the pair. struct PairHasher { size_t operator () (const std::pair &p) const { return static_cast(p.first); @@ -645,10 +755,10 @@ bool NormalizeStateAndOutputArc( float prune_cutoff, vector *arcs_out, vector *arc_weights_out, - vector > *derivs_per_arc, + vector > *derivs_per_arc, DetStateMap *state_map) { float arc_weight; - std::vector deriv_info; + std::vector deriv_info; state->Normalize(wfsa_in, &arc_weight, &deriv_info); int32_t next_state_id; bool is_new_state = state_map->GetOutputState(state); @@ -665,16 +775,18 @@ int32_t DetState::ProcessArcs( double prune_cutoff, vector *arcs_out, vector *arc_weights_out, - vector > *derivs_per_arc, + vector > *derivs_per_arc, DetStateMap *state_map, DetStatePriorityQueue *queue) { int32_t num_steps = 0; - std::unordered_map > > label_to_state; + std::unordered_map* > label_to_state; + // The following loop populates `label_to_state`, creating successor + // DetStates (unnormalized). Fsa *fsa = wsfa_in.fsa; const float *arc_weights = wfsa_in.arc_weights; - for (const std::shared_ptr &state_ptr: elements) { + for (const shared_ptr &state_ptr: elements) { int32_t state_id = state_ptr->state_id, begin_arc = fsa->arc_indexes[state_id], end_arc = fsa->arc_indexes[state_id + 1]; @@ -683,56 +795,43 @@ int32_t DetState::ProcessArcs( const Arc &arc = fsa->arcs[a]; float weight = arc_weights[a]; int32_t label = arc.label; - - auto ret = label_to_state.insert({label, nullptr}); auto iter = ret.first; if (ret.second) { // Inserted -> this label was not a key in this map. // Allocate new DetState. - iter->second = std::make_shared >(seq_len + 1, - this->output_state, - label); + iter->second = new DetState >(seq_len + 1); } - TracebackState *state = iter->second.get(); - state->Accept(state_ptr, a, arc.label, weight); + DetState *det_state = iter->second.get(); + det_state->Accept(state_ptr, a, arc.label, weight); } } CHECK(!label_to_state.empty() || elements[0]->state_id == fsa->FinalState()); // I'm assuming the input // FSA is connected. - + // The following loop normalizes successor det-states, outputs the arcs + // that lead to them, and adds them to the queue if necessary. for (auto iter = label_to_state.begin(); iter != label_to_state.end(); ++iter) { - std::shared_ptr &det_state = iter->second; + DetState *det_state = iter->second; float arc_weight; - std::vector deriv_info; + std::vector deriv_info; det_state->Normalize(wfsa_in, &arc_weight, &deriv_info); - if (det_state->forward_backward_weight >= prune_cutoff) { + if (det_state->forward_backward_prob >= prune_cutoff) { bool is_new_state = state_map->GetOutputState(state); arcs_out->push_back({this->state_id, next_state_id, label}); arc_weights_out->push_back(arc_weight); derivs_per_arc->push_back(std::move(deriv_info)); if (is_new_state) - queue->push(det_state); + queue->push(std::unique_ptr >(det_state)); + } else { + delete det_state; } } return num_steps; } -template -void DetState::ComputeFbWeight( - const float *backward_state_weights) { - forward_backward_weight = -std::numeric_limits::infinity(); - for (auto p: elements) { - TracebackState *state = p.second.get(); - forward_backward_weight = max(forward_backward_weight, - state->forward_prob + - backward_state_weights[state->state_id]); - } -} - template double LogSumOrMax(double, double); @@ -749,7 +848,7 @@ double LogSumOrMax(double a, double b) { template void DetState::Normalize(const WfsaWithFbWeights &wfsa_in, float *removed_weight, - std::vector *deriv_info) { + std::vector *deriv_info) { std::unordered_set cur_states; double fb_prob = -std::numeric_limits::infinity(); @@ -791,20 +890,23 @@ void DetState::Normalize(const WfsaWithFbWeights &wfsa_in, } - -void DeterminizePrunedLogSum( +template +float DeterminizePrunedTpl( const WfsaWithFbWeights &wfsa_in, float beam, int64_t max_step, Fsa *fsa_out, std::vector *arc_weights_out, - std::vector > > *arc_derivs_out) { + std::vector > *arc_derivs_out) { CHECK_GT(beam, 0); + CHECK(IsDeterministic(*wfsa_in.fsa)); + CHECK(!IsEmpty(*wfs_in.fsa)); - DetStatePriorityQueue queue; - DetStateMap map; + DetStatePriorityQueue queue; + DetStateMap map; + using DS = DetState; - std::shared_ptr start_state = std::make_shared(); + shared_ptr start_state = std::make_shared(); std::vector arcs_out; arc_weights_out->clear(); @@ -822,15 +924,47 @@ void DeterminizePrunedLogSum( double total_prob = wfsa_in.backward_state_weights[0], prune_cutoff = total_prob - beam; while (num_steps < max_step && !queue.empty()) { - std::shared_ptr state = queue.top(); + shared_ptr state = queue.top(); queue.pop(); num_steps += state->ProcessArcs(wfsa_in, prune_cutoff, arcs_out, arc_weights_out, arc_derivs_out, &map, &queue); } + if (!queue.empty()) { // We stopped early due to max_step + return total_prob - queue.top()->forward_backward_prob; + } else { + return beam; + } +} + + +void DeterminizePrunedLogSum( + const WfsaWithFbWeights &wfsa_in, + float beam, + int64_t max_step, + Fsa *fsa_out, + std::vector *arc_weights_out, + std::vector > > *arc_derivs_out) { + CHECK_EQ(wfsa_in.weight_type, kLogSumWeight); + return DeterminizePrunedTpl *arc_weights_out, + std::vector > > *arc_derivs_out) { + CHECK_EQ(wfsa_in.weight_type, kMaxWeight); + return DeterminizePrunedTpl( + wfsa_in, beam, max_step, fsa_out, + arc_weights_out, arc_derivs_out); } -// TODO: do the max version of Determinize(), which is much the same as the -// LogSum version. + + } // namespace k2