Skip to content

Commit

Permalink
fix: fixed bugs in rag
Browse files Browse the repository at this point in the history
  • Loading branch information
ErikBjare committed Nov 15, 2024
1 parent 59d0e70 commit 445e49a
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 9 deletions.
14 changes: 8 additions & 6 deletions gptme/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import os
from dataclasses import dataclass, field
from functools import lru_cache
from pathlib import Path

import tomlkit
Expand Down Expand Up @@ -73,12 +74,12 @@ class ProjectConfig:
def get_config() -> Config:
global _config
if _config is None:
_config = load_config()
_config = _load_config()
return _config


def load_config() -> Config:
config = _load_config()
def _load_config() -> Config:
config = _load_config_doc()
assert "prompt" in config, "prompt key missing in config"
assert "env" in config, "env key missing in config"
prompt = config.pop("prompt")
Expand All @@ -88,7 +89,7 @@ def load_config() -> Config:
return Config(prompt=prompt, env=env)


def _load_config() -> tomlkit.TOMLDocument:
def _load_config_doc() -> tomlkit.TOMLDocument:
# Check if the config file exists
if not os.path.exists(config_path):
# If not, create it and write some default settings
Expand All @@ -106,7 +107,7 @@ def _load_config() -> tomlkit.TOMLDocument:


def set_config_value(key: str, value: str) -> None: # pragma: no cover
doc: TOMLDocument | Container = _load_config()
doc: TOMLDocument | Container = _load_config_doc()

# Set the value
keypath = key.split(".")
Expand All @@ -121,9 +122,10 @@ def set_config_value(key: str, value: str) -> None: # pragma: no cover

# Reload config
global _config
_config = load_config()
_config = _load_config()


@lru_cache
def get_project_config(workspace: Path) -> ProjectConfig | None:
project_config_paths = [
p
Expand Down
4 changes: 3 additions & 1 deletion gptme/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ def get_context(self, query: str, max_tokens: int = 1000) -> list[Context]:
class RAGContextProvider(ContextProvider):
"""Context provider using RAG."""

# TODO: refactor this to share code with rag tool

def __init__(self):
try:
self._has_rag = True
Expand All @@ -45,7 +47,7 @@ def __init__(self):
# Storage configuration
self.indexer = gptme_rag.Indexer(
persist_directory=config.rag.get("index_path", "~/.cache/gptme/rag"),
collection_name=config.rag.get("collection", "gptme_docs"),
collection_name=config.rag.get("collection", "default"),
)

# Context enhancement configuration
Expand Down
4 changes: 2 additions & 2 deletions gptme/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from dotenv import load_dotenv

from .config import config_path, load_config, set_config_value
from .config import config_path, get_config, set_config_value
from .llm import init_llm
from .models import (
PROVIDERS,
Expand All @@ -30,7 +30,7 @@ def init(model: str | None, interactive: bool, tool_allowlist: list[str] | None)
logger.debug("Started")
load_dotenv()

config = load_config()
config = get_config()

# get from config
if not model:
Expand Down

0 comments on commit 445e49a

Please sign in to comment.