5
5
"""
6
6
import csv
7
7
import ast
8
+ from typing import NamedTuple , Dict
8
9
9
10
from ..repo import Repo
10
11
from .memory import MemorySource
11
- from .file import FileSource
12
+ from .file import FileSource , FileSourceConfig
13
+ from ..util .cli .arg import Arg
12
14
13
15
csv .register_dialect ("strip" , skipinitialspace = True )
14
16
15
17
18
+ class CSVSourceConfig (FileSourceConfig , NamedTuple ):
19
+ filename : str
20
+ label : str = "unlabeled"
21
+ readonly : bool = False
22
+ key : str = None
23
+
24
+
16
25
class CSVSource (FileSource , MemorySource ):
17
26
"""
18
27
Uses a CSV file as the source of repo feature data
@@ -21,6 +30,30 @@ class CSVSource(FileSource, MemorySource):
21
30
# Headers we've added to track data other than feature data for a repo
22
31
CSV_HEADERS = ["prediction" , "confidence" , "classification" ]
23
32
33
+ @classmethod
34
+ def args (cls , args , * above ) -> Dict [str , Arg ]:
35
+ cls .config_set (args , above , "filename" , Arg ())
36
+ cls .config_set (
37
+ args ,
38
+ above ,
39
+ "readonly" ,
40
+ Arg (type = bool , action = "store_true" , default = False ),
41
+ )
42
+ cls .config_set (
43
+ args , above , "label" , Arg (type = str , default = "unlabeled" )
44
+ )
45
+ cls .config_set (args , above , "key" , Arg (type = str , default = None ))
46
+ return args
47
+
48
+ @classmethod
49
+ def config (cls , config , * above ):
50
+ return CSVSourceConfig (
51
+ filename = cls .config_get (config , above , "filename" ),
52
+ readonly = cls .config_get (config , above , "readonly" ),
53
+ label = cls .config_get (config , above , "label" ),
54
+ key = cls .config_get (config , above , "key" ),
55
+ )
56
+
24
57
async def load_fd (self , fd ):
25
58
"""
26
59
Parses a CSV stream into Repo instances
@@ -45,6 +78,11 @@ async def load_fd(self, fd):
45
78
repo_data ["features" ][key ] = ast .literal_eval (value )
46
79
except (SyntaxError , ValueError ):
47
80
repo_data ["features" ][key ] = value
81
+ if self .config .key is not None and self .config .key == key :
82
+ src_url = value
83
+ if self .config .key is None :
84
+ src_url = str (i )
85
+ i += 1
48
86
# Correct types and structure of repo data from csv_meta
49
87
if "classification" in csv_meta :
50
88
repo_data .update (
@@ -59,9 +97,7 @@ async def load_fd(self, fd):
59
97
}
60
98
}
61
99
)
62
- # Create the repo with the source URL being the row index
63
- repo = Repo (str (i ), data = repo_data )
64
- i += 1
100
+ repo = Repo (src_url , data = repo_data )
65
101
self .mem [repo .src_url ] = repo
66
102
self .logger .debug ("%r loaded %d records" , self , len (self .mem ))
67
103
0 commit comments