Skip to content

Commit

Permalink
[FIX] Optional prameters in Configuration
Browse files Browse the repository at this point in the history
  • Loading branch information
wbenbihi committed Aug 19, 2022
1 parent 6cf74a1 commit b8aa2ef
Showing 1 changed file with 107 additions and 41 deletions.
148 changes: 107 additions & 41 deletions hourglass_tensorflow/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Set
from typing import Dict
from typing import List
from typing import Type
from typing import Union
from typing import Literal
from typing import Callable
Expand All @@ -18,11 +19,15 @@
from pydantic import BaseModel

from hourglass_tensorflow._errors import BadConfigurationError
from hourglass_tensorflow.datasets import HTFBaseDatasetHandler
from hourglass_tensorflow.utils.object_logger import ObjectLogger
from hourglass_tensorflow.utils.parsers._parse_import import get_dataset

# region BaseModels


class HTFDatasetSplitConfig(BaseModel):
on: bool = False
activate: bool = False
train_ratio: Optional[float] = 0.8


Expand All @@ -33,26 +38,29 @@ class HTFDatasetConfig(BaseModel):

class HTFDataInputConfig(BaseModel):
source: str
extension: List[str] = Field(default_factory=["png", "jpeg", "jpg"])
extensions: List[str] = Field(default_factory=["png", "jpeg", "jpg"])


class HTFDataOutputJointsFormatSuffixConfig(BaseModel):
x: str = "X"
y: str = "Y"
visible: str = "visible"

class Config:
extra = "allow"


class HTFDataOutputJointsFormatConfig(BaseModel):
suffix: HTFDataOutputJointsFormatSuffixConfig
suffix: Optional[HTFDataOutputJointsFormatSuffixConfig] = Field(
default_factory=HTFDataOutputJointsFormatSuffixConfig
)


class HTFDataOutputJointsConfig(BaseModel):
n: int = 16
naming_convention: str = "joint_{JOINT_ID}_{SUFFIX}"
format: HTFDataOutputJointsFormatConfig
format: Optional[HTFDataOutputJointsFormatConfig] = Field(
default_factory=HTFDataOutputJointsFormatConfig
)
names: List[str] = Field(
default_factory=[
"00_rAnkle",
Expand All @@ -78,7 +86,7 @@ class HTFDataOutputJointsConfig(BaseModel):
class HTFDataOutputConfig(BaseModel):
source: str
source_column: str = "image"
source_prefix: bool = False
source_prefixed: bool = False
prefix_columns: List[str] = Field(
default_factory=[
"is_training",
Expand Down Expand Up @@ -107,6 +115,11 @@ class HTFConfig(BaseModel):
dataset: HTFDatasetConfig


# endregion

# region ConfigParser


class HTFConfigParser(ObjectLogger):
"""Parse configuration files for `hourglass_tensorflow`
Expand Down Expand Up @@ -206,10 +219,17 @@ def _parse_config(self) -> None:
self._config = HTFConfig.parse_obj(self._data)


# endregion

# region Configuration Handler


class HTFConfigMeta(BaseModel):
available_images: Optional[Set[str]] = Field(default_factory=set)
label_type: Optional[Union[Literal["json"], Literal["csv"]]]
label_headers: Optional[List[str]] = Field(default_factory=set)
label_headers: Optional[List[str]] = Field(default_factory=list)
label_mapper: Optional[Dict[int, str]] = Field(default_factory=dict)
dataset_object: Optional[Type[HTFBaseDatasetHandler]]

class Config:
extra = "allow"
Expand All @@ -228,30 +248,43 @@ def __init__(self, config_file: str, verbose: bool = True) -> None:
def config(self) -> HTFConfig:
return self._config.config

@property
def _cfg_data(self) -> HTFDataConfig:
return self.config.data

@property
def _cfg_data_inp(self) -> HTFDataInputConfig:
return self.config.data.input

@property
def _cfg_data_out(self) -> HTFDataOutputConfig:
return self.config.data.output

# DATA - Prepare - Methods
def _list_input_images(self) -> None:
if not os.path.exists(self.config.data.input.source):
if not os.path.exists(self._cfg_data_inp.source):
raise BadConfigurationError(
f"Unable to find source folder {self.config.data.input.source}"
f"Unable to find source folder {self._cfg_data_inp.source}"
)
self.info(
f"Listing {self.config.data.input.extension} images in {self.config.data.input.source}"
f"Listing {self._cfg_data_inp.extensions} images in {self._cfg_data_inp.source}"
)
self._metadata.available_images = {
*itertools.chain(
[
glob.glob(os.path.join(self.config.data.input.source, f"*.{ext}"))
for ext in self.config.data.input.extension
self._metadata.available_images = set(
itertools.chain(
*[
glob(os.path.join(self._cfg_data_inp.source, f"*.{ext}"))
for ext in self._cfg_data_inp.extensions
]
)
}
)

def _valid_labels_header(self, df: pd.DataFrame, _error: bool = False) -> bool:
# Check if numbers of columns are valid
n_joint = self.config.data.output.joints.n
n_joint = self._cfg_data_out.joints.n
# TODO(@wbenbihi) UNNECESSARY BLOCK
# num_prefix_columns = len(self.config.data.output.prefix_columns)
# num_prefix_columns = len(self._cfg_data_out.prefix_columns)
# num_columns_per_joint = len(
# self.config.data.output.joints.format.suffix.__fields__
# self._cfg_data_out.joints.format.suffix.__fields__
# )
# estimated_column_size = n_joint * num_columns_per_joint + num_prefix_columns
# if len(df.columns) != estimated_column_size:
Expand All @@ -261,16 +294,17 @@ def _valid_labels_header(self, df: pd.DataFrame, _error: bool = False) -> bool:
# )
# return False
# Check if columns names are valid
naming_convention = self.config.data.output.joints.naming_convention
headers = self.config.data.output.prefix_columns + [
naming_convention.format(JOINT_ID=jid, SUFFIX=suffix.name)
naming_convention = self._cfg_data_out.joints.naming_convention
suffixes = self._cfg_data_out.joints.format.suffix
headers = self._cfg_data_out.prefix_columns + [
naming_convention.format(JOINT_ID=jid, SUFFIX=suffix)
for jid in range(n_joint)
for suffix in self.config.data.output.joints.format.suffix.__fields__.values()
for suffix in suffixes.__dict__.values()
]
if not set(df.columns).difference(set(headers)):
if set(headers).difference(set(list(df.columns))):
if _error:
raise BadConfigurationError(
f"Columns' name does not match configuration\n\tEXPECTED:\n\t{headers}\n\tRECEIVED:\n\t{df.columns}"
f"Columns' name does not match configuration\n\tEXPECTED:\n\t{headers}\n\tRECEIVED:\n\t{list(df.columns)}\n\tMISSING COLUMNS:\n\t{set(headers).difference(set(list(df.columns)))}"
)
return False
# If everything is good we store the expected headers in _metadata
Expand All @@ -279,48 +313,80 @@ def _valid_labels_header(self, df: pd.DataFrame, _error: bool = False) -> bool:

def _read_labels(self, _error: bool = False) -> bool:
# Check if data.output.source exists ?
if not os.path.exists(self.config.data.output.source):
raise BadConfigurationError(
f"Unable to find {self.config.data.output.source}"
)
if not os.path.exists(self._cfg_data_out.source):
raise BadConfigurationError(f"Unable to find {self._cfg_data_out.source}")
# Read Data
self.info(f"Reading labels from {self.config.data.output.source}")
self.info(f"Reading labels from {self._cfg_data_out.source}")
## Check if the file extension is in [.json, .csv]
if self.config.data.output.source.endswith(".json"):
if self._cfg_data_out.source.endswith(".json"):
self._metadata.label_type = "json"
labels = pd.read_json(self.config.data.output.source, orient="records")
elif self.config.data.output.source.endswith(".csv"):
labels = pd.read_json(self._cfg_data_out.source, orient="records")
elif self._cfg_data_out.source.endswith(".csv"):
self._metadata.label_type = "csv"
labels = pd.read_csv(self.config.data.output.source)
labels = pd.read_csv(self._cfg_data_out.source)
else:
raise BadConfigurationError(
f"{self.config.data.output.source} should be of type .json or .csv"
f"{self._cfg_data_out.source} should be of type .json or .csv"
)
if not isinstance(labels, pd.DataFrame):
raise BadConfigurationError(
f"{self.config.data.output.source} not parsable as pandas.DataFrame"
f"{self._cfg_data_out.source} not parsable as pandas.DataFrame"
)
# Validate expected labels columns
if not self._valid_labels_header(labels, _error=_error):
self.error("Labels are not matching")
return False
self._labels_df = labels[self._metadata.label_headers]
self._labels_df: pd.DataFrame = labels[self._metadata.label_headers]
self._metadata.label_mapper = {
label: i for i, label in enumerate(self._metadata.label_headers)
}
if self._cfg_data_out.source_prefixed:
# Now we also prefix the image column with the image folder
# in case the source_prefix attribute is set to false
folder_prefix = self._cfg_data_inp.source
source_column = self._cfg_data_out.source_column
self._labels_df = self._labels_df.assign(
**{
source_column: lambda x: os.path.join(
folder_prefix, x[source_column]
)
}
)
return True

def _prepare_input(self) -> None:
# List files in Input Source Folder
self._list_input_images()

def _validate_joints(self) -> None:
pass
def _validate_joints(self, _error: bool = True) -> bool:
conditions = [
len(self._cfg_data_out.joints.names) == self._cfg_data_out.joints.n
]
if not all(conditions):
if _error:
raise BadConfigurationError("Joints properties are not valid")
return False
return True

def _prepare_output(self) -> None:
def _prepare_output(self, _error: bool = True) -> None:
# Read the label file
self._read_labels()
self._validate_joints(_error=_error)
self._read_labels(_error=_error)

def prepare_data(self) -> None:
self._prepare_input()
self._prepare_output()

def prepare(self) -> None:
self.prepare_data()

# DATASET - Prepare - Methods

def _load_dataset_object(self) -> None:
self._metadata.dataset_object = get_dataset(self.config.dataset.object)

def prepare_dataset(self) -> None:
self._load_dataset_object()


# endregion

0 comments on commit b8aa2ef

Please sign in to comment.