@@ -6,7 +6,6 @@ val noOutput = CLA.parseFlag "no-output"
6
6
val verbose = CLA.parseFlag " verbose"
7
7
val filename = List.hd (CLA.positional ())
8
8
handle _ => Util.die " missing input filename"
9
- val showParsed = CLA.parseFlag " check-show-parsed"
10
9
val noBoundsChecks = CLA.parseFlag " unsafe-no-bounds-checks"
11
10
12
11
(* capacity 20000 should be a reasonable choice according to the spec of
@@ -16,6 +15,8 @@ val noBoundsChecks = CLA.parseFlag "unsafe-no-bounds-checks"
16
15
val capacity = CLA.parseInt " table-capacity" 19997
17
16
18
17
val blockSize = CLA.parseInt " block-size" 100000
18
+ val bufferSize = CLA.parseInt " buffer-size" 1000
19
+ val contentionFactor = CLA.parseInt " contention-factor" 8
19
20
20
21
(* =======================================================================
21
22
* a few utilities
@@ -34,13 +35,31 @@ fun assert b msg =
34
35
35
36
(* ======================================================================== *)
36
37
37
- (* By parameterizing the whole thing by {numBytes, getBytes}, we can
38
- * easily enable/disable bounds checking by passing in two different versions
39
- * of `getByte`. See bottom of file where we instantiate Main(...).
40
- *)
41
- functor Main (val numBytes: int val getByte: int -> Word8.word) =
38
+ fun bb b =
39
+ if b then " yes" else " no"
40
+
41
+ val _ = vprint (" no-output? " ^ bb noOutput ^ " \n " )
42
+ val _ = vprint (" unsafe-no-bounds-checks? " ^ bb noBoundsChecks ^ " \n " )
43
+ val _ = vprint (" table-capacity " ^ Int.toString capacity ^ " \n " )
44
+ val _ = vprint (" block-size " ^ Int.toString blockSize ^ " \n " )
45
+ val _ = vprint (" buffer-size " ^ Int.toString bufferSize ^ " \n " )
46
+ val _ = vprint
47
+ (" contention-factor " ^ Int.toString contentionFactor ^ " \n " )
48
+
49
+ (* ======================================================================== *)
50
+
51
+ functor Main
52
+ (Args:
53
+ sig
54
+ val numBytes: int
55
+ val getByte: int -> Word8.word
56
+ val getBytes: {offset: int, buffer: Word8.word array} -> int
57
+ val arraySub: 'a array * int -> 'a
58
+ end ) =
42
59
struct
43
60
61
+ open Args
62
+
44
63
(* ======================================================================
45
64
* Parsing.
46
65
*)
@@ -49,15 +68,82 @@ struct
49
68
val newline_id: Word8.word = 0wxA (* #"\n" *)
50
69
val dash_id: Word8.word = 0wx2D (* #"-" *)
51
70
val zero_id: Word8.word = 0wx30 (* #"0" *)
71
+ val dot_id: Word8.word = 0wx2E (* #"." *)
52
72
53
73
type index = int
54
74
type station_name = string
55
75
type measurement = int
56
76
77
+
78
+ datatype buffer = Buffer of {buffer: Word8.word array, offset: int, size: int}
79
+
80
+
81
+ fun newBuffer capacity =
82
+ Buffer {buffer = ForkJoin.alloc capacity, offset = 0 , size = 0 }
83
+
84
+
85
+ fun fillBuffer (Buffer {buffer, ...}) start =
86
+ let val size' = getBytes {offset = start, buffer = buffer}
87
+ in Buffer {buffer = buffer, offset = start, size = size'}
88
+ end
89
+
90
+
91
+ fun readBufferByte (b as Buffer {buffer, offset, size}) i =
92
+ if i >= offset andalso i < offset + size then
93
+ (b, arraySub (buffer, i - offset))
94
+ else
95
+ readBufferByte (fillBuffer b i) i
96
+
97
+
98
+ fun bufferLoop (Buffer {buffer, offset, size}) {start, continue, z, func} =
99
+ let
100
+ fun finish offset bufferSize acc i =
101
+ ( Buffer {buffer = buffer, size = bufferSize, offset = offset}
102
+ , offset + i
103
+ , acc
104
+ )
105
+
106
+ fun loop offset bufferSize acc i =
107
+ if i < bufferSize then
108
+ let
109
+ val byte = arraySub (buffer, i)
110
+ in
111
+ if continue byte then
112
+ loop offset bufferSize (func (acc, byte)) (i + 1 )
113
+ else
114
+ finish offset bufferSize acc i
115
+ end
116
+ else if offset + i >= numBytes then
117
+ finish offset bufferSize acc i
118
+ else
119
+ let
120
+ val offset' = offset + bufferSize
121
+ val bufferSize' = getBytes {offset = offset', buffer = buffer}
122
+ in
123
+ loop offset' bufferSize' acc 0
124
+ end
125
+ in
126
+ if start < offset orelse start >= offset + size then
127
+ (* this will immediately fill the buffer *)
128
+ loop start 0 z 0
129
+ else
130
+ (* can reuse some of the existing buffer *)
131
+ loop offset size z (start - offset)
132
+ end
133
+
134
+
135
+ fun findNextBuffered buffer c i =
136
+ let
137
+ val (buffer, position, ()) = bufferLoop buffer
138
+ {start = i, continue = fn byte => byte <> c, z = (), func = fn _ => ()}
139
+ in
140
+ (buffer, if position >= numBytes then NONE else SOME position)
141
+ end
142
+
143
+
57
144
fun findNext c i =
58
- if i >= numBytes then NONE
59
- else if getByte i = c then SOME i
60
- else findNext c (i + 1 )
145
+ #2 (findNextBuffered (newBuffer 10 ) c i)
146
+
61
147
62
148
fun parseStationName (start: index) : int * station_name =
63
149
let
@@ -69,35 +155,28 @@ struct
69
155
)
70
156
end
71
157
72
- fun parseMeasurement (start: index) : (int * measurement) =
73
- let
74
- val stop = valOf (findNext newline_id start)
75
158
159
+ fun parseMeasurement buffer (start: index) : (buffer * int * measurement) =
160
+ let
161
+ val (buffer, firstByte) = readBufferByte buffer start
76
162
val (start, isNeg) =
77
- if getByte start = dash_id then (start + 1 , true ) else (start, false )
78
-
79
- val numDigits = stop - start - 1 (* exclude the dot *)
80
- fun getDigit i =
81
- let
82
- val c =
83
- if i < numDigits - 1 then getByte (start + i)
84
- else getByte (start + i + 1 )
85
- in
86
- Word8.toInt (c - zero_id)
87
- end
88
-
89
- val x = Util.loop (0 , numDigits) 0 (fn (acc, i) => 10 * acc + getDigit i)
163
+ if firstByte = dash_id then (start + 1 , true ) else (start, false )
164
+ val (buffer, stop, x) = bufferLoop buffer
165
+ { start = start
166
+ , continue = fn byte => byte <> newline_id
167
+ , z = 0
168
+ , func = fn (acc, byte) =>
169
+ if byte = dot_id then acc
170
+ else 10 * acc + (Word8.toInt (byte - zero_id))
171
+ }
90
172
in
91
- (stop, if isNeg then ~x else x)
173
+ (buffer, stop, if isNeg then ~x else x)
92
174
end
93
175
94
176
95
177
fun getStationName start =
96
178
#2 (parseStationName start)
97
179
98
- fun getMeasurement start =
99
- #2 (parseMeasurement start)
100
-
101
180
102
181
(* ==========================================================================
103
182
* Define the hash table type. We identify entries by their starting index.
@@ -198,7 +277,7 @@ struct
198
277
end
199
278
end
200
279
in
201
- loop (Array.sub (arr, i))
280
+ loop (arraySub (arr, i))
202
281
end
203
282
204
283
@@ -232,7 +311,11 @@ struct
232
311
end
233
312
234
313
235
- structure T = PackedWeightedHashTable (structure K = Key structure W = Weight)
314
+ structure T =
315
+ PackedWeightedHashTable
316
+ (structure K = Key
317
+ structure W = Weight
318
+ val contentionFactor = contentionFactor)
236
319
237
320
(* ==========================================================================
238
321
* do the main loop. split the input into blocks and parse each block
@@ -243,21 +326,21 @@ struct
243
326
let
244
327
val table = T.make {capacity = capacity}
245
328
246
- fun loop cursor stop =
329
+ fun loop buffer cursor stop =
247
330
if cursor >= stop then
248
331
()
249
332
else
250
333
let
251
334
val start = cursor
252
- val cursor = valOf (findNext semicolon_id cursor)
335
+ val (buffer, cursor) = findNextBuffered buffer semicolon_id cursor
336
+ val cursor = valOf cursor
253
337
val cursor = cursor + 1 (* get past the ";" *)
254
- val (cursor, m) = parseMeasurement cursor
338
+ val (buffer, cursor, m) = parseMeasurement buffer cursor
255
339
val cursor = cursor + 1 (* get past the newline character *)
256
-
257
340
val weight = {min = m, max = m, tot = m, count = 1 }
258
341
in
259
342
T.insertCombineWeights table (start, weight);
260
- loop cursor stop
343
+ loop buffer cursor stop
261
344
end
262
345
263
346
fun findLineStart i =
@@ -273,7 +356,7 @@ struct
273
356
val start = findLineStart (b * blockSize)
274
357
val stop = findLineStart ((b + 1 ) * blockSize)
275
358
in
276
- loop start stop
359
+ loop (newBuffer bufferSize) start stop
277
360
end ))
278
361
279
362
val compacted = reportTime " compact" (fn _ =>
@@ -326,35 +409,61 @@ end
326
409
327
410
(* ======================================================================= *)
328
411
329
- val _ = vprint (" loading " ^ filename ^ " \n " )
412
+ (* val _ = vprint ("loading " ^ filename ^ "\n") * )
330
413
331
- val contents: Word8.word Seq.t = reportTime " load file" (fn _ =>
414
+ (* val contents: Word8.word Seq.t = reportTime "load file" (fn _ =>
332
415
ReadFile.contentsBinSeq filename)
333
416
334
417
val contents =
335
418
let val (arr, i, _) = ArraySlice.base contents
336
419
in if i = 0 then arr else Util.die ("whoops! strip away Seq failed")
337
- end
420
+ end *)
338
421
339
422
340
- (* ======================================================================= *)
423
+ val file = MPL.File.openFile filename
424
+
425
+ val numBytes = MPL.File.size file
341
426
427
+ fun getByte i = MPL.File.readWord8 file i
428
+
429
+ fun getBytes {offset, buffer} =
430
+ let
431
+ val count = Int.max (0 , Int.min (Array.length buffer, numBytes - offset))
432
+ (* val _ = print
433
+ ("getBytes " ^ Int.toString offset ^ " "
434
+ ^ Int.toString (Array.length buffer) ^ " " ^ Int.toString count ^ "\n") *)
435
+ val buffer = ArraySlice.slice (buffer, 0 , SOME count)
436
+ in
437
+ MPL.File.readWord8s file offset buffer;
438
+ count
439
+ end
440
+
441
+ (* ======================================================================= *)
342
442
343
443
structure MainWithBoundsChecks =
344
444
Main
345
- (val numBytes = Array.length contents
346
- fun getByte i = Array.sub (contents, i))
445
+ (struct
446
+ val numBytes = numBytes
447
+ val getByte = getByte
448
+ val getBytes = getBytes
449
+ val arraySub = Array.sub
450
+ end )
451
+
347
452
structure MainNoBoundsChecks =
348
453
Main
349
- (val numBytes = Array.length contents
350
- fun getByte i = Unsafe.Array.sub (contents, i))
351
-
454
+ (struct
455
+ val numBytes = numBytes
456
+ val getByte = getByte
457
+ val getBytes = getBytes
458
+ val arraySub = Unsafe.Array.sub
459
+ end )
352
460
353
461
val _ =
354
462
if noBoundsChecks then MainNoBoundsChecks.main ()
355
463
else MainWithBoundsChecks.main ()
356
464
357
465
466
+ val _ = MPL.File.closeFile file
358
467
val stop_time = Time.now ()
359
468
val _ = vprint
360
469
(" \n total time: " ^ Time.fmt 4 (Time.- (stop_time, start_time)) ^ " s\n " )
0 commit comments