|
122 | 122 | " flattened_token_ids.extend(id_group)\n",
|
123 | 123 | " return flattened_token_ids\n",
|
124 | 124 | " \n",
|
125 |
| - "def get_token_id_groups_from_words_in_input(text, tokenizer):\n", |
| 125 | + "def get_token_id_groups_from_words_in_input(text, baseline_token, tokenizer):\n", |
126 | 126 | " # get default start and end ids of a sentence\n",
|
127 |
| - " text_ids, _ = get_input_baseline_ids(text, \"\", tokenizer)\n", |
| 127 | + " text_ids, baseline_ids = get_input_baseline_ids(text, baseline_token, tokenizer)\n", |
128 | 128 | " start_id, end_id = text_ids[0], text_ids[-1]\n",
|
129 | 129 | "\n",
|
130 | 130 | " # get each word in the input text as a list of corresponding token ids\n",
|
|
145 | 145 | "metadata": {},
|
146 | 146 | "outputs": [],
|
147 | 147 | "source": [
|
148 |
| - "token_id_groups, baseline_id_groups = get_token_id_groups_from_words_in_input(text, tokenizer)\n", |
| 148 | + "token_id_groups, baseline_id_groups = get_token_id_groups_from_words_in_input(text, baseline_token, tokenizer)\n", |
149 | 149 | "\n",
|
150 | 150 | "xf = WordXformer(token_id_groups, baseline_id_groups) \n",
|
151 | 151 | "apgo = Archipelago(model_wrapper, data_xformer=xf, output_indices=class_idx, batch_size=20)"
|
|
0 commit comments