Skip to content

Commit

Permalink
refactor DecisionFramework for testability
Browse files Browse the repository at this point in the history
  • Loading branch information
ahouseholder committed Mar 3, 2025
1 parent d953c6a commit e8c94dd
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 21 deletions.
39 changes: 31 additions & 8 deletions src/ssvc/framework/decision_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from ssvc._mixins import _Base, _Namespaced, _Versioned
from ssvc.dp_groups.base import SsvcDecisionPointGroup
from ssvc.outcomes.base import OutcomeGroup, OutcomeValue
from ssvc.outcomes.base import OutcomeGroup
from ssvc.policy_generator import PolicyGenerator


Expand All @@ -37,12 +37,19 @@ class DecisionFramework(_Versioned, _Namespaced, _Base, BaseModel):

decision_point_group: SsvcDecisionPointGroup
outcome_group: OutcomeGroup
mapping: dict[tuple[tuple[str, str], ...], OutcomeValue]
mapping: dict[str, str]

def populate_mapping(self):
def __init__(self, **data):
super().__init__(**data)

if not self.mapping:
self.mapping = self.generate_mapping()

def generate_mapping(self) -> dict[str, str]:
"""
Populate the mapping with all possible combinations of decision points.
"""
mapping = {}
dp_lookup = {
dp.name.lower(): dp for dp in self.decision_point_group.decision_points
}
Expand Down Expand Up @@ -84,21 +91,37 @@ def populate_mapping(self):
dp = dp_lookup[col]
val = value_lookup[val]

k = (f"{dp.name}:{dp.version}", f"{val.name}")
key_delim = ":"
k = key_delim.join([dp.namespace, dp.key, val.key])
dp_values.append(k)

key = tuple(dp_values)
self.mapping[key] = outcome
key = ",".join([str(k) for k in dp_values])

outcome_group = self.outcome_group
outcome_str = ":".join([outcome_group.key, outcome.key])

mapping[key] = outcome_str

return self.mapping
return mapping


# convenience alias
Policy = DecisionFramework


def main():
pass
from ssvc.dp_groups.ssvc.supplier import LATEST as dpg
from ssvc.outcomes.groups import MOSCOW as og

dfw = DecisionFramework(
name="Example Decision Framework",
description="The description for an Example Decision Framework",
version="1.0.0",
decision_point_group=dpg,
outcome_group=og,
mapping={},
)
print(dfw.model_dump_json(indent=2))


if __name__ == "__main__":
Expand Down
27 changes: 14 additions & 13 deletions src/test/test_decision_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from ssvc.decision_points.system_exposure import LATEST as exposure_dp
from ssvc.dp_groups.base import SsvcDecisionPointGroup
from ssvc.framework.decision_framework import DecisionFramework
from ssvc.outcomes.base import OutcomeGroup
from ssvc.outcomes.groups import DSOI as dsoi_og


Expand All @@ -33,11 +32,7 @@ def setUp(self):
description="Test Decision Point Group Description",
decision_points=[exploitation_dp, exposure_dp, safety_dp],
),
outcome_group=OutcomeGroup(
name="Test Outcome Group",
description="Test Outcome Group Description",
outcomes=dsoi_og,
),
outcome_group=dsoi_og,
mapping={},
)

Expand All @@ -51,22 +46,28 @@ def test_create(self):
self.assertEqual(3, len(self.framework.decision_point_group))

def test_populate_mapping(self):
result = self.framework.populate_mapping()
result = self.framework.generate_mapping()

# there should be one row in result for each combination of decision points
combo_count = len(list(self.framework.decision_point_group.combinations()))
self.assertEqual(len(result), combo_count)

# the length of each key should be the number of decision points
for key in result.keys():
self.assertEqual(len(key), 3)
for i, (dp_name_version, dp_value_name) in enumerate(key):
parts = key.split(",")
self.assertEqual(len(parts), 3)
for i, keypart in enumerate(parts):
dp_namespace, dp_key, dp_value_key = keypart.split(":")

dp = self.framework.decision_point_group.decision_points[i]
name_version = f"{dp.name}:{dp.version}"
self.assertEqual(name_version, dp_name_version)
self.assertEqual(dp_namespace, dp.namespace)
self.assertEqual(dp_key, dp.key)
value_keys = [v.key for v in dp.values]
self.assertIn(dp_value_key, value_keys)

value_names = [v.name for v in dp.values]
self.assertIn(dp_value_name, value_names)
print()
print()
print(self.framework.model_dump_json(indent=2))


if __name__ == "__main__":
Expand Down

0 comments on commit e8c94dd

Please sign in to comment.