Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wfh/json schema evaluation #12389

Merged
merged 18 commits into from
Oct 27, 2023
Merged
95 changes: 93 additions & 2 deletions docs/docs/guides/evaluation/string/json.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 8,
"id": "7a8f3ec5-1cde-4b0e-80cd-ac0ac290d375",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -261,11 +261,102 @@
"print(result)"
]
},
{
"cell_type": "markdown",
"id": "6b15d18e-9b97-434f-905c-70acd4c35aea",
"metadata": {},
"source": [
"## JsonSchemaEvaluator\n",
"\n",
"The `JsonSchemaEvaluator` validates a JSON prediction against a provided JSON schema. If the prediction conforms to the schema, it returns a score of True (indicating no errors). Otherwise, it returns a score of 0 (indicating an error).\n",
"\n",
"### Overview:\n",
"- **Requires Input?**: Yes\n",
"- **Requires Reference?**: Yes (A JSON schema)\n",
"- **Score**: True (No errors) or False (Error occurred)"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 12,
"id": "85afcf33-d2f4-406e-9d8f-15dc0a4772f2",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'score': True}\n"
]
}
],
"source": [
"from langchain.evaluation import JsonSchemaEvaluator\n",
"\n",
"evaluator = JsonSchemaEvaluator()\n",
"# Equivalently\n",
"# evaluator = load_evaluator(\"json_schema_validation\")\n",
"\n",
"result = evaluator.evaluate_strings(\n",
" prediction='{\"name\": \"John\", \"age\": 30}',\n",
" reference={\n",
" \"type\": \"object\",\n",
" \"properties\": {\"name\": {\"type\": \"string\"}, \"age\": {\"type\": \"integer\"}},\n",
" },\n",
")\n",
"print(result)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "bb5b89f6-0c87-4335-9091-55fd67a0565f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'score': True}\n"
]
}
],
"source": [
"result = evaluator.evaluate_strings(\n",
" prediction='{\"name\": \"John\", \"age\": 30}',\n",
" reference='{\"type\": \"object\", \"properties\": {\"name\": {\"type\": \"string\"}, \"age\": {\"type\": \"integer\"}}}',\n",
")\n",
"print(result)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "ff914d24-36bc-482a-a9ba-259cd0dd2a52",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'score': False, 'reasoning': \"<ValidationError: '30 is less than the minimum of 66'>\"}\n"
]
}
],
"source": [
"result = evaluator.evaluate_strings(\n",
" prediction='{\"name\": \"John\", \"age\": 30}',\n",
" reference='{\"type\": \"object\", \"properties\": {\"name\": {\"type\": \"string\"},'\n",
" '\"age\": {\"type\": \"integer\", \"minimum\": 66}}}',\n",
")\n",
"print(result)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b073f12d-4603-481c-8081-fab1af6bfcfe",
"metadata": {},
"outputs": [],
"source": []
}
Expand Down
2 changes: 2 additions & 0 deletions libs/langchain/langchain/evaluation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
JsonValidityEvaluator,
)
from langchain.evaluation.parsing.json_distance import JsonEditDistanceEvaluator
from langchain.evaluation.parsing.json_schema import JsonSchemaEvaluator
from langchain.evaluation.qa import ContextQAEvalChain, CotQAEvalChain, QAEvalChain
from langchain.evaluation.regex_match.base import RegexMatchStringEvaluator
from langchain.evaluation.schema import (
Expand Down Expand Up @@ -122,4 +123,5 @@
"JsonValidityEvaluator",
"JsonEqualityEvaluator",
"JsonEditDistanceEvaluator",
"JsonSchemaEvaluator",
]
2 changes: 2 additions & 0 deletions libs/langchain/langchain/evaluation/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
JsonValidityEvaluator,
)
from langchain.evaluation.parsing.json_distance import JsonEditDistanceEvaluator
from langchain.evaluation.parsing.json_schema import JsonSchemaEvaluator
from langchain.evaluation.qa import ContextQAEvalChain, CotQAEvalChain, QAEvalChain
from langchain.evaluation.regex_match.base import RegexMatchStringEvaluator
from langchain.evaluation.schema import EvaluatorType, LLMEvalChain, StringEvaluator
Expand Down Expand Up @@ -88,6 +89,7 @@ def load_dataset(uri: str) -> List[Dict]:
EvaluatorType.JSON_VALIDITY: JsonValidityEvaluator,
EvaluatorType.JSON_EQUALITY: JsonEqualityEvaluator,
EvaluatorType.JSON_EDIT_DISTANCE: JsonEditDistanceEvaluator,
EvaluatorType.JSON_SCHEMA_VALIDATION: JsonSchemaEvaluator,
EvaluatorType.REGEX_MATCH: RegexMatchStringEvaluator,
EvaluatorType.EXACT_MATCH: ExactMatchStringEvaluator,
}
Expand Down
4 changes: 2 additions & 2 deletions libs/langchain/langchain/evaluation/parsing/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def _evaluate_strings(
prediction: str,
input: Optional[str] = None,
reference: Optional[str] = None,
**kwargs: Any
**kwargs: Any,
) -> dict:
"""Evaluate the prediction string.
Expand Down Expand Up @@ -134,7 +134,7 @@ def _evaluate_strings(
prediction: str,
input: Optional[str] = None,
reference: Optional[str] = None,
**kwargs: Any
**kwargs: Any,
) -> dict:
"""Evaluate the prediction string.
Expand Down
8 changes: 5 additions & 3 deletions libs/langchain/langchain/evaluation/parsing/json_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(
self,
string_distance: Optional[Callable[[str, str], float]] = None,
canonicalize: Optional[Callable[[Any], Any]] = None,
**kwargs: Any
**kwargs: Any,
) -> None:
super().__init__()
if string_distance is not None:
Expand All @@ -58,7 +58,9 @@ def __init__(
self._canonicalize = canonicalize
else:
self._canonicalize = lambda x: json.dumps(
x, separators=(",", ":"), sort_keys=True # eliminate whitespace
x,
separators=(",", ":"),
sort_keys=True, # eliminate whitespace
)

@property
Expand All @@ -83,7 +85,7 @@ def _evaluate_strings(
prediction: str,
input: Optional[str] = None,
reference: Optional[str] = None,
**kwargs: Any
**kwargs: Any,
) -> dict:
parsed = self._canonicalize(self._parse_json(prediction))
label = self._canonicalize(self._parse_json(reference))
Expand Down
95 changes: 95 additions & 0 deletions libs/langchain/langchain/evaluation/parsing/json_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from typing import Any, Union

from langchain.evaluation.schema import StringEvaluator
from langchain.output_parsers.json import parse_json_markdown


class JsonSchemaEvaluator(StringEvaluator):
"""An evaluator that validates a JSON prediction against a JSON schema reference.
This evaluator checks if a given JSON prediction conforms to the provided JSON schema.
If the prediction is valid, the score is True (no errors). Otherwise, the score is False (error occurred).
Attributes:
requires_input (bool): Whether the evaluator requires input.
requires_reference (bool): Whether the evaluator requires reference.
evaluation_name (str): The name of the evaluation.
Examples:
evaluator = JsonSchemaEvaluator()
result = evaluator.evaluate_strings(
prediction='{"name": "John", "age": 30}',
reference={
"type": "object",
"properties": {
"name": {"type": "string"},
"age": {"type": "integer"}
}
}
)
assert result["score"] is not None
""" # noqa: E501

def __init__(self, **kwargs: Any) -> None:
"""Initializes the JsonSchemaEvaluator.
Args:
**kwargs: Additional keyword arguments.
Raises:
ImportError: If the jsonschema package is not installed.
"""
super().__init__()
try:
import jsonschema # noqa: F401
except ImportError:
raise ImportError(
"The JsonSchemaEvaluator requires the jsonschema package."
" Please install it with `pip install jsonschema`."
)

@property
def requires_input(self) -> bool:
"""Returns whether the evaluator requires input."""
return False

@property
def requires_reference(self) -> bool:
"""Returns whether the evaluator requires reference."""
return True

@property
def evaluation_name(self) -> str:
"""Returns the name of the evaluation."""
return "json_schema_validation"

def _parse_json(self, node: Any) -> Union[dict, list, None, float, bool, int, str]:
if isinstance(node, str):
return parse_json_markdown(node)
elif hasattr(node, "schema") and callable(getattr(node, "schema")):
# Pydantic model
return getattr(node, "schema")()
return node

def _validate(self, prediction: Any, schema: Any) -> dict:
from jsonschema import ValidationError, validate # noqa: F401

try:
validate(instance=prediction, schema=schema)
return {
"score": True,
}
except ValidationError as e:
return {"score": False, "reasoning": repr(e)}

def _evaluate_strings(
self,
prediction: Union[str, Any],
input: Union[str, Any] = None,
reference: Union[str, Any] = None,
**kwargs: Any,
) -> dict:
parsed_prediction = self._parse_json(prediction)
schema = self._parse_json(reference)
return self._validate(parsed_prediction, schema)
16 changes: 9 additions & 7 deletions libs/langchain/langchain/evaluation/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from abc import ABC, abstractmethod
from enum import Enum
from functools import partial
from typing import Any, Optional, Sequence, Tuple
from typing import Any, Optional, Sequence, Tuple, Union
from warnings import warn

from langchain.chains.base import Chain
Expand Down Expand Up @@ -66,6 +66,8 @@ class EvaluatorType(str, Enum):
"""Check if a prediction is equal to a reference JSON."""
JSON_EDIT_DISTANCE = "json_edit_distance"
"""Compute the edit distance between two JSON strings after canonicalization."""
JSON_SCHEMA_VALIDATION = "json_schema_validation"
"""Check if a prediction is valid JSON according to a JSON schema."""


class LLMEvalChain(Chain):
Expand Down Expand Up @@ -144,9 +146,9 @@ def requires_reference(self) -> bool:
def _evaluate_strings(
self,
*,
prediction: str,
reference: Optional[str] = None,
input: Optional[str] = None,
prediction: Union[str, Any],
reference: Optional[Union[str, Any]] = None,
input: Optional[Union[str, Any]] = None,
**kwargs: Any,
) -> dict:
"""Evaluate Chain or LLM output, based on optional input and label.
Expand All @@ -167,9 +169,9 @@ def _evaluate_strings(
async def _aevaluate_strings(
self,
*,
prediction: str,
reference: Optional[str] = None,
input: Optional[str] = None,
prediction: Union[str, Any],
reference: Optional[Union[str, Any]] = None,
input: Optional[Union[str, Any]] = None,
**kwargs: Any,
) -> dict:
"""Asynchronously evaluate Chain or LLM output, based on optional input and label.
Expand Down
Loading