@@ -118,6 +118,35 @@ def set_seed(seed):
118
118
np .random .seed (seed )
119
119
120
120
121
+ def create_data_loader (dataset , mode = "train" , batch_size = 1 , trans_fn = None ):
122
+ """
123
+ Create dataloader.
124
+ Args:
125
+ dataset(obj:`paddle.io.Dataset`): Dataset instance.
126
+ mode(obj:`str`, optional, defaults to obj:`train`): If mode is 'train', it will shuffle the dataset randomly.
127
+ batch_size(obj:`int`, optional, defaults to 1): The sample number of a mini-batch.
128
+ trans_fn(obj:`callable`, optional, defaults to `None`): function to convert a data sample to input ids, etc.
129
+ Returns:
130
+ dataloader(obj:`paddle.io.DataLoader`): The dataloader which generates batches.
131
+ """
132
+ if trans_fn :
133
+ dataset = dataset .map (trans_fn )
134
+
135
+ shuffle = True if mode == 'train' else False
136
+ if mode == "train" :
137
+ sampler = paddle .io .DistributedBatchSampler (dataset = dataset ,
138
+ batch_size = batch_size ,
139
+ shuffle = shuffle )
140
+ else :
141
+ sampler = paddle .io .BatchSampler (dataset = dataset ,
142
+ batch_size = batch_size ,
143
+ shuffle = shuffle )
144
+ dataloader = paddle .io .DataLoader (dataset ,
145
+ batch_sampler = sampler ,
146
+ return_list = True )
147
+ return dataloader
148
+
149
+
121
150
def convert_example (example , tokenizer , max_seq_len ):
122
151
"""
123
152
example: {
@@ -267,6 +296,48 @@ def unify_prompt_name(prompt):
267
296
return prompt
268
297
269
298
299
+ def get_relation_type_dict (relation_data ):
300
+
301
+ def compare (a , b ):
302
+ a = a [::- 1 ]
303
+ b = b [::- 1 ]
304
+ res = ''
305
+ for i in range (min (len (a ), len (b ))):
306
+ if a [i ] == b [i ]:
307
+ res += a [i ]
308
+ else :
309
+ break
310
+ if res == "" :
311
+ return res
312
+ elif res [::- 1 ][0 ] == "็" :
313
+ return res [::- 1 ][1 :]
314
+ return ""
315
+
316
+ relation_type_dict = {}
317
+ added_list = []
318
+ for i in range (len (relation_data )):
319
+ added = False
320
+ if relation_data [i ][0 ] not in added_list :
321
+ for j in range (i + 1 , len (relation_data )):
322
+ match = compare (relation_data [i ][0 ], relation_data [j ][0 ])
323
+ if match != "" :
324
+ match = unify_prompt_name (match )
325
+ if relation_data [i ][0 ] not in added_list :
326
+ added_list .append (relation_data [i ][0 ])
327
+ relation_type_dict .setdefault (match , []).append (
328
+ relation_data [i ][1 ])
329
+ added_list .append (relation_data [j ][0 ])
330
+ relation_type_dict .setdefault (match , []).append (
331
+ relation_data [j ][1 ])
332
+ added = True
333
+ if not added :
334
+ added_list .append (relation_data [i ][0 ])
335
+ suffix = relation_data [i ][0 ].rsplit ("็" , 1 )[1 ]
336
+ suffix = unify_prompt_name (suffix )
337
+ relation_type_dict [suffix ] = relation_data [i ][1 ]
338
+ return relation_type_dict
339
+
340
+
270
341
def add_entity_negative_example (examples , texts , prompts , label_set ,
271
342
negative_ratio ):
272
343
negative_examples = []
@@ -610,26 +681,31 @@ def _sep_cls_label(label, separator):
610
681
redundants1 = inverse_relation_list [i ]
611
682
612
683
# 2. entity_name_set ^ subject_goldens[i]
613
- nonentity_list = list (
614
- set (entity_name_set ) ^ set (subject_goldens [i ]))
615
- nonentity_list .sort ()
616
-
617
- redundants2 = [
618
- nonentity + "็" + predicate_list [i ][random .randrange (
619
- len (predicate_list [i ]))]
620
- for nonentity in nonentity_list
621
- ]
684
+ redundants2 = []
685
+ if len (predicate_list [i ]) != 0 :
686
+ nonentity_list = list (
687
+ set (entity_name_set ) ^ set (subject_goldens [i ]))
688
+ nonentity_list .sort ()
689
+
690
+ redundants2 = [
691
+ nonentity + "็" +
692
+ predicate_list [i ][random .randrange (
693
+ len (predicate_list [i ]))]
694
+ for nonentity in nonentity_list
695
+ ]
622
696
623
697
# 3. entity_label_set ^ entity_prompts[i]
624
- non_ent_label_list = list (
625
- set (entity_label_set ) ^ set (entity_prompts [i ]))
626
- non_ent_label_list .sort ()
627
-
628
- redundants3 = [
629
- subject_goldens [i ][random .randrange (
630
- len (subject_goldens [i ]))] + "็" + non_ent_label
631
- for non_ent_label in non_ent_label_list
632
- ]
698
+ redundants3 = []
699
+ if len (subject_goldens [i ]) != 0 :
700
+ non_ent_label_list = list (
701
+ set (entity_label_set ) ^ set (entity_prompts [i ]))
702
+ non_ent_label_list .sort ()
703
+
704
+ redundants3 = [
705
+ subject_goldens [i ][random .randrange (
706
+ len (subject_goldens [i ]))] + "็" + non_ent_label
707
+ for non_ent_label in non_ent_label_list
708
+ ]
633
709
634
710
redundants_list = [redundants1 , redundants2 , redundants3 ]
635
711
0 commit comments