From f250f8620b9e823d92af36a43226acf5088a5b01 Mon Sep 17 00:00:00 2001 From: Martin Uhrin Date: Thu, 9 Nov 2017 11:56:51 +0000 Subject: [PATCH] Fixes 902 There was a mistake in the logic of the 'if' block of workchains when loading. A position varaiable was being used to keep track of how many times that conditional block was ticked. This was being upped each time but also being used to determine which of possible conditional branches the condition was at e.g.: if(...)( <-- pos 0 ) elif(...) <-- pos 1 ) else( ) Long story short, when the condition was reloaded from a saved state it was possible that pos was larger than the number of conditions (usual just one if there's only an if) and it couldn't resume from where it was when it was saved. --- aiida/backends/tests/work/workChain.py | 60 ++++++++++++++++++++++++-- aiida/work/workchain.py | 40 ++++++++++------- 2 files changed, 82 insertions(+), 18 deletions(-) diff --git a/aiida/backends/tests/work/workChain.py b/aiida/backends/tests/work/workChain.py index 2f662f8ff8..9ab8eb3539 100644 --- a/aiida/backends/tests/work/workChain.py +++ b/aiida/backends/tests/work/workChain.py @@ -14,6 +14,7 @@ from aiida.backends.testbase import AiidaTestCase from plum.engine.ticking import TickingEngine +from plum.persistence.bundle import Bundle import plum.process_monitor from aiida.orm.calculation.work import WorkCalculation from aiida.work.workchain import WorkChain, \ @@ -61,7 +62,7 @@ def __init__(self): [self.s1.__name__, self.s2.__name__, self.s3.__name__, self.s4.__name__, self.s5.__name__, self.s6.__name__, self.isA.__name__, self.isB.__name__, self.ltN.__name__] - } + } def s1(self): self._set_finished(inspect.stack()[0][3]) @@ -120,6 +121,33 @@ def test_dict(self): c['new_attr'] +class IfTest(WorkChain): + @classmethod + def define(cls, spec): + super(IfTest, cls).define(spec) + spec.outline( + if_(cls.condition)( + cls.step1, + cls.step2 + ) + ) + + def on_create(self, pid, inputs, saved_state): + super(IfTest, self).on_create(pid, inputs, saved_state) + if saved_state is None: + self.ctx.s1 = False + self.ctx.s2 = False + + def condition(self): + return True + + def step1(self): + self.ctx.s1 = True + + def step2(self): + self.ctx.s2 = True + + class TestWorkchain(AiidaTestCase): def setUp(self): super(TestWorkchain, self).setUp() @@ -321,6 +349,29 @@ def run(self): run(MainWorkChain) + def test_if_block_persistence(self): + """ This test was created to capture issue #902 """ + wc = IfTest.new_instance() + + while not wc.ctx.s1 and not wc.has_finished(): + wc.tick() + self.assertTrue(wc.ctx.s1) + self.assertFalse(wc.ctx.s2) + + # Now bundle the thing + b = Bundle() + wc.save_instance_state(b) + # Abort the current one + wc.stop() + wc.destroy(execute=True) + + # Load from saved tate + wc = IfTest.create_from(b) + self.assertTrue(wc.ctx.s1) + self.assertFalse(wc.ctx.s2) + + wc.run_until_complete() + def _run_with_checkpoints(self, wf_class, inputs=None): finished_steps = {} @@ -336,7 +387,6 @@ def _run_with_checkpoints(self, wf_class, inputs=None): class TestWorkchainWithOldWorkflows(AiidaTestCase): - def setUp(self): super(TestWorkchainWithOldWorkflows, self).setUp() import logging @@ -409,10 +459,12 @@ def test_get_proc_outputs(self): self.assertEquals(outputs['a'], a) self.assertEquals(outputs['b'], b) + class TestWorkChainAbort(AiidaTestCase): """ Test the functionality to abort a workchain """ + class AbortableWorkChain(WorkChain): @classmethod def define(cls, spec): @@ -490,11 +542,13 @@ def test_simple_kill_through_process(self): self.assertEquals(future.process.calc.has_aborted(), True) engine.shutdown() + class TestWorkChainAbortChildren(AiidaTestCase): """ Test the functionality to abort a workchain and verify that children are also aborted appropriately """ + class SubWorkChain(WorkChain): @classmethod def define(cls, spec): @@ -575,4 +629,4 @@ def test_simple_kill_through_node(self): self.assertEquals(future.process.calc.has_finished_ok(), False) self.assertEquals(future.process.calc.has_failed(), False) self.assertEquals(future.process.calc.has_aborted(), True) - engine.shutdown() \ No newline at end of file + engine.shutdown() diff --git a/aiida/work/workchain.py b/aiida/work/workchain.py index f35f630adf..e5a2564ad9 100644 --- a/aiida/work/workchain.py +++ b/aiida/work/workchain.py @@ -342,6 +342,7 @@ def abort(self, msg=None, timeout=None): self._aborted = True self.stop() + def ToContext(**kwargs): """ Utility function that returns a list of UpdateContext Interstep instances @@ -366,6 +367,7 @@ class _InterstepFactory(object): Factory to create the appropriate Interstep instance based on the class string that was written to the bundle """ + def create(self, bundle): class_string = bundle[Bundle.CLASS] if class_string == get_class_string(ToContext): @@ -567,22 +569,21 @@ class Stepper(Stepper): def __init__(self, workflow, if_spec): super(_If.Stepper, self).__init__(workflow) self._if_spec = if_spec - self._pos = 0 + self._pos = -1 self._current_stepper = None def step(self): if self._current_stepper is None: - stepper = self._get_next_stepper() - # If we can't get a stepper then no conditions match, return - if stepper is None: - return True, None - self._current_stepper = stepper + self._create_stepper() + + # If we can't get a stepper then no conditions match, return + if self._current_stepper is None: + return True, None finished, retval = self._current_stepper.step() if finished: self._current_stepper = None - else: - self._pos += 1 + self._pos = -1 return finished, retval @@ -596,15 +597,24 @@ def save_position(self, out_position): def load_position(self, bundle): self._pos = bundle[self._POSITION] if self._STEPPER_POS in bundle: - self._current_stepper = self._get_next_stepper() + self._create_stepper() self._current_stepper.load_position(bundle[self._STEPPER_POS]) + else: + self._current_stepper = None - def _get_next_stepper(self): - # Check the conditions until we find that that is true - for conditional in self._if_spec.conditionals[self._pos:]: - if conditional.is_true(self._workflow): - return conditional.body.create_stepper(self._workflow) - return None + def _create_stepper(self): + if self._pos == -1: + self._current_stepper = None + # Check the conditions until we find one that is true + for idx, condition in enumerate(self._if_spec.conditionals): + if condition.is_true(self._workflow): + stepper = condition.body.create_stepper(self._workflow) + self._pos = idx + self._current_stepper = stepper + return + else: + branch = self._if_spec.conditionals[self._pos] + self._current_stepper = branch.body.create_stepper(self._workflow) def __init__(self, condition): super(_If, self).__init__()