Skip to content

Commit 33415f4

Browse files
vedanujapsdehal
authored andcommitted
[fix] Fix issue with unused parameter in BertLMPredictionHead in ViLBERT (#35)
1 parent 828abb4 commit 33415f4

File tree

1 file changed

+22
-14
lines changed

1 file changed

+22
-14
lines changed

pythia/models/vilbert.py

+22-14
Original file line numberDiff line numberDiff line change
@@ -935,11 +935,13 @@ def __init__(self, config, training_head_type, dropout_prob=0.1):
935935
super(BertForMultiModalPreTraining, self).__init__(config)
936936

937937
self.bert = BertModel(config)
938-
self.cls = BertPreTrainingHeads(config)
939938
self.training_head_type = training_head_type
940939
self.fusion_method = config.fusion_method
941940
self.dropout = nn.Dropout(dropout_prob)
942941

942+
if "pretraining" in self.training_head_type:
943+
self.cls = BertPreTrainingHeads(config)
944+
943945
# Create a copy of config since struct mode won't allow direct overrides
944946
# classifier_config is only needed for initializing the classifier
945947
classifier_config = deepcopy(config)
@@ -949,30 +951,30 @@ def __init__(self, config, training_head_type, dropout_prob=0.1):
949951
self.answer_space_size = 3129
950952
self.classifier = nn.Sequential(
951953
BertPredictionHeadTransform(classifier_config),
952-
nn.Linear(classifier_config.hidden_size, 3129)
954+
nn.Linear(classifier_config.hidden_size, 3129),
953955
)
954956
elif "vizwiz" in self.training_head_type:
955957
self.answer_space_size = 7371
956958
self.classifier = nn.Sequential(
957959
BertPredictionHeadTransform(classifier_config),
958-
nn.Linear(classifier_config.hidden_size, 7371)
960+
nn.Linear(classifier_config.hidden_size, 7371),
959961
)
960962
elif self.training_head_type == "nlvr2":
961963
classifier_config.hidden_size *= 2
962964
self.classifier = nn.Sequential(
963965
BertPredictionHeadTransform(classifier_config),
964-
nn.Linear(classifier_config.hidden_size, 2)
966+
nn.Linear(classifier_config.hidden_size, 2),
965967
)
966968
classifier_config.hidden_size /= 2
967969
elif self.training_head_type == "visual_entailment":
968970
self.classifier = nn.Sequential(
969971
BertPredictionHeadTransform(classifier_config),
970-
nn.Linear(classifier_config.hidden_size, 3)
972+
nn.Linear(classifier_config.hidden_size, 3),
971973
)
972974
elif self.training_head_type == "mmimdb":
973975
self.classifier = nn.Sequential(
974976
BertPredictionHeadTransform(classifier_config),
975-
nn.Linear(classifier_config.hidden_size, 24)
977+
nn.Linear(classifier_config.hidden_size, 24),
976978
)
977979

978980
self.init_weights()
@@ -993,9 +995,10 @@ def tie_weights(self):
993995
""" Make sure we are sharing the input and output embeddings.
994996
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
995997
"""
996-
self._tie_or_clone_weights(
997-
self.cls.predictions.decoder, self.bert.embeddings.word_embeddings
998-
)
998+
if hasattr(self, "cls"):
999+
self._tie_or_clone_weights(
1000+
self.cls.predictions.decoder, self.bert.embeddings.word_embeddings
1001+
)
9991002

10001003
def forward(
10011004
self,
@@ -1012,7 +1015,13 @@ def forward(
10121015
output_all_attention_masks=False,
10131016
):
10141017
# in this model, we first embed the images.
1015-
sequence_output_t, sequence_output_v, pooled_output_t, pooled_output_v, all_attention_mask = self.bert(
1018+
(
1019+
sequence_output_t,
1020+
sequence_output_v,
1021+
pooled_output_t,
1022+
pooled_output_v,
1023+
all_attention_mask,
1024+
) = self.bert(
10161025
input_ids,
10171026
image_feat,
10181027
image_loc,
@@ -1023,12 +1032,11 @@ def forward(
10231032
output_all_attention_masks=output_all_attention_masks,
10241033
)
10251034

1026-
prediction_scores_t, prediction_scores_v, seq_relationship_score = self.cls(
1027-
sequence_output_t, sequence_output_v, pooled_output_t, pooled_output_v
1028-
)
1029-
10301035
output_dict = {}
10311036
if "pretraining" in self.training_head_type:
1037+
prediction_scores_t, prediction_scores_v, seq_relationship_score = self.cls(
1038+
sequence_output_t, sequence_output_v, pooled_output_t, pooled_output_v
1039+
)
10321040
if image_target is not None:
10331041
if self.visual_target == 1:
10341042
img_loss = self.vis_criterion(prediction_scores_v, image_target)

0 commit comments

Comments
 (0)