Skip to content

Commit 6cbd4f8

Browse files
committed
Added Argument for Aggregation
1 parent 21dae65 commit 6cbd4f8

15 files changed

+124
-92
lines changed

groqeval/metrics/answer_relevance.py

+5-16
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# groqeval/metrics/answer_relevance.py
22
import json
33
from groq import Groq
4+
from cachetools import cached, TTLCache
45
from groqeval.models.output import Output, ScoredOutput
56
from groqeval.metrics.base_metric import BaseMetric
67

@@ -75,6 +76,7 @@ def output_decomposition(self):
7576
self.logger.info("Decomposition of the Output into Statements: %s", response.choices[0].message.content)
7677
return Output.model_validate_json(response.choices[0].message.content)
7778

79+
@cached(cache=TTLCache(maxsize=100, ttl=300))
7880
def score_relevance(self):
7981
"""
8082
Each identified statement is then scored on a scale from 1 (completely irrelevant)
@@ -96,19 +98,6 @@ def score_relevance(self):
9698
self.logger.info("Breakdown of the Answer Relevance Score: %s", response.choices[0].message.content)
9799
return ScoredOutput.model_validate_json(response.choices[0].message.content), json.loads(response.choices[0].message.content)
98100

99-
def score(self):
100-
"""
101-
Aggregation of individual scores and final result.
102-
"""
103-
scored_output, output_dictionary = self.score_relevance()
104-
if scored_output.scores:
105-
average_score = sum([output.score for output in scored_output.scores]) / len(scored_output.scores)
106-
return {
107-
'score': average_score,
108-
'score_breakdown': output_dictionary
109-
}
110-
else:
111-
return {
112-
'score': 0, # Default to 0 if there are no sentences to score
113-
'score_breakdown': output_dictionary
114-
}
101+
@property
102+
def scoring_function(self):
103+
return self.score_relevance

groqeval/metrics/base_metric.py

+32-6
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,21 @@
11
import logging
2+
import statistics
3+
from abc import ABC,abstractmethod
4+
from cachetools import cached, TTLCache
25
from groq import Groq
36

4-
class BaseMetric:
7+
class BaseMetric(ABC):
58
"""
69
The Base Metric class.
710
"""
811
def __init__(self, groq_client: Groq, verbose: bool = None):
912
self.groq_client = groq_client
13+
self.aggregation = statistics.mean
1014
self.logger = logging.getLogger(__name__)
1115
if verbose:
1216
self.logger.setLevel(logging.INFO)
1317

14-
18+
@cached(cache=TTLCache(maxsize=100, ttl=300))
1519
def groq_chat_completion(self, messages, model, temperature=0.5, response_format=None):
1620
"""
1721
Groq's chat completion API
@@ -43,8 +47,30 @@ def check_data_types(self, **kwargs):
4347
else:
4448
if not all(isinstance(item, str) for item in value):
4549
raise TypeError(f"All items in '{key}' must be strings")
46-
47-
48-
49-
def score(self):
50+
51+
@property
52+
@abstractmethod
53+
def scoring_function(self):
54+
"""
55+
This property should be implemented by each child class
56+
"""
5057
raise NotImplementedError("This method should be overridden by subclasses")
58+
59+
def score(self, aggregation = None):
60+
"""
61+
Aggregation of individual scores and final result.
62+
"""
63+
if aggregation is not None:
64+
self.aggregation = aggregation
65+
scored_output, output_dictionary = self.scoring_function()
66+
if scored_output.scores:
67+
average_score = self.aggregation([output.score for output in scored_output.scores])
68+
return {
69+
'score': average_score,
70+
'score_breakdown': output_dictionary
71+
}
72+
else:
73+
return {
74+
'score': 0, # Default to 0 if there are no sentences to score
75+
'score_breakdown': output_dictionary
76+
}

groqeval/metrics/bias.py

+8-14
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# groqeval/metrics/bias.py
22
import json
33
from groq import Groq
4+
from cachetools import cached, TTLCache
45
from groqeval.models.output import Output, ScoredOutput
56
from groqeval.metrics.base_metric import BaseMetric
67

@@ -16,6 +17,8 @@ def __init__(self, groq_client: Groq, output: str, prompt: str, **kwargs):
1617
super().__init__(groq_client, kwargs.get('verbose'))
1718
self.output = output
1819
self.prompt = prompt
20+
self.aggregation = max
21+
1922
self.check_data_types(prompt=prompt, output=output)
2023

2124
@property
@@ -79,6 +82,7 @@ def output_decomposition(self):
7982
self.logger.info("Decomposition of the Output into Opinions: %s", response.choices[0].message.content)
8083
return Output.model_validate_json(response.choices[0].message.content)
8184

85+
@cached(cache=TTLCache(maxsize=100, ttl=300))
8286
def score_bias(self):
8387
"""
8488
Each opinion in the output is scored on a scale from 1 (completely unbiased)
@@ -99,17 +103,7 @@ def score_bias(self):
99103
)
100104
self.logger.info("Breakdown of the Bias Score: %s", response.choices[0].message.content)
101105
return ScoredOutput.model_validate_json(response.choices[0].message.content), json.loads(response.choices[0].message.content)
102-
103-
def score(self):
104-
scored_output, output_dictionary = self.score_bias()
105-
if scored_output.scores:
106-
average_score = max([output.score for output in scored_output.scores])
107-
return {
108-
'score': average_score,
109-
'score_breakdown': output_dictionary
110-
}
111-
else:
112-
return {
113-
'score': 0, # Default to 0 if there are no sentences to score
114-
'score_breakdown': output_dictionary
115-
}
106+
107+
@property
108+
def scoring_function(self):
109+
return self.score_bias

groqeval/metrics/context_relevance.py

+5-13
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import json
33
from typing import List
44
from groq import Groq
5+
from cachetools import cached, TTLCache
56
from groqeval.models.context import Context, ScoredContext
67
from groqeval.metrics.base_metric import BaseMetric
78

@@ -88,6 +89,7 @@ def context_decomposition(self):
8889
self.logger.info("Decomposition of the Context into Statements: %s", response.choices[0].message.content)
8990
return Context.model_validate_json(response.choices[0].message.content)
9091

92+
@cached(cache=TTLCache(maxsize=100, ttl=300))
9193
def score_relevance(self):
9294
"""
9395
Each statement of context is evaluated to determine if it can be
@@ -113,16 +115,6 @@ def score_relevance(self):
113115
self.logger.info("Breakdown of the Context Relevance Score: %s", response.choices[0].message.content)
114116
return ScoredContext.model_validate_json(response.choices[0].message.content), json.loads(response.choices[0].message.content)
115117

116-
def score(self):
117-
scored_context, output_dictionary = self.score_relevance()
118-
if scored_context.scores:
119-
average_score = sum([context.score for context in scored_context.scores]) / len(scored_context.scores)
120-
return {
121-
'score': average_score,
122-
'score_breakdown': output_dictionary
123-
}
124-
else:
125-
return {
126-
'score': 0, # Default to 0 if there are no sentences to score
127-
'score_breakdown': output_dictionary
128-
}
118+
@property
119+
def scoring_function(self):
120+
return self.score_relevance

groqeval/metrics/faithfulness.py

+7-15
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import json
33
from typing import List
44
from groq import Groq
5+
from cachetools import cached, TTLCache
56
from groqeval.models.output import Output, ScoredOutput
67
from groqeval.metrics.base_metric import BaseMetric
78

@@ -15,7 +16,7 @@ class Faithfulness(BaseMetric):
1516
def __init__(self, groq_client: Groq, context: List[str], output: str, **kwargs):
1617
super().__init__(groq_client, kwargs.get('verbose'))
1718
self.context = context
18-
self.output = output
19+
self.output = output
1920
self.check_data_types(context=context, output=output)
2021

2122
@property
@@ -89,6 +90,7 @@ def output_decomposition(self):
8990
self.logger.info("Decomposition of the Output into Claims: %s", response.choices[0].message.content)
9091
return Output.model_validate_json(response.choices[0].message.content)
9192

93+
@cached(cache=TTLCache(maxsize=100, ttl=300))
9294
def score_faithfulness(self):
9395
"""
9496
Claims are then scored on a scale from 1 to 10.
@@ -114,17 +116,7 @@ def score_faithfulness(self):
114116
)
115117
self.logger.info("Breakdown of the Faithfulness Score: %s", response.choices[0].message.content)
116118
return ScoredOutput.model_validate_json(response.choices[0].message.content), json.loads(response.choices[0].message.content)
117-
118-
def score(self):
119-
scored_output, output_dictionary = self.score_faithfulness()
120-
if scored_output.scores:
121-
average_score = sum([output.score for output in scored_output.scores]) / len(scored_output.scores)
122-
return {
123-
'score': average_score,
124-
'score_breakdown': output_dictionary
125-
}
126-
else:
127-
return {
128-
'score': 0, # Default to 0 if there are no sentences to score
129-
'score_breakdown': output_dictionary
130-
}
119+
120+
@property
121+
def scoring_function(self):
122+
return self.score_faithfulness

groqeval/metrics/hallucination.py

+5-13
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import json
33
from typing import List
44
from groq import Groq
5+
from cachetools import cached, TTLCache
56
from groqeval.models.context import Context, ScoredContext
67
from groqeval.metrics.base_metric import BaseMetric
78

@@ -98,6 +99,7 @@ def context_decomposition(self):
9899
self.logger.info("Decomposition of the Context into Statements: %s", response.choices[0].message.content)
99100
return Context.model_validate_json(response.choices[0].message.content)
100101

102+
@cached(cache=TTLCache(maxsize=100, ttl=300))
101103
def score_hallucination(self):
102104
"""
103105
The hallucination metric evaluates the alignment between an output and its context,
@@ -119,16 +121,6 @@ def score_hallucination(self):
119121
self.logger.info("Breakdown of the Hallucination Score: %s", response.choices[0].message.content)
120122
return ScoredContext.model_validate_json(response.choices[0].message.content), json.loads(response.choices[0].message.content)
121123

122-
def score(self):
123-
scored_context, output_dictionary = self.score_hallucination()
124-
if scored_context.scores:
125-
average_score = sum([context.score for context in scored_context.scores]) / len(scored_context.scores)
126-
return {
127-
'score': average_score,
128-
'score_breakdown': output_dictionary
129-
}
130-
else:
131-
return {
132-
'score': 0, # Default to 0 if there are no sentences to score
133-
'score_breakdown': output_dictionary
134-
}
124+
@property
125+
def scoring_function(self):
126+
return self.score_hallucination

groqeval/metrics/toxicity.py

+8-13
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# groqeval/metrics/toxicity.py
22
import json
33
from groq import Groq
4+
from cachetools import cached, TTLCache
45
from groqeval.models.output import Output, ScoredOutput
56
from groqeval.metrics.base_metric import BaseMetric
67

@@ -16,6 +17,8 @@ def __init__(self, groq_client: Groq, output: str, prompt: str, **kwargs):
1617
super().__init__(groq_client, kwargs.get('verbose'))
1718
self.output = output
1819
self.prompt = prompt
20+
self.aggregation = max
21+
1922
self.check_data_types(prompt=prompt, output=output)
2023

2124

@@ -78,6 +81,7 @@ def output_decomposition(self):
7881
self.logger.info("Breakdown of the Toxicity Score: %s", response.choices[0].message.content)
7982
return Output.model_validate_json(response.choices[0].message.content)
8083

84+
@cached(cache=TTLCache(maxsize=100, ttl=300))
8185
def score_toxicity(self):
8286
"""
8387
Each phrase is examined to see if it represents an opinion
@@ -100,16 +104,7 @@ def score_toxicity(self):
100104
)
101105
return ScoredOutput.model_validate_json(response.choices[0].message.content), json.loads(response.choices[0].message.content)
102106

103-
def score(self):
104-
scored_output, output_dictionary = self.score_toxicity()
105-
if scored_output.scores:
106-
average_score = max([output.score for output in scored_output.scores])
107-
return {
108-
'score': average_score,
109-
'score_breakdown': output_dictionary
110-
}
111-
else:
112-
return {
113-
'score': 0, # Default to 0 if there are no sentences to score
114-
'score_breakdown': output_dictionary
115-
}
107+
@property
108+
def scoring_function(self):
109+
return self.score_toxicity
110+

pyproject.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ requires-python = ">=3.10"
2222

2323
dependencies = [
2424
"groq>=0.9.0",
25-
"pydantic>=2.7.4"
25+
"pydantic>=2.7.4",
26+
"cachetools>=5.3.3"
2627
]
2728

2829
[tool.twine]

requirements.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
groq==0.9.0
2-
pydantic==2.7.4
2+
pydantic==2.7.4
3+
cachetools==5.3.3

tests/metrics/test_answer_relevance.py

+8
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import pytest
22
import math
3+
import statistics
34

45
@pytest.mark.parametrize("prompt, output, expected_score_range", [
56
("Discuss the implications of quantum computing on data security.",
@@ -23,6 +24,13 @@ def test_answer_relevance_scoring(evaluator, prompt, output, expected_score_rang
2324
result = answer_relevance.score()
2425
assert math.ceil(result['score']) >= expected_score_range[0] and math.floor(result['score']) <= expected_score_range[1], f"Score {result['score']} not in range {expected_score_range}"
2526

27+
max_score = answer_relevance.score(max)['score']
28+
mean_score = answer_relevance.score(statistics.mean)['score']
29+
min_score = answer_relevance.score(min)['score']
30+
31+
assert max_score >= mean_score, f"Max score {max_score} is not greater than mean score {mean_score}"
32+
assert min_score <= mean_score, f"Min score {min_score} is not less than mean score {mean_score}"
33+
2634
def test_answer_relevance_empty_output(evaluator):
2735
prompt = "What are the benefits of meditation?"
2836
output = ""

tests/metrics/test_bias.py

+8
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import pytest
22
import math
3+
import statistics
34

45
@pytest.mark.parametrize("prompt, output, expected_score_range", [
56
@@ -24,6 +25,13 @@ def test_bias_scoring(evaluator, prompt, output, expected_score_range):
2425
result = bias.score()
2526
assert math.ceil(result['score']) >= expected_score_range[0] and math.floor(result['score']) <= expected_score_range[1], f"Score {result['score']} not in range {expected_score_range}"
2627

28+
max_score = bias.score(max)['score']
29+
mean_score = bias.score(statistics.mean)['score']
30+
min_score = bias.score(min)['score']
31+
32+
assert max_score >= mean_score, f"Max score {max_score} is not greater than mean score {mean_score}"
33+
assert min_score <= mean_score, f"Min score {min_score} is not less than mean score {mean_score}"
34+
2735
def test_bias_empty_output(evaluator):
2836
prompt = "Evaluate the current role of renewable energy in economic development."
2937
output = ""

tests/metrics/test_context_relevance.py

+9
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import pytest
22
import math
3+
import statistics
34

45
@pytest.mark.parametrize("prompt, context, expected_score_range", [
56
("Describe the impact of climate change on polar bears.",
@@ -23,6 +24,14 @@ def test_context_relevance_scoring(evaluator, prompt, context, expected_score_ra
2324
result = context_relevance.score()
2425
assert math.ceil(result['score']) >= expected_score_range[0] and math.floor(result['score']) <= expected_score_range[1], f"Score {result['score']} not in range {expected_score_range}"
2526

27+
max_score = context_relevance.score(max)['score']
28+
mean_score = context_relevance.score(statistics.mean)['score']
29+
min_score = context_relevance.score(min)['score']
30+
31+
assert max_score >= mean_score, f"Max score {max_score} is not greater than mean score {mean_score}"
32+
assert min_score <= mean_score, f"Min score {min_score} is not less than mean score {mean_score}"
33+
34+
2635
def test_context_relevance_empty_context(evaluator):
2736
prompt = "What are the benefits of meditation?"
2837
context = []

0 commit comments

Comments
 (0)