Skip to content

Commit

Permalink
feat: store Entry suffix separately (#503)
Browse files Browse the repository at this point in the history
* feat: save entry suffix separately

* change LibraryPrefs to allow identical values, add test
  • Loading branch information
yedpodtrzitko authored Oct 7, 2024
1 parent 1c7aaf0 commit e075282
Show file tree
Hide file tree
Showing 15 changed files with 303 additions and 110 deletions.
11 changes: 0 additions & 11 deletions tagstudio/src/core/constants.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
from enum import Enum

VERSION: str = "9.3.2" # Major.Minor.Patch
VERSION_BRANCH: str = "" # Usually "" or "Pre-Release"

# The folder & file names where TagStudio keeps its data relative to a library.
TS_FOLDER_NAME: str = ".TagStudio"
BACKUP_FOLDER_NAME: str = "backups"
COLLAGE_FOLDER_NAME: str = "collages"
LIBRARY_FILENAME: str = "ts_library.json"

# TODO: Turn this whitelist into a user-configurable blacklist.
IMAGE_TYPES: list[str] = [
Expand Down Expand Up @@ -122,13 +119,5 @@
+ SHORTCUT_TYPES
)


TAG_FAVORITE = 1
TAG_ARCHIVED = 0


class LibraryPrefs(Enum):
IS_EXCLUDE_LIST = True
EXTENSION_LIST: list[str] = [".json", ".xmp", ".aae"]
PAGE_SIZE: int = 500
DB_VERSION: int = 1
40 changes: 40 additions & 0 deletions tagstudio/src/core/driver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from pathlib import Path

import structlog
from PySide6.QtCore import QSettings
from src.core.constants import TS_FOLDER_NAME
from src.core.enums import SettingItems
from src.core.library.alchemy.library import LibraryStatus

logger = structlog.get_logger(__name__)


class DriverMixin:
settings: QSettings

def evaluate_path(self, open_path: str | None) -> LibraryStatus:
"""Check if the path of library is valid."""
library_path: Path | None = None
if open_path:
library_path = Path(open_path)
if not library_path.exists():
logger.error("Path does not exist.", open_path=open_path)
return LibraryStatus(success=False, message="Path does not exist.")
elif self.settings.value(
SettingItems.START_LOAD_LAST, defaultValue=True, type=bool
) and self.settings.value(SettingItems.LAST_LIBRARY):
library_path = Path(str(self.settings.value(SettingItems.LAST_LIBRARY)))
if not (library_path / TS_FOLDER_NAME).exists():
logger.error(
"TagStudio folder does not exist.",
library_path=library_path,
ts_folder=TS_FOLDER_NAME,
)
self.settings.setValue(SettingItems.LAST_LIBRARY, "")
# dont consider this a fatal error, just skip opening the library
library_path = None

return LibraryStatus(
success=True,
library_path=library_path,
)
30 changes: 30 additions & 0 deletions tagstudio/src/core/enums.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import enum
from typing import Any
from uuid import uuid4


class SettingItems(str, enum.Enum):
Expand Down Expand Up @@ -31,3 +33,31 @@ class MacroID(enum.Enum):
BUILD_URL = "build_url"
MATCH = "match"
CLEAN_URL = "clean_url"


class DefaultEnum(enum.Enum):
"""Allow saving multiple identical values in property called .default."""

default: Any

def __new__(cls, value):
# Create the enum instance
obj = object.__new__(cls)
# make value random
obj._value_ = uuid4()
# assign the actual value into .default property
obj.default = value
return obj

@property
def value(self):
raise AttributeError("access the value via .default property instead")


class LibraryPrefs(DefaultEnum):
"""Library preferences with default value accessible via .default property."""

IS_EXCLUDE_LIST = True
EXTENSION_LIST: list[str] = [".json", ".xmp", ".aae"]
PAGE_SIZE: int = 500
DB_VERSION: int = 2
16 changes: 8 additions & 8 deletions tagstudio/src/core/library/alchemy/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,27 +18,27 @@ class BaseField(Base):
__abstract__ = True

@declared_attr
def id(cls) -> Mapped[int]: # noqa: N805
def id(self) -> Mapped[int]:
return mapped_column(primary_key=True, autoincrement=True)

@declared_attr
def type_key(cls) -> Mapped[str]: # noqa: N805
def type_key(self) -> Mapped[str]:
return mapped_column(ForeignKey("value_type.key"))

@declared_attr
def type(cls) -> Mapped[ValueType]: # noqa: N805
return relationship(foreign_keys=[cls.type_key], lazy=False) # type: ignore
def type(self) -> Mapped[ValueType]:
return relationship(foreign_keys=[self.type_key], lazy=False) # type: ignore

@declared_attr
def entry_id(cls) -> Mapped[int]: # noqa: N805
def entry_id(self) -> Mapped[int]:
return mapped_column(ForeignKey("entries.id"))

@declared_attr
def entry(cls) -> Mapped[Entry]: # noqa: N805
return relationship(foreign_keys=[cls.entry_id]) # type: ignore
def entry(self) -> Mapped[Entry]:
return relationship(foreign_keys=[self.entry_id]) # type: ignore

@declared_attr
def position(cls) -> Mapped[int]: # noqa: N805
def position(self) -> Mapped[int]:
return mapped_column(default=0)

def __hash__(self):
Expand Down
91 changes: 69 additions & 22 deletions tagstudio/src/core/library/alchemy/library.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import re
import shutil
import sys
import unicodedata
from dataclasses import dataclass
from datetime import UTC, datetime
Expand Down Expand Up @@ -34,8 +35,8 @@
TAG_ARCHIVED,
TAG_FAVORITE,
TS_FOLDER_NAME,
LibraryPrefs,
)
from ...enums import LibraryPrefs
from .db import make_tables
from .enums import FieldTypeEnum, FilterState, TagColor
from .fields import (
Expand All @@ -48,8 +49,6 @@
from .joins import TagField, TagSubtag
from .models import Entry, Folder, Preferences, Tag, TagAlias, ValueType

LIBRARY_FILENAME: str = "ts_library.sqlite"

logger = structlog.get_logger(__name__)


Expand Down Expand Up @@ -115,6 +114,15 @@ def __getitem__(self, index: int) -> Entry:
return self.items[index]


@dataclass
class LibraryStatus:
"""Keep status of library opening operation."""

success: bool
library_path: Path | None = None
message: str | None = None


class Library:
"""Class for the Library object, and all CRUD operations made upon it."""

Expand All @@ -123,30 +131,28 @@ class Library:
engine: Engine | None
folder: Folder | None

FILENAME: str = "ts_library.sqlite"

def close(self):
if self.engine:
self.engine.dispose()
self.library_dir = None
self.storage_path = None
self.folder = None

def open_library(self, library_dir: Path | str, storage_path: str | None = None) -> None:
if isinstance(library_dir, str):
library_dir = Path(library_dir)

self.library_dir = library_dir
def open_library(self, library_dir: Path, storage_path: str | None = None) -> LibraryStatus:
if storage_path == ":memory:":
self.storage_path = storage_path
else:
self.verify_ts_folders(self.library_dir)
self.storage_path = self.library_dir / TS_FOLDER_NAME / LIBRARY_FILENAME
self.verify_ts_folders(library_dir)
self.storage_path = library_dir / TS_FOLDER_NAME / self.FILENAME

connection_string = URL.create(
drivername="sqlite",
database=str(self.storage_path),
)

logger.info("opening library", connection_string=connection_string)
logger.info("opening library", library_dir=library_dir, connection_string=connection_string)
self.engine = create_engine(connection_string)
with Session(self.engine) as session:
make_tables(self.engine)
Expand All @@ -159,9 +165,24 @@ def open_library(self, library_dir: Path | str, storage_path: str | None = None)
# default tags may exist already
session.rollback()

if "pytest" not in sys.modules:
db_version = session.scalar(
select(Preferences).where(Preferences.key == LibraryPrefs.DB_VERSION.name)
)

if not db_version:
# TODO - remove after #503 is merged and LibraryPrefs.DB_VERSION increased again
return LibraryStatus(
success=False,
message=(
"Library version mismatch.\n"
f"Found: v0, expected: v{LibraryPrefs.DB_VERSION.default}"
),
)

for pref in LibraryPrefs:
try:
session.add(Preferences(key=pref.name, value=pref.value))
session.add(Preferences(key=pref.name, value=pref.default))
session.commit()
except IntegrityError:
logger.debug("preference already exists", pref=pref)
Expand All @@ -183,11 +204,30 @@ def open_library(self, library_dir: Path | str, storage_path: str | None = None)
logger.debug("ValueType already exists", field=field)
session.rollback()

db_version = session.scalar(
select(Preferences).where(Preferences.key == LibraryPrefs.DB_VERSION.name)
)
# if the db version is different, we cant proceed
if db_version.value != LibraryPrefs.DB_VERSION.default:
logger.error(
"DB version mismatch",
db_version=db_version.value,
expected=LibraryPrefs.DB_VERSION.default,
)
# TODO - handle migration
return LibraryStatus(
success=False,
message=(
"Library version mismatch.\n"
f"Found: v{db_version.value}, expected: v{LibraryPrefs.DB_VERSION.default}"
),
)

# check if folder matching current path exists already
self.folder = session.scalar(select(Folder).where(Folder.path == self.library_dir))
self.folder = session.scalar(select(Folder).where(Folder.path == library_dir))
if not self.folder:
folder = Folder(
path=self.library_dir,
path=library_dir,
uuid=str(uuid4()),
)
session.add(folder)
Expand All @@ -196,6 +236,10 @@ def open_library(self, library_dir: Path | str, storage_path: str | None = None)
session.commit()
self.folder = folder

# everything is fine, set the library path
self.library_dir = library_dir
return LibraryStatus(success=True, library_path=library_dir)

@property
def default_fields(self) -> list[BaseField]:
with Session(self.engine) as session:
Expand Down Expand Up @@ -324,15 +368,18 @@ def add_entries(self, items: list[Entry]) -> list[int]:

with Session(self.engine) as session:
# add all items
session.add_all(items)
session.flush()

new_ids = [item.id for item in items]
try:
session.add_all(items)
session.commit()
except IntegrityError:
session.rollback()
logger.exception("IntegrityError")
return []

new_ids = [item.id for item in items]
session.expunge_all()

session.commit()

return new_ids

def remove_entries(self, entry_ids: list[int]) -> None:
Expand Down Expand Up @@ -396,9 +443,9 @@ def search_library(

if not search.id: # if `id` is set, we don't need to filter by extensions
if extensions and is_exclude_list:
statement = statement.where(Entry.path.notilike(f"%.{','.join(extensions)}"))
statement = statement.where(Entry.suffix.notin_(extensions))
elif extensions:
statement = statement.where(Entry.path.ilike(f"%.{','.join(extensions)}"))
statement = statement.where(Entry.suffix.in_(extensions))

statement = statement.options(
selectinload(Entry.text_fields),
Expand Down Expand Up @@ -770,7 +817,7 @@ def save_library_backup_to_disk(self) -> Path:
target_path = self.library_dir / TS_FOLDER_NAME / BACKUP_FOLDER_NAME / filename

shutil.copy2(
self.library_dir / TS_FOLDER_NAME / LIBRARY_FILENAME,
self.library_dir / TS_FOLDER_NAME / self.FILENAME,
target_path,
)

Expand Down
3 changes: 3 additions & 0 deletions tagstudio/src/core/library/alchemy/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ class Entry(Base):
folder: Mapped[Folder] = relationship("Folder")

path: Mapped[Path] = mapped_column(PathType, unique=True)
suffix: Mapped[str] = mapped_column()

text_fields: Mapped[list[TextField]] = relationship(
back_populates="entry",
Expand Down Expand Up @@ -177,6 +178,8 @@ def __init__(
self.path = path
self.folder = folder

self.suffix = path.suffix.lstrip(".").lower()

for field in fields:
if isinstance(field, TextField):
self.text_fields.append(field)
Expand Down
2 changes: 2 additions & 0 deletions tagstudio/src/core/library/json/library.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,8 @@ def compressed_dict(self):
class Library:
"""Class for the Library object, and all CRUD operations made upon it."""

FILENAME: str = "ts_library.json"

def __init__(self) -> None:
# Library Info =========================================================
self.library_dir: Path = None
Expand Down
4 changes: 2 additions & 2 deletions tagstudio/src/qt/modals/file_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
QVBoxLayout,
QWidget,
)
from src.core.constants import LibraryPrefs
from src.core.enums import LibraryPrefs
from src.core.library import Library
from src.qt.widgets.panel import PanelWidget

Expand Down Expand Up @@ -104,7 +104,7 @@ def save(self):
for i in range(self.table.rowCount()):
ext = self.table.item(i, 0)
if ext and ext.text().strip():
extensions.append(ext.text().strip().lower())
extensions.append(ext.text().strip().lstrip(".").lower())

# save preference
self.lib.set_prefs(LibraryPrefs.EXTENSION_LIST, extensions)
Loading

0 comments on commit e075282

Please sign in to comment.