@@ -935,11 +935,13 @@ def __init__(self, config, training_head_type, dropout_prob=0.1):
935
935
super (BertForMultiModalPreTraining , self ).__init__ (config )
936
936
937
937
self .bert = BertModel (config )
938
- self .cls = BertPreTrainingHeads (config )
939
938
self .training_head_type = training_head_type
940
939
self .fusion_method = config .fusion_method
941
940
self .dropout = nn .Dropout (dropout_prob )
942
941
942
+ if "pretraining" in self .training_head_type :
943
+ self .cls = BertPreTrainingHeads (config )
944
+
943
945
# Create a copy of config since struct mode won't allow direct overrides
944
946
# classifier_config is only needed for initializing the classifier
945
947
classifier_config = deepcopy (config )
@@ -949,30 +951,30 @@ def __init__(self, config, training_head_type, dropout_prob=0.1):
949
951
self .answer_space_size = 3129
950
952
self .classifier = nn .Sequential (
951
953
BertPredictionHeadTransform (classifier_config ),
952
- nn .Linear (classifier_config .hidden_size , 3129 )
954
+ nn .Linear (classifier_config .hidden_size , 3129 ),
953
955
)
954
956
elif "vizwiz" in self .training_head_type :
955
957
self .answer_space_size = 7371
956
958
self .classifier = nn .Sequential (
957
959
BertPredictionHeadTransform (classifier_config ),
958
- nn .Linear (classifier_config .hidden_size , 7371 )
960
+ nn .Linear (classifier_config .hidden_size , 7371 ),
959
961
)
960
962
elif self .training_head_type == "nlvr2" :
961
963
classifier_config .hidden_size *= 2
962
964
self .classifier = nn .Sequential (
963
965
BertPredictionHeadTransform (classifier_config ),
964
- nn .Linear (classifier_config .hidden_size , 2 )
966
+ nn .Linear (classifier_config .hidden_size , 2 ),
965
967
)
966
968
classifier_config .hidden_size /= 2
967
969
elif self .training_head_type == "visual_entailment" :
968
970
self .classifier = nn .Sequential (
969
971
BertPredictionHeadTransform (classifier_config ),
970
- nn .Linear (classifier_config .hidden_size , 3 )
972
+ nn .Linear (classifier_config .hidden_size , 3 ),
971
973
)
972
974
elif self .training_head_type == "mmimdb" :
973
975
self .classifier = nn .Sequential (
974
976
BertPredictionHeadTransform (classifier_config ),
975
- nn .Linear (classifier_config .hidden_size , 24 )
977
+ nn .Linear (classifier_config .hidden_size , 24 ),
976
978
)
977
979
978
980
self .init_weights ()
@@ -993,9 +995,10 @@ def tie_weights(self):
993
995
""" Make sure we are sharing the input and output embeddings.
994
996
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
995
997
"""
996
- self ._tie_or_clone_weights (
997
- self .cls .predictions .decoder , self .bert .embeddings .word_embeddings
998
- )
998
+ if hasattr (self , "cls" ):
999
+ self ._tie_or_clone_weights (
1000
+ self .cls .predictions .decoder , self .bert .embeddings .word_embeddings
1001
+ )
999
1002
1000
1003
def forward (
1001
1004
self ,
@@ -1012,7 +1015,13 @@ def forward(
1012
1015
output_all_attention_masks = False ,
1013
1016
):
1014
1017
# in this model, we first embed the images.
1015
- sequence_output_t , sequence_output_v , pooled_output_t , pooled_output_v , all_attention_mask = self .bert (
1018
+ (
1019
+ sequence_output_t ,
1020
+ sequence_output_v ,
1021
+ pooled_output_t ,
1022
+ pooled_output_v ,
1023
+ all_attention_mask ,
1024
+ ) = self .bert (
1016
1025
input_ids ,
1017
1026
image_feat ,
1018
1027
image_loc ,
@@ -1023,12 +1032,11 @@ def forward(
1023
1032
output_all_attention_masks = output_all_attention_masks ,
1024
1033
)
1025
1034
1026
- prediction_scores_t , prediction_scores_v , seq_relationship_score = self .cls (
1027
- sequence_output_t , sequence_output_v , pooled_output_t , pooled_output_v
1028
- )
1029
-
1030
1035
output_dict = {}
1031
1036
if "pretraining" in self .training_head_type :
1037
+ prediction_scores_t , prediction_scores_v , seq_relationship_score = self .cls (
1038
+ sequence_output_t , sequence_output_v , pooled_output_t , pooled_output_v
1039
+ )
1032
1040
if image_target is not None :
1033
1041
if self .visual_target == 1 :
1034
1042
img_loss = self .vis_criterion (prediction_scores_v , image_target )
0 commit comments