Skip to content

Improvements and validation #32

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/.ruff_cache
/.venv
/.venv*
__pycache__
.intentguard
.mypy_cache
Expand Down
3 changes: 3 additions & 0 deletions ai_research/dataset_generation/domain/category.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ var items = []string{
"Code Style",
"Business Logic",
"Miscellaneous",
"Function/Method Calls and Arguments",
"Variable Scope and Lifetime",
"Control Flow Understanding",
}

func GetRandomCategory() string {
Expand Down
2 changes: 1 addition & 1 deletion intentguard/infrastructure/fs_judgement_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class FsJudgementCache(JudgementCache):
"""

def __init__(self):
self.cache_dir = Path(".intentguard")
self.cache_dir = Path(".intentguard") / "cache"
logger.debug("Initialized cache directory at %s", self.cache_dir)

def _get_cache_file_path(
Expand Down
211 changes: 132 additions & 79 deletions intentguard/infrastructure/llamafile.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import json
import logging
import platform
import re
import subprocess
import time
import os
Expand All @@ -12,6 +11,7 @@
from typing import List
import threading
import atexit
import socket

from intentguard.app.inference_options import InferenceOptions
from intentguard.app.inference_provider import InferenceProvider
Expand All @@ -26,10 +26,11 @@
CONTEXT_SIZE = 8192
MODEL_FILENAME = "IntentGuard-1.Q8_0.gguf"
MODEL_NAME = "IntentGuard-1"
LLAMAFILE_URL = "https://github.com/Mozilla-Ocho/llamafile/releases/download/0.8.17/llamafile-0.8.17" # URL for llamafile
LLAMAFILE_SHA256 = "1041e05b2c254674e03c66052b1a6cf646e8b15ebd29a195c77fed92cac60d6b" # SHA-256 checksum for llamafile
LLAMAFILE_URL = "https://github.com/Mozilla-Ocho/llamafile/releases/download/0.9.0/llamafile-0.9.0" # URL for llamafile
LLAMAFILE_SHA256 = "5a93cafd16abe61e79761575436339693806385a1f0b7d625024e9f91e71bcf1" # SHA-256 checksum for llamafile
GGUF_URL = "https://huggingface.co/kdunee/IntentGuard-1/resolve/main/IntentGuard-1.Q8_0.gguf" # URL for the GGUF file
GGUF_SHA256 = "0cb9476a129e7fc68b419ab86397b9ce4309b0d5faf6ba5d18629e796ca01383" # SHA-256 checksum for the GGUF file
MAX_RETRY_ATTEMPTS = 3 # Maximum number of retries for handling connection errors

STORAGE_DIR = Path(".intentguard")

Expand All @@ -50,7 +51,7 @@ def verify_checksum(file_path: Path, expected_sha256: str) -> bool:

def download_file(url: str, target_path: Path, expected_sha256: str):
"""Download a file and verify its checksum."""
print(f"Downloading {url} to {target_path}...")
logger.info(f"Downloading {url} to {target_path}...")

# Create parent directories if they don't exist
target_path.parent.mkdir(parents=True, exist_ok=True)
Expand All @@ -63,22 +64,33 @@ def download_file(url: str, target_path: Path, expected_sha256: str):
target_path.unlink() # Delete the file if checksum verification fails
raise ValueError(f"Checksum verification failed for {target_path}")

print(f"Successfully downloaded and verified {target_path}")
logger.info(f"Successfully downloaded and verified {target_path}")


def ensure_file(url: str, target_path: Path, expected_sha256: str):
"""Ensure a file exists with the correct checksum."""
if target_path.exists():
if verify_checksum(target_path, expected_sha256):
print(
logger.debug(
f"{target_path} already exists with correct checksum, skipping download"
)
return
print(f"{target_path} exists but has incorrect checksum, re-downloading")
logger.debug(f"{target_path} exists but has incorrect checksum, re-downloading")
target_path.unlink()
download_file(url, target_path, expected_sha256)


def get_free_port():
"""
Dynamically finds a free port on localhost.
"""
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.bind(("localhost", 0)) # Bind to port 0 to let OS choose a free port
port = sock.getsockname()[1] # Get the port number assigned by the OS
sock.close() # Close the socket (port is still considered occupied temporarily)
return port


class Llamafile(InferenceProvider):
"""
Implementation of InferenceProvider using a local Llamafile server.
Expand Down Expand Up @@ -119,6 +131,7 @@ def shutdown(self):
logger.debug("Killed llamafile server process")
except Exception as e:
logger.warning("Failed to kill llamafile server process: %s", e)

self._process = None
self._port = None

Expand All @@ -131,8 +144,7 @@ def _ensure_process(self):
concurrent initialization attempts.

Raises:
Exception: If the server fails to start, if the port cannot be
detected within the startup timeout period, or if file downloads
Exception: If the server fails to start, if file downloads
or verification fail.
"""
if self._process is not None:
Expand All @@ -148,6 +160,9 @@ def _ensure_process(self):
ensure_file(LLAMAFILE_URL, llamafile_path, LLAMAFILE_SHA256)
ensure_file(GGUF_URL, model_path, GGUF_SHA256)

# Get a free port and use it directly
self._port = get_free_port()

command = [
str(llamafile_path),
"--server",
Expand All @@ -157,6 +172,8 @@ def _ensure_process(self):
str(CONTEXT_SIZE),
"--host",
"127.0.0.1",
"--port",
str(self._port),
"--nobrowser",
]

Expand All @@ -169,74 +186,47 @@ def _ensure_process(self):
logger.warning(f"Failed to make llamafile executable: {e}")
command.insert(0, "sh")

# Start the process with stdout/stderr redirected to devnull
self._process = subprocess.Popen(
command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
command, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
)

# Give the server a moment to start
start_time = time.time()
while time.time() - start_time < STARTUP_TIMEOUT_SECONDS:
line = self._process.stderr.readline()
if not line:
if self._process.poll() is not None:
break
continue
match = re.search(
r"llama server listening at http://127.0.0.1:(\d+)", line
)
if match:
self._port = int(match.group(1))
break

if not self._port:
status = self._process.poll()
if status is not None:
# Check if the process is still running
if self._process.poll() is not None:
status = self._process.poll()
logger.error(
"Llamafile server failed to start with status %d", status
)
raise Exception(f"llamafile exited with status {status}")
else:
logger.error(
"Could not detect Llamafile server port within timeout period"
)
self._process.kill()
raise Exception("Could not find port in llamafile output")

logger.info("Llamafile server started successfully on port %d", self._port)
# Try to connect to the server to verify it's ready
try:
with socket.create_connection(("127.0.0.1", self._port), timeout=1):
logger.info(
"Llamafile server started successfully on port %d",
self._port,
)
return
except (socket.timeout, ConnectionRefusedError):
# Wait a bit before trying again
time.sleep(1)

# If we get here, the server didn't start within the timeout
self._process.kill()
self._process = None
self._port = None
raise Exception(
f"Llamafile server failed to start within {STARTUP_TIMEOUT_SECONDS} seconds"
)

def predict(
self, prompt: List[Message], inference_options: InferenceOptions
) -> Evaluation:
def _send_http_request(self, payload: dict) -> dict:
"""
Generate a prediction using the Llamafile server.

Ensures the server is running (starting it if necessary) and makes an HTTP
request to the local server's chat completions endpoint, formatting the
input according to the OpenAI API specification. The response is expected
to be a JSON object containing a result and optional explanation.

Args:
prompt: List of messages forming the input prompt
inference_options: Configuration options for the inference

Returns:
An Evaluation object containing the model's assessment

Raises:
Exception: If the server fails to start, returns an error, if the
response cannot be parsed, or if the request times out
Sends an HTTP request to the Llamafile server with the given payload
and returns the parsed JSON response.
"""
self._ensure_process()
logger.debug(
"Preparing prediction request with temperature %.2f",
inference_options.temperature,
)
messages = [{"role": m.role, "content": m.content} for m in prompt]
payload = {
"model": MODEL_NAME,
"messages": messages,
"temperature": inference_options.temperature,
}

conn = http.client.HTTPConnection(
"127.0.0.1", self._port, timeout=INFERENCE_TIMEOUT_SECONDS
)
Expand All @@ -258,24 +248,87 @@ def predict(
raise Exception(error_msg)

json_response = json.loads(data)

if not json_response["choices"]:
if not json_response.get("choices"):
error_msg = f"Llamafile API returned no choices: {json_response}"
logger.error(error_msg)
raise Exception(error_msg)
return json_response

generated_text = json_response["choices"][0]["message"]["content"]
if generated_text.endswith("<|eot_id|>"):
generated_text = generated_text[: -len("<|eot_id|>")]
def predict(
self, prompt: List[Message], inference_options: InferenceOptions
) -> Evaluation:
"""
Generate a prediction using the Llamafile server.

try:
llm_response = json.loads(generated_text)
If a timeout occurs, the method will retry up to MAX_RETRY_ATTEMPTS times,
restarting the server process between attempts.
"""
messages = [{"role": m.role, "content": m.content} for m in prompt]
payload = {
"model": MODEL_NAME,
"messages": messages,
"temperature": inference_options.temperature,
}

return Evaluation(
result=llm_response["result"],
explanation=llm_response["explanation"],
)
except json.JSONDecodeError as e:
error_msg = f"Could not parse Llamafile response: {generated_text}"
logger.error(error_msg)
raise Exception(error_msg) from e
attempts = 0
last_error = None

while attempts < MAX_RETRY_ATTEMPTS:
attempts += 1
try:
self._ensure_process()
logger.debug(
f"Attempt {attempts}/{MAX_RETRY_ATTEMPTS}: Preparing prediction request with temperature {inference_options.temperature:.2f}"
)

json_response = self._send_http_request(payload)
generated_text = json_response["choices"][0]["message"]["content"]
if generated_text.endswith("<|eot_id|>"):
generated_text = generated_text[: -len("<|eot_id|>")]

# Fix common JSON parsing issues
generated_text = (
generated_text.replace('"""', '\\"\\"\\"')
.replace('\\\\"\\"\\"', '\\"\\"\\"')
.replace('\\\\"', '\\"')
)

try:
llm_response = json.loads(generated_text)
return Evaluation(
result=llm_response["result"],
explanation=llm_response["explanation"],
)
except json.JSONDecodeError as e:
error_msg = f"Could not parse Llamafile response: {generated_text}"
logger.error(error_msg)
raise Exception(error_msg) from e

except (
socket.timeout,
TimeoutError,
ConnectionRefusedError,
ConnectionError,
http.client.HTTPException,
) as e:
last_error = e
logger.warning(
f"Error occurred during attempt {attempts}/{MAX_RETRY_ATTEMPTS}: {e}"
)
if attempts < MAX_RETRY_ATTEMPTS:
logger.info("Restarting llamafile process and retrying...")
self.shutdown() # Kill the existing process
else:
logger.error(
f"Failed after {MAX_RETRY_ATTEMPTS} attempts due to timeouts"
)
raise Exception(
f"Llamafile request failed after {MAX_RETRY_ATTEMPTS} attempts"
) from last_error
except Exception as e:
logger.error(f"Error during prediction: {e}")
raise

raise Exception(
f"Llamafile request failed after {MAX_RETRY_ATTEMPTS} attempts"
) from last_error
31 changes: 31 additions & 0 deletions validation/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Model Validation Framework

This framework provides a systematic approach to evaluate the model's performance in code property verification tasks.

## Methodology

### Test Configuration
- Each test example is evaluated multiple times (15 total evaluations)
- Tests are organized in groups of 3 evaluations (5 trials × 3 evaluations per trial)
- A voting mechanism is applied within each group (jury size = 3)

### Success Criteria
- A single trial succeeds if the majority vote within its jury agrees (≥2 out of 3)
- A test example passes only if ALL 5 trials succeed
- This strict requirement ensures high confidence in the model's consistency

## Metrics

The framework calculates the following metrics:

### Primary Metrics
- **Accuracy**: (True Positives + True Negatives) / Total Cases
- **Precision**: True Positives / (True Positives + False Positives)
- **Recall**: True Positives / (True Positives + False Negatives)

## Implementation Notes

- The multiple trial approach helps identify inconsistencies in model behavior
- The strict all-trials-must-pass requirement minimizes false positives
- Caching must be disabled during validation experiments to ensure independent evaluations
- Each evaluation must be performed with a fresh model context
3 changes: 3 additions & 0 deletions validation/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
datasets
../
tqdm
Loading
Loading