diff --git a/gto/api.py b/gto/api.py index 52ae5527..02f21840 100644 --- a/gto/api.py +++ b/gto/api.py @@ -1,8 +1,10 @@ -from typing import Union +from typing import List, Union import pandas as pd from git import Repo +from gto.config import CONFIG +from gto.ext import EnrichmentInfo from gto.index import FileIndexManager, RepoIndexManager from gto.registry import GitRegistry from gto.tag import parse_name @@ -82,7 +84,8 @@ def find_active_label(repo: Union[str, Repo], name: str, label: str): def check_ref(repo: Union[str, Repo], ref: str): """Find out what have been registered/promoted in the provided ref""" reg = GitRegistry.from_repo(repo) - ref = ref.removeprefix("refs/tags/") + if ref.startswith("refs/tags/"): + ref = ref[len("refs/tags/") :] if ref.startswith("refs/heads/"): ref = reg.repo.commit(ref).hexsha result = reg.check_ref(ref) @@ -173,3 +176,12 @@ def audit_promotion(repo: Union[str, Repo], dataframe: bool = False): df.sort_values("creation_date", ascending=False, inplace=True) df.set_index(["creation_date", "name"], inplace=True) return df + + +def describe(name: str) -> List[EnrichmentInfo]: + res = [] + for enrichment in CONFIG.enrichments: + enrichment_data = enrichment.describe(name) + if enrichment_data is not None: + res.append(enrichment_data) + return res diff --git a/gto/base.py b/gto/base.py index 425c4ec4..73141790 100644 --- a/gto/base.py +++ b/gto/base.py @@ -1,8 +1,9 @@ from datetime import datetime -from typing import Dict, FrozenSet, List, Optional +from typing import Dict, FrozenSet, List, Optional, overload import git from pydantic import BaseModel +from typing_extensions import Literal from gto.constants import Action from gto.index import ObjectCommits @@ -84,11 +85,31 @@ def latest_labels(self) -> Dict[str, BaseLabel]: labels[label.name] = label return labels + @overload def find_version( self, name: str = None, commit_hexsha: str = None, - raise_if_not_found=False, + raise_if_not_found: Literal[True] = ..., + skip_unregistered=True, + ) -> BaseVersion: + ... + + @overload + def find_version( + self, + name: str = None, + commit_hexsha: str = None, + raise_if_not_found: Literal[False] = ..., + skip_unregistered=True, + ) -> Optional[BaseVersion]: + ... + + def find_version( + self, + name: str = None, + commit_hexsha: str = None, + raise_if_not_found: bool = False, skip_unregistered=True, ) -> Optional[BaseVersion]: versions = [ diff --git a/gto/cli.py b/gto/cli.py index 3817b15b..a6dcb802 100644 --- a/gto/cli.py +++ b/gto/cli.py @@ -250,5 +250,13 @@ def print_index(repo: str, format: str): raise NotImplementedError("Unknown format") +@gto_command() +@arg_name +def describe(name: str): + infos = gto.api.describe(name) + for info in infos: + click.echo(info.get_human_readable()) + + if __name__ == "__main__": cli() diff --git a/gto/config.py b/gto/config.py index 95fe1d49..31825811 100644 --- a/gto/config.py +++ b/gto/config.py @@ -2,7 +2,7 @@ from pathlib import Path from typing import Any, Dict, List, Optional -from pydantic import BaseSettings, validator +from pydantic import BaseModel, BaseSettings, validator from pydantic.env_settings import InitSettingsSource from ruamel.yaml import YAML @@ -10,11 +10,12 @@ from .constants import BRANCH, COMMIT, TAG from .exceptions import UnknownEnvironment +from .ext import Enrichment, find_enrichment_types, find_enrichments yaml = YAML(typ="safe", pure=True) yaml.default_flow_style = False -CONFIG_FILE = "gto.yaml" +CONFIG_FILE_NAME = "gto.yaml" def _set_location_init_source(init_source: InitSettingsSource): @@ -32,7 +33,7 @@ def config_settings_source(settings: "RegistryConfig") -> Dict[str, Any]: """ encoding = settings.__config__.env_file_encoding - config_file = getattr(settings, "CONFIG_FILE", CONFIG_FILE) + config_file = getattr(settings, "CONFIG_FILE", CONFIG_FILE_NAME) if not isinstance(config_file, Path): config_file = Path(config_file) if not config_file.exists(): @@ -42,6 +43,14 @@ def config_settings_source(settings: "RegistryConfig") -> Dict[str, Any]: return {k.upper(): v for k, v in conf.items()} if conf else {} +class EnrichmentConfig(BaseModel): + type: str + config: Dict = {} + + def load(self) -> Enrichment: + return find_enrichment_types()[self.type](**self.config) + + class RegistryConfig(BaseSettings): INDEX: str = "artifacts.yaml" VERSION_BASE: str = TAG @@ -52,7 +61,9 @@ class RegistryConfig(BaseSettings): ENV_BRANCH_MAPPING: Dict[str, str] = {} LOG_LEVEL: str = "INFO" DEBUG: bool = False - CONFIG_FILE: Optional[str] = CONFIG_FILE + ENRICHMENTS: List[EnrichmentConfig] = [] + AUTOLOAD_ENRICHMENTS: bool = True + CONFIG_FILE: Optional[str] = CONFIG_FILE_NAME @property def VERSION_SYSTEM_MAPPING(self): @@ -74,6 +85,13 @@ def ENV_MANAGERS_MAPPING(self): return {TAG: TagEnvManager, BRANCH: BranchEnvManager} + @property + def enrichments(self) -> List[Enrichment]: + res = [e.load() for e in self.ENRICHMENTS] + if self.AUTOLOAD_ENRICHMENTS: + return find_enrichments() + res + return res + def assert_env(self, name): if not self.check_env(name): raise UnknownEnvironment(name) diff --git a/gto/ext.py b/gto/ext.py new file mode 100644 index 00000000..b75be393 --- /dev/null +++ b/gto/ext.py @@ -0,0 +1,102 @@ +import subprocess +from abc import ABC, abstractmethod +from functools import lru_cache +from json import loads +from typing import Dict, List, Optional, Type, Union + +import entrypoints +from mlem.utils.importing import import_string +from pydantic import BaseModel, parse_obj_as, validator + +ENRICHMENT_ENRTYPOINT = "gto.enrichment" + + +class EnrichmentInfo(BaseModel, ABC): + source: str + + @abstractmethod + def get_object(self) -> BaseModel: + raise NotImplementedError + + def get_dict(self): + return self.get_object().dict() + + @abstractmethod + def get_human_readable(self) -> str: + raise NotImplementedError + + +class Enrichment(BaseModel, ABC): + @abstractmethod + def describe(self, obj: str) -> Optional[EnrichmentInfo]: + raise NotImplementedError + + +class CLIEnrichmentInfo(EnrichmentInfo): + data: Dict + repr: str + + def get_object(self) -> BaseModel: + return self + + def get_human_readable(self) -> str: + return self.repr + + +class CLIEnrichment(Enrichment): + cmd: str + info_type: Union[str, Type[EnrichmentInfo]] = CLIEnrichmentInfo + + @validator("info_type") + def info_class_validator( + cls, value + ): # pylint: disable=no-self-argument,no-self-use # noqa: B902 + if isinstance(value, type): + return value + info_class = import_string(value) + if not isinstance(info_class, type) or not issubclass( + info_class, EnrichmentInfo + ): + raise ValueError( + "Wrong value for info_type: should be class or string path to class (e.g. `package.module.ClassName`)" + ) + return info_class + + @property + def info_class(self) -> Type[EnrichmentInfo]: + return self.info_class_validator(self.info_type) + + def describe(self, obj: str) -> Optional[EnrichmentInfo]: + try: + data = loads(subprocess.check_output(self.cmd.split() + [obj])) + return parse_obj_as(self.info_class, data) + except subprocess.SubprocessError: + return None + + +@lru_cache() +def _find_enrichments(): + eps = entrypoints.get_group_named(ENRICHMENT_ENRTYPOINT) + return {k: ep.load() for k, ep in eps.items()} + + +@lru_cache() +def find_enrichments() -> List[Enrichment]: + enrichments = _find_enrichments() + res = [] + for e in enrichments: + if isinstance(e, type) and issubclass(e, Enrichment) and not e.__fields_set__: + res.append(e()) + if isinstance(e, Enrichment): + res.append(e) + return res + + +@lru_cache() +def find_enrichment_types() -> Dict[str, Type[Enrichment]]: + enrichments = _find_enrichments() + return { + k: e + for k, e in enrichments.items() + if isinstance(e, type) and issubclass(e, Enrichment) + } diff --git a/gto/ext_dvc.py b/gto/ext_dvc.py new file mode 100644 index 00000000..89cf0fef --- /dev/null +++ b/gto/ext_dvc.py @@ -0,0 +1,29 @@ +from typing import Optional + +from pydantic import BaseModel +from ruamel.yaml import safe_load + +from gto.ext import Enrichment, EnrichmentInfo + + +class DVCEnrichmentInfo(EnrichmentInfo): + source = "dvc" + size: int + hash: str + + def get_object(self) -> BaseModel: + return self + + def get_human_readable(self) -> str: + return f"""DVC-tracked [{self.size} bytes]""" + + +class DVCEnrichment(Enrichment): + def describe(self, obj: str) -> Optional[DVCEnrichmentInfo]: + try: + with open(obj + ".dvc", encoding="utf8") as f: + dvc_data = safe_load(f) + data = dvc_data["outs"][0] + return DVCEnrichmentInfo(size=data["size"], hash=data["md5"]) + except FileNotFoundError: + return None diff --git a/gto/ext_mlem.py b/gto/ext_mlem.py new file mode 100644 index 00000000..b0f1e7fb --- /dev/null +++ b/gto/ext_mlem.py @@ -0,0 +1,34 @@ +"""This is temporary file that should be moved to mlem.gto module""" +from typing import Optional + +from mlem.core.errors import MlemObjectNotFound +from mlem.core.metadata import load_meta +from mlem.core.objects import DatasetMeta, MlemMeta, ModelMeta +from pydantic import BaseModel + +from gto.ext import Enrichment, EnrichmentInfo + + +class MlemInfo(EnrichmentInfo): + source = "mlem" + meta: MlemMeta + + def get_object(self) -> BaseModel: + return self.meta + + def get_human_readable(self) -> str: + # TODO: create `.describe` method in MlemMeta https://github.com/iterative/mlem/issues/98 + description = f"""Mlem {self.meta.object_type}""" + if isinstance(self.meta, ModelMeta): + description += f": {self.meta.model_type.type}" + if isinstance(self.meta, DatasetMeta): + description += f": {self.meta.dataset.dataset_type.type}" + return description + + +class MlemEnrichment(Enrichment): + def describe(self, obj: str) -> Optional[MlemInfo]: + try: + return MlemInfo(meta=load_meta(obj)) + except MlemObjectNotFound: + return None diff --git a/gto/registry.py b/gto/registry.py index 33320fd9..b93a3e74 100644 --- a/gto/registry.py +++ b/gto/registry.py @@ -7,7 +7,7 @@ from pydantic import BaseModel from gto.base import BaseManager, BaseObject, BaseRegistryState -from gto.config import CONFIG_FILE, RegistryConfig +from gto.config import CONFIG_FILE_NAME, RegistryConfig from gto.exceptions import ( NoActiveLabel, VersionAlreadyRegistered, @@ -32,7 +32,7 @@ def from_repo(cls, repo=Union[str, Repo], config=None): repo = git.Repo(repo) if config is None: config = RegistryConfig( - CONFIG_FILE=os.path.join(repo.working_dir, CONFIG_FILE) + CONFIG_FILE=os.path.join(repo.working_dir, CONFIG_FILE_NAME) ) return cls( diff --git a/gto/tag.py b/gto/tag.py index cd9c8140..766ad874 100644 --- a/gto/tag.py +++ b/gto/tag.py @@ -147,7 +147,7 @@ def label_from_tag(tag: git.Tag, obj: BaseObject) -> BaseLabel: object=mtag.name, version=obj.find_version( commit_hexsha=tag.commit.hexsha, raise_if_not_found=True - ).name, # type: ignore + ).name, name=mtag.label, creation_date=mtag.creation_date, author=tag.tag.tagger.name, diff --git a/setup.py b/setup.py index 1d0273b6..9243f6ca 100644 --- a/setup.py +++ b/setup.py @@ -11,6 +11,7 @@ "pydantic", "ruamel.yaml", "semver==3.0.0-dev.3", + "entrypoints", ] @@ -54,6 +55,11 @@ include_package_data=True, entry_points={ "console_scripts": ["gto = gto.cli:cli"], + "gto.enrichment": [ + "mlem = gto.ext_mlem:MlemEnrichment", + "dvc = gto.ext_dvc:DVCEnrichment", + "cli = gto.ext:CLIEnrichment", + ], }, cmdclass={"build_py": build_py}, zip_safe=False, diff --git a/tests/conftest.py b/tests/conftest.py index c2f6f625..4cc6fec4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,7 +5,7 @@ import pytest import gto -from gto.config import CONFIG_FILE +from gto.config import CONFIG_FILE_NAME @pytest.fixture @@ -24,7 +24,7 @@ def write_file(name, content): file.write(content) write_file( - CONFIG_FILE, + CONFIG_FILE_NAME, """ version_base: tag env_base: tag @@ -46,7 +46,7 @@ def write_file(name, content): file.write(content) write_file( - CONFIG_FILE, + CONFIG_FILE_NAME, """ version_base: tag env_base: tag