Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/ext #52

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions gto/api.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import Union
from typing import List, Union

import pandas as pd
from git import Repo

from gto.config import CONFIG
from gto.ext import EnrichmentInfo
from gto.index import FileIndexManager, RepoIndexManager
from gto.registry import GitRegistry
from gto.tag import parse_name
Expand Down Expand Up @@ -82,7 +84,8 @@ def find_active_label(repo: Union[str, Repo], name: str, label: str):
def check_ref(repo: Union[str, Repo], ref: str):
"""Find out what have been registered/promoted in the provided ref"""
reg = GitRegistry.from_repo(repo)
ref = ref.removeprefix("refs/tags/")
if ref.startswith("refs/tags/"):
ref = ref[len("refs/tags/") :]
if ref.startswith("refs/heads/"):
ref = reg.repo.commit(ref).hexsha
result = reg.check_ref(ref)
Expand Down Expand Up @@ -173,3 +176,12 @@ def audit_promotion(repo: Union[str, Repo], dataframe: bool = False):
df.sort_values("creation_date", ascending=False, inplace=True)
df.set_index(["creation_date", "name"], inplace=True)
return df


def describe(name: str) -> List[EnrichmentInfo]:
res = []
for enrichment in CONFIG.enrichments:
enrichment_data = enrichment.describe(name)
if enrichment_data is not None:
res.append(enrichment_data)
return res
25 changes: 23 additions & 2 deletions gto/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from datetime import datetime
from typing import Dict, FrozenSet, List, Optional
from typing import Dict, FrozenSet, List, Optional, overload

import git
from pydantic import BaseModel
from typing_extensions import Literal

from gto.constants import Action
from gto.index import ObjectCommits
Expand Down Expand Up @@ -84,11 +85,31 @@ def latest_labels(self) -> Dict[str, BaseLabel]:
labels[label.name] = label
return labels

@overload
def find_version(
self,
name: str = None,
commit_hexsha: str = None,
raise_if_not_found=False,
raise_if_not_found: Literal[True] = ...,
skip_unregistered=True,
) -> BaseVersion:
...

@overload
def find_version(
self,
name: str = None,
commit_hexsha: str = None,
raise_if_not_found: Literal[False] = ...,
skip_unregistered=True,
) -> Optional[BaseVersion]:
...

def find_version(
self,
name: str = None,
commit_hexsha: str = None,
raise_if_not_found: bool = False,
skip_unregistered=True,
) -> Optional[BaseVersion]:
versions = [
Expand Down
8 changes: 8 additions & 0 deletions gto/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,5 +250,13 @@ def print_index(repo: str, format: str):
raise NotImplementedError("Unknown format")


@gto_command()
@arg_name
def describe(name: str):
infos = gto.api.describe(name)
for info in infos:
click.echo(info.get_human_readable())


if __name__ == "__main__":
cli()
26 changes: 22 additions & 4 deletions gto/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,20 @@
from pathlib import Path
from typing import Any, Dict, List, Optional

from pydantic import BaseSettings, validator
from pydantic import BaseModel, BaseSettings, validator
from pydantic.env_settings import InitSettingsSource
from ruamel.yaml import YAML

from gto.versions import AbstractVersion

from .constants import BRANCH, COMMIT, TAG
from .exceptions import UnknownEnvironment
from .ext import Enrichment, find_enrichment_types, find_enrichments

yaml = YAML(typ="safe", pure=True)
yaml.default_flow_style = False

CONFIG_FILE = "gto.yaml"
CONFIG_FILE_NAME = "gto.yaml"


def _set_location_init_source(init_source: InitSettingsSource):
Expand All @@ -32,7 +33,7 @@ def config_settings_source(settings: "RegistryConfig") -> Dict[str, Any]:
"""

encoding = settings.__config__.env_file_encoding
config_file = getattr(settings, "CONFIG_FILE", CONFIG_FILE)
config_file = getattr(settings, "CONFIG_FILE", CONFIG_FILE_NAME)
if not isinstance(config_file, Path):
config_file = Path(config_file)
if not config_file.exists():
Expand All @@ -42,6 +43,14 @@ def config_settings_source(settings: "RegistryConfig") -> Dict[str, Any]:
return {k.upper(): v for k, v in conf.items()} if conf else {}


class EnrichmentConfig(BaseModel):
type: str
config: Dict = {}

def load(self) -> Enrichment:
return find_enrichment_types()[self.type](**self.config)


class RegistryConfig(BaseSettings):
INDEX: str = "artifacts.yaml"
VERSION_BASE: str = TAG
Expand All @@ -52,7 +61,9 @@ class RegistryConfig(BaseSettings):
ENV_BRANCH_MAPPING: Dict[str, str] = {}
LOG_LEVEL: str = "INFO"
DEBUG: bool = False
CONFIG_FILE: Optional[str] = CONFIG_FILE
ENRICHMENTS: List[EnrichmentConfig] = []
AUTOLOAD_ENRICHMENTS: bool = True
CONFIG_FILE: Optional[str] = CONFIG_FILE_NAME

@property
def VERSION_SYSTEM_MAPPING(self):
Expand All @@ -74,6 +85,13 @@ def ENV_MANAGERS_MAPPING(self):

return {TAG: TagEnvManager, BRANCH: BranchEnvManager}

@property
def enrichments(self) -> List[Enrichment]:
res = [e.load() for e in self.ENRICHMENTS]
if self.AUTOLOAD_ENRICHMENTS:
return find_enrichments() + res
return res

def assert_env(self, name):
if not self.check_env(name):
raise UnknownEnvironment(name)
Expand Down
102 changes: 102 additions & 0 deletions gto/ext.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import subprocess
from abc import ABC, abstractmethod
from functools import lru_cache
from json import loads
from typing import Dict, List, Optional, Type, Union

import entrypoints
from mlem.utils.importing import import_string
from pydantic import BaseModel, parse_obj_as, validator

ENRICHMENT_ENRTYPOINT = "gto.enrichment"


class EnrichmentInfo(BaseModel, ABC):
source: str

@abstractmethod
def get_object(self) -> BaseModel:
raise NotImplementedError

def get_dict(self):
return self.get_object().dict()

@abstractmethod
def get_human_readable(self) -> str:
raise NotImplementedError


class Enrichment(BaseModel, ABC):
@abstractmethod
def describe(self, obj: str) -> Optional[EnrichmentInfo]:
raise NotImplementedError


class CLIEnrichmentInfo(EnrichmentInfo):
data: Dict
repr: str

def get_object(self) -> BaseModel:
return self

def get_human_readable(self) -> str:
return self.repr


class CLIEnrichment(Enrichment):
cmd: str
info_type: Union[str, Type[EnrichmentInfo]] = CLIEnrichmentInfo

@validator("info_type")
def info_class_validator(
cls, value
): # pylint: disable=no-self-argument,no-self-use # noqa: B902
if isinstance(value, type):
return value
info_class = import_string(value)
if not isinstance(info_class, type) or not issubclass(
info_class, EnrichmentInfo
):
raise ValueError(
"Wrong value for info_type: should be class or string path to class (e.g. `package.module.ClassName`)"
)
return info_class

@property
def info_class(self) -> Type[EnrichmentInfo]:
return self.info_class_validator(self.info_type)

def describe(self, obj: str) -> Optional[EnrichmentInfo]:
try:
data = loads(subprocess.check_output(self.cmd.split() + [obj]))
return parse_obj_as(self.info_class, data)
except subprocess.SubprocessError:
return None


@lru_cache()
def _find_enrichments():
eps = entrypoints.get_group_named(ENRICHMENT_ENRTYPOINT)
return {k: ep.load() for k, ep in eps.items()}


@lru_cache()
def find_enrichments() -> List[Enrichment]:
enrichments = _find_enrichments()
res = []
for e in enrichments:
if isinstance(e, type) and issubclass(e, Enrichment) and not e.__fields_set__:
res.append(e())
if isinstance(e, Enrichment):
res.append(e)
return res


@lru_cache()
def find_enrichment_types() -> Dict[str, Type[Enrichment]]:
enrichments = _find_enrichments()
return {
k: e
for k, e in enrichments.items()
if isinstance(e, type) and issubclass(e, Enrichment)
}
29 changes: 29 additions & 0 deletions gto/ext_dvc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from typing import Optional

from pydantic import BaseModel
from ruamel.yaml import safe_load

from gto.ext import Enrichment, EnrichmentInfo


class DVCEnrichmentInfo(EnrichmentInfo):
source = "dvc"
size: int
hash: str

def get_object(self) -> BaseModel:
return self

def get_human_readable(self) -> str:
return f"""DVC-tracked [{self.size} bytes]"""


class DVCEnrichment(Enrichment):
def describe(self, obj: str) -> Optional[DVCEnrichmentInfo]:
try:
with open(obj + ".dvc", encoding="utf8") as f:
dvc_data = safe_load(f)
data = dvc_data["outs"][0]
return DVCEnrichmentInfo(size=data["size"], hash=data["md5"])
except FileNotFoundError:
return None
34 changes: 34 additions & 0 deletions gto/ext_mlem.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""This is temporary file that should be moved to mlem.gto module"""
from typing import Optional

from mlem.core.errors import MlemObjectNotFound
from mlem.core.metadata import load_meta
from mlem.core.objects import DatasetMeta, MlemMeta, ModelMeta
from pydantic import BaseModel

from gto.ext import Enrichment, EnrichmentInfo


class MlemInfo(EnrichmentInfo):
source = "mlem"
meta: MlemMeta

def get_object(self) -> BaseModel:
return self.meta

def get_human_readable(self) -> str:
# TODO: create `.describe` method in MlemMeta https://github.com/iterative/mlem/issues/98
description = f"""Mlem {self.meta.object_type}"""
if isinstance(self.meta, ModelMeta):
description += f": {self.meta.model_type.type}"
if isinstance(self.meta, DatasetMeta):
description += f": {self.meta.dataset.dataset_type.type}"
return description


class MlemEnrichment(Enrichment):
def describe(self, obj: str) -> Optional[MlemInfo]:
try:
return MlemInfo(meta=load_meta(obj))
except MlemObjectNotFound:
return None
4 changes: 2 additions & 2 deletions gto/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pydantic import BaseModel

from gto.base import BaseManager, BaseObject, BaseRegistryState
from gto.config import CONFIG_FILE, RegistryConfig
from gto.config import CONFIG_FILE_NAME, RegistryConfig
from gto.exceptions import (
NoActiveLabel,
VersionAlreadyRegistered,
Expand All @@ -32,7 +32,7 @@ def from_repo(cls, repo=Union[str, Repo], config=None):
repo = git.Repo(repo)
if config is None:
config = RegistryConfig(
CONFIG_FILE=os.path.join(repo.working_dir, CONFIG_FILE)
CONFIG_FILE=os.path.join(repo.working_dir, CONFIG_FILE_NAME)
)

return cls(
Expand Down
2 changes: 1 addition & 1 deletion gto/tag.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def label_from_tag(tag: git.Tag, obj: BaseObject) -> BaseLabel:
object=mtag.name,
version=obj.find_version(
commit_hexsha=tag.commit.hexsha, raise_if_not_found=True
).name, # type: ignore
).name,
name=mtag.label,
creation_date=mtag.creation_date,
author=tag.tag.tagger.name,
Expand Down
6 changes: 6 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
"pydantic",
"ruamel.yaml",
"semver==3.0.0-dev.3",
"entrypoints",
]


Expand Down Expand Up @@ -54,6 +55,11 @@
include_package_data=True,
entry_points={
"console_scripts": ["gto = gto.cli:cli"],
"gto.enrichment": [
"mlem = gto.ext_mlem:MlemEnrichment",
"dvc = gto.ext_dvc:DVCEnrichment",
"cli = gto.ext:CLIEnrichment",
],
},
cmdclass={"build_py": build_py},
zip_safe=False,
Expand Down
Loading