Skip to content

Commit

Permalink
Expose an unload method for SBT nodes (#784)
Browse files Browse the repository at this point in the history
* add unload method for SBT nodes
* default unload_data to False, except in sourmash search CLI
  • Loading branch information
luizirber authored Jan 15, 2020
1 parent 2f05b07 commit cba11c5
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 8 deletions.
3 changes: 2 additions & 1 deletion sourmash/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,8 @@ def search(args):
# do the actual search
results = search_databases(query, databases,
args.threshold, args.containment,
args.best_only, args.ignore_abundance)
args.best_only, args.ignore_abundance,
unload_data=True)

n_matches = len(results)
if args.best_only:
Expand Down
19 changes: 17 additions & 2 deletions sourmash/sbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,8 @@ def add_node(self, node):
def find(self, search_fn, *args, **kwargs):
"Search the tree using `search_fn`."

unload_data = kwargs.get("unload_data", False)

# initialize search queue with top node of tree
matches = []
visited, queue = set(), [0]
Expand Down Expand Up @@ -261,6 +263,10 @@ def find(self, search_fn, *args, **kwargs):
queue.insert(0, c.pos)
else: # bfs
queue.extend(c.pos for c in self.children(node_p))

if unload_data:
node_g.unload()

return matches

def search(self, query, *args, **kwargs):
Expand All @@ -272,6 +278,7 @@ def search(self, query, *args, **kwargs):
ignore_abundance = kwargs['ignore_abundance']
do_containment = kwargs['do_containment']
best_only = kwargs['best_only']
unload_data = kwargs.get('unload_data', False)

search_fn = search_minhashes
query_match = lambda x: query.similarity(
Expand All @@ -296,7 +303,7 @@ def search(self, query, *args, **kwargs):

# now, search!
results = []
for leaf in self.find(search_fn, tree_query, threshold):
for leaf in self.find(search_fn, tree_query, threshold, unload_data=unload_data):
similarity = query_match(leaf.data)

# tree search should always/only return matches above threshold
Expand All @@ -312,6 +319,8 @@ def gather(self, query, *args, **kwargs):
# use a tree search function that keeps track of its best match.
search_fn = GatherMinHashes().search

unload_data = kwargs.get('unload_data', False)

leaf = next(iter(self.leaves()))
tree_mh = leaf.data.minhash
scaled = tree_mh.scaled
Expand All @@ -320,7 +329,7 @@ def gather(self, query, *args, **kwargs):
threshold = threshold_bp / (len(query.minhash) * scaled)

results = []
for leaf in self.find(search_fn, query, 0.0):
for leaf in self.find(search_fn, query, threshold, unload_data=unload_data):
leaf_e = leaf.data.minhash
similarity = query.minhash.containment_ignore_maxhash(leaf_e)
if similarity > 0.0:
Expand Down Expand Up @@ -1061,6 +1070,9 @@ def data(self):
def data(self, new_data):
self._data = new_data

def unload(self):
self._data = None

@staticmethod
def load(info, storage=None):
new_node = Node(info['factory'],
Expand Down Expand Up @@ -1115,6 +1127,9 @@ def data(self):
def data(self, new_data):
self._data = new_data

def unload(self):
self._data = None

def save(self, path):
# We need to do this tempfile dance because khmer only load
# data from files.
Expand Down
5 changes: 2 additions & 3 deletions sourmash/sbtmh.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def search_sbt_index(tree, query, threshold):
for match_sig, similarity in search_sbt_index(tree, query, threshold):
...
"""
for leaf in tree.find(search_minhashes, query, threshold):
for leaf in tree.find(search_minhashes, query, threshold, unload_data=True):
similarity = query.similarity(leaf.data)
yield leaf.data, similarity

Expand Down Expand Up @@ -175,8 +175,7 @@ def search(self, node, sig, threshold, results=None):
return 0


def search_minhashes_containment(node, sig, threshold,
results=None, downsample=True):
def search_minhashes_containment(node, sig, threshold, results=None, downsample=True):
mins = sig.minhash.get_mins()

if isinstance(node, SigLeaf):
Expand Down
5 changes: 3 additions & 2 deletions sourmash/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,15 @@ def format_bp(bp):


def search_databases(query, databases, threshold, do_containment, best_only,
ignore_abundance):
ignore_abundance, unload_data=False):
results = []
found_md5 = set()
for (obj, filename, filetype) in databases:
search_iter = obj.search(query, threshold=threshold,
do_containment=do_containment,
ignore_abundance=ignore_abundance,
best_only=best_only)
best_only=best_only,
unload_data=unload_data)
for (similarity, match, filename) in search_iter:
md5 = match.md5sum()
if md5 not in found_md5:
Expand Down

0 comments on commit cba11c5

Please sign in to comment.