Skip to content
This repository was archived by the owner on Aug 25, 2024. It is now read-only.

Commit c854e40

Browse files
sudharsana-kjlpdxjohnny
authored and
John Andersen
committed
source: csv: Add label
Co-authored-by: John Andersen <johnandersenpdx@gmail.com> Signed-off-by: John Andersen <johnandersenpdx@gmail.com>
1 parent 7afd85a commit c854e40

File tree

5 files changed

+297
-62
lines changed

5 files changed

+297
-62
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
5858
- A temporary directory is used to replicate `mktemp -u` functionality so as to
5959
provide tests using a FileSource with a valid tempfile name.
6060
- Labels for JSON sources
61+
- Labels for CSV sources
6162
- util.cli CMD's correcly set the description of subparsers instead of their
6263
help, they also accept the `CLI_FORMATTER_CLASS` property.
6364
- CSV source now has `entry_point` decoration

dffml/source/csv.py

+158-48
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@
55
"""
66
import csv
77
import ast
8-
from typing import NamedTuple, Dict
8+
import asyncio
9+
from typing import NamedTuple, Dict, List
10+
from dataclasses import dataclass
11+
from contextlib import asynccontextmanager
912

1013
from ..repo import Repo
1114
from .memory import MemorySource
@@ -16,13 +19,38 @@
1619
csv.register_dialect("strip", skipinitialspace=True)
1720

1821

22+
@dataclass
23+
class OpenCSVFile:
24+
write_out: Dict
25+
active: int
26+
lock: asyncio.Lock
27+
write_back_key: bool = True
28+
write_back_label: bool = False
29+
30+
async def inc(self):
31+
async with self.lock:
32+
self.active += 1
33+
34+
async def dec(self):
35+
async with self.lock:
36+
self.active -= 1
37+
return bool(self.active < 1)
38+
39+
40+
CSV_SOURCE_CONFIG_DEFAULT_KEY = "src_url"
41+
CSV_SOURCE_CONFIG_DEFAULT_LABEL = "unlabeled"
42+
CSV_SOURCE_CONFIG_DEFAULT_LABEL_COLUMN = "label"
43+
44+
1945
class CSVSourceConfig(FileSourceConfig, NamedTuple):
2046
filename: str
21-
label: str = "unlabeled"
2247
readonly: bool = False
23-
key: str = None
48+
key: str = CSV_SOURCE_CONFIG_DEFAULT_KEY
49+
label: str = CSV_SOURCE_CONFIG_DEFAULT_LABEL
50+
label_column: str = CSV_SOURCE_CONFIG_DEFAULT_LABEL_COLUMN
2451

2552

53+
# CSVSource is a bit of a mess
2654
@entry_point("csv")
2755
class CSVSource(FileSource, MemorySource):
2856
"""
@@ -32,6 +60,29 @@ class CSVSource(FileSource, MemorySource):
3260
# Headers we've added to track data other than feature data for a repo
3361
CSV_HEADERS = ["prediction", "confidence"]
3462

63+
OPEN_CSV_FILES: Dict[str, OpenCSVFile] = {}
64+
OPEN_CSV_FILES_LOCK: asyncio.Lock = asyncio.Lock()
65+
66+
@asynccontextmanager
67+
async def _open_csv(self, fd=None):
68+
async with self.OPEN_CSV_FILES_LOCK:
69+
if self.config.filename not in self.OPEN_CSV_FILES:
70+
self.logger.debug(f"{self.config.filename} first open")
71+
open_file = OpenCSVFile(
72+
active=1, lock=asyncio.Lock(), write_out={}
73+
)
74+
self.OPEN_CSV_FILES[self.config.filename] = open_file
75+
if fd is not None:
76+
await self.read_csv(fd, open_file)
77+
else:
78+
self.logger.debug(f"{self.config.filename} already open")
79+
await self.OPEN_CSV_FILES[self.config.filename].inc()
80+
yield self.OPEN_CSV_FILES[self.config.filename]
81+
82+
async def _empty_file_init(self):
83+
async with self._open_csv():
84+
return {}
85+
3586
@classmethod
3687
def args(cls, args, *above) -> Dict[str, Arg]:
3788
cls.config_set(args, above, "filename", Arg())
@@ -42,9 +93,18 @@ def args(cls, args, *above) -> Dict[str, Arg]:
4293
Arg(type=bool, action="store_true", default=False),
4394
)
4495
cls.config_set(
45-
args, above, "label", Arg(type=str, default="unlabeled")
96+
args,
97+
above,
98+
"label",
99+
Arg(type=str, default=CSV_SOURCE_CONFIG_DEFAULT_LABEL),
46100
)
47-
cls.config_set(args, above, "key", Arg(type=str, default=None))
101+
cls.config_set(
102+
args,
103+
above,
104+
"labelcol",
105+
Arg(type=str, default=CSV_SOURCE_CONFIG_DEFAULT_LABEL_COLUMN),
106+
)
107+
cls.config_set(args, above, "key", Arg(type=str, default="src_url"))
48108
return args
49109

50110
@classmethod
@@ -54,38 +114,53 @@ def config(cls, config, *above):
54114
readonly=cls.config_get(config, above, "readonly"),
55115
label=cls.config_get(config, above, "label"),
56116
key=cls.config_get(config, above, "key"),
117+
label_column=cls.config_get(config, above, "labelcol"),
57118
)
58119

59-
async def load_fd(self, fd):
60-
"""
61-
Parses a CSV stream into Repo instances
62-
"""
63-
i = 0
64-
self.mem = {}
65-
for data in csv.DictReader(fd, dialect="strip"):
120+
async def read_csv(self, fd, open_file):
121+
dict_reader = csv.DictReader(fd, dialect="strip")
122+
# Record what headers are present when the file was opened
123+
if not self.config.key in dict_reader.fieldnames:
124+
open_file.write_back_key = False
125+
if self.config.label_column in dict_reader.fieldnames:
126+
open_file.write_back_label = True
127+
# Store all the repos by their label in write_out
128+
open_file.write_out = {}
129+
# If there is no key track row index to be used as src_url by label
130+
index = {}
131+
for row in dict_reader:
132+
# Grab label from row
133+
label = row.get(self.config.label_column, self.config.label)
134+
if self.config.label_column in row:
135+
del row[self.config.label_column]
136+
index.setdefault(label, 0)
137+
# Grab src_url from row
138+
src_url = row.get(self.config.key, index[label])
139+
if self.config.key in row:
140+
del row[self.config.key]
141+
else:
142+
index[label] += 1
66143
# Repo data we are going to parse from this row (must include
67144
# features).
68-
repo_data = {"features": {}}
145+
repo_data = {}
69146
# Parse headers we as the CSV source added
70147
csv_meta = {}
71148
for header in self.CSV_HEADERS:
72-
if not data.get(header) is None and data[header] != "":
73-
csv_meta[header] = data[header]
149+
value = row.get(header, None)
150+
if value is not None and value != "":
151+
csv_meta[header] = row[header]
74152
# Remove from feature data
75-
del data[header]
76-
# Parse feature data
77-
for key, value in data.items():
153+
del row[header]
154+
# Set the features
155+
features = {}
156+
for key, value in row.items():
78157
if value != "":
79158
try:
80-
repo_data["features"][key] = ast.literal_eval(value)
159+
features[key] = ast.literal_eval(value)
81160
except (SyntaxError, ValueError):
82-
repo_data["features"][key] = value
83-
if self.config.key is not None and self.config.key == key:
84-
src_url = value
85-
if self.config.key is None:
86-
src_url = str(i)
87-
i += 1
88-
# Correct types and structure of repo data from csv_meta
161+
features[key] = value
162+
if features:
163+
repo_data["features"] = features
89164
if "prediction" in csv_meta and "confidence" in csv_meta:
90165
repo_data.update(
91166
{
@@ -95,32 +170,67 @@ async def load_fd(self, fd):
95170
}
96171
}
97172
)
98-
repo = Repo(src_url, data=repo_data)
99-
self.mem[repo.src_url] = repo
173+
# If there was no data in the row, skip it
174+
if not repo_data and src_url == str(index[label] - 1):
175+
continue
176+
# Add the repo to our internal memory representation
177+
open_file.write_out.setdefault(label, {})
178+
open_file.write_out[label][src_url] = Repo(src_url, data=repo_data)
179+
180+
async def load_fd(self, fd):
181+
"""
182+
Parses a CSV stream into Repo instances
183+
"""
184+
async with self._open_csv(fd) as open_file:
185+
self.mem = open_file.write_out.get(self.config.label, {})
100186
self.logger.debug("%r loaded %d records", self, len(self.mem))
101187

102188
async def dump_fd(self, fd):
103189
"""
104190
Dumps data into a CSV stream
105191
"""
106-
# Sample some headers without iterating all the way through
107-
fieldnames = []
108-
for repo in self.mem.values():
109-
fieldnames = list(repo.data.features.keys())
110-
break
111-
# Add our headers
112-
fieldnames += self.CSV_HEADERS
113-
# Write out the file
114-
writer = csv.DictWriter(fd, fieldnames=fieldnames)
115-
writer.writeheader()
116-
# Write out rows in order
117-
for repo in self.mem.values():
118-
repo_data = repo.dict()
119-
row = {}
120-
for key, value in repo_data["features"].items():
121-
row[key] = value
122-
if "prediction" in repo_data:
123-
row["prediction"] = repo_data["prediction"]["value"]
124-
row["confidence"] = repo_data["prediction"]["confidence"]
125-
writer.writerow(row)
192+
async with self.OPEN_CSV_FILES_LOCK:
193+
open_file = self.OPEN_CSV_FILES[self.config.filename]
194+
open_file.write_out.setdefault(self.config.label, {})
195+
open_file.write_out[self.config.label].update(self.mem)
196+
# Bail if not last open source for this file
197+
if not (await open_file.dec()):
198+
return
199+
# Add our headers
200+
fieldnames = (
201+
[] if not open_file.write_back_key else [self.config.key]
202+
)
203+
fieldnames.append(self.config.label_column)
204+
# Get all the feature names
205+
feature_fieldnames = set()
206+
for label, repos in open_file.write_out.items():
207+
for repo in repos.values():
208+
feature_fieldnames |= set(repo.data.features.keys())
209+
fieldnames += list(feature_fieldnames)
210+
fieldnames += self.CSV_HEADERS
211+
self.logger.debug(f"fieldnames: {fieldnames}")
212+
# Write out the file
213+
writer = csv.DictWriter(fd, fieldnames=fieldnames)
214+
writer.writeheader()
215+
for label, repos in open_file.write_out.items():
216+
for repo in repos.values():
217+
repo_data = repo.dict()
218+
row = {name: "" for name in fieldnames}
219+
# Always write the label
220+
row[self.config.label_column] = label
221+
# Write the key if it existed
222+
if open_file.write_back_key:
223+
row[self.config.key] = repo.src_url
224+
# Write the features
225+
for key, value in repo_data.get("features", {}).items():
226+
row[key] = value
227+
# Write the prediction
228+
if "prediction" in repo_data:
229+
row["prediction"] = repo_data["prediction"]["value"]
230+
row["confidence"] = repo_data["prediction"][
231+
"confidence"
232+
]
233+
writer.writerow(row)
234+
del self.OPEN_CSV_FILES[self.config.filename]
235+
self.logger.debug(f"{self.config.filename} written")
126236
self.logger.debug("%r saved %d records", self, len(self.mem))

tests/service/test_dev.py

+2
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ async def generic_test(self, name, package_specific_files):
8787
package_name,
8888
package_specific_files,
8989
)
90+
else: # pragma: no cov
91+
pass
9092

9193
async def test_model(self):
9294
await self.generic_test(

tests/source/test_csv.py

+45-10
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@
22
# Copyright (c) 2019 Intel Corporation
33
import unittest
44
import tempfile
5+
import os
6+
import csv
7+
import random
8+
import pathlib
59

6-
from dffml.source.file import FileSourceConfig
710
from dffml.source.csv import CSVSource, CSVSourceConfig
811
from dffml.util.testing.source import FileSourceTest
912
from dffml.util.asynctestcase import AsyncTestCase
@@ -15,35 +18,67 @@ class TestCSVSource(FileSourceTest, AsyncTestCase):
1518
async def setUpSource(self):
1619
return CSVSource(CSVSourceConfig(filename=self.testfile))
1720

18-
@unittest.skip("Labels not implemented yet for CSV files")
1921
async def test_label(self):
20-
"""
21-
Labels not implemented yet for CSV files
22-
"""
22+
with tempfile.TemporaryDirectory() as testdir:
23+
self.testfile = os.path.join(testdir, str(random.random()))
24+
unlabeled = await self.setUpSource()
25+
labeled = await self.setUpSource()
26+
labeled.config = labeled.config._replace(label="somelabel")
27+
async with unlabeled, labeled:
28+
async with unlabeled() as uctx, labeled() as lctx:
29+
await uctx.update(
30+
Repo("0", data={"features": {"feed": 1}})
31+
)
32+
await lctx.update(
33+
Repo("0", data={"features": {"face": 2}})
34+
)
35+
# async with unlabeled, labeled:
36+
async with unlabeled() as uctx, labeled() as lctx:
37+
repo = await uctx.repo("0")
38+
self.assertIn("feed", repo.features())
39+
repo = await lctx.repo("0")
40+
self.assertIn("face", repo.features())
41+
with open(self.testfile, "r") as fd:
42+
dict_reader = csv.DictReader(fd, dialect="strip")
43+
rows = {
44+
row["label"]: {row["src_url"]: row} for row in dict_reader
45+
}
46+
self.assertIn("unlabeled", rows)
47+
self.assertIn("somelabel", rows)
48+
self.assertIn("0", rows["unlabeled"])
49+
self.assertIn("0", rows["somelabel"])
50+
self.assertIn("feed", rows["unlabeled"]["0"])
51+
self.assertIn("face", rows["somelabel"]["0"])
52+
self.assertEqual("1", rows["unlabeled"]["0"]["feed"])
53+
self.assertEqual("2", rows["somelabel"]["0"]["face"])
2354

24-
def test_config_readonly_default(self):
55+
def test_config_default(self):
2556
config = CSVSource.config(
2657
parse_unknown("--source-csv-filename", "feedface")
2758
)
2859
self.assertEqual(config.filename, "feedface")
2960
self.assertEqual(config.label, "unlabeled")
30-
self.assertEqual(config.key, None)
61+
self.assertEqual(config.label_column, "label")
62+
self.assertEqual(config.key, "src_url")
3163
self.assertFalse(config.readonly)
3264

33-
def test_config_readonly_set(self):
65+
def test_config_set(self):
3466
config = CSVSource.config(
3567
parse_unknown(
3668
"--source-csv-filename",
3769
"feedface",
3870
"--source-csv-label",
3971
"default-label",
72+
"--source-csv-labelcol",
73+
"dffml_label",
4074
"--source-csv-key",
4175
"SourceURLColumn",
4276
"--source-csv-readonly",
4377
)
4478
)
4579
self.assertEqual(config.filename, "feedface")
4680
self.assertEqual(config.label, "default-label")
81+
self.assertEqual(config.label_column, "dffml_label")
4782
self.assertEqual(config.key, "SourceURLColumn")
4883
self.assertTrue(config.readonly)
4984

@@ -59,5 +94,5 @@ async def test_key(self):
5994
async with source() as sctx:
6095
repo_a = await sctx.repo("a")
6196
repo_b = await sctx.repo("b")
62-
self.assertEqual(repo_a.data.features["ValueColumn"], 42)
63-
self.assertEqual(repo_b.data.features["ValueColumn"], 420)
97+
self.assertEqual(repo_a.feature("ValueColumn"), 42)
98+
self.assertEqual(repo_b.feature("ValueColumn"), 420)

0 commit comments

Comments
 (0)