5
5
"""
6
6
import csv
7
7
import ast
8
- from typing import NamedTuple , Dict
8
+ import asyncio
9
+ from typing import NamedTuple , Dict , List
10
+ from dataclasses import dataclass
11
+ from contextlib import asynccontextmanager
9
12
10
13
from ..repo import Repo
11
14
from .memory import MemorySource
16
19
csv .register_dialect ("strip" , skipinitialspace = True )
17
20
18
21
22
+ @dataclass
23
+ class OpenCSVFile :
24
+ write_out : Dict
25
+ active : int
26
+ lock : asyncio .Lock
27
+ write_back_key : bool = True
28
+ write_back_label : bool = False
29
+
30
+ async def inc (self ):
31
+ async with self .lock :
32
+ self .active += 1
33
+
34
+ async def dec (self ):
35
+ async with self .lock :
36
+ self .active -= 1
37
+ return bool (self .active < 1 )
38
+
39
+
40
+ CSV_SOURCE_CONFIG_DEFAULT_KEY = "src_url"
41
+ CSV_SOURCE_CONFIG_DEFAULT_LABEL = "unlabeled"
42
+ CSV_SOURCE_CONFIG_DEFAULT_LABEL_COLUMN = "label"
43
+
44
+
19
45
class CSVSourceConfig (FileSourceConfig , NamedTuple ):
20
46
filename : str
21
- label : str = "unlabeled"
22
47
readonly : bool = False
23
- key : str = None
48
+ key : str = CSV_SOURCE_CONFIG_DEFAULT_KEY
49
+ label : str = CSV_SOURCE_CONFIG_DEFAULT_LABEL
50
+ label_column : str = CSV_SOURCE_CONFIG_DEFAULT_LABEL_COLUMN
24
51
25
52
53
+ # CSVSource is a bit of a mess
26
54
@entry_point ("csv" )
27
55
class CSVSource (FileSource , MemorySource ):
28
56
"""
@@ -32,6 +60,29 @@ class CSVSource(FileSource, MemorySource):
32
60
# Headers we've added to track data other than feature data for a repo
33
61
CSV_HEADERS = ["prediction" , "confidence" ]
34
62
63
+ OPEN_CSV_FILES : Dict [str , OpenCSVFile ] = {}
64
+ OPEN_CSV_FILES_LOCK : asyncio .Lock = asyncio .Lock ()
65
+
66
+ @asynccontextmanager
67
+ async def _open_csv (self , fd = None ):
68
+ async with self .OPEN_CSV_FILES_LOCK :
69
+ if self .config .filename not in self .OPEN_CSV_FILES :
70
+ self .logger .debug (f"{ self .config .filename } first open" )
71
+ open_file = OpenCSVFile (
72
+ active = 1 , lock = asyncio .Lock (), write_out = {}
73
+ )
74
+ self .OPEN_CSV_FILES [self .config .filename ] = open_file
75
+ if fd is not None :
76
+ await self .read_csv (fd , open_file )
77
+ else :
78
+ self .logger .debug (f"{ self .config .filename } already open" )
79
+ await self .OPEN_CSV_FILES [self .config .filename ].inc ()
80
+ yield self .OPEN_CSV_FILES [self .config .filename ]
81
+
82
+ async def _empty_file_init (self ):
83
+ async with self ._open_csv ():
84
+ return {}
85
+
35
86
@classmethod
36
87
def args (cls , args , * above ) -> Dict [str , Arg ]:
37
88
cls .config_set (args , above , "filename" , Arg ())
@@ -42,9 +93,18 @@ def args(cls, args, *above) -> Dict[str, Arg]:
42
93
Arg (type = bool , action = "store_true" , default = False ),
43
94
)
44
95
cls .config_set (
45
- args , above , "label" , Arg (type = str , default = "unlabeled" )
96
+ args ,
97
+ above ,
98
+ "label" ,
99
+ Arg (type = str , default = CSV_SOURCE_CONFIG_DEFAULT_LABEL ),
46
100
)
47
- cls .config_set (args , above , "key" , Arg (type = str , default = None ))
101
+ cls .config_set (
102
+ args ,
103
+ above ,
104
+ "labelcol" ,
105
+ Arg (type = str , default = CSV_SOURCE_CONFIG_DEFAULT_LABEL_COLUMN ),
106
+ )
107
+ cls .config_set (args , above , "key" , Arg (type = str , default = "src_url" ))
48
108
return args
49
109
50
110
@classmethod
@@ -54,38 +114,53 @@ def config(cls, config, *above):
54
114
readonly = cls .config_get (config , above , "readonly" ),
55
115
label = cls .config_get (config , above , "label" ),
56
116
key = cls .config_get (config , above , "key" ),
117
+ label_column = cls .config_get (config , above , "labelcol" ),
57
118
)
58
119
59
- async def load_fd (self , fd ):
60
- """
61
- Parses a CSV stream into Repo instances
62
- """
63
- i = 0
64
- self .mem = {}
65
- for data in csv .DictReader (fd , dialect = "strip" ):
120
+ async def read_csv (self , fd , open_file ):
121
+ dict_reader = csv .DictReader (fd , dialect = "strip" )
122
+ # Record what headers are present when the file was opened
123
+ if not self .config .key in dict_reader .fieldnames :
124
+ open_file .write_back_key = False
125
+ if self .config .label_column in dict_reader .fieldnames :
126
+ open_file .write_back_label = True
127
+ # Store all the repos by their label in write_out
128
+ open_file .write_out = {}
129
+ # If there is no key track row index to be used as src_url by label
130
+ index = {}
131
+ for row in dict_reader :
132
+ # Grab label from row
133
+ label = row .get (self .config .label_column , self .config .label )
134
+ if self .config .label_column in row :
135
+ del row [self .config .label_column ]
136
+ index .setdefault (label , 0 )
137
+ # Grab src_url from row
138
+ src_url = row .get (self .config .key , index [label ])
139
+ if self .config .key in row :
140
+ del row [self .config .key ]
141
+ else :
142
+ index [label ] += 1
66
143
# Repo data we are going to parse from this row (must include
67
144
# features).
68
- repo_data = {"features" : {} }
145
+ repo_data = {}
69
146
# Parse headers we as the CSV source added
70
147
csv_meta = {}
71
148
for header in self .CSV_HEADERS :
72
- if not data .get (header ) is None and data [header ] != "" :
73
- csv_meta [header ] = data [header ]
149
+ value = row .get (header , None )
150
+ if value is not None and value != "" :
151
+ csv_meta [header ] = row [header ]
74
152
# Remove from feature data
75
- del data [header ]
76
- # Parse feature data
77
- for key , value in data .items ():
153
+ del row [header ]
154
+ # Set the features
155
+ features = {}
156
+ for key , value in row .items ():
78
157
if value != "" :
79
158
try :
80
- repo_data [ " features" ] [key ] = ast .literal_eval (value )
159
+ features [key ] = ast .literal_eval (value )
81
160
except (SyntaxError , ValueError ):
82
- repo_data ["features" ][key ] = value
83
- if self .config .key is not None and self .config .key == key :
84
- src_url = value
85
- if self .config .key is None :
86
- src_url = str (i )
87
- i += 1
88
- # Correct types and structure of repo data from csv_meta
161
+ features [key ] = value
162
+ if features :
163
+ repo_data ["features" ] = features
89
164
if "prediction" in csv_meta and "confidence" in csv_meta :
90
165
repo_data .update (
91
166
{
@@ -95,32 +170,67 @@ async def load_fd(self, fd):
95
170
}
96
171
}
97
172
)
98
- repo = Repo (src_url , data = repo_data )
99
- self .mem [repo .src_url ] = repo
173
+ # If there was no data in the row, skip it
174
+ if not repo_data and src_url == str (index [label ] - 1 ):
175
+ continue
176
+ # Add the repo to our internal memory representation
177
+ open_file .write_out .setdefault (label , {})
178
+ open_file .write_out [label ][src_url ] = Repo (src_url , data = repo_data )
179
+
180
+ async def load_fd (self , fd ):
181
+ """
182
+ Parses a CSV stream into Repo instances
183
+ """
184
+ async with self ._open_csv (fd ) as open_file :
185
+ self .mem = open_file .write_out .get (self .config .label , {})
100
186
self .logger .debug ("%r loaded %d records" , self , len (self .mem ))
101
187
102
188
async def dump_fd (self , fd ):
103
189
"""
104
190
Dumps data into a CSV stream
105
191
"""
106
- # Sample some headers without iterating all the way through
107
- fieldnames = []
108
- for repo in self .mem .values ():
109
- fieldnames = list (repo .data .features .keys ())
110
- break
111
- # Add our headers
112
- fieldnames += self .CSV_HEADERS
113
- # Write out the file
114
- writer = csv .DictWriter (fd , fieldnames = fieldnames )
115
- writer .writeheader ()
116
- # Write out rows in order
117
- for repo in self .mem .values ():
118
- repo_data = repo .dict ()
119
- row = {}
120
- for key , value in repo_data ["features" ].items ():
121
- row [key ] = value
122
- if "prediction" in repo_data :
123
- row ["prediction" ] = repo_data ["prediction" ]["value" ]
124
- row ["confidence" ] = repo_data ["prediction" ]["confidence" ]
125
- writer .writerow (row )
192
+ async with self .OPEN_CSV_FILES_LOCK :
193
+ open_file = self .OPEN_CSV_FILES [self .config .filename ]
194
+ open_file .write_out .setdefault (self .config .label , {})
195
+ open_file .write_out [self .config .label ].update (self .mem )
196
+ # Bail if not last open source for this file
197
+ if not (await open_file .dec ()):
198
+ return
199
+ # Add our headers
200
+ fieldnames = (
201
+ [] if not open_file .write_back_key else [self .config .key ]
202
+ )
203
+ fieldnames .append (self .config .label_column )
204
+ # Get all the feature names
205
+ feature_fieldnames = set ()
206
+ for label , repos in open_file .write_out .items ():
207
+ for repo in repos .values ():
208
+ feature_fieldnames |= set (repo .data .features .keys ())
209
+ fieldnames += list (feature_fieldnames )
210
+ fieldnames += self .CSV_HEADERS
211
+ self .logger .debug (f"fieldnames: { fieldnames } " )
212
+ # Write out the file
213
+ writer = csv .DictWriter (fd , fieldnames = fieldnames )
214
+ writer .writeheader ()
215
+ for label , repos in open_file .write_out .items ():
216
+ for repo in repos .values ():
217
+ repo_data = repo .dict ()
218
+ row = {name : "" for name in fieldnames }
219
+ # Always write the label
220
+ row [self .config .label_column ] = label
221
+ # Write the key if it existed
222
+ if open_file .write_back_key :
223
+ row [self .config .key ] = repo .src_url
224
+ # Write the features
225
+ for key , value in repo_data .get ("features" , {}).items ():
226
+ row [key ] = value
227
+ # Write the prediction
228
+ if "prediction" in repo_data :
229
+ row ["prediction" ] = repo_data ["prediction" ]["value" ]
230
+ row ["confidence" ] = repo_data ["prediction" ][
231
+ "confidence"
232
+ ]
233
+ writer .writerow (row )
234
+ del self .OPEN_CSV_FILES [self .config .filename ]
235
+ self .logger .debug (f"{ self .config .filename } written" )
126
236
self .logger .debug ("%r saved %d records" , self , len (self .mem ))
0 commit comments