-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpython.txt
294 lines (218 loc) · 10.9 KB
/
python.txt
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
These are some notes regarding how we'll interact with Python and PyTorch.
# Assumes that A is an acceptor but B may
# have auxiliary symbols (i.e. may be a transducer).
def TransducerCompose(a: FsaVec, a_weights: Tensor,
b: FsaVec, b_weights: Tensor,
b_aux_symbols = None):
c, indexes_a, indexes_b = fsa.FsaVecCompose(a, b)
c_weights = a_weights[indexes_a] + b_weights[indexes_b]
if b_aux_symbols is None:
return c, c_weights
else:
return c, c_weights, fsa.MapAuxSymbols(b_aux_symbols, indexes_b)
def Compose(a: FsaVec, a_weights: Tensor,
b: FsaVec, b_weights: Tensor):
c, input_indexes1, input_indexes2 = fsa.FsaVecCompose(a, b)
c_weights = a_weights[input_indexes1] + b_weights[input_indexes2]
# Handle transducers:
if a.aux_symbol != None:
c.aux_symbol = a.aux_symbol[input_indexes1]
if b.aux_symbol != None:
c.aux_symbol = b.aux_symbol[input_indexes2]
return c, c_weights
def RmEpsilon(a: FsaVec, a_weights, a_aux_symbols = None):
# At the C++ level, RmEpsilon outputs a `vector<vector<int> > indexes` that
# says, for each arc in the output, what list of arcs in the input it corresponds
# to. For exposition purposes, imagine we're dealing with a single FSA, not
# a vector of FSAs. Suppose indexes == [ [ 1, 2 ], [ 6, 8, 9 ] ], we'd form
# indexes1 = [ 1, 2, 6, 8, 9 ], and indexes2 = [ 0, 0, 1, 1, 1 ].
b, indexes1, indexes2 = fsa.FsaVecRmEpsilon(a)
# Note: the 1st dim of a_weights must equal a.num_arcs, but it is allowed
# to have other dims (e.g. for non-scalar weights).
# In the normal case, a_weights and b_weights will have just one axis.
b_weights = torch.zeros( (b.num_arcs, *a_weights.shape[1:]) )
b_weights._index_add(0, a_weights[indexes1], indexes2)
# If we later need access to indexes1 and indexes2, we can
# create a different version of this function or extend its interface.
if a_aux_symbols is None:
return b, b_weights
else:
return b, b_weights, fsa.MapAuxSymbols(a_aux_symbols, indexes1, indexes2)
class TotalWeight(Function):
"""
Returns the total weight of FSAs (i.e. the log-sum-exp across
paths) as a Tensor with shape (num_fsas,)
"""
@staticmethod
def forward(ctx, a:FsaVec, a_weights: Tensor):
ctx.a = a
ctx.fb = fsa.FsaVecWithFbWeights(a, a_weights, fsa.kLogSumWeight)
ans = fb.GetTotalWeights() # a Tensor
return ans
def backward(ctx, grad_out):
# `indexes` is a Tensor that would contain, for each arc in a,
# the index of the FSA in the FSAVec it belongs to.
# It would be of the form [ 0, 0, 0 .., 0, 1, 1, 1, .. 1, 2 ... ]
indexes = fsa.GetFsaVecIndexes(ctx.a)
# GetArcProbs would return the probability of traversing each arc of
# each FSA, as a single Tensor.
return ctx.fb.GetArcProbs() * grad_out[indexes]
# TODO: handle transfers to/from GPU in case grad_out was on GPU.
# Maybe mark this only once differentiable (it's twice differentiable,
# I think, but this code doesn't currently support that).
=====================
From here is some ideas on how we'd use this in a program.
# decoding_graph is an FsaVec, graph_weights is a float Tensor,
# graph_word_syms is a LongTensor (both with 1 axis
# decoding_graph will have the following extra attributes:
# decoding_graph.weights,
# decoding_graph.word_syms
# both of shape (num_arcs,), of dtypes float and long respectively.
decoding_graph = fsa.ReadDecodingGraph('a/b/c.fsa')
# nnet_output's shape is (num_seq, num_frames, num_symbols).. these symbols
# might be phones or letters or small word-pieces.
nnet_output = model(input_feats)
# Interpret each sequence numbered `n` in `nnet_output` as being a FSA with
# `num_frames + 1` states numbered 0, 1, ... num_frames, and for each 0 <= i
# < num_frames, arcs from state i to i+1 with each symbol `s` as the label
# and the output given by `nnet_output[n,i,s]`. This is a lightweight
# operation (except for the fact that it transfers the matrix to CPU for
# now).
# NOTE: the above is a slight simplification because the cuts may not
# span the entirety of `nnet_output`, they may be sub-sequences of frames
# within there, and `cut_info` will describe that somehow. So the ends
# of the cuts may be "ragged".
nnet_output_fsas = fsa.DenseFsaVec(nnet_output, cut_info)
# nnet_output_fsas.input_indexes is a LongTensor with dimension (num_arcs_in_nnet_output_fsas, 3)
# where the 3 is are indexes n,i,s into nnet_output.
nnet_output_fsas.input_indexes = nnet_output_fsas.GetIndexes()
# ... and implicitly nnet_output_fsas has a `weights` vector, which is a
# reshape/copy of parts of `nnet_output`. (It's not just a reshape of the
# whole thing due to its ragged structure).
# Composing with the FSAs representing the supervision; this is CTC aligment.
# alignment_fsas will be an FsaVec.
alignment_fsas = fsa.compose_pruned(nnet_output_fsas, supervision_fsas, beam=10.0)
# objf_part1 is the CTC part of the objective; we'll later add it with the others.
objf_part1 = fsa.GetTotalLogWeight(alignment_fsas).sum()
# first_pass_fsas will be an FsaVec.
# This differs from `alignment_fsas` because it allows all possible word sequences,
# not just the supervision one(s).
# For a while, now, the code path will be the same as we'll take in test time.
#
first_pass_fsas = fsa.compose_pruned(nnet_output_fsas, decoding_graph, beam=10.0)
# `first_pass_fsas` will have the following attributes:
#
# input0.{arc_indexes,weights,input_indexes}
# input1.{arc_indexes,weights,word_syms}
#
# and the following derived/computed params:
# weights = input0.weights + input1.weights
# word_syms = input1.word_syms
# We'll want to propagate `nnet_arc_indexes` forward, so we can
# keep track of alignments.
first_pass_fsas.nnet_arc_indexes = first_pass_fsas.input0.arc_indexes
# the following computes arc-level log-posteriors
log_posts = first_pass_fsas.ComputeLogPosteriors()
# Also get the symbol-level posteriors:
# The following will be the posteriors of a particular phone (or blank) at a particular
# time-index of a particular sequence. see fsautil.py.
first_pass_fsas.phone_log_posts = fsautil.sum_by_index(
log_posts.exp(),
first_pass_fsas.input0.arc_indexes).log()
# the following mean "the posterior of the current word".
first_pass_fsas.word_log_posts = fsautil.sum_by_index(
log_posts.exp(),
torch.cat(first_pass_fsas.input0.arc_indexes,
first_pass_fsas.word_syms))
# aggregate all these weights ?
# first_pass_fsas.all_weights = torch.cat(
# first_pass_fsas.input0.weights.unsqueeze(0), # acoustic weight
# first_pass_fsa.input1.weights.unsqueeze(0), # graph weight
# first_pass_fsas.phone_log_posts.unsqueeze(0),
# first_pass_fsas.word_log_posts.unsqueeze(0))
objf_part2 =
# OK, now we want to do a pass of RNNLM rescoring. Compose with a
# virtual FSA (DeterministicOnDemandFsa) that has words as the
# labels. We can make this a bit more efficient by determinizing the
# input first.
# we need the following to propagate members weights,all_weights.
# Invert the FST, swapping its symbols with word_syms.
fsas_inverted = fsa.Invert(first_pass_fsas,
first_pass_fsas.word_syms,
other_label_name='phone_syms',
keep=[...])
fsas_rmeps = fsa.RmEpsWeighted(first_pass_fsas)
fsas_det = fsa.DeterminizeTropicalPruned(fsas_rmeps, beam=..)
# note: fsas_det still has `phone_syms` and `nnet_arc_indexes`, now as sequences.
def PseudoTensor:
"""This is a class that behaves like a torch.Tensor but in fact only supports one kind of
operation, namely indexing with another torch.Tensor"""
def __init__(self, t, divisor):
""" Constructor.
Parameters:
t: torch.LongTensor
divisor: int
"""
self.t = t
self.divisor = divisor
def __getitem__(self, indexes):
return self.t[indexes / divisor]
def DenseFsaVec:
"""Represents a vector of FSAs, but with a special regular
structure. Each FSA would normally correspond to one supervised
segment within an acoustic sequence. This wraps the data
output from the neural net. Each segment has T+2 states
numbered 0, 1, .. T, T+1 (the T+1'th is the final-state and
only has final-arcs entering it). From state i to i+1
there is an arc for each symbol, whose loglike comes from
lookup in the neural-net output."""
def __init__(self, loglikes, seq_indexes, start_times, end_times):
""" Constructor.
Params:
loglikes (torch.Tensor): The tensor of log-likelihoods
output by the neural network. Will be interpreted
as having shape (num_seqs, num_frames, num_symbols).
Here `num_symbols` includes epsilon.
seq_indexes, start_times, end_times (torch.Tensor):
These must all have the same shape, of the form
(num_segments,). Here, num_segments would normally
be >= num_seqs (each sequence may have more than
one supervised segment in it, and they may
overlap).
- seq_indexes says, for each segment, which sequence
it is a part of
- start_times says, for each segment, what the first
frame-index in `loglikes` is
- start_times says, for each segment, what the
one-past-the-last frame-index in `loglikes` is.
"""
# note: this will reate a csrc.CfsaVec object internal to this
# object.
# It will also create self.arc_loglikes containing the
# loglikes, one per arc of the CfsaVec object. This is
# a repeat of `loglikes` but possibly in a different
# order.
pass
@property
def loglikes(self):
return self.arc_loglikes
@property
def seg_frames(self):
"""Return something that 'acts' like a tensor, indexed by arc, of
the frame-index relative to the segment start corresponding to that
arc. NOTE: self.frame_loglikes will actually be a sub-Tensor
of the Tensor created at the C++ level as the DenseFsaVecMeta object.
"""
return PseudoTensor(self.frame_loglikes, self.num_symbols)
@property
def seq_frames(self):
""" as for seg_frames"""
pass
@property
def seq_ids(self):
""" as for seg_frames"""
pass
@property
def seg_ids(self):
""" as for seg_frames"""
pass