|
42 | 42 | parser.add_argument("--max_seq_length", default=128, type=int, help="The maximum total input sequence length after tokenization. Sequences longer than this will be truncated, sequences shorter will be padded.")
|
43 | 43 | parser.add_argument("--batch_size", default=16, type=int, help="Batch size per GPU/CPU for training.")
|
44 | 44 | parser.add_argument("--seed", type=int, default=3, help="random seed for initialization")
|
45 |
| -parser.add_argument("--rationale_num", type=int, default=3, help="Number of rationales per example.") |
| 45 | +parser.add_argument("--rationale_num_sparse", type=int, default=3, help="Number of rationales per example for sparse data.") |
| 46 | +parser.add_argument("--rationale_num_support", type=int, default=6, help="Number of rationales per example for support data.") |
46 | 47 | parser.add_argument("--sparse_num", type=int, default=100, help="Number of sparse data.")
|
47 | 48 | parser.add_argument("--support_threshold", type=float, default="0.7", help="The threshold to select support data.")
|
48 | 49 | parser.add_argument("--support_num", type=int, default=100, help="Number of support data.")
|
@@ -180,7 +181,8 @@ def find_sparse_data():
|
180 | 181 | # Feature similarity analysis & select sparse data
|
181 | 182 | analysis_result = []
|
182 | 183 | for batch in dev_data_loader:
|
183 |
| - analysis_result += feature_sim(batch, sample_num=args.rationale_num) |
| 184 | + analysis_result += feature_sim(batch, |
| 185 | + sample_num=args.rationale_num_sparse) |
184 | 186 | sparse_indexs, sparse_scores, preds = get_sparse_data(
|
185 | 187 | analysis_result, args.sparse_num)
|
186 | 188 |
|
@@ -285,7 +287,8 @@ def find_support_data():
|
285 | 287 | # Feature similarity analysis
|
286 | 288 | analysis_result = []
|
287 | 289 | for batch in sparse_data_loader:
|
288 |
| - analysis_result += feature_sim(batch, sample_num=-1) |
| 290 | + analysis_result += feature_sim(batch, |
| 291 | + sample_num=args.rationale_num_support) |
289 | 292 |
|
290 | 293 | support_indexs, support_scores = get_support_data(analysis_result,
|
291 | 294 | args.support_num,
|
|
0 commit comments