Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Safe escape csv with double quotes if having comma #189

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 46 additions & 40 deletions rdbtools/memprofiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@ def next_record(self, record):
self.add_aggregate('database_memory', 'all', record.bytes)
self.add_aggregate('type_memory', record.type, record.bytes)
self.add_aggregate('encoding_memory', record.encoding, record.bytes)

self.add_aggregate('type_count', record.type, 1)
self.add_aggregate('encoding_count', record.encoding, 1)

self.add_histogram(record.type + "_length", record.size)
self.add_histogram(record.type + "_memory", (record.bytes/10) * 10)

if record.type == 'list':
self.add_scatter('list_memory_by_length', record.bytes, record.size)
elif record.type == 'hash':
Expand All @@ -57,12 +57,12 @@ def next_record(self, record):
def add_aggregate(self, heading, subheading, metric):
if not heading in self.aggregates :
self.aggregates[heading] = {}

if not subheading in self.aggregates[heading]:
self.aggregates[heading][subheading] = 0

self.aggregates[heading][subheading] += metric

def add_histogram(self, heading, metric):
if not heading in self.histograms:
self.histograms[heading] = {}
Expand All @@ -71,18 +71,18 @@ def add_histogram(self, heading, metric):
self.histograms[heading][metric] = 1
else :
self.histograms[heading][metric] += 1

def add_scatter(self, heading, x, y):
if not heading in self.scatters:
self.scatters[heading] = []
self.scatters[heading].append([x, y])

def set_metadata(self, key, val):
self.metadata[key] = val

def get_json(self):
return json.dumps({"aggregates": self.aggregates, "scatters": self.scatters, "histograms": self.histograms, "metadata": self.metadata})

class PrintAllKeys(object):
def __init__(self, out, bytes, largest):
self._bytes = bytes
Expand All @@ -94,14 +94,20 @@ def __init__(self, out, bytes, largest):

if self._largest is not None:
self._heap = []


def _safe_csv(self, s):
if "," in s:
return "\"" + s.replace("\"", "\"\"") + "\""
else:
return s

def next_record(self, record) :
if record.key is None:
return # some records are not keys (e.g. dict)
if self._largest is None:
if self._bytes is None or record.bytes >= int(self._bytes):
rec_str = "%d,%s,%s,%d,%s,%d,%d,%s\n" % (
record.database, record.type, record.key, record.bytes, record.encoding, record.size,
record.database, record.type, self._safe_csv(record.key), record.bytes, record.encoding, record.size,
record.len_largest_element,
record.expiry.isoformat() if record.expiry else '')
self._out.write(codecs.encode(rec_str, 'latin-1'))
Expand All @@ -120,7 +126,7 @@ def end_rdb(self):
class PrintJustKeys(object):
def __init__(self, out):
self._out = out

def next_record(self, record):
self._out.write(codecs.encode("%s\n" % record.key, 'latin-1'))

Expand Down Expand Up @@ -197,58 +203,58 @@ def set(self, key, value, expiry, info):
length = self.element_length(value)
self.emit_record("string", key, size, self._current_encoding, length, length, expiry)
self.end_key()

def start_hash(self, key, length, expiry, info):
self._current_encoding = info['encoding']
self._current_length = length
self._key_expiry = expiry
size = self.top_level_object_overhead(key, expiry)

if 'sizeof_value' in info:
size += info['sizeof_value']
elif 'encoding' in info and info['encoding'] == 'hashtable':
size += self.hashtable_overhead(length)
else:
raise Exception('start_hash', 'Could not find encoding or sizeof_value in info object %s' % info)
self._current_size = size

def hset(self, key, field, value):
if(self.element_length(field) > self._len_largest_element) :
self._len_largest_element = self.element_length(field)
if(self.element_length(value) > self._len_largest_element) :
self._len_largest_element = self.element_length(value)

if self._current_encoding == 'hashtable':
self._current_size += self.sizeof_string(field)
self._current_size += self.sizeof_string(value)
self._current_size += self.hashtable_entry_overhead()
if self._redis_version < StrictVersion('4.0'):
self._current_size += 2*self.robj_overhead()

def end_hash(self, key):
self.emit_record("hash", key, self._current_size, self._current_encoding, self._current_length,
self._len_largest_element, self._key_expiry)
self.end_key()

def start_set(self, key, cardinality, expiry, info):
# A set is exactly like a hashmap
self.start_hash(key, cardinality, expiry, info)

def sadd(self, key, member):
if(self.element_length(member) > self._len_largest_element) :
self._len_largest_element = self.element_length(member)

if self._current_encoding == 'hashtable':
self._current_size += self.sizeof_string(member)
self._current_size += self.hashtable_entry_overhead()
if self._redis_version < StrictVersion('4.0'):
self._current_size += self.robj_overhead()

def end_set(self, key):
self.emit_record("set", key, self._current_size, self._current_encoding, self._current_length,
self._len_largest_element, self._key_expiry)
self.end_key()

def start_list(self, key, expiry, info):
self._current_length = 0
self._list_items_size = 0 # size of all elements in case list ends up using linked list
Expand All @@ -272,7 +278,7 @@ def start_list(self, key, expiry, info):
self._list_max_ziplist_value = 64

self._current_size = size

def rpush(self, key, value):
self._current_length += 1
# in linked list, when the robj has integer encoding, the value consumes no memory on top of the robj
Expand Down Expand Up @@ -385,30 +391,30 @@ def start_sorted_set(self, key, length, expiry, info):
else:
raise Exception('start_sorted_set', 'Could not find encoding or sizeof_value in info object %s' % info)
self._current_size = size

def zadd(self, key, score, member):
if(self.element_length(member) > self._len_largest_element):
self._len_largest_element = self.element_length(member)

if self._current_encoding == 'skiplist':
self._current_size += 8 # score (double)
self._current_size += self.sizeof_string(member)
if self._redis_version < StrictVersion('4.0'):
self._current_size += self.robj_overhead()
self._current_size += self.skiplist_entry_overhead()

def end_sorted_set(self, key):
self.emit_record("sortedset", key, self._current_size, self._current_encoding, self._current_length,
self._len_largest_element, self._key_expiry)
self.end_key()

def end_key(self):
self._db_keys += 1
self._current_encoding = None
self._current_size = 0
self._len_largest_element = 0
self._key_expiry = None

def sizeof_string(self, string):
# https://github.com/antirez/redis/blob/unstable/src/sds.h
try:
Expand All @@ -433,7 +439,7 @@ def sizeof_string(self, string):
return self.malloc_overhead(l + 1 + 16 + 1)

def top_level_object_overhead(self, key, expiry):
# Each top level object is an entry in a dictionary, and so we have to include
# Each top level object is an entry in a dictionary, and so we have to include
# the overhead of a dictionary entry
return self.hashtable_entry_overhead() + self.sizeof_string(key) + self.robj_overhead() + self.key_expiry_overhead(expiry)

Expand All @@ -445,25 +451,25 @@ def key_expiry_overhead(self, expiry):
# Key expiry is stored in a hashtable, so we have to pay for the cost of a hashtable entry
# The timestamp itself is stored as an int64, which is a 8 bytes
return self.hashtable_entry_overhead() + 8

def hashtable_overhead(self, size):
# See https://github.com/antirez/redis/blob/unstable/src/dict.h
# See the structures dict and dictht
# 2 * (3 unsigned longs + 1 pointer) + int + long + 2 pointers
#
#
# Additionally, see **table in dictht
# The length of the table is the next power of 2
# When the hashtable is rehashing, another instance of **table is created
# Due to the possibility of rehashing during loading, we calculate the worse
# Due to the possibility of rehashing during loading, we calculate the worse
# case in which both tables are allocated, and so multiply
# the size of **table by 1.5
return 4 + 7*self.sizeof_long() + 4*self.sizeof_pointer() + self.next_power(size)*self.sizeof_pointer()*1.5

def hashtable_entry_overhead(self):
# See https://github.com/antirez/redis/blob/unstable/src/dict.h
# Each dictEntry has 2 pointers + int64
return 2*self.sizeof_pointer() + 8

def linkedlist_overhead(self):
# See https://github.com/antirez/redis/blob/unstable/src/adlist.h
# A list has 5 pointers + an unsigned long
Expand Down Expand Up @@ -514,24 +520,24 @@ def ziplist_entry_overhead(self, value):

def skiplist_overhead(self, size):
return 2*self.sizeof_pointer() + self.hashtable_overhead(size) + (2*self.sizeof_pointer() + 16)

def skiplist_entry_overhead(self):
return self.hashtable_entry_overhead() + 2*self.sizeof_pointer() + 8 + (self.sizeof_pointer() + 8) * self.zset_random_level()

def robj_overhead(self):
return self.sizeof_pointer() + 8

def malloc_overhead(self, size):
alloc = get_jemalloc_allocation(size)
self._total_internal_frag += alloc - size
return alloc

def size_t(self):
return self.sizeof_pointer()

def sizeof_pointer(self):
return self._pointer_size

def sizeof_long(self):
return self._long_size

Expand All @@ -540,13 +546,13 @@ def next_power(self, size):
while (power <= size) :
power = power << 1
return power

def zset_random_level(self):
level = 1
rint = random.randint(0, 0xFFFF)
while (rint < ZSKIPLIST_P * 0xFFFF):
level += 1
rint = random.randint(0, 0xFFFF)
rint = random.randint(0, 0xFFFF)
if level < ZSKIPLIST_MAXLEVEL :
return level
else:
Expand Down
Binary file added tests/dumps/key-with-comma.rdb
Binary file not shown.
12 changes: 9 additions & 3 deletions tests/memprofiler_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
0,string,simplekey,72,string,7,7,
0,module,foo,101,ReJSON-RL,1,101,
"""

CSV_WITH_COMMA = """database,type,key,size_in_bytes,encoding,num_elements,len_largest_element,expiry
0,string,"a,""b"",c",64,string,4,4,
"""
class Stats(object):
def __init__(self):
self.sums = {}
Expand Down Expand Up @@ -54,7 +56,7 @@ def get_csv(dump_file_name):
buff = BytesIO()
callback = MemoryCallback(PrintAllKeys(buff, None, None), 64)
parser = RdbParser(callback)
parser.parse(os.path.join(os.path.dirname(__file__),
parser.parse(os.path.join(os.path.dirname(__file__),
'dumps', dump_file_name))
csv = buff.getvalue().decode()
return csv
Expand All @@ -75,6 +77,10 @@ def test_csv_with_module(self):
csv = get_csv('redis_40_with_module.rdb')
self.assertEquals(csv, CSV_WITH_MODULE)

def test_csv_key_with_comma(self):
csv = get_csv('key-with-comma.rdb')
self.assertEquals(csv, CSV_WITH_COMMA)

def test_expiry(self):
stats = get_stats('keys_with_expiry.rdb')

Expand All @@ -85,7 +91,7 @@ def test_expiry(self):
self.assertEquals(expiry.hour, 10)
self.assertEquals(expiry.minute, 11)
self.assertEquals(expiry.second, 12)
self.assertEquals(expiry.microsecond, 573000)
self.assertEquals(expiry.microsecond, 573000)

def test_len_largest_element(self):
stats = get_stats('ziplist_that_compresses_easily.rdb')
Expand Down