Skip to content

Commit 87613d4

Browse files
authored
Update debug mode for relation prompt (#3263)
* update debug mode for relation prompt * update * update
1 parent 135e9fa commit 87613d4

File tree

6 files changed

+160
-74
lines changed

6 files changed

+160
-74
lines changed

โ€Žmodel_zoo/uie/README.md

+19-13
Original file line numberDiff line numberDiff line change
@@ -640,7 +640,7 @@ python finetune.py \
640640
--device gpu
641641
```
642642

643-
ๅคšๅกๅฏๅŠจ๏ผš
643+
ๅฆ‚ๆžœๅœจGPU็Žฏๅขƒไธญไฝฟ็”จ๏ผŒๅฏไปฅๆŒ‡ๅฎš``gpus``ๅ‚ๆ•ฐ่ฟ›่กŒๅคšๅก่ฎญ็ปƒ๏ผš
644644

645645
```shell
646646
python -u -m paddle.distributed.launch --gpus "0,1" finetune.py \
@@ -701,18 +701,24 @@ python evaluate.py \
701701
่พ“ๅ‡บๆ‰“ๅฐ็คบไพ‹๏ผš
702702

703703
```text
704-
[2022-06-23 08:25:23,017] [ INFO] - -----------------------------
705-
[2022-06-23 08:25:23,017] [ INFO] - Class name: ๆ—ถ้—ด
706-
[2022-06-23 08:25:23,018] [ INFO] - Evaluation precision: 1.00000 | recall: 1.00000 | F1: 1.00000
707-
[2022-06-23 08:25:23,145] [ INFO] - -----------------------------
708-
[2022-06-23 08:25:23,146] [ INFO] - Class name: ็›ฎ็š„ๅœฐ
709-
[2022-06-23 08:25:23,146] [ INFO] - Evaluation precision: 0.64286 | recall: 0.90000 | F1: 0.75000
710-
[2022-06-23 08:25:23,272] [ INFO] - -----------------------------
711-
[2022-06-23 08:25:23,273] [ INFO] - Class name: ่ดน็”จ
712-
[2022-06-23 08:25:23,273] [ INFO] - Evaluation precision: 0.11111 | recall: 0.10000 | F1: 0.10526
713-
[2022-06-23 08:25:23,399] [ INFO] - -----------------------------
714-
[2022-06-23 08:25:23,399] [ INFO] - Class name: ๅ‡บๅ‘ๅœฐ
715-
[2022-06-23 08:25:23,400] [ INFO] - Evaluation precision: 1.00000 | recall: 1.00000 | F1: 1.00000
704+
[2022-09-14 03:13:58,877] [ INFO] - -----------------------------
705+
[2022-09-14 03:13:58,877] [ INFO] - Class Name: ็–พ็—…
706+
[2022-09-14 03:13:58,877] [ INFO] - Evaluation Precision: 0.89744 | Recall: 0.83333 | F1: 0.86420
707+
[2022-09-14 03:13:59,145] [ INFO] - -----------------------------
708+
[2022-09-14 03:13:59,145] [ INFO] - Class Name: ๆ‰‹ๆœฏๆฒป็–—
709+
[2022-09-14 03:13:59,145] [ INFO] - Evaluation Precision: 0.90000 | Recall: 0.85714 | F1: 0.87805
710+
[2022-09-14 03:13:59,439] [ INFO] - -----------------------------
711+
[2022-09-14 03:13:59,440] [ INFO] - Class Name: ๆฃ€ๆŸฅ
712+
[2022-09-14 03:13:59,440] [ INFO] - Evaluation Precision: 0.77778 | Recall: 0.56757 | F1: 0.65625
713+
[2022-09-14 03:13:59,708] [ INFO] - -----------------------------
714+
[2022-09-14 03:13:59,709] [ INFO] - Class Name: X็š„ๆ‰‹ๆœฏๆฒป็–—
715+
[2022-09-14 03:13:59,709] [ INFO] - Evaluation Precision: 0.90000 | Recall: 0.85714 | F1: 0.87805
716+
[2022-09-14 03:13:59,893] [ INFO] - -----------------------------
717+
[2022-09-14 03:13:59,893] [ INFO] - Class Name: X็š„ๅฎž้ชŒๅฎคๆฃ€ๆŸฅ
718+
[2022-09-14 03:13:59,894] [ INFO] - Evaluation Precision: 0.71429 | Recall: 0.55556 | F1: 0.62500
719+
[2022-09-14 03:14:00,057] [ INFO] - -----------------------------
720+
[2022-09-14 03:14:00,058] [ INFO] - Class Name: X็š„ๅฝฑๅƒๅญฆๆฃ€ๆŸฅ
721+
[2022-09-14 03:14:00,058] [ INFO] - Evaluation Precision: 0.69231 | Recall: 0.45000 | F1: 0.54545
716722
```
717723

718724
ๅฏ้…็ฝฎๅ‚ๆ•ฐ่ฏดๆ˜Ž๏ผš

โ€Žmodel_zoo/uie/data_distill/README.md

-7
Original file line numberDiff line numberDiff line change
@@ -146,13 +146,6 @@ python train.py \
146146
'text': '็™ป้ฉ็ƒญ'}]}]
147147
```
148148

149-
## ๆ•ˆๆžœ้ชŒ่ฏ
150-
151-
| ๆจกๅž‹ | Entity-F1 | SPO-F1 |
152-
| :---: | :--------: | :--------: |
153-
| UIE-Finetune | 78.57 | 56.25 |
154-
| GPLinker-ernie-3.0-mini-zh | 68.18 | 47.06 |
155-
| GPLinker-ernie-3.0-mini-zh + UIEๆ•ฐๆฎ่’ธ้ฆ | 76.38 | 50.42 |
156149

157150
# References
158151

โ€Žmodel_zoo/uie/data_distill/data_distill.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def do_data_distill():
8585
for text in tqdm(infer_texts, desc="Predicting: ", leave=False):
8686
infer_results.extend(uie(text))
8787

88-
train_synthetic_lines = synthetic2distill(texts, infer_results,
88+
train_synthetic_lines = synthetic2distill(infer_texts, infer_results,
8989
args.task_type)
9090

9191
# Concat origin and synthetic data

โ€Žmodel_zoo/uie/evaluate.py

+33-11
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from paddlenlp.utils.log import logger
2424

2525
from model import UIE
26-
from utils import convert_example, reader, unify_prompt_name
26+
from utils import convert_example, reader, unify_prompt_name, get_relation_type_dict, create_data_loader
2727

2828

2929
@paddle.no_grad()
@@ -60,28 +60,34 @@ def do_eval():
6060
max_seq_len=args.max_seq_len,
6161
lazy=False)
6262
class_dict = {}
63+
relation_data = []
6364
if args.debug:
6465
for data in test_ds:
6566
class_name = unify_prompt_name(data['prompt'])
6667
# Only positive examples are evaluated in debug mode
6768
if len(data['result_list']) != 0:
68-
class_dict.setdefault(class_name, []).append(data)
69+
if "็š„" not in data['prompt']:
70+
class_dict.setdefault(class_name, []).append(data)
71+
else:
72+
relation_data.append((data['prompt'], data))
73+
relation_type_dict = get_relation_type_dict(relation_data)
6974
else:
7075
class_dict["all_classes"] = test_ds
76+
77+
trans_fn = partial(convert_example,
78+
tokenizer=tokenizer,
79+
max_seq_len=args.max_seq_len)
80+
7181
for key in class_dict.keys():
7282
if args.debug:
7383
test_ds = MapDataset(class_dict[key])
7484
else:
7585
test_ds = class_dict[key]
76-
test_ds = test_ds.map(
77-
partial(convert_example,
78-
tokenizer=tokenizer,
79-
max_seq_len=args.max_seq_len))
80-
test_batch_sampler = paddle.io.BatchSampler(dataset=test_ds,
81-
batch_size=args.batch_size,
82-
shuffle=False)
83-
test_data_loader = paddle.io.DataLoader(
84-
dataset=test_ds, batch_sampler=test_batch_sampler, return_list=True)
86+
87+
test_data_loader = create_data_loader(test_ds,
88+
mode="test",
89+
batch_size=args.batch_size,
90+
trans_fn=trans_fn)
8591

8692
metric = SpanEvaluator()
8793
precision, recall, f1 = evaluate(model, metric, test_data_loader)
@@ -90,6 +96,22 @@ def do_eval():
9096
logger.info("Evaluation Precision: %.5f | Recall: %.5f | F1: %.5f" %
9197
(precision, recall, f1))
9298

99+
if args.debug and len(relation_type_dict.keys()) != 0:
100+
for key in relation_type_dict.keys():
101+
test_ds = MapDataset(relation_type_dict[key])
102+
103+
test_data_loader = create_data_loader(test_ds,
104+
mode="test",
105+
batch_size=args.batch_size,
106+
trans_fn=trans_fn)
107+
108+
metric = SpanEvaluator()
109+
precision, recall, f1 = evaluate(model, metric, test_data_loader)
110+
logger.info("-----------------------------")
111+
logger.info("Class Name: X็š„%s" % key)
112+
logger.info("Evaluation Precision: %.5f | Recall: %.5f | F1: %.5f" %
113+
(precision, recall, f1))
114+
93115

94116
if __name__ == "__main__":
95117
# yapf: disable

โ€Žmodel_zoo/uie/finetune.py

+13-24
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727
from model import UIE
2828
from evaluate import evaluate
29-
from utils import set_seed, convert_example, reader, MODEL_MAP
29+
from utils import set_seed, convert_example, reader, MODEL_MAP, create_data_loader
3030

3131

3232
def do_train():
@@ -57,28 +57,18 @@ def do_train():
5757
max_seq_len=args.max_seq_len,
5858
lazy=False)
5959

60-
train_ds = train_ds.map(
61-
partial(convert_example,
62-
tokenizer=tokenizer,
63-
max_seq_len=args.max_seq_len))
64-
dev_ds = dev_ds.map(
65-
partial(convert_example,
66-
tokenizer=tokenizer,
67-
max_seq_len=args.max_seq_len))
68-
69-
train_batch_sampler = paddle.io.BatchSampler(dataset=train_ds,
70-
batch_size=args.batch_size,
71-
shuffle=True)
72-
train_data_loader = paddle.io.DataLoader(dataset=train_ds,
73-
batch_sampler=train_batch_sampler,
74-
return_list=True)
75-
76-
dev_batch_sampler = paddle.io.BatchSampler(dataset=dev_ds,
77-
batch_size=args.batch_size,
78-
shuffle=False)
79-
dev_data_loader = paddle.io.DataLoader(dataset=dev_ds,
80-
batch_sampler=dev_batch_sampler,
81-
return_list=True)
60+
trans_fn = partial(convert_example,
61+
tokenizer=tokenizer,
62+
max_seq_len=args.max_seq_len)
63+
64+
train_data_loader = create_data_loader(train_ds,
65+
mode="train",
66+
batch_size=args.batch_size,
67+
trans_fn=trans_fn)
68+
dev_data_loader = create_data_loader(dev_ds,
69+
mode="dev",
70+
batch_size=args.batch_size,
71+
trans_fn=trans_fn)
8272

8373
if args.init_from_ckpt and os.path.isfile(args.init_from_ckpt):
8474
state_dict = paddle.load(args.init_from_ckpt)
@@ -95,7 +85,6 @@ def do_train():
9585

9686
loss_list = []
9787
global_step = 0
98-
best_step = 0
9988
best_f1 = 0
10089
tic_train = time.time()
10190
for epoch in range(1, args.num_epochs + 1):

โ€Žmodel_zoo/uie/utils.py

+94-18
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,35 @@ def set_seed(seed):
118118
np.random.seed(seed)
119119

120120

121+
def create_data_loader(dataset, mode="train", batch_size=1, trans_fn=None):
122+
"""
123+
Create dataloader.
124+
Args:
125+
dataset(obj:`paddle.io.Dataset`): Dataset instance.
126+
mode(obj:`str`, optional, defaults to obj:`train`): If mode is 'train', it will shuffle the dataset randomly.
127+
batch_size(obj:`int`, optional, defaults to 1): The sample number of a mini-batch.
128+
trans_fn(obj:`callable`, optional, defaults to `None`): function to convert a data sample to input ids, etc.
129+
Returns:
130+
dataloader(obj:`paddle.io.DataLoader`): The dataloader which generates batches.
131+
"""
132+
if trans_fn:
133+
dataset = dataset.map(trans_fn)
134+
135+
shuffle = True if mode == 'train' else False
136+
if mode == "train":
137+
sampler = paddle.io.DistributedBatchSampler(dataset=dataset,
138+
batch_size=batch_size,
139+
shuffle=shuffle)
140+
else:
141+
sampler = paddle.io.BatchSampler(dataset=dataset,
142+
batch_size=batch_size,
143+
shuffle=shuffle)
144+
dataloader = paddle.io.DataLoader(dataset,
145+
batch_sampler=sampler,
146+
return_list=True)
147+
return dataloader
148+
149+
121150
def convert_example(example, tokenizer, max_seq_len):
122151
"""
123152
example: {
@@ -267,6 +296,48 @@ def unify_prompt_name(prompt):
267296
return prompt
268297

269298

299+
def get_relation_type_dict(relation_data):
300+
301+
def compare(a, b):
302+
a = a[::-1]
303+
b = b[::-1]
304+
res = ''
305+
for i in range(min(len(a), len(b))):
306+
if a[i] == b[i]:
307+
res += a[i]
308+
else:
309+
break
310+
if res == "":
311+
return res
312+
elif res[::-1][0] == "็š„":
313+
return res[::-1][1:]
314+
return ""
315+
316+
relation_type_dict = {}
317+
added_list = []
318+
for i in range(len(relation_data)):
319+
added = False
320+
if relation_data[i][0] not in added_list:
321+
for j in range(i + 1, len(relation_data)):
322+
match = compare(relation_data[i][0], relation_data[j][0])
323+
if match != "":
324+
match = unify_prompt_name(match)
325+
if relation_data[i][0] not in added_list:
326+
added_list.append(relation_data[i][0])
327+
relation_type_dict.setdefault(match, []).append(
328+
relation_data[i][1])
329+
added_list.append(relation_data[j][0])
330+
relation_type_dict.setdefault(match, []).append(
331+
relation_data[j][1])
332+
added = True
333+
if not added:
334+
added_list.append(relation_data[i][0])
335+
suffix = relation_data[i][0].rsplit("็š„", 1)[1]
336+
suffix = unify_prompt_name(suffix)
337+
relation_type_dict[suffix] = relation_data[i][1]
338+
return relation_type_dict
339+
340+
270341
def add_entity_negative_example(examples, texts, prompts, label_set,
271342
negative_ratio):
272343
negative_examples = []
@@ -610,26 +681,31 @@ def _sep_cls_label(label, separator):
610681
redundants1 = inverse_relation_list[i]
611682

612683
# 2. entity_name_set ^ subject_goldens[i]
613-
nonentity_list = list(
614-
set(entity_name_set) ^ set(subject_goldens[i]))
615-
nonentity_list.sort()
616-
617-
redundants2 = [
618-
nonentity + "็š„" + predicate_list[i][random.randrange(
619-
len(predicate_list[i]))]
620-
for nonentity in nonentity_list
621-
]
684+
redundants2 = []
685+
if len(predicate_list[i]) != 0:
686+
nonentity_list = list(
687+
set(entity_name_set) ^ set(subject_goldens[i]))
688+
nonentity_list.sort()
689+
690+
redundants2 = [
691+
nonentity + "็š„" +
692+
predicate_list[i][random.randrange(
693+
len(predicate_list[i]))]
694+
for nonentity in nonentity_list
695+
]
622696

623697
# 3. entity_label_set ^ entity_prompts[i]
624-
non_ent_label_list = list(
625-
set(entity_label_set) ^ set(entity_prompts[i]))
626-
non_ent_label_list.sort()
627-
628-
redundants3 = [
629-
subject_goldens[i][random.randrange(
630-
len(subject_goldens[i]))] + "็š„" + non_ent_label
631-
for non_ent_label in non_ent_label_list
632-
]
698+
redundants3 = []
699+
if len(subject_goldens[i]) != 0:
700+
non_ent_label_list = list(
701+
set(entity_label_set) ^ set(entity_prompts[i]))
702+
non_ent_label_list.sort()
703+
704+
redundants3 = [
705+
subject_goldens[i][random.randrange(
706+
len(subject_goldens[i]))] + "็š„" + non_ent_label
707+
for non_ent_label in non_ent_label_list
708+
]
633709

634710
redundants_list = [redundants1, redundants2, redundants3]
635711

0 commit comments

Comments
ย (0)