Skip to content

Commit e2123d3

Browse files
committed
Merge remote-tracking branch 'origin/master' into dolci/fix_sphinx_fail
2 parents 6199ea8 + 9bfe70f commit e2123d3

File tree

10 files changed

+729
-81
lines changed

10 files changed

+729
-81
lines changed

docs/source/documentation/index.rst

+1-2
Original file line numberDiff line numberDiff line change
@@ -361,8 +361,7 @@ a tape. The current working tape can be set and retrieved with the functions :py
361361
:py:func:`get_working_tape`.
362362

363363
Annotation can be temporarily disabled using :py:func:`pause_annotation` and enabled again using :py:func:`continue_annotation`.
364-
Note that if you call :py:func:`pause_annotation` twice, then :py:func:`continue_annotation` must be called twice
365-
to enable annotation. Due to this, the recommended annotation control functions are :py:class:`stop_annotating` and :py:func:`no_annotations`.
364+
It is recommended to use :py:class:`stop_annotating` and :py:func:`no_annotations` for annotation control.
366365
:py:class:`stop_annotating` is a context manager and should be used as follows
367366

368367
.. code-block:: python

docs/source/documentation/pyadjoint_api.rst

+2
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ Core classes
1919
.. automethod:: add_block
2020
.. automethod:: visualise
2121
.. autoproperty:: progress_bar
22+
.. automethod:: end_timestep
23+
.. automethod:: timestepper
2224

2325
.. autoclass:: Block
2426

pyadjoint/block.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def add_dependency(self, dep, no_duplicates=False):
5050
5151
"""
5252
if not no_duplicates or dep.block_variable not in self._dependencies:
53-
dep._ad_will_add_as_dependency()
53+
dep.block_variable.will_add_as_dependency()
5454
self._dependencies.append(dep.block_variable)
5555

5656
def get_dependencies(self):

pyadjoint/block_variable.py

+19-5
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .tape import no_annotations
1+
from .tape import no_annotations, get_working_tape
22

33

44
class BlockVariable(object):
@@ -16,6 +16,10 @@ def __init__(self, output):
1616
self.floating_type = False
1717
# Helper flag for use during tape traversals.
1818
self.marked_in_path = False
19+
# By default assume the variable is created externally to the tape.
20+
self.creation_timestep = -1
21+
# The timestep during which this variable was last used as an input.
22+
self.last_use = -1
1923

2024
def add_adj_output(self, val):
2125
if self.adj_value is None:
@@ -59,13 +63,23 @@ def saved_output(self):
5963

6064
def will_add_as_dependency(self):
6165
overwrite = self.output._ad_will_add_as_dependency()
62-
overwrite = False if overwrite is None else overwrite
63-
self.save_output(overwrite=overwrite)
66+
overwrite = bool(overwrite)
67+
tape = get_working_tape()
68+
if self.last_use < tape.latest_checkpoint:
69+
self.save_output(overwrite=overwrite)
70+
tape.add_to_checkpointable_state(self, self.last_use)
71+
self.last_use = tape.latest_timestep
6472

6573
def will_add_as_output(self):
74+
tape = get_working_tape()
75+
self.creation_timestep = tape.latest_timestep
76+
self.last_use = self.creation_timestep
6677
overwrite = self.output._ad_will_add_as_output()
67-
overwrite = True if overwrite is None else overwrite
68-
self.save_output(overwrite=overwrite)
78+
overwrite = bool(overwrite)
79+
if not overwrite:
80+
self._checkpoint = None
81+
if tape._eagerly_checkpoint_outputs:
82+
self.save_output()
6983

7084
def __str__(self):
7185
return str(self.output)

0 commit comments

Comments
 (0)