Skip to content

Commit

Permalink
Populate _hooks for T5 and BART (#291)
Browse files Browse the repository at this point in the history
  • Loading branch information
hmellor authored Mar 22, 2023
1 parent 0a9418d commit 69ec329
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 6 deletions.
8 changes: 4 additions & 4 deletions optimum/graphcore/models/bart/modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,7 +751,7 @@ def parallelize(self, for_generation=False):

for index, (layer, ipu) in enumerate(zip(self.model.encoder.layers, encoder_layer_ipu)):
if self.ipu_config.recompute_checkpoint_every_layer and index != self.config.num_hidden_layers - 1:
recomputation_checkpoint(layer)
self._hooks.append(recomputation_checkpoint(layer))
self.model.encoder.layers[index] = poptorch.BeginBlock(layer, f"Encoder{index}", ipu_id=ipu)
logger.info(f"Encoder {index:<2} --> IPU {ipu}")

Expand All @@ -764,7 +764,7 @@ def parallelize(self, for_generation=False):

for index, (layer, ipu) in enumerate(zip(self.model.decoder.layers, decoder_layer_ipu)):
if self.ipu_config.recompute_checkpoint_every_layer and index != self.config.num_hidden_layers - 1:
recomputation_checkpoint(layer)
self._hooks.append(recomputation_checkpoint(layer))
self.model.decoder.layers[index] = poptorch.BeginBlock(layer, f"Decoder{index}", ipu_id=ipu)
logger.info(f"Decoder {index:<2} --> IPU {ipu}")

Expand Down Expand Up @@ -887,7 +887,7 @@ def parallelize(self):
for index, layer in enumerate(self.model.encoder.layers):
ipu = layer_ipu[index]
if self.ipu_config.recompute_checkpoint_every_layer and index != self.config.num_hidden_layers - 1:
recomputation_checkpoint(layer)
self._hooks.append(recomputation_checkpoint(layer))
self.model.encoder.layers[index] = poptorch.BeginBlock(layer, f"Encoder{index}", ipu_id=ipu)
logger.info(f"Encoder {index:<2} --> IPU {ipu}")

Expand All @@ -901,7 +901,7 @@ def parallelize(self):
for index, layer in enumerate(self.model.decoder.layers):
ipu = layer_ipu[index + shift]
if self.ipu_config.recompute_checkpoint_every_layer and index != self.config.num_hidden_layers - 1:
recomputation_checkpoint(layer)
self._hooks.append(recomputation_checkpoint(layer))
self.model.decoder.layers[index] = poptorch.BeginBlock(layer, f"Decoder{index}", ipu_id=ipu)
logger.info(f"Decoder {index:<2} --> IPU {ipu}")

Expand Down
5 changes: 3 additions & 2 deletions optimum/graphcore/models/t5/modeling_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ def parallelize(self, for_generation=False):
model = PipelinedT5ForConditionalGeneration(config).parallelize().half()
```
"""
PipelineMixin.parallelize(self)

logger.info("-------------------- Device Allocation --------------------")
logger.info("Embedding --> IPU 0")
Expand Down Expand Up @@ -276,7 +277,7 @@ def parallelize(self, for_generation=False):

for index, (layer, ipu) in enumerate(zip(self.encoder.block, encoder_layer_ipu)):
if self.ipu_config.recompute_checkpoint_every_layer and index != self.config.num_layers - 1:
recomputation_checkpoint(layer)
self._hooks.append(recomputation_checkpoint(layer))
self.encoder.block[index] = poptorch.BeginBlock(layer, f"Encoder{index}", ipu_id=ipu)
logger.info(f"Encoder {index:<2} --> IPU {ipu}")

Expand All @@ -286,7 +287,7 @@ def parallelize(self, for_generation=False):

for index, (layer, ipu) in enumerate(zip(self.decoder.block, decoder_layer_ipu)):
if self.ipu_config.recompute_checkpoint_every_layer and index != self.config.num_layers - 1:
recomputation_checkpoint(layer)
self._hooks.append(recomputation_checkpoint(layer))
self.decoder.block[index] = poptorch.BeginBlock(layer, f"Decoder{index}", ipu_id=ipu)
logger.info(f"Decoder {index:<2} --> IPU {ipu}")

Expand Down

0 comments on commit 69ec329

Please sign in to comment.