Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix pruning bounds #1009

Merged
merged 5 commits into from
Aug 7, 2022
Merged

Fix pruning bounds #1009

merged 5 commits into from
Aug 7, 2022

Conversation

pkufool
Copy link
Collaborator

@pkufool pkufool commented Jul 11, 2022

This PR fixes the issue of returning nan or inf loss (posted at k2-fsa/fast_rnnt#10).
It also updates the method to generate pruning bounds (i.e. the method we wrote in our pruned rnn-t paper). I did some experiments on April, but fogot to submit the code, sorry! About the new pruning bounds, I think we can not 100% gurantee it would be better than the old one, as both of them are local optimal bounds.

# We won't do any assertion in this test, just printing out the losses,
# because we can not 100% sure that the new method is better than the old
# one all the time, both of them are local optimal bounds.
def test_prune_ranges(self):
Copy link
Collaborator Author

@pkufool pkufool Jul 11, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested the two pruning bounds here, first, I generated a mask according to the ranges, and filled the logits with 0.0 at the positions outof ranges. then, I computed the loss of the modifed logits. If the new pruning bounds is better than the old one, the loss should be smaller. From my experiments (both on real data and random generated data), most of the time, the loss of the logit pruned with new bounds is smaller than the one pruned with old bounds. But it can not 100% gurantee it. If I did something wrong, please LMK @danpovey, thanks!

if s_range > S:
s_range = S
s_range = S + 1
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The buggy code that raises nan or inf loss is here. It only occurs when S==1 and modified==False, because when modified==False the transitions can only go up and right, so s_range MUST be equal to or greater than 2, or no path can survive pruning.

@pkufool
Copy link
Collaborator Author

pkufool commented Jul 15, 2022

I tested the new bounds and old bounds on real data (librispech 100h), in the following figure, pink one is the curve of loss pruned by old bounds, blun one is the curve of loss pruned by new bounds (Note, I used the trivial joiner to calculate the losses). You can see in the figure, only in the early batches, the loss of the new one is a little larger than the old one, after 3k iterations, the old one is identical to the new one.

I also compared the pruning bounds, in the early 3k iterations, torch.mean(new_bounds - old_bounds) == 0.01535, for the iterations from 5k -> 10k, torch.mean(new_bounds - old_bounds) == 0.00558.

image

@pkufool
Copy link
Collaborator Author

pkufool commented Jul 15, 2022

It seems that the old / new bounds are almost identical, I will use the new bounds to reproduce one of our recipes.

@pkufool
Copy link
Collaborator Author

pkufool commented Jul 25, 2022

I tried this new pruning bounds on our two baseline recipes, here are the results:

From the following results, training with new pruning bounds makes some improvements on pruned_transducer_stateless4 recipe, especially on test-other. While on pruned_transducer_stateless5 (exp-B, 88MB parameters), I got slightly worse results.

pruned transducer stateless4

  greedy_search modified_beam_search fast_beam_search comments
baseline 2.69 & 6.64 2.62 & 6.57 2.66 & 6.6 --epoch 30 --avg 6
New bounds 2.62 & 6.28 2.58 & 6.18 2.57 & 6.13 --epoch 25 --avg 13
tensorboard for baseline tensorboard for new bounds
image image

pruned transducer stateless5 exp-B

  greedy_search modified_beam_search fast_beam_search comments
baseline 2.54 & 5.72 2.47 & 5.71 2.5 & 5.72 --epoch 30 --avg 10
New bounds 2.53 & 5.77 2.53 & 5.7 2.52 & 5.71 --epoch 25 --avg 11
tensorboard for baseline tensorboard for new bounds
image image

@pkufool
Copy link
Collaborator Author

pkufool commented Aug 4, 2022

I conducted two experiments to compare the pruned loss value of old bounds and new bounds.

It seems that whether we use "old bounds" or "new bounds" in train, the loss values of "new bounds" are always larger than that of "old bounds" at the early iterations. BTW from the picture, we can see "new bounds" helps to converge faster, the only difference of these two experiments are the bounds used for pruning.

training with new boulds

The blue curve is for new bounds
The pink curve is for old bounds
image

code to generate tensorboard above

   ranges = k2.get_rnnt_prune_ranges(
      px_grad=px_grad,
      py_grad=py_grad,
      boundary=boundary,
      s_range=prune_range,
    )

    # this loss will be used to train the model
    pruned_loss = k2.rnnt_loss_pruned(
        logits=logits.float(),   # this logits is from joiner output.
        symbols=y_padded,
        ranges=ranges,
        termination_symbol=blank_id,
        boundary=boundary,
        reduction="sum",
    )

      new_am_pruned, new_lm_pruned = k2.do_rnnt_pruning(am=am, lm=lm, ranges=ranges)
      new_logits = new_am_pruned + new_lm_pruned

     # this loss is used to draw the blue curve
      new_pruned_loss = k2.rnnt_loss_pruned(
          logits = new_logits.float(),
          symbols=y_padded,
          ranges=ranges,
          termination_symbol=blank_id,
          boundary=boundary,
          reduction="sum",
      )

      old_ranges = k2.get_rnnt_prune_ranges_deprecated(
          px_grad=px_grad,
          py_grad=py_grad,
          boundary=boundary,
          s_range=prune_range,
      )

      old_am_pruned, old_lm_pruned = k2.do_rnnt_pruning(am=am, lm=lm, ranges=old_ranges)
      old_logits = old_am_pruned + old_lm_pruned

      # this loss is used to draw the pink curve
      old_pruned_loss = k2.rnnt_loss_pruned(
          logits = old_logits.float(),
          symbols=y_padded,
          ranges=old_ranges,
          termination_symbol=blank_id,
          boundary=boundary,
          reduction="sum",
      )
      return (simple_loss, pruned_loss, new_pruned_loss, old_pruned_loss)

training with old boulds

The blue curve is for new bounds
The pink curve is for old bounds
image

s_range >= 2
), "Pruning range for standard RNN-T should be equal to or greater than 2, or no valid paths could survive pruning."

blk_grad = torch.as_strided(
Copy link
Collaborator

@danpovey danpovey Aug 4, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, I think it would be better not to assume that py is contiguous, in place of S1 you should probably
be using py_grad.stride(2), and so on, and the S1 * T should be py_grad.stride(0) or possibly a variable called B_stride, e.g. (B_stride, S_stride, T_stride) = py_grad.stride()


# (B, T)
s_begin = torch.argmax(final_grad, axis=1)
s_begin = s_begin[:, :T]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this statement may be redundant?

# padding value which is `len(symbols) - s_range + 1`.
# This is to guarantee that we reach the last symbol at last frame of real
# data.
mask = torch.arange(0, T, device=px_grad.device).reshape(1, T).expand(B, T)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we comment to say what mask will be, e.g. its shape and when it is true?

@danpovey
Copy link
Collaborator

danpovey commented Aug 4, 2022

After fixing those small issues and making sure the tests run, I incline towards merging this.
Generally if there are inconsistent WER results that are not clearly significant, I prefer to just average them to get the overall story, so I tend to think this is better than the previous method; also it's encouraging that it converges faster.
Regarding the difference between the loss functions: firstly, double-check that you didn't swap the two of them at some point in your code, or when you interpreted the graph; but if the difference persists, I suppose it might have to do with the continuity constraints somehow, e.g. the new method might be more locally optimal but less globally optimal somehow.

@pkufool
Copy link
Collaborator Author

pkufool commented Aug 4, 2022

After fixing those small issues and making sure the tests run, I incline towards merging this. Generally if there are inconsistent WER results that are not clearly significant, I prefer to just average them to get the overall story, so I tend to think this is better than the previous method; also it's encouraging that it converges faster. Regarding the difference between the loss functions: firstly, double-check that you didn't swap the two of them at some point in your code, or when you interpreted the graph; but if the difference persists, I suppose it might have to do with the continuity constraints somehow, e.g. the new method might be more locally optimal but less globally optimal somehow.

Ok, will fix the issues and check the code again. Thanks!

@pkufool pkufool added the ready Ready for review and trigger GitHub actions to run label Aug 5, 2022
@danpovey
Copy link
Collaborator

danpovey commented Aug 6, 2022

Thanks... looks OK to me.

@pkufool pkufool merged commit 7b7a8a5 into k2-fsa:master Aug 7, 2022
@RuABraun
Copy link

On real-world datasets the assert T>=S fails a non-insignificant number of times for me. I think a subsampling factor of 4 might be too aggressive.

@danpovey
Copy link
Collaborator

We should make sure there's a way for calling code to ignore this error.
@RuABraun Are you sure that the transcripts in these cases are correct?

@RuABraun
Copy link

Yup (it's German if that matters). Doesn't happen that often honestly but just wanted to raise awareness.

@danpovey
Copy link
Collaborator

Ah, OK. For German it may be different than English (or might need larger BPE vocab).
It should be possible to train with modified==False (on all minibatches), so we shouldn't have this limitation (IIRC).
To take advantage of that in test time, we can only use certain decoding methods.
In any case we should make it possible to skip over batches with this problem.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready Ready for review and trigger GitHub actions to run
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants