Skip to content

Commit

Permalink
Fix docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Mar 3, 2025
1 parent ff427e5 commit 2688bfc
Showing 1 changed file with 26 additions and 25 deletions.
51 changes: 26 additions & 25 deletions keras/src/backend/common/remat.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,30 +156,31 @@ def remat(f):
pass, the forward computation is recomputed as needed.
Example:
```python
from keras import Model
class CustomRematLayer(layers.Layer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.remat_function = remat(self.intermediate_function)
def intermediate_function(self, x):
for _ in range(2):
x = x + x * 0.1 # Simple scaled transformation
return x
def call(self, inputs):
return self.remat_function(inputs)
# Define a simple model using the custom layer
inputs = layers.Input(shape=(4,))
x = layers.Dense(4, activation="relu")(inputs)
x = CustomRematLayer()(x) # Custom layer with rematerialization
outputs = layers.Dense(1)(x)
# Create and compile the model
model = Model(inputs=inputs, outputs=outputs)
model.compile(optimizer="sgd", loss="mse")
```
```python
from keras import Model
class CustomRematLayer(layers.Layer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.remat_function = remat(self.intermediate_function)
def intermediate_function(self, x):
for _ in range(2):
x = x + x * 0.1 # Simple scaled transformation
return x
def call(self, inputs):
return self.remat_function(inputs)
# Define a simple model using the custom layer
inputs = layers.Input(shape=(4,))
x = layers.Dense(4, activation="relu")(inputs)
x = CustomRematLayer()(x) # Custom layer with rematerialization
outputs = layers.Dense(1)(x)
# Create and compile the model
model = Model(inputs=inputs, outputs=outputs)
model.compile(optimizer="sgd", loss="mse")
```
"""
return backend.core.remat(f)

0 comments on commit 2688bfc

Please sign in to comment.