|
24 | 24 | from icefall.dist import cleanup_dist, setup_dist
|
25 | 25 | from icefall.graph_compiler import CtcTrainingGraphCompiler
|
26 | 26 | from icefall.lexicon import Lexicon
|
27 |
| -from icefall.utils import ( |
28 |
| - AttributeDict, |
29 |
| - encode_supervisions, |
30 |
| - setup_logger, |
31 |
| - str2bool, |
32 |
| -) |
| 27 | +from icefall.utils import AttributeDict, setup_logger, str2bool |
33 | 28 |
|
34 | 29 |
|
35 | 30 | def get_parser():
|
@@ -61,7 +56,7 @@ def get_parser():
|
61 | 56 | parser.add_argument(
|
62 | 57 | "--num-epochs",
|
63 | 58 | type=int,
|
64 |
| - default=50, |
| 59 | + default=15, |
65 | 60 | help="Number of epochs to train.",
|
66 | 61 | )
|
67 | 62 |
|
@@ -129,11 +124,10 @@ def get_params() -> AttributeDict:
|
129 | 124 | {
|
130 | 125 | "exp_dir": Path("tdnn/exp"),
|
131 | 126 | "lang_dir": Path("data/lang_phone"),
|
132 |
| - "lr": 1e-3, |
| 127 | + "lr": 1e-2, |
133 | 128 | "feature_dim": 23,
|
134 | 129 | "weight_decay": 1e-6,
|
135 | 130 | "start_epoch": 0,
|
136 |
| - "num_epochs": 50, |
137 | 131 | "best_train_loss": float("inf"),
|
138 | 132 | "best_valid_loss": float("inf"),
|
139 | 133 | "best_train_epoch": -1,
|
@@ -278,9 +272,14 @@ def compute_loss(
|
278 | 272 | # different duration in decreasing order, required by
|
279 | 273 | # `k2.intersect_dense` called in `k2.ctc_loss`
|
280 | 274 | supervisions = batch["supervisions"]
|
281 |
| - supervision_segments, texts = encode_supervisions( |
282 |
| - supervisions, subsampling_factor=1 |
| 275 | + texts = supervisions["text"] |
| 276 | + |
| 277 | + batch_size = nnet_output.shape[0] |
| 278 | + supervision_segments = torch.tensor( |
| 279 | + [[i, 0, nnet_output.shape[1]] for i in range(batch_size)], |
| 280 | + dtype=torch.int32, |
283 | 281 | )
|
| 282 | + |
284 | 283 | decoding_graph = graph_compiler.compile(texts)
|
285 | 284 |
|
286 | 285 | dense_fsa_vec = k2.DenseFsaVec(
|
@@ -491,7 +490,7 @@ def run(rank, world_size, args):
|
491 | 490 | if world_size > 1:
|
492 | 491 | model = DDP(model, device_ids=[rank])
|
493 | 492 |
|
494 |
| - optimizer = optim.AdamW( |
| 493 | + optimizer = optim.SGD( |
495 | 494 | model.parameters(),
|
496 | 495 | lr=params.lr,
|
497 | 496 | weight_decay=params.weight_decay,
|
|
0 commit comments