Skip to content
This repository was archived by the owner on Aug 25, 2024. It is now read-only.

Commit 1f85201

Browse files
sudharsana-kjlJohn Andersen
authored and
John Andersen
committed
source: csv: Enable setting repo src_url using CSVSourceConfig
* Add test_key in CSVTest * Add CSVSourceConfig
1 parent e479a4e commit 1f85201

File tree

3 files changed

+87
-6
lines changed

3 files changed

+87
-6
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2020
config for them.
2121
- shouldi example uses updated `MemoryOrchestrator.basic_config` method and
2222
includes more explanation in comments.
23+
- CSVSource allows for setting the Repo's `src_url` from a csv column
2324
### Fixed
2425
- Docs get version from dffml.version.VERSION.
2526
- FileSource zipfiles are wrapped with TextIOWrapper because CSVSource expects

dffml/source/csv.py

+40-4
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,23 @@
55
"""
66
import csv
77
import ast
8+
from typing import NamedTuple, Dict
89

910
from ..repo import Repo
1011
from .memory import MemorySource
11-
from .file import FileSource
12+
from .file import FileSource, FileSourceConfig
13+
from ..util.cli.arg import Arg
1214

1315
csv.register_dialect("strip", skipinitialspace=True)
1416

1517

18+
class CSVSourceConfig(FileSourceConfig, NamedTuple):
19+
filename: str
20+
label: str = "unlabeled"
21+
readonly: bool = False
22+
key: str = None
23+
24+
1625
class CSVSource(FileSource, MemorySource):
1726
"""
1827
Uses a CSV file as the source of repo feature data
@@ -21,6 +30,30 @@ class CSVSource(FileSource, MemorySource):
2130
# Headers we've added to track data other than feature data for a repo
2231
CSV_HEADERS = ["prediction", "confidence", "classification"]
2332

33+
@classmethod
34+
def args(cls, args, *above) -> Dict[str, Arg]:
35+
cls.config_set(args, above, "filename", Arg())
36+
cls.config_set(
37+
args,
38+
above,
39+
"readonly",
40+
Arg(type=bool, action="store_true", default=False),
41+
)
42+
cls.config_set(
43+
args, above, "label", Arg(type=str, default="unlabeled")
44+
)
45+
cls.config_set(args, above, "key", Arg(type=str, default=None))
46+
return args
47+
48+
@classmethod
49+
def config(cls, config, *above):
50+
return CSVSourceConfig(
51+
filename=cls.config_get(config, above, "filename"),
52+
readonly=cls.config_get(config, above, "readonly"),
53+
label=cls.config_get(config, above, "label"),
54+
key=cls.config_get(config, above, "key"),
55+
)
56+
2457
async def load_fd(self, fd):
2558
"""
2659
Parses a CSV stream into Repo instances
@@ -45,6 +78,11 @@ async def load_fd(self, fd):
4578
repo_data["features"][key] = ast.literal_eval(value)
4679
except (SyntaxError, ValueError):
4780
repo_data["features"][key] = value
81+
if self.config.key is not None and self.config.key == key:
82+
src_url = value
83+
if self.config.key is None:
84+
src_url = str(i)
85+
i += 1
4886
# Correct types and structure of repo data from csv_meta
4987
if "classification" in csv_meta:
5088
repo_data.update(
@@ -59,9 +97,7 @@ async def load_fd(self, fd):
5997
}
6098
}
6199
)
62-
# Create the repo with the source URL being the row index
63-
repo = Repo(str(i), data=repo_data)
64-
i += 1
100+
repo = Repo(src_url, data=repo_data)
65101
self.mem[repo.src_url] = repo
66102
self.logger.debug("%r loaded %d records", self, len(self.mem))
67103

tests/source/test_csv.py

+46-2
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,63 @@
11
# SPDX-License-Identifier: MIT
22
# Copyright (c) 2019 Intel Corporation
33
import unittest
4+
import tempfile
45

56
from dffml.source.file import FileSourceConfig
6-
from dffml.source.csv import CSVSource
7+
from dffml.source.csv import CSVSource, CSVSourceConfig
78
from dffml.util.testing.source import FileSourceTest
89
from dffml.util.asynctestcase import AsyncTestCase
10+
from dffml.repo import Repo
11+
from dffml.util.cli.arg import parse_unknown
912

1013

1114
class TestCSVSource(FileSourceTest, AsyncTestCase):
1215
async def setUpSource(self):
13-
return CSVSource(FileSourceConfig(filename=self.testfile))
16+
return CSVSource(CSVSourceConfig(filename=self.testfile))
1417

1518
@unittest.skip("Labels not implemented yet for CSV files")
1619
async def test_label(self):
1720
"""
1821
Labels not implemented yet for CSV files
1922
"""
23+
24+
def test_config_readonly_default(self):
25+
config = CSVSource.config(
26+
parse_unknown("--source-file-filename", "feedface")
27+
)
28+
self.assertEqual(config.filename, "feedface")
29+
self.assertEqual(config.label, "unlabeled")
30+
self.assertEqual(config.key, None)
31+
self.assertFalse(config.readonly)
32+
33+
def test_config_readonly_set(self):
34+
config = CSVSource.config(
35+
parse_unknown(
36+
"--source-file-filename",
37+
"feedface",
38+
"--source-file-label",
39+
"default-label",
40+
"--source-file-key",
41+
"SourceURLColumn",
42+
"--source-file-readonly",
43+
)
44+
)
45+
self.assertEqual(config.filename, "feedface")
46+
self.assertEqual(config.label, "default-label")
47+
self.assertEqual(config.key, "SourceURLColumn")
48+
self.assertTrue(config.readonly)
49+
50+
async def test_key(self):
51+
with tempfile.NamedTemporaryFile() as fileobj:
52+
fileobj.write(b"KeyHeader,ValueColumn\n")
53+
fileobj.write(b"a,42\n")
54+
fileobj.write(b"b,420\n")
55+
fileobj.seek(0)
56+
async with CSVSource(
57+
CSVSourceConfig(filename=fileobj.name, key="KeyHeader")
58+
) as source:
59+
async with source() as sctx:
60+
repo_a = await sctx.repo("a")
61+
repo_b = await sctx.repo("b")
62+
self.assertEqual(repo_a.data.features["ValueColumn"], 42)
63+
self.assertEqual(repo_b.data.features["ValueColumn"], 420)

0 commit comments

Comments
 (0)