From 12b7946f3d6bcab19766816e031856a95dfb6671 Mon Sep 17 00:00:00 2001 From: "Joel Z. Leibo" Date: Mon, 25 Nov 2024 03:28:08 -0800 Subject: [PATCH] Improve a few baseline agents PiperOrigin-RevId: 699918558 Change-Id: I9e8fc1c130acf857d03f8f0de22fb9e77f1b4fbb --- .../contrib/components/agent/__init__.py | 2 + .../agent/observations_since_last_update.py | 140 ++++++++ .../situation_representation_via_narrative.py | 204 +++++++++++ .../factory/agent/alternative_basic_agent.py | 173 +--------- .../agent/alternative_rational_agent.py | 321 ++++++++++++++++++ concordia/factory/agent/factories_test.py | 14 + .../agent/observe_and_summarize_agent.py | 229 +++++++++++++ .../agent/parochial_universalization_agent.py | 170 +--------- 8 files changed, 916 insertions(+), 337 deletions(-) create mode 100644 concordia/contrib/components/agent/observations_since_last_update.py create mode 100644 concordia/contrib/components/agent/situation_representation_via_narrative.py create mode 100644 concordia/factory/agent/alternative_rational_agent.py create mode 100644 concordia/factory/agent/observe_and_summarize_agent.py diff --git a/concordia/contrib/components/agent/__init__.py b/concordia/contrib/components/agent/__init__.py index bff7fb04..4da1ef86 100644 --- a/concordia/contrib/components/agent/__init__.py +++ b/concordia/contrib/components/agent/__init__.py @@ -16,3 +16,5 @@ from concordia.contrib.components.agent import affect_reflection from concordia.contrib.components.agent import dialectical_reflection +from concordia.contrib.components.agent import observations_since_last_update +from concordia.contrib.components.agent import situation_representation_via_narrative diff --git a/concordia/contrib/components/agent/observations_since_last_update.py b/concordia/contrib/components/agent/observations_since_last_update.py new file mode 100644 index 00000000..f37e4a06 --- /dev/null +++ b/concordia/contrib/components/agent/observations_since_last_update.py @@ -0,0 +1,140 @@ +# Copyright 2024 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A component for tracking observations since the last update. +""" + +from collections.abc import Callable +import datetime + +from absl import logging as absl_logging +from concordia.components import agent as agent_components +from concordia.components.agent import action_spec_ignored +from concordia.components.agent import memory_component +from concordia.language_model import language_model +from concordia.memory_bank import legacy_associative_memory +from concordia.typing import entity_component +from concordia.typing import logging + + +def _get_earliest_timepoint( + memory_component_: agent_components.memory_component.MemoryComponent, +) -> datetime.datetime: + """Returns all memories in the memory bank. + + Args: + memory_component_: The memory component to retrieve memories from. + """ + memories_data_frame = memory_component_.get_raw_memory() + if not memories_data_frame.empty: + sorted_memories_data_frame = memories_data_frame.sort_values( + 'time', ascending=True) + return sorted_memories_data_frame['time'][0] + else: + absl_logging.warn('No memories found in memory bank.') + return datetime.datetime.now() + + +class ObservationsSinceLastUpdate(action_spec_ignored.ActionSpecIgnored): + """Report all observations since the last update.""" + + def __init__( + self, + model: language_model.LanguageModel, + clock_now: Callable[[], datetime.datetime], + memory_component_name: str = ( + memory_component.DEFAULT_MEMORY_COMPONENT_NAME + ), + pre_act_key: str = '\nObservations', + logging_channel: logging.LoggingChannel = logging.NoOpLoggingChannel, + ): + """Initialize a component to consider the latest observations. + + Args: + model: The language model to use. + clock_now: Function that returns the current time. + memory_component_name: The name of the memory component from which to + retrieve related memories. + pre_act_key: Prefix to add to the output of the component when called + in `pre_act`. + logging_channel: The channel to log debug information to. + """ + super().__init__(pre_act_key) + self._model = model + self._clock_now = clock_now + self._memory_component_name = memory_component_name + self._logging_channel = logging_channel + + self._previous_time = None + + def pre_observe( + self, + observation: str, + ) -> str: + memory = self.get_entity().get_component( + self._memory_component_name, + type_=memory_component.MemoryComponent) + memory.add( + f'[observation] {observation}', + metadata={'tags': ['observation']}, + ) + return '' + + def _make_pre_act_value(self) -> str: + """Returns a representation of the current situation to pre act.""" + current_time = self._clock_now() + memory = self.get_entity().get_component( + self._memory_component_name, + type_=memory_component.MemoryComponent) + + if self._previous_time is None: + self._previous_time = _get_earliest_timepoint(memory) + + interval_scorer = legacy_associative_memory.RetrieveTimeInterval( + time_from=self._previous_time, + time_until=current_time, + add_time=True, + ) + mems = [mem.text for mem in memory.retrieve(scoring_fn=interval_scorer)] + result = '\n'.join(mems) + '\n' + + self._logging_channel({ + 'Key': self.get_pre_act_key(), + 'Value': result, + }) + + self._previous_time = current_time + + return result + + def get_state(self) -> entity_component.ComponentState: + """Converts the component to JSON data.""" + with self._lock: + if self._previous_time is None: + previous_time = '' + else: + previous_time = self._previous_time.strftime('%Y-%m-%d %H:%M:%S') + return { + 'previous_time': previous_time, + } + + def set_state(self, state: entity_component.ComponentState) -> None: + """Sets the component state from JSON data.""" + with self._lock: + if state['previous_time']: + previous_time = datetime.datetime.strptime( + state['previous_time'], '%Y-%m-%d %H:%M:%S') + else: + previous_time = None + self._previous_time = previous_time diff --git a/concordia/contrib/components/agent/situation_representation_via_narrative.py b/concordia/contrib/components/agent/situation_representation_via_narrative.py new file mode 100644 index 00000000..3a4d4281 --- /dev/null +++ b/concordia/contrib/components/agent/situation_representation_via_narrative.py @@ -0,0 +1,204 @@ +# Copyright 2024 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A component for representing the current situation via narrative. +""" + +from collections.abc import Callable, Sequence +import datetime + +from absl import logging as absl_logging +from concordia.components import agent as agent_components +from concordia.components.agent import action_spec_ignored +from concordia.components.agent import memory_component +from concordia.document import interactive_document +from concordia.language_model import language_model +from concordia.memory_bank import legacy_associative_memory +from concordia.typing import entity_component +from concordia.typing import logging +from concordia.typing import memory as memory_lib + + +def _get_all_memories( + memory_component_: agent_components.memory_component.MemoryComponent, + add_time: bool = True, + sort_by_time: bool = True, + constant_score: float = 0.0, +) -> Sequence[memory_lib.MemoryResult]: + """Returns all memories in the memory bank. + + Args: + memory_component_: The memory component to retrieve memories from. + add_time: whether to add time + sort_by_time: whether to sort by time + constant_score: assign this score value to each memory + """ + texts = memory_component_.get_all_memories_as_text(add_time=add_time, + sort_by_time=sort_by_time) + return [memory_lib.MemoryResult(text=t, score=constant_score) for t in texts] + + +def _get_earliest_timepoint( + memory_component_: agent_components.memory_component.MemoryComponent, +) -> datetime.datetime: + """Returns all memories in the memory bank. + + Args: + memory_component_: The memory component to retrieve memories from. + """ + memories_data_frame = memory_component_.get_raw_memory() + if not memories_data_frame.empty: + sorted_memories_data_frame = memories_data_frame.sort_values( + 'time', ascending=True) + return sorted_memories_data_frame['time'][0] + else: + absl_logging.warn('No memories found in memory bank.') + return datetime.datetime.now() + + +class SituationRepresentation(action_spec_ignored.ActionSpecIgnored): + """Consider ``what kind of situation am I in now?``.""" + + def __init__( + self, + model: language_model.LanguageModel, + clock_now: Callable[[], datetime.datetime], + memory_component_name: str = ( + memory_component.DEFAULT_MEMORY_COMPONENT_NAME + ), + pre_act_key: str = 'The current situation', + logging_channel: logging.LoggingChannel = logging.NoOpLoggingChannel, + ): + """Initialize a component to consider the current situation. + + Args: + model: The language model to use. + clock_now: Function that returns the current time. + memory_component_name: The name of the memory component from which to + retrieve related memories. + pre_act_key: Prefix to add to the output of the component when called + in `pre_act`. + logging_channel: The channel to log debug information to. + """ + super().__init__(pre_act_key) + self._model = model + self._clock_now = clock_now + self._memory_component_name = memory_component_name + self._logging_channel = logging_channel + + self._previous_time = None + self._situation_thus_far = None + + def _make_pre_act_value(self) -> str: + """Returns a representation of the current situation to pre act.""" + agent_name = self.get_entity().name + current_time = self._clock_now() + memory = self.get_entity().get_component( + self._memory_component_name, + type_=memory_component.MemoryComponent) + + initial_step_thought_chain = '' + if self._situation_thus_far is None: + self._previous_time = _get_earliest_timepoint(memory) + chain_of_thought = interactive_document.InteractiveDocument(self._model) + chain_of_thought.statement('~~ Creative Writing Assignment ~~') + chain_of_thought.statement(f'Protagonist: {agent_name}') + mems = '\n'.join([mem.text for mem in _get_all_memories(memory)]) + chain_of_thought.statement(f'Story fragments and world data:\n{mems}') + chain_of_thought.statement(f'Events continue after {current_time}') + self._situation_thus_far = chain_of_thought.open_question( + question=( + 'Narratively summarize the story fragments and world data. Give ' + 'special emphasis to atypical features of the setting such as ' + 'when and where the story takes place as well as any causal ' + 'mechanisms or affordances mentioned in the information ' + 'provided. Highlight the goals, personalities, occupations, ' + 'skills, and affordances of the named characters and ' + 'relationships between them. If any specific numbers were ' + 'mentioned then make sure to include them. Use third-person ' + 'omniscient perspective.'), + max_tokens=1000, + terminators=(), + question_label='Exercise') + initial_step_thought_chain = '\n'.join( + chain_of_thought.view().text().splitlines()) + + interval_scorer = legacy_associative_memory.RetrieveTimeInterval( + time_from=self._previous_time, + time_until=current_time, + add_time=True, + ) + mems = [mem.text for mem in memory.retrieve(scoring_fn=interval_scorer)] + result = '\n'.join(mems) + '\n' + chain_of_thought = interactive_document.InteractiveDocument(self._model) + chain_of_thought.statement(f'Context:\n{self._situation_thus_far}') + chain_of_thought.statement(f'Protagonist: {agent_name}') + chain_of_thought.statement( + f'Thoughts and memories of {agent_name}:\n{result}' + ) + self._situation_thus_far = chain_of_thought.open_question( + question=( + 'What situation does the protagonist find themselves in? ' + 'Make sure to provide enough detail to give the ' + 'reader a comprehensive understanding of the world ' + 'inhabited by the protagonist, their affordances in that ' + 'world, actions they may be able to take, effects their ' + 'actions may produce, and what is currently going on. If any ' + 'specific numbers were mentioned then make sure to include them.' + 'Also, make sure to repeat all details of the context that could ' + 'ever be relevant, now or in the future.' + ), + max_tokens=1000, + terminators=(), + question_label='Exercise', + ) + chain_of_thought.statement(f'The current date and time is {current_time}') + + chain_of_thought_text = '\n'.join( + chain_of_thought.view().text().splitlines()) + + self._logging_channel({ + 'Key': self.get_pre_act_key(), + 'Value': self._situation_thus_far, + 'Chain of thought': (initial_step_thought_chain + + '\n***\n' + + chain_of_thought_text), + }) + + self._previous_time = current_time + + return self._situation_thus_far + + def get_state(self) -> entity_component.ComponentState: + """Converts the component to JSON data.""" + with self._lock: + if self._previous_time is None: + previous_time = '' + else: + previous_time = self._previous_time.strftime('%Y-%m-%d %H:%M:%S') + return { + 'previous_time': previous_time, + 'situation_thus_far': self._situation_thus_far, + } + + def set_state(self, state: entity_component.ComponentState) -> None: + """Sets the component state from JSON data.""" + with self._lock: + if state['previous_time']: + previous_time = datetime.datetime.strptime( + state['previous_time'], '%Y-%m-%d %H:%M:%S') + else: + previous_time = None + self._previous_time = previous_time + self._situation_thus_far = state['situation_thus_far'] diff --git a/concordia/factory/agent/alternative_basic_agent.py b/concordia/factory/agent/alternative_basic_agent.py index 40cd835a..a3b791cc 100644 --- a/concordia/factory/agent/alternative_basic_agent.py +++ b/concordia/factory/agent/alternative_basic_agent.py @@ -14,24 +14,19 @@ """A factory implementing the three key questions agent as an entity.""" -from collections.abc import Callable, Sequence +from collections.abc import Callable import datetime import json -from absl import logging as absl_logging from concordia.agents import entity_agent_with_logging from concordia.associative_memory import associative_memory from concordia.associative_memory import formative_memories from concordia.clocks import game_clock from concordia.components import agent as agent_components -from concordia.components.agent import action_spec_ignored -from concordia.components.agent import memory_component -from concordia.document import interactive_document +from concordia.contrib.components.agent import situation_representation_via_narrative from concordia.language_model import language_model from concordia.memory_bank import legacy_associative_memory from concordia.typing import entity_component -from concordia.typing import logging -from concordia.typing import memory as memory_lib from concordia.utils import measurements as measurements_lib import numpy as np @@ -45,168 +40,6 @@ def _get_class_name(object_: object) -> str: return object_.__class__.__name__ -def _get_all_memories( - memory_component_: agent_components.memory_component.MemoryComponent, - add_time: bool = True, - sort_by_time: bool = True, - constant_score: float = 0.0, -) -> Sequence[memory_lib.MemoryResult]: - """Returns all memories in the memory bank. - - Args: - memory_component_: The memory component to retrieve memories from. - add_time: whether to add time - sort_by_time: whether to sort by time - constant_score: assign this score value to each memory - """ - texts = memory_component_.get_all_memories_as_text(add_time=add_time, - sort_by_time=sort_by_time) - return [memory_lib.MemoryResult(text=t, score=constant_score) for t in texts] - - -def _get_earliest_timepoint( - memory_component_: agent_components.memory_component.MemoryComponent, -) -> datetime.datetime: - """Returns all memories in the memory bank. - - Args: - memory_component_: The memory component to retrieve memories from. - """ - memories_data_frame = memory_component_.get_raw_memory() - if not memories_data_frame.empty: - sorted_memories_data_frame = memories_data_frame.sort_values( - 'time', ascending=True) - return sorted_memories_data_frame['time'][0] - else: - absl_logging.warn('No memories found in memory bank.') - return datetime.datetime.now() - - -class SituationRepresentation(action_spec_ignored.ActionSpecIgnored): - """Consider ``what kind of situation am I in now?``.""" - - def __init__( - self, - model: language_model.LanguageModel, - clock_now: Callable[[], datetime.datetime], - memory_component_name: str = ( - memory_component.DEFAULT_MEMORY_COMPONENT_NAME - ), - pre_act_key: str = 'The current situation', - logging_channel: logging.LoggingChannel = logging.NoOpLoggingChannel, - ): - """Initialize a component to consider the current situation. - - Args: - model: The language model to use. - clock_now: Function that returns the current time. - memory_component_name: The name of the memory component from which to - retrieve related memories. - pre_act_key: Prefix to add to the output of the component when called - in `pre_act`. - logging_channel: The channel to log debug information to. - """ - super().__init__(pre_act_key) - self._model = model - self._clock_now = clock_now - self._memory_component_name = memory_component_name - self._logging_channel = logging_channel - - self._previous_time = None - self._situation_thus_far = None - - def _make_pre_act_value(self) -> str: - """Returns a representation of the current situation to pre act.""" - agent_name = self.get_entity().name - current_time = self._clock_now() - memory = self.get_entity().get_component( - self._memory_component_name, - type_=memory_component.MemoryComponent) - - if self._situation_thus_far is None: - self._previous_time = _get_earliest_timepoint(memory) - chain_of_thought = interactive_document.InteractiveDocument(self._model) - chain_of_thought.statement('~~ Creative Writing Assignment ~~') - chain_of_thought.statement(f'Protagonist: {agent_name}') - mems = '\n'.join([mem.text for mem in _get_all_memories(memory)]) - chain_of_thought.statement(f'Story fragments and world data:\n{mems}') - chain_of_thought.statement(f'Events continue after {current_time}') - self._situation_thus_far = chain_of_thought.open_question( - question=( - 'Narratively summarize the story fragments and world data. Give ' - 'special emphasis to atypical features of the setting such as ' - 'when and where the story takes place as well as any causal ' - 'mechanisms or affordances mentioned in the information ' - 'provided. Highlight the goals, personalities, occupations, ' - 'skills, and affordances of the named characters and ' - 'relationships between them. Use third-person omniscient ' - 'perspective.'), - max_tokens=1000, - terminators=(), - question_label='Exercise') - - interval_scorer = legacy_associative_memory.RetrieveTimeInterval( - time_from=self._previous_time, - time_until=current_time, - add_time=True, - ) - mems = [mem.text for mem in memory.retrieve(scoring_fn=interval_scorer)] - result = '\n'.join(mems) + '\n' - chain_of_thought = interactive_document.InteractiveDocument(self._model) - chain_of_thought.statement(f'Context:\n{self._situation_thus_far}') - chain_of_thought.statement(f'Protagonist: {agent_name}') - chain_of_thought.statement( - f'Thoughts and memories of {agent_name}:\n{result}' - ) - self._situation_thus_far = chain_of_thought.open_question( - question=( - 'What situation does the protagonist find themselves in? ' - 'Make sure to provide enough detail to give the ' - 'reader a comprehensive understanding of the world ' - 'inhabited by the protagonist, their affordances in that ' - 'world, actions they may be able to take, effects their ' - 'actions may produce, and what is currently going on.' - ), - max_tokens=1000, - terminators=(), - question_label='Exercise', - ) - chain_of_thought.statement(f'The current date and time is {current_time}') - - self._logging_channel({ - 'Key': self.get_pre_act_key(), - 'Value': self._situation_thus_far, - 'Chain of thought': chain_of_thought.view().text().splitlines(), - }) - - self._previous_time = current_time - - return self._situation_thus_far - - def get_state(self) -> entity_component.ComponentState: - """Converts the component to JSON data.""" - with self._lock: - if self._previous_time is None: - previous_time = '' - else: - previous_time = self._previous_time.strftime('%Y-%m-%d %H:%M:%S') - return { - 'previous_time': previous_time, - 'situation_thus_far': self._situation_thus_far, - } - - def set_state(self, state: entity_component.ComponentState) -> None: - """Sets the component state from JSON data.""" - with self._lock: - if state['previous_time']: - previous_time = datetime.datetime.strptime( - state['previous_time'], '%Y-%m-%d %H:%M:%S') - else: - previous_time = None - self._previous_time = previous_time - self._situation_thus_far = state['situation_thus_far'] - - def build_agent( *, config: formative_memories.AgentConfig, @@ -260,7 +93,7 @@ def build_agent( situation_representation_label = ( f'\nQuestion: What situation is {agent_name} in right now?\nAnswer') situation_representation = ( - SituationRepresentation( + situation_representation_via_narrative.SituationRepresentation( model=model, clock_now=clock.now, pre_act_key=situation_representation_label, diff --git a/concordia/factory/agent/alternative_rational_agent.py b/concordia/factory/agent/alternative_rational_agent.py new file mode 100644 index 00000000..086132af --- /dev/null +++ b/concordia/factory/agent/alternative_rational_agent.py @@ -0,0 +1,321 @@ +# Copyright 2024 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""An Agent Factory.""" + +from collections.abc import Callable +import datetime +import json + +from concordia.agents import entity_agent_with_logging +from concordia.associative_memory import associative_memory +from concordia.associative_memory import formative_memories +from concordia.clocks import game_clock +from concordia.components import agent as agent_components +from concordia.contrib.components.agent import situation_representation_via_narrative +from concordia.language_model import language_model +from concordia.memory_bank import legacy_associative_memory +from concordia.typing import entity_component +from concordia.utils import measurements as measurements_lib +import numpy as np + + +DEFAULT_INSTRUCTIONS_COMPONENT_KEY = 'Instructions' +DEFAULT_INSTRUCTIONS_PRE_ACT_KEY = '\nInstructions' +DEFAULT_GOAL_COMPONENT_KEY = 'Goal' + + +def _get_class_name(object_: object) -> str: + return object_.__class__.__name__ + + +class AvailableOptionsPerception( + agent_components.question_of_recent_memories.QuestionOfRecentMemories): + """This component answers the question 'what actions are available to me?'.""" + + def __init__(self, **kwargs): + + super().__init__( + question=( + 'Given the information above, what options are available to ' + '{agent_name} right now? Make sure not to consider too few ' + 'alternatives. Brainstorm at least three options. Try to include ' + 'actions that seem most likely to be effective along with some ' + 'creative or unusual choices that could also plausibly work.' + ), + terminators=('\n\n',), + answer_prefix='', + add_to_memory=False, + **kwargs, + ) + + +class BestOptionPerception( + agent_components.question_of_recent_memories.QuestionOfRecentMemories): + """This component answers 'which action is best for achieving my goal?'.""" + + def __init__(self, **kwargs): + super().__init__( + question=( + "Given the information above, which of {agent_name}'s options " + 'has the highest likelihood of causing {agent_name} to achieve ' + 'their goal? If multiple options have the same likelihood, select ' + 'the option that {agent_name} thinks will most quickly and most ' + 'surely achieve their goal. The right choice is nearly always ' + 'one that is proactive, involves seizing the initative, ' + 'resoving uncertainty, and decisively moving towards the goal.' + ), + answer_prefix="{agent_name}'s best course of action is ", + add_to_memory=False, + **kwargs, + ) + + +def build_agent( + *, + config: formative_memories.AgentConfig, + model: language_model.LanguageModel, + memory: associative_memory.AssociativeMemory, + clock: game_clock.MultiIntervalClock, + update_time_interval: datetime.timedelta | None = None, +) -> entity_agent_with_logging.EntityAgentWithLogging: + """Build an agent. + + Args: + config: The agent config to use. + model: The language model to use. + memory: The agent's memory object. + clock: The clock to use. + update_time_interval: Agent calls update every time this interval passes. + + Returns: + An agent. + """ + del update_time_interval + if not config.extras.get('main_character', False): + raise ValueError('This function is meant for a main character ' + 'but it was called on a supporting character.') + + agent_name = config.name + + raw_memory = legacy_associative_memory.AssociativeMemoryBank(memory) + + measurements = measurements_lib.Measurements() + instructions = agent_components.instructions.Instructions( + agent_name=agent_name, + pre_act_key=DEFAULT_INSTRUCTIONS_PRE_ACT_KEY, + logging_channel=measurements.get_channel('Instructions').on_next, + ) + + time_display = agent_components.report_function.ReportFunction( + function=clock.current_time_interval_str, + pre_act_key='\nCurrent time', + logging_channel=measurements.get_channel('TimeDisplay').on_next, + ) + + observation_label = '\nObservation' + observation = agent_components.observation.Observation( + clock_now=clock.now, + timeframe=clock.get_step_size(), + pre_act_key=observation_label, + logging_channel=measurements.get_channel('Observation').on_next, + ) + + situation_representation_label = ( + f'\nQuestion: What situation is {agent_name} in right now?\nAnswer') + situation_representation = ( + situation_representation_via_narrative.SituationRepresentation( + model=model, + clock_now=clock.now, + pre_act_key=situation_representation_label, + logging_channel=measurements.get_channel( + 'SituationRepresentation' + ).on_next, + ) + ) + + options_perception_components = {} + universalization_context_components = {} + best_option_perception = {} + if config.goal: + goal_label = f'{agent_name}\'s goal' + overarching_goal = agent_components.constant.Constant( + state=config.goal, + pre_act_key=goal_label, + logging_channel=measurements.get_channel(goal_label).on_next) + options_perception_components[DEFAULT_GOAL_COMPONENT_KEY] = goal_label + universalization_context_components[DEFAULT_GOAL_COMPONENT_KEY] = goal_label + best_option_perception[DEFAULT_GOAL_COMPONENT_KEY] = goal_label + else: + overarching_goal = None + + options_perception_components.update({ + DEFAULT_INSTRUCTIONS_COMPONENT_KEY: DEFAULT_INSTRUCTIONS_PRE_ACT_KEY, + _get_class_name(situation_representation): situation_representation_label, + _get_class_name(observation): observation_label, + }) + options_perception_label = ( + f'\nQuestion: Which options are available to {agent_name} ' + 'right now?\nAnswer') + options_perception = ( + AvailableOptionsPerception( + model=model, + components=options_perception_components, + clock_now=clock.now, + pre_act_key=options_perception_label, + num_memories_to_retrieve=0, + logging_channel=measurements.get_channel( + 'AvailableOptionsPerception' + ).on_next, + ) + ) + + best_option_perception_label = ( + f'\nQuestion: Of the options available to {agent_name}, and ' + 'given their goal, which choice of action or strategy is ' + f'best to take right now?\nAnswer') + best_option_perception.update({ + DEFAULT_INSTRUCTIONS_COMPONENT_KEY: DEFAULT_INSTRUCTIONS_PRE_ACT_KEY, + _get_class_name(options_perception): options_perception_label, + }) + best_option_perception = ( + agent_components.question_of_recent_memories.BestOptionPerception( + model=model, + components=best_option_perception, + clock_now=clock.now, + pre_act_key=best_option_perception_label, + num_memories_to_retrieve=0, + logging_channel=measurements.get_channel( + 'BestOptionPerception' + ).on_next, + ) + ) + + entity_components = ( + # Components that provide pre_act context. + time_display, + observation, + situation_representation, + options_perception, + best_option_perception, + ) + components_of_agent = {_get_class_name(component): component + for component in entity_components} + components_of_agent[ + agent_components.memory_component.DEFAULT_MEMORY_COMPONENT_NAME] = ( + agent_components.memory_component.MemoryComponent(raw_memory)) + component_order = list(components_of_agent.keys()) + + # Put the instructions first. + components_of_agent[DEFAULT_INSTRUCTIONS_COMPONENT_KEY] = instructions + component_order.insert(0, DEFAULT_INSTRUCTIONS_COMPONENT_KEY) + if overarching_goal is not None: + components_of_agent[DEFAULT_GOAL_COMPONENT_KEY] = overarching_goal + # Place goal after the instructions. + component_order.insert(1, DEFAULT_GOAL_COMPONENT_KEY) + + act_component = agent_components.concat_act_component.ConcatActComponent( + model=model, + clock=clock, + component_order=component_order, + logging_channel=measurements.get_channel('ActComponent').on_next, + ) + + agent = entity_agent_with_logging.EntityAgentWithLogging( + agent_name=agent_name, + act_component=act_component, + context_components=components_of_agent, + component_logging=measurements, + ) + + return agent + + +def save_to_json( + agent: entity_agent_with_logging.EntityAgentWithLogging, +) -> str: + """Saves an agent to JSON data. + + This function saves the agent's state to a JSON string, which can be loaded + afterwards with `rebuild_from_json`. The JSON data + includes the state of the agent's context components, act component, memory, + agent name and the initial config. The clock, model and embedder are not + saved and will have to be provided when the agent is rebuilt. The agent must + be in the `READY` phase to be saved. + + Args: + agent: The agent to save. + + Returns: + A JSON string representing the agent's state. + + Raises: + ValueError: If the agent is not in the READY phase. + """ + + if agent.get_phase() != entity_component.Phase.READY: + raise ValueError('The agent must be in the `READY` phase to be saved.') + + data = { + component_name: agent.get_component(component_name).get_state() + for component_name in agent.get_all_context_components() + } + + data['act_component'] = agent.get_act_component().get_state() + + config = agent.get_config() + if config is not None: + data['agent_config'] = config.to_dict() + + return json.dumps(data) + + +def rebuild_from_json( + json_data: str, + model: language_model.LanguageModel, + clock: game_clock.MultiIntervalClock, + embedder: Callable[[str], np.ndarray], + memory_importance: Callable[[str], float] | None = None, +) -> entity_agent_with_logging.EntityAgentWithLogging: + """Rebuilds an agent from JSON data.""" + + data = json.loads(json_data) + + new_agent_memory = associative_memory.AssociativeMemory( + sentence_embedder=embedder, + importance=memory_importance, + clock=clock.now, + clock_step_size=clock.get_step_size(), + ) + + if 'agent_config' not in data: + raise ValueError('The JSON data does not contain the agent config.') + agent_config = formative_memories.AgentConfig.from_dict( + data.pop('agent_config') + ) + + agent = build_agent( + config=agent_config, + model=model, + memory=new_agent_memory, + clock=clock, + ) + + for component_name in agent.get_all_context_components(): + agent.get_component(component_name).set_state(data.pop(component_name)) + + agent.get_act_component().set_state(data.pop('act_component')) + + assert not data, f'Unused data {sorted(data)}' + return agent diff --git a/concordia/factory/agent/factories_test.py b/concordia/factory/agent/factories_test.py index bb2d9e7d..ba1d06b3 100644 --- a/concordia/factory/agent/factories_test.py +++ b/concordia/factory/agent/factories_test.py @@ -25,8 +25,10 @@ from concordia.associative_memory import formative_memories from concordia.clocks import game_clock from concordia.factory.agent import alternative_basic_agent +from concordia.factory.agent import alternative_rational_agent from concordia.factory.agent import basic_agent from concordia.factory.agent import basic_agent_without_plan +from concordia.factory.agent import observe_and_summarize_agent from concordia.factory.agent import observe_recall_prompt_agent from concordia.factory.agent import paranoid_agent from concordia.factory.agent import parochial_universalization_agent @@ -48,9 +50,11 @@ AGENT_FACTORIES = { 'alternative_basic_agent': alternative_basic_agent, + 'alternative_rational_agent': alternative_rational_agent, 'basic_agent': basic_agent, 'basic_agent_without_plan': basic_agent_without_plan, 'observe_recall_prompt_agent': observe_recall_prompt_agent, + 'observe_and_summarize_agent': observe_and_summarize_agent, 'paranoid_agent': paranoid_agent, 'parochial_universalization_agent': parochial_universalization_agent, 'rational_agent': rational_agent, @@ -71,6 +75,11 @@ class AgentFactoriesTest(parameterized.TestCase): agent_name='alternative_basic_agent', main_role=True ), + dict( + testcase_name='alternative_rational_agent', + agent_name='alternative_rational_agent', + main_role=True + ), dict( testcase_name='basic_agent', agent_name='basic_agent', @@ -81,6 +90,11 @@ class AgentFactoriesTest(parameterized.TestCase): agent_name='basic_agent_without_plan', main_role=True, ), + dict( + testcase_name='observe_and_summarize_agent', + agent_name='observe_and_summarize_agent', + main_role=True, + ), dict( testcase_name='observe_recall_prompt_agent', agent_name='observe_recall_prompt_agent', diff --git a/concordia/factory/agent/observe_and_summarize_agent.py b/concordia/factory/agent/observe_and_summarize_agent.py new file mode 100644 index 00000000..60c4d587 --- /dev/null +++ b/concordia/factory/agent/observe_and_summarize_agent.py @@ -0,0 +1,229 @@ +# Copyright 2024 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""An Agent Factory.""" + +from collections.abc import Callable +import datetime +import json + +from concordia.agents import entity_agent_with_logging +from concordia.associative_memory import associative_memory +from concordia.associative_memory import formative_memories +from concordia.clocks import game_clock +from concordia.components import agent as agent_components +from concordia.contrib.components.agent import observations_since_last_update +from concordia.contrib.components.agent import situation_representation_via_narrative +from concordia.language_model import language_model +from concordia.memory_bank import legacy_associative_memory +from concordia.typing import entity_component +from concordia.utils import measurements as measurements_lib +import numpy as np + + +def _get_class_name(object_: object) -> str: + return object_.__class__.__name__ + + +def build_agent( + *, + config: formative_memories.AgentConfig, + model: language_model.LanguageModel, + memory: associative_memory.AssociativeMemory, + clock: game_clock.MultiIntervalClock, + update_time_interval: datetime.timedelta | None = None, +) -> entity_agent_with_logging.EntityAgentWithLogging: + """Build an agent. + + Args: + config: The agent config to use. + model: The language model to use. + memory: The agent's memory object. + clock: The clock to use. + update_time_interval: Agent calls update every time this interval passes. + + Returns: + An agent. + """ + del update_time_interval + if not config.extras.get('main_character', False): + raise ValueError('This function is meant for a main character ' + 'but it was called on a supporting character.') + + agent_name = config.name + + raw_memory = legacy_associative_memory.AssociativeMemoryBank(memory) + + measurements = measurements_lib.Measurements() + instructions = agent_components.instructions.Instructions( + agent_name=agent_name, + logging_channel=measurements.get_channel('Instructions').on_next, + ) + + time_display = agent_components.report_function.ReportFunction( + function=clock.current_time_interval_str, + pre_act_key='\nCurrent time', + logging_channel=measurements.get_channel('TimeDisplay').on_next, + ) + + observation_label = '\nObservation' + observation = agent_components.observation.Observation( + clock_now=clock.now, + timeframe=clock.get_step_size(), + pre_act_key=observation_label, + logging_channel=measurements.get_channel('Observation').on_next, + ) + observations_since_last_update.ObservationsSinceLastUpdate( + model=model, + clock_now=clock.now, + pre_act_key=observation_label, + logging_channel=measurements.get_channel('Observation').on_next, + ) + + situation_representation_label = ( + f'\nQuestion: What situation is {agent_name} in right now?\nAnswer') + situation_representation = ( + situation_representation_via_narrative.SituationRepresentation( + model=model, + clock_now=clock.now, + pre_act_key=situation_representation_label, + logging_channel=measurements.get_channel( + 'SituationRepresentation' + ).on_next, + ) + ) + + if config.goal: + goal_label = '\nOverarching goal' + overarching_goal = agent_components.constant.Constant( + state=config.goal, + pre_act_key=goal_label, + logging_channel=measurements.get_channel(goal_label).on_next) + else: + goal_label = None + overarching_goal = None + + entity_components = ( + # Components that provide pre_act context. + instructions, + time_display, + observation, + situation_representation, + ) + components_of_agent = {_get_class_name(component): component + for component in entity_components} + components_of_agent[ + agent_components.memory_component.DEFAULT_MEMORY_COMPONENT_NAME] = ( + agent_components.memory_component.MemoryComponent(raw_memory)) + + component_order = list(components_of_agent.keys()) + if overarching_goal is not None: + components_of_agent[goal_label] = overarching_goal + # Place goal after the instructions. + component_order.insert(1, goal_label) + + act_component = agent_components.concat_act_component.ConcatActComponent( + model=model, + clock=clock, + component_order=component_order, + logging_channel=measurements.get_channel('ActComponent').on_next, + ) + + agent = entity_agent_with_logging.EntityAgentWithLogging( + agent_name=agent_name, + act_component=act_component, + context_components=components_of_agent, + component_logging=measurements, + ) + + return agent + + +def save_to_json( + agent: entity_agent_with_logging.EntityAgentWithLogging, +) -> str: + """Saves an agent to JSON data. + + This function saves the agent's state to a JSON string, which can be loaded + afterwards with `rebuild_from_json`. The JSON data + includes the state of the agent's context components, act component, memory, + agent name and the initial config. The clock, model and embedder are not + saved and will have to be provided when the agent is rebuilt. The agent must + be in the `READY` phase to be saved. + + Args: + agent: The agent to save. + + Returns: + A JSON string representing the agent's state. + + Raises: + ValueError: If the agent is not in the READY phase. + """ + + if agent.get_phase() != entity_component.Phase.READY: + raise ValueError('The agent must be in the `READY` phase to be saved.') + + data = { + component_name: agent.get_component(component_name).get_state() + for component_name in agent.get_all_context_components() + } + + data['act_component'] = agent.get_act_component().get_state() + + config = agent.get_config() + if config is not None: + data['agent_config'] = config.to_dict() + + return json.dumps(data) + + +def rebuild_from_json( + json_data: str, + model: language_model.LanguageModel, + clock: game_clock.MultiIntervalClock, + embedder: Callable[[str], np.ndarray], + memory_importance: Callable[[str], float] | None = None, +) -> entity_agent_with_logging.EntityAgentWithLogging: + """Rebuilds an agent from JSON data.""" + + data = json.loads(json_data) + + new_agent_memory = associative_memory.AssociativeMemory( + sentence_embedder=embedder, + importance=memory_importance, + clock=clock.now, + clock_step_size=clock.get_step_size(), + ) + + if 'agent_config' not in data: + raise ValueError('The JSON data does not contain the agent config.') + agent_config = formative_memories.AgentConfig.from_dict( + data.pop('agent_config') + ) + + agent = build_agent( + config=agent_config, + model=model, + memory=new_agent_memory, + clock=clock, + ) + + for component_name in agent.get_all_context_components(): + agent.get_component(component_name).set_state(data.pop(component_name)) + + agent.get_act_component().set_state(data.pop('act_component')) + + assert not data, f'Unused data {sorted(data)}' + return agent diff --git a/concordia/factory/agent/parochial_universalization_agent.py b/concordia/factory/agent/parochial_universalization_agent.py index 71c2aeab..ca537a6b 100644 --- a/concordia/factory/agent/parochial_universalization_agent.py +++ b/concordia/factory/agent/parochial_universalization_agent.py @@ -14,25 +14,23 @@ """An Agent Factory.""" -from collections.abc import Callable, Mapping, Sequence +from collections.abc import Callable, Mapping import datetime import json import types -from absl import logging as absl_logging from concordia.agents import entity_agent_with_logging from concordia.associative_memory import associative_memory from concordia.associative_memory import formative_memories from concordia.clocks import game_clock from concordia.components import agent as agent_components from concordia.components.agent import action_spec_ignored -from concordia.components.agent import memory_component +from concordia.contrib.components.agent import situation_representation_via_narrative from concordia.document import interactive_document from concordia.language_model import language_model from concordia.memory_bank import legacy_associative_memory from concordia.typing import entity_component from concordia.typing import logging -from concordia.typing import memory as memory_lib from concordia.utils import measurements as measurements_lib import numpy as np @@ -46,43 +44,6 @@ def _get_class_name(object_: object) -> str: return object_.__class__.__name__ -def _get_all_memories( - memory_component_: agent_components.memory_component.MemoryComponent, - add_time: bool = True, - sort_by_time: bool = True, - constant_score: float = 0.0, -) -> Sequence[memory_lib.MemoryResult]: - """Returns all memories in the memory bank. - - Args: - memory_component_: The memory component to retrieve memories from. - add_time: whether to add time - sort_by_time: whether to sort by time - constant_score: assign this score value to each memory - """ - texts = memory_component_.get_all_memories_as_text(add_time=add_time, - sort_by_time=sort_by_time) - return [memory_lib.MemoryResult(text=t, score=constant_score) for t in texts] - - -def _get_earliest_timepoint( - memory_component_: agent_components.memory_component.MemoryComponent, -) -> datetime.datetime: - """Returns all memories in the memory bank. - - Args: - memory_component_: The memory component to retrieve memories from. - """ - memories_data_frame = memory_component_.get_raw_memory() - if not memories_data_frame.empty: - sorted_memories_data_frame = memories_data_frame.sort_values( - 'time', ascending=True) - return sorted_memories_data_frame['time'][0] - else: - absl_logging.warn('No memories found in memory bank.') - return datetime.datetime.now() - - class AvailableOptionsPerception( agent_components.question_of_recent_memories.QuestionOfRecentMemories): """This component answers the question 'what actions are available to me?'.""" @@ -125,131 +86,6 @@ def __init__(self, **kwargs): ) -class SituationRepresentation(action_spec_ignored.ActionSpecIgnored): - """Consider ``what kind of situation am I in now?``.""" - - def __init__( - self, - model: language_model.LanguageModel, - clock_now: Callable[[], datetime.datetime], - memory_component_name: str = ( - memory_component.DEFAULT_MEMORY_COMPONENT_NAME - ), - pre_act_key: str = 'The current situation', - logging_channel: logging.LoggingChannel = logging.NoOpLoggingChannel, - ): - """Initialize a component to consider the current situation. - - Args: - model: The language model to use. - clock_now: Function that returns the current time. - memory_component_name: The name of the memory component from which to - retrieve related memories. - pre_act_key: Prefix to add to the output of the component when called - in `pre_act`. - logging_channel: The channel to log debug information to. - """ - super().__init__(pre_act_key) - self._model = model - self._clock_now = clock_now - self._memory_component_name = memory_component_name - self._logging_channel = logging_channel - - self._previous_time = None - self._situation_thus_far = None - - def _make_pre_act_value(self) -> str: - """Returns a representation of the current situation to pre act.""" - agent_name = self.get_entity().name - current_time = self._clock_now() - memory = self.get_entity().get_component( - self._memory_component_name, - type_=memory_component.MemoryComponent) - - if self._situation_thus_far is None: - self._previous_time = _get_earliest_timepoint(memory) - chain_of_thought = interactive_document.InteractiveDocument(self._model) - chain_of_thought.statement('~~ Creative Writing Assignment ~~') - chain_of_thought.statement(f'Protagonist: {agent_name}') - mems = '\n'.join([mem.text for mem in _get_all_memories(memory)]) - chain_of_thought.statement(f'Story fragments and world data:\n{mems}') - chain_of_thought.statement(f'Events continue after {current_time}') - self._situation_thus_far = chain_of_thought.open_question( - question=( - 'Narratively summarize the story fragments and world data. Give ' - 'special emphasis to atypical features of the setting such as ' - 'when and where the story takes place as well as any causal ' - 'mechanisms or affordances mentioned in the information ' - 'provided. Highlight the goals, personalities, occupations, ' - 'skills, and affordances of the named characters and ' - 'relationships between them. Use third-person omniscient ' - 'perspective.'), - max_tokens=1000, - terminators=(), - question_label='Exercise') - - interval_scorer = legacy_associative_memory.RetrieveTimeInterval( - time_from=self._previous_time, - time_until=current_time, - add_time=True, - ) - mems = [mem.text for mem in memory.retrieve(scoring_fn=interval_scorer)] - result = '\n'.join(mems) + '\n' - chain_of_thought = interactive_document.InteractiveDocument(self._model) - chain_of_thought.statement(f'Context:\n{self._situation_thus_far}') - chain_of_thought.statement(f'Protagonist: {agent_name}') - chain_of_thought.statement( - f'Thoughts and memories of {agent_name}:\n{result}' - ) - self._situation_thus_far = chain_of_thought.open_question( - question=( - 'What situation does the protagonist find themselves in? ' - 'Make sure to provide enough detail to give the ' - 'reader a comprehensive understanding of the world ' - 'inhabited by the protagonist, their affordances in that ' - 'world, actions they may be able to take, effects their ' - 'actions may produce, and what is currently going on.' - ), - max_tokens=1000, - terminators=(), - question_label='Exercise', - ) - chain_of_thought.statement(f'The current date and time is {current_time}') - - self._logging_channel({ - 'Key': self.get_pre_act_key(), - 'Value': self._situation_thus_far, - 'Chain of thought': chain_of_thought.view().text().splitlines(), - }) - - self._previous_time = current_time - - return self._situation_thus_far - - def get_state(self) -> entity_component.ComponentState: - """Converts the component to JSON data.""" - with self._lock: - if self._previous_time is None: - previous_time = '' - else: - previous_time = self._previous_time.strftime('%Y-%m-%d %H:%M:%S') - return { - 'previous_time': previous_time, - 'situation_thus_far': self._situation_thus_far, - } - - def set_state(self, state: entity_component.ComponentState) -> None: - """Sets the component state from JSON data.""" - with self._lock: - if state['previous_time']: - previous_time = datetime.datetime.strptime( - state['previous_time'], '%Y-%m-%d %H:%M:%S') - else: - previous_time = None - self._previous_time = previous_time - self._situation_thus_far = state['situation_thus_far'] - - class Universalization(action_spec_ignored.ActionSpecIgnored): """Consider ``what if everyone behaved that way?``.""" @@ -378,7 +214,7 @@ def build_agent( situation_representation_label = ( f'\nQuestion: What situation is {agent_name} in right now?\nAnswer') situation_representation = ( - SituationRepresentation( + situation_representation_via_narrative.SituationRepresentation( model=model, clock_now=clock.now, pre_act_key=situation_representation_label,