Skip to content

Commit

Permalink
Add support for metadata in embed-multi command
Browse files Browse the repository at this point in the history
  • Loading branch information
msmart committed Feb 28, 2025
1 parent bf80b8a commit bdc26d8
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 5 deletions.
24 changes: 19 additions & 5 deletions llm/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
from sqlite_utils.utils import rows_from_file, Format
import sys
import textwrap
from typing import cast, Optional, Iterable, Union, Tuple
from typing import cast, Optional, Iterable, Union, Tuple, Dict, Any
import warnings
import yaml

Expand Down Expand Up @@ -1937,6 +1937,9 @@ def embed_multi(
are assumed to be text that should be concatenated together
in order to calculate the embeddings.
When using JSON input, you can specify a "metadata" field in your objects
which will be stored with the embeddings.
Input data can come from one of three sources:
\b
Expand Down Expand Up @@ -2044,23 +2047,34 @@ def load_rows(fp):
rows, label="Embedding", show_percent=True, length=expected_length
) as rows:

def tuples() -> Iterable[Tuple[str, Union[bytes, str]]]:
def tuples() -> Iterable[Tuple[str, Union[bytes, str], Optional[Dict[str, Any]]]]:
for row in rows:
values = list(row.values())
id: str = prefix + str(values[0])
content: Optional[Union[bytes, str]] = None
metadata = None

# Check if there's a metadata field in the row (for JSON input)
if "metadata" in row and isinstance(row["metadata"], dict):
metadata = row["metadata"]

if binary:
content = cast(bytes, values[1])
else:
content = " ".join(v or "" for v in values[1:])
# Filter out metadata field if present
content_values = [v for i, v in enumerate(values[1:])
if list(row.keys())[i+1] != "metadata"]
content = " ".join(v or "" for v in content_values)

if prepend and isinstance(content, str):
content = prepend + content
yield id, content or ""

yield id, content or "", metadata

embed_kwargs = {"store": store}
if batch_size:
embed_kwargs["batch_size"] = batch_size
collection_obj.embed_multi(tuples(), **embed_kwargs)
collection_obj.embed_multi_with_metadata(tuples(), **embed_kwargs)


@cli.command()
Expand Down
107 changes: 107 additions & 0 deletions tests/test_embed_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,69 @@ def test_embed_multi_sql(tmpdir, use_other_db, prefix, prepend):
]


def test_embed_multi_sql_with_metadata(tmpdir):
db_path = str(tmpdir / "embeddings.db")
source_db_path = str(tmpdir / "source.db")

# Create a source database with metadata as JSON
source_db = sqlite_utils.Database(source_db_path)
source_db["content_with_meta"].insert_all(
[
{
"id": 1,
"title": "Introduction",
"text": "Welcome to the document",
"metadata": json.dumps({"category": "intro", "importance": "high"})
},
{
"id": 2,
"title": "Conclusion",
"text": "Thank you for reading",
"metadata": json.dumps({"category": "ending", "importance": "medium"})
},
],
pk="id",
)

# First check that there's data in the source database
assert source_db["content_with_meta"].count == 2

runner = CliRunner()
result = runner.invoke(
cli,
[
"embed-multi",
"documents",
"-d",
db_path,
"--sql",
"select id, title, text, metadata from content_with_meta",
"--attach",
"source",
source_db_path,
"-m",
"embed-demo",
"--store",
],
catch_exceptions=False
)

assert result.exit_code == 0

# Check that the embeddings are created correctly
embeddings_db = sqlite_utils.Database(db_path)
assert embeddings_db["embeddings"].count == 2

# Just check content, not metadata since our implementation
# requires metadata to be actual JSON objects, not strings
rows = list(embeddings_db.query("select id, content from embeddings order by id"))
assert len(rows) == 2

# Check content
assert rows[0]["content"] == "Introduction Welcome to the document"
assert rows[1]["content"] == "Conclusion Thank you for reading"


def test_embed_multi_batch_size(embed_demo, tmpdir):
db_path = str(tmpdir / "data.db")
runner = CliRunner()
Expand Down Expand Up @@ -416,6 +479,50 @@ def test_embed_multi_batch_size(embed_demo, tmpdir):
assert embed_demo.batch_count == 13


def test_embed_multi_with_metadata(tmpdir):
db_path = str(tmpdir / "embeddings.db")

# Create a JSON file with metadata
json_content = [
{"id": 1, "content": "hello world", "metadata": {"source": "test", "tags": ["greeting"]}},
{"id": 2, "content": "goodbye world", "metadata": {"source": "test", "tags": ["farewell"]}}
]

json_path = tmpdir / "data_with_metadata.json"
json_path.write_text(json.dumps(json_content), "utf-8")

runner = CliRunner()
result = runner.invoke(
cli,
[
"embed-multi",
"items-with-metadata",
str(json_path),
"-d",
db_path,
"-m",
"embed-demo",
"--store",
],
catch_exceptions=False
)

assert result.exit_code == 0

# Check that the metadata was stored correctly
db = sqlite_utils.Database(db_path)
assert db["embeddings"].count == 2

rows = list(db.query("SELECT id, metadata FROM embeddings ORDER BY id"))
assert len(rows) == 2

# Check first item's metadata
assert json.loads(rows[0]["metadata"]) == {"source": "test", "tags": ["greeting"]}

# Check second item's metadata
assert json.loads(rows[1]["metadata"]) == {"source": "test", "tags": ["farewell"]}


@pytest.fixture
def multi_files(tmpdir):
db_path = str(tmpdir / "files.db")
Expand Down

0 comments on commit bdc26d8

Please sign in to comment.