diff --git a/camel/benchmarks/__init__.py b/camel/benchmarks/__init__.py index d4e5816007..0b0ac28b65 100644 --- a/camel/benchmarks/__init__.py +++ b/camel/benchmarks/__init__.py @@ -16,11 +16,18 @@ from .apibench import APIBenchBenchmark from .base import BaseBenchmark from .gaia import DefaultGAIARetriever, GAIABenchmark +from .math_benchmarks.gsm8k import GSM8KBenchmark +from .math_benchmarks.math_base import MathBenchmark, Mode +from .math_benchmarks.math_bench import MATHBenchmark from .nexus import NexusBenchmark from .ragbench import RAGBenchBenchmark __all__ = [ "BaseBenchmark", + "MathBenchmark", + "Mode", + "MATHBenchmark", + "GSM8KBenchmark", "GAIABenchmark", "DefaultGAIARetriever", "NexusBenchmark", diff --git a/camel/benchmarks/math_benchmarks/gsm8k.py b/camel/benchmarks/math_benchmarks/gsm8k.py new file mode 100644 index 0000000000..554029ed0b --- /dev/null +++ b/camel/benchmarks/math_benchmarks/gsm8k.py @@ -0,0 +1,173 @@ +# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= + +from typing import Any, Dict, List + +from camel.agents import ChatAgent +from camel.benchmarks.math_benchmarks.math_base import MathBenchmark, Mode +from camel.logger import get_logger + +logger = get_logger(__name__) + + +class GSM8KBenchmark(MathBenchmark): + r""" + Benchmark for evaluating ChatAgents on the GSM8K dataset, a collection of + grade-school-level math problems sourced from Hugging Face Hub. + + Attributes: + DATASET_NAME (str): The name of the dataset. + DATASET_REPO (str): The dataset's repository on Hugging Face. + QUESTION_COLUMN (str): The column containing math problems. + ANSWER_COLUMN (str): The column containing solutions. + """ + + import pandas as pd + from datasets import load_dataset + + DATASET_NAME = "gsm8k" + DATASET_REPO = "openai/gsm8k" + QUESTION_COLUMN = "question" + ANSWER_COLUMN = "answer" + + def __init__(self, data_dir: str, save_to: str, processes: int = 1): + r""" + Initializes the GSM8K Benchmark instance. + + Args: + data_dir (str): Directory for storing the dataset. + save_to (str): Path for saving benchmark results. + processes (int, optional): Number of parallel processes. + Defaults to 1. + """ + super().__init__( + name="GSM8K", + data_dir=data_dir, + save_to=save_to, + processes=processes, + ) + self._data: Dict[str, List[Dict[str, Any]]] = {} + + def download(self) -> "GSM8KBenchmark": + r""" + Ensures the GSM8K dataset is available locally. Uses Hugging Face + Datasets for automatic caching and management. + + Returns: + GSM8KBenchmark: The benchmark instance after downloading. + """ + logger.info("Ensuring GSM8K dataset is downloaded...") + _ = GSM8KBenchmark.load_dataset( + self.DATASET_REPO, 'main', cache_dir=str(self.data_dir) + ) + + logger.info("GSM8K dataset is ready.") + return self + + def load(self, force_download: bool = False) -> "GSM8KBenchmark": + r""" + Loads the GSM8K dataset into memory, optionally forcing a re-download. + + Args: + force_download (bool, optional): Whether to force re-downloading + the dataset. Defaults to False. + + Returns: + GSM8KBenchmark: The benchmark instance after loading. + """ + logger.info("Loading GSM8K dataset...") + + dataset = GSM8KBenchmark.load_dataset( + self.DATASET_REPO, + 'main', + cache_dir=str(self.data_dir), + download_mode="force_redownload" + if force_download + else "reuse_dataset_if_exists", + ) + + self._data = { + "train": dataset["train"].to_pandas().to_dict(orient="records"), + "test": dataset["test"].to_pandas().to_dict(orient="records"), + } + return self + + @property + def valid(self) -> List[Dict[str, Any]]: + r""" + Returns an empty list since GSM8K does not have a validation set. + + Returns: + List[Dict[str, Any]]: An empty list. + """ + return [] + + def _prepare_dataset(self, dataset: List[Dict[str, Any]]) -> pd.DataFrame: + r""" + Prepares the dataset by extracting numeric solutions from the answer + field. + + Args: + dataset (List[Dict[str, Any]]): The dataset to process. + + Returns: + pd.DataFrame: The processed dataset with extracted solutions. + """ + df = self.pd.DataFrame(dataset) + df["solution"] = df["answer"].str.extract(r"####\s*(-?\d+)")[0] + return df + + def _generate_solutions( + self, agent: ChatAgent, dataset: pd.DataFrame, mode: Mode + ) -> pd.DataFrame: + r""" + Efficiently generates responses for each math problem using the + ChatAgent, ensuring the agent resets between questions without + unnecessary instantiations. + + Args: + agent (ChatAgent): The agent responsible for generating answers. + dataset (pd.DataFrame): The dataset containing math problems. + mode (Mode): The evaluation mode for generating multiple responses. + + Returns: + pd.DataFrame: The dataset with generated answers. + """ + + def generate_answer(question: str) -> List[str]: + r""" + Generate `k` responses while resetting the agent after each + question. + """ + agent.reset() # Ensuring statelessness + return [ + agent.step(question).msgs[0].content for _ in range(mode.k) + ] + + dataset["answers"] = dataset["question"].apply(generate_answer) + return dataset + + def _preprocess_answers(self, raw_answers: pd.Series) -> pd.Series: + r""" + Extracts numeric answers from generated responses using a regular + expression. + + Args: + raw_answers (pd.Series): The series containing raw model-generated + responses. + + Returns: + pd.Series: Extracted numeric answers. + """ + return raw_answers.str.extract(r"####\s*(-?\d+)")[0] diff --git a/camel/benchmarks/math_benchmarks/math_base.py b/camel/benchmarks/math_benchmarks/math_base.py new file mode 100644 index 0000000000..9aeb7dece7 --- /dev/null +++ b/camel/benchmarks/math_benchmarks/math_base.py @@ -0,0 +1,253 @@ +# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= + +from abc import abstractmethod +from pathlib import Path +from typing import Any, ClassVar, Dict, List, Literal, Optional, Union + +from math_verify import parse, verify + +from camel.agents import ChatAgent +from camel.benchmarks import BaseBenchmark +from camel.logger import get_logger + +logger = get_logger(__name__) + + +class Mode: + r""" + Defines different evaluation modes for benchmarking. + + Attributes: + VALID_MODES (set): Supported evaluation modes. + mode (Literal["pass@k", "majority voting"]): Selected evaluation mode. + k (Optional[int]): Parameter defining attempts or votes required. + """ + + VALID_MODES: ClassVar[set[str]] = {"pass@k", "majority voting"} + + def __init__( + self, + mode: Literal["pass@k", "majority voting"], + k: Optional[int] = None, + ): + r""" + Initializes a Mode object. + + Args: + mode (Literal["pass@k", "majority voting"]): The evaluation mode. + k (Optional[int]): Number of attempts (for "pass@k") or votes + (for "majority voting"). + + Raises: + ValueError: If `k` is not valid for the selected mode. + """ + self.mode = mode + + if mode == "pass@k": + if k is None or k < 1: + raise ValueError("`k` must be at least 1 for 'pass@k'.") + self.k = k + + elif mode == "majority voting": + if k is None or k < 2: + raise ValueError( + "`k` must be at least 2 for 'majority voting'." + ) + self.k = k + + else: + raise ValueError( + f"Invalid mode '{mode}'. Supported modes: {self.VALID_MODES}" + ) + + def __repr__(self) -> str: + r"""Returns a string representation of the Mode instance.""" + return f"Mode(mode={self.mode}, k={self.k})" + + +class MathBenchmark(BaseBenchmark): + import numpy as np + import pandas as pd + + r""" + Benchmark class for evaluating mathematical problem-solving capabilities. + + Inherits from: + BaseBenchmark + """ + + def __init__( + self, name: str, data_dir: str, save_to: str, processes: int = 1 + ): + r""" + Initializes the MathBenchmark class. + + Args: + name (str): Name of the benchmark. + data_dir (str): Directory containing the dataset. + save_to (str): Path to save the benchmark results. + processes (int, optional): Number of parallel processes. + Defaults to 1. + """ + super().__init__(name, data_dir, save_to, processes) + + def run( + self, + agent: ChatAgent, + on: Literal["train", "valid", "test"], + randomize: bool = False, + subset: Optional[int] = None, + mode: Optional[Mode] = None, + *args, + **kwargs, + ) -> "MathBenchmark": + r""" + Runs the benchmark, evaluates answers, and saves results as JSON. + + Args: + agent (ChatAgent): The agent used to generate answers. + on (Literal["train", "valid", "test"]): The dataset split to use. + randomize (bool, optional): Whether to randomize dataset order. + Defaults to False. + subset (Optional[int], optional): Number of problems to process. + Defaults to None (all). + mode (Mode, optional): The evaluation mode. Defaults to + Mode("pass@k", 1). + + Returns: + MathBenchmark: The benchmark instance with results. + + Raises: + ValueError: If an invalid dataset split is specified. + TypeError: If the results are not in the expected format. + """ + + if mode is None: + mode = Mode("pass@k", 1) + + logger.info( + f"Running {mode.mode} evaluation on {on} set with k={mode.k}" + ) + + if on not in ["train", "test", "valid"]: + raise ValueError( + f"Invalid dataset split '{on}'. Use 'train', 'valid' (empty), " + f"or 'test'." + ) + + if not self._data: + self.load() + + dataset = self._prepare_dataset(self._data[on]) + + # TODO: Fix Seed for reproducibility + if randomize: + import random + + random.shuffle(dataset) + + if subset: + dataset = dataset[:subset] + + # Generate solutions for each question in the dataset + results = self._generate_solutions( + agent, dataset, mode, *args, **kwargs + ) + + # Ensure the results are in the expected format + if isinstance(results, dict): + results = self.pd.DataFrame(results) + + if not isinstance(results, self.pd.DataFrame): + raise TypeError( + "Expected results as a pandas DataFrame or dictionary." + ) + + if ( + "answers" not in results.columns + or "solution" not in results.columns + ): + raise ValueError( + "Results must contain 'answers' and 'solution' columns." + ) + + # Process answers based on mode + results["correct"] = results.apply( + lambda row: self._evaluate(row, mode), axis=1 + ) + + # Save results as JSON + save_dir = Path(self.save_to) + save_dir.mkdir(parents=True, exist_ok=True) + + save_path = Path(self.save_to) / f"{self.name}_results.json" + results.to_json(save_path, orient="records", indent=2) + + logger.info(f"Results saved to {save_path}") + + self._results = results.to_dict(orient="records") + + return self + + def _evaluate(self, row: pd.Series, mode: Mode) -> bool: + r""" + Evaluate model predictions based on the chosen evaluation mode. + """ + answers = row["answers"] + solution = row["solution"] + + if not isinstance(answers, list): + raise ValueError( + f"Expected 'answers' to be a list, but got {type(answers)}" + ) + + if mode.mode == "pass@k": + responses = row["answers"][: mode.k] + return any( + verify(parse(response), parse(solution)) + for response in responses + ) + + elif mode.mode == "majority voting": + most_common = self.pd.Series(answers).mode().iloc[0] + return verify(parse(most_common) == parse(solution)) + + return False # Default case + + @abstractmethod + def _prepare_dataset(self, dataset: List[Dict[str, Any]]) -> pd.DataFrame: + r""" + Method to further prepare the dataset, like renaming or normalizing + columns. + """ + pass + + @abstractmethod + def _generate_solutions( + self, + agent: ChatAgent, + dataset: pd.DataFrame, + mode: Mode, + *args, + **kwargs, + ) -> Union[pd.DataFrame, Dict[str, List[Any]]]: + r""" + Method to be implemented by subclasses. + + This method should return a pandas DataFrame or a dictionary with: + - "answers": List of generated answers for each problem. + - "solution": The correct solution. + """ + pass diff --git a/camel/benchmarks/math_benchmarks/math_bench.py b/camel/benchmarks/math_benchmarks/math_bench.py new file mode 100644 index 0000000000..e13629170b --- /dev/null +++ b/camel/benchmarks/math_benchmarks/math_bench.py @@ -0,0 +1,225 @@ +# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= + +from typing import Any, ClassVar, Dict, List + +from camel.agents import ChatAgent +from camel.benchmarks import MathBenchmark, Mode +from camel.logger import get_logger + +logger = get_logger(__name__) + + +class MATHBenchmark(MathBenchmark): + r""" + Benchmark for evaluating ChatAgents on the MATH dataset, a collection of + high school-level and competition-style math problems sourced from the + Hugging Face Hub. + + Attributes: + DATASET_NAME (str): The name of the dataset. + DATASET_REPO (str): The dataset's repository on Hugging Face. + DATASET_CONFIGS (List[str]): + The different subcategories in the dataset. + """ + + import pandas as pd + from datasets import load_dataset + + DATASET_NAME: ClassVar[str] = "math" + DATASET_REPO: ClassVar[str] = "EleutherAI/hendrycks_math" + DATASET_CONFIGS: ClassVar[list[str]] = [ + "algebra", + "counting_and_probability", + "geometry", + "intermediate_algebra", + "number_theory", + "prealgebra", + "precalculus", + ] + + def __init__(self, data_dir: str, save_to: str, processes: int = 1): + r""" + Initializes the MATH Benchmark instance. + + Args: + data_dir (str): Directory for storing the dataset. + save_to (str): Path for saving benchmark results. + processes (int, optional): Number of parallel processes. + Defaults to 1. + """ + super().__init__( + name="MATH", + data_dir=data_dir, + save_to=save_to, + processes=processes, + ) + self._data: Dict[str, List[Dict[str, Any]]] = {} + + def download(self) -> "MATHBenchmark": + r""" + Ensures the MATH dataset is available locally. Uses Hugging Face + Datasets for automatic caching and management. + + Returns: + MATHBenchmark: The benchmark instance after downloading. + """ + logger.info("Ensuring MATH dataset is downloaded...") + for config in self.DATASET_CONFIGS: + _ = MATHBenchmark.load_dataset( + self.DATASET_REPO, config, cache_dir=str(self.data_dir) + ) + logger.info("MATH dataset is ready.") + return self + + def load(self, force_download: bool = False) -> "MATHBenchmark": + r""" + Loads the MATH dataset into memory, optionally forcing a re-download. + + Args: + force_download (bool, optional): Whether to force re-downloading + the dataset. Defaults to False. + + Returns: + MATHBenchmark: The benchmark instance after loading. + """ + logger.info("Loading MATH dataset...") + + self._data = {"train": [], "test": []} + + for config in self.DATASET_CONFIGS: + dataset = MATHBenchmark.load_dataset( + self.DATASET_REPO, + config, + cache_dir=str(self.data_dir), + download_mode="force_redownload" + if force_download + else "reuse_dataset_if_exists", + ) + + # Convert to pandas DataFrame and add a `config` column + train_df = dataset["train"].to_pandas() + train_df["config"] = config + self._data["train"].extend(train_df.to_dict(orient="records")) + + test_df = dataset["test"].to_pandas() + test_df["config"] = config + self._data["test"].extend(test_df.to_dict(orient="records")) + + return self + + @property + def valid(self) -> List[Dict[str, Any]]: + r""" + Returns an empty list since the MATH dataset does not have a validation + set. + + Returns: + List[Dict[str, Any]]: An empty list. + """ + return [] + + def _prepare_dataset(self, dataset: List[Dict[str, Any]]) -> pd.DataFrame: + r""" + Prepares the dataset by extracting solutions from provided answers. + + - Renames the "problem" column to "questions" for consistency. + - Extracts the final answer from solutions wrapped in `\boxed{}`. + + Args: + dataset (List[Dict[str, Any]]): The dataset to process. + + Returns: + pd.DataFrame: The processed dataset with extracted solutions. + """ + df = self.pd.DataFrame(dataset) + df.rename(columns={"problem": "questions"}, inplace=True) + + def extract_boxed(text: str) -> str: + r""" + Extracts the content inside the first `\boxed{}`. + + Args: + text (str): The solution text containing `\boxed{}`. + + Returns: + str: The extracted final answer. + + Raises: + ValueError: If the answer cannot be extracted properly. + """ + start_seq = r"\boxed{" + stack = [] # Stack to track `{}` nesting + content: List[str] = [] + inside_boxed = False + i = 0 + + while i < len(text): + if ( + text[i : i + len(start_seq)] == start_seq + and not inside_boxed + ): + inside_boxed = True + stack.append("{") + i += len(start_seq) # Skip `\boxed{` + continue + + if inside_boxed: + if text[i] == "{": + stack.append("{") + elif text[i] == "}": + stack.pop() + # If stack is empty, we've closed `\boxed{}` correctly + if not stack: + return "".join(content) + + content.append(text[i]) + + i += 1 + + raise ValueError(f"Couldn't extract value from solution: {text}") + + df["solutions"] = df["solution"].apply(extract_boxed) + + return df + + def _generate_solutions( + self, agent: ChatAgent, dataset: pd.DataFrame, mode: Mode + ) -> pd.DataFrame: + r""" + Efficiently generates responses for each math problem using the + ChatAgent, ensuring the agent resets between questions without + unnecessary instantiations. + + Args: + agent (ChatAgent): The agent responsible for generating answers. + dataset (pd.DataFrame): The dataset containing math problems. + mode (Mode): The evaluation mode for generating multiple responses. + + Returns: + pd.DataFrame: The dataset with generated answers. + """ + + def generate_answer(question: str) -> List[str]: + r""" + Generate `k` responses while resetting the agent after each + question. + """ + agent.reset() # Ensuring statelessness + return [ + agent.step(question).msgs[0].content for _ in range(mode.k) + ] + + dataset["answers"] = dataset["questions"].apply(generate_answer) + return dataset diff --git a/examples/benchmarks/gsm8k.py b/examples/benchmarks/gsm8k.py new file mode 100644 index 0000000000..e3622b5b0a --- /dev/null +++ b/examples/benchmarks/gsm8k.py @@ -0,0 +1,38 @@ +# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= + +from camel.agents import ChatAgent +from camel.benchmarks import GSM8KBenchmark + +# Set up the agent to be benchmarked +agent = ChatAgent() + +# Set up the Gradeschool Math Benchmark +benchmark = GSM8KBenchmark(data_dir="GSM8K-Data", save_to="GSM8KResults") +benchmark.download() + +# Run the benchmark to get results +benchmark = benchmark.run(agent, on="test", subset=10) + +total_answers = len(benchmark.results) +correct_answers = sum(row["correct"] for row in benchmark.results) + +print("Total:", total_answers) +print("Correct:", correct_answers) +''' +=============================================================================== +Total: 10 +Correct: 9 +=============================================================================== +''' diff --git a/examples/benchmarks/math_bench.py b/examples/benchmarks/math_bench.py new file mode 100644 index 0000000000..dc58a72007 --- /dev/null +++ b/examples/benchmarks/math_bench.py @@ -0,0 +1,38 @@ +# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= + +from camel.agents import ChatAgent +from camel.benchmarks.math_benchmarks.math_bench import MATHBenchmark + +# Set up the agent to be benchmarked +agent = ChatAgent() + +# Set up the Hendrykson MATH Benchmark +benchmark = MATHBenchmark(data_dir="MATH-Data", save_to="MATHResults") +benchmark.download() + +# Run the benchmark to get results +benchmark = benchmark.run(agent, on="test", subset=10) + +total_answers = len(benchmark.results) +correct_answers = sum(row["correct"] for row in benchmark.results) + +print("Total:", total_answers) +print("Correct:", correct_answers) +''' +=============================================================================== +Total: 10 +Correct: 9 +=============================================================================== +''' diff --git a/pyproject.toml b/pyproject.toml index 4179e19ebf..89c7243ce8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,6 +65,7 @@ torch = [ soundfile = { version = "^0.13", optional = true } sentencepiece = { version = "^0.2", optional = true } opencv-python = { version = "^4", optional = true } +math-verify = { version = "^0.7", optional = true } # Core RAG components sentence-transformers = { version = "^3.0.1", optional = true } @@ -271,6 +272,7 @@ huggingface = [ "soundfile", "sentencepiece", "opencv-python", + "math-verify", ] # Storage solutions @@ -477,6 +479,7 @@ module = [ "huggingface_hub", "huggingface_hub.utils._errors", "huggingface_hub.errors", + "math-verify", "wikipedia", "linkup-sdk", "duckduckgo_search", diff --git a/test/benchmarks/test_gsm8k_benchmark.py b/test/benchmarks/test_gsm8k_benchmark.py new file mode 100644 index 0000000000..12b1605f14 --- /dev/null +++ b/test/benchmarks/test_gsm8k_benchmark.py @@ -0,0 +1,87 @@ +# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= +import tempfile +from pathlib import Path +from unittest.mock import MagicMock, mock_open, patch + +import pandas as pd +import pytest + +from camel.agents import ChatAgent +from camel.benchmarks import GSM8KBenchmark, Mode + +SAMPLE_DATA = [ + {"question": "What is 5 + 7?", "answer": "#### 12"}, + {"question": "Find the product of 8 and 3.", "answer": "#### 24"}, +] + + +@pytest.fixture +def benchmark(): + r"""Fixture to initialize GSM8KBenchmark + with a fully mocked file system.""" + with ( + patch("pathlib.Path.mkdir"), + patch("pathlib.Path.is_dir", return_value=True), + patch("pathlib.Path.exists", return_value=True), + ): + temp_dir = tempfile.mkdtemp() + return GSM8KBenchmark(data_dir=Path(temp_dir), save_to=Path(temp_dir)) + + +@patch("builtins.open", new_callable=mock_open) +@patch("pathlib.Path.is_dir", return_value=True) +@patch("pathlib.Path.exists", return_value=True) +def test_run(mock_exists, mock_is_dir, mock_file, benchmark): + r"""Test that GSM8KBenchmark runs correctly and writes expected results.""" + benchmark._data = {"test": SAMPLE_DATA} + mock_agent = MagicMock(spec=ChatAgent) + mock_agent.step.return_value.msgs = [MagicMock(content="#### 12")] + results = benchmark.run( + agent=mock_agent, on="test", mode=Mode("pass@k", 1) + ) + assert "correct" in results._results[0] + mock_file().write.assert_called() + + +def test_prepare_dataset(benchmark): + r"""Test that _prepare_dataset extracts solutions correctly.""" + df = benchmark._prepare_dataset(SAMPLE_DATA) + assert "solution" in df.columns + assert list(df["solution"]) == ["12", "24"] + + +def test_preprocess_answers(benchmark): + r"""Test that _preprocess_answers correctly extracts numeric values + from answers.""" + raw_answers = pd.Series( + ["#### 12", "#### 24", "Mock test with text and numbers 13 #### -7"] + ) + processed = benchmark._preprocess_answers(raw_answers) + assert list(processed) == ["12", "24", "-7"] + + +def test_download(): + r"""Test that GSM8KBenchmark downloads + the dataset to the data/ directory.""" + + data_dir = Path("data/") + save_to = Path("data/") + + benchmark = GSM8KBenchmark(data_dir=str(data_dir), save_to=str(save_to)) + benchmark.download() + + assert data_dir.exists(), "Data directory was not created!" + dataset_files = list(data_dir.glob("**/*")) + assert len(dataset_files) > 0, "Dataset files were not downloaded!" diff --git a/test/benchmarks/test_math_bench_benchmark.py b/test/benchmarks/test_math_bench_benchmark.py new file mode 100644 index 0000000000..67c019d566 --- /dev/null +++ b/test/benchmarks/test_math_bench_benchmark.py @@ -0,0 +1,97 @@ +# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= +import tempfile +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +from camel.agents import ChatAgent +from camel.benchmarks import MATHBenchmark, Mode + +SAMPLE_DATA = [ + { + "problem": "Solve for x: x^2 - 4 = 0", + "solution": r"Let's solve: $x^2 - 4 = 0 \boxed{2}$", + }, + { + "problem": "What is the sum of the first 10 positive integers?", + "solution": r"Using the formula: $\boxed{55}$", + }, +] + + +@pytest.fixture +def benchmark(): + r"""Fixture to initialize MATHBenchmark with a fully mocked file system.""" + with ( + patch("pathlib.Path.mkdir"), + patch("pathlib.Path.is_dir", return_value=True), + patch("pathlib.Path.exists", return_value=True), + ): + temp_dir = tempfile.mkdtemp() + return MATHBenchmark(data_dir=Path(temp_dir), save_to=Path(temp_dir)) + + +def test_prepare_dataset(benchmark): + r"""Test that _prepare_dataset extracts solutions correctly.""" + df = benchmark._prepare_dataset(SAMPLE_DATA) + assert "solutions" in df.columns + assert list(df["solutions"]) == ["2", "55"] + + +@patch("builtins.open") +@patch("pathlib.Path.is_dir", return_value=True) +@patch("pathlib.Path.exists", return_value=True) +def test_run(mock_exists, mock_is_dir, mock_file, benchmark): + r"""Test that MATHBenchmark runs correctly and writes expected results.""" + benchmark._data = {"test": SAMPLE_DATA} + mock_agent = MagicMock(spec=ChatAgent) + mock_agent.step.return_value.msgs = [MagicMock(content="\boxed{2}")] + + results = benchmark.run( + agent=mock_agent, on="test", mode=Mode("pass@k", 1) + ) + assert "correct" in results._results[0] + mock_file().write.assert_called() + + +def test_generate_solutions(benchmark): + r"""Test that _generate_solutions properly calls ChatAgent + and formats responses.""" + df = benchmark._prepare_dataset(SAMPLE_DATA) + mock_agent = MagicMock(spec=ChatAgent) + mock_agent.step.return_value.msgs = [MagicMock(content="\boxed{2}")] + + result_df = benchmark._generate_solutions( + mock_agent, df, Mode("pass@k", 1) + ) + assert "answers" in result_df.columns + assert result_df["answers"].apply(lambda x: x[0] == "\boxed{2}").all() + + +def test_download_math(): + """Test that MATHBenchmark downloads the dataset to the data/ directory.""" + + data_dir = Path("data/") + save_to = Path("data/") + + benchmark = MATHBenchmark(data_dir=str(data_dir), save_to=str(save_to)) + benchmark.download() + + assert data_dir.exists(), "Data directory was not created!" + dataset_files = list(data_dir.glob("**/*")) + assert len(dataset_files) > 0, "Dataset files were not downloaded!" + + print("MATH dataset downloaded successfully.")