Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: speed improvements for primer pair hit building from single primer hits #99

Merged
merged 13 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 116 additions & 26 deletions prymer/offtarget/offtarget_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
""" # noqa: E501

import itertools
from collections import defaultdict
from contextlib import AbstractContextManager
from dataclasses import dataclass
from dataclasses import field
Expand All @@ -83,13 +84,15 @@
from types import TracebackType
from typing import Optional
from typing import Self
from typing import TypeAlias
from typing import TypeVar

from ordered_set import OrderedSet

from prymer.api.oligo import Oligo
from prymer.api.primer_pair import PrimerPair
from prymer.api.span import Span
from prymer.api.span import Strand
from prymer.offtarget.bwa import BWA_EXECUTABLE_NAME
from prymer.offtarget.bwa import BwaAlnInteractive
from prymer.offtarget.bwa import BwaHit
Expand All @@ -98,6 +101,9 @@

PrimerType = TypeVar("PrimerType", bound=Oligo)

ReferenceName: TypeAlias = str
"""Alias for a reference sequence name."""


@dataclass(init=True, frozen=True)
class OffTargetResult:
Expand Down Expand Up @@ -334,27 +340,78 @@ def _build_off_target_result(
result: OffTargetResult

# Get the mappings for the left primer and right primer respectively
p1: BwaResult = hits_by_primer[primer_pair.left_primer.bases]
p2: BwaResult = hits_by_primer[primer_pair.right_primer.bases]

# Get all possible amplicons from the left_primer_mappings and right_primer_mappings
# primer hits, filtering if there are too many for either
if p1.hit_count > self._max_primer_hits or p2.hit_count > self._max_primer_hits:
left_bwa_result: BwaResult = hits_by_primer[primer_pair.left_primer.bases]
right_bwa_result: BwaResult = hits_by_primer[primer_pair.right_primer.bases]

# If there are too many hits, this primer pair will not pass. Exit early.
if (
left_bwa_result.hit_count > self._max_primer_hits
or right_bwa_result.hit_count > self._max_primer_hits
):
result = OffTargetResult(primer_pair=primer_pair, passes=False)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we just return here? Is there any value to caching the result in this case, where it's so little work to compute the result?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caching is required by the unit test at

# Test that using the cache (or not) does not affect the results

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, ok. Perhaps I'll "fix" this in a separate PR.

else:
amplicons = self._to_amplicons(p1.hits, p2.hits, self._max_amplicon_size)
result = OffTargetResult(
primer_pair=primer_pair,
passes=self._min_primer_pair_hits <= len(amplicons) <= self._max_primer_pair_hits,
spans=amplicons if self._keep_spans else [],
left_primer_spans=(
[self._hit_to_span(h) for h in p1.hits] if self._keep_primer_spans else []
),
right_primer_spans=(
[self._hit_to_span(h) for h in p2.hits] if self._keep_primer_spans else []
),
if self._cache_results:
self._primer_pair_cache[primer_pair] = replace(result, cached=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If keeping this (and again on line 408), why not just set cached=self._cache_results at construction time, so there's no need to replace(cached=True?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the logic is that the returned object should have cached=False, because it's the first time it's been generated, and the cached object has cached=True for when it's retrieved from the cache. This jives with the description of cached at

cached: True if this result is part of a cache, False otherwise. This is useful for testing

        cached: True if this result is part of a cache, False otherwise.  This is useful for testing

I suggest we come back to the cache issue on a subsequent PR.

return result

# Map the hits by reference name
left_positive_hits: defaultdict[ReferenceName, list[BwaHit]] = defaultdict(list)
left_negative_hits: defaultdict[ReferenceName, list[BwaHit]] = defaultdict(list)
right_positive_hits: defaultdict[ReferenceName, list[BwaHit]] = defaultdict(list)
right_negative_hits: defaultdict[ReferenceName, list[BwaHit]] = defaultdict(list)

# Split the hits for left and right by reference name and strand
for hit in left_bwa_result.hits:
if hit.negative:
left_negative_hits[hit.refname].append(hit)
else:
left_positive_hits[hit.refname].append(hit)

for hit in right_bwa_result.hits:
if hit.negative:
right_negative_hits[hit.refname].append(hit)
else:
right_positive_hits[hit.refname].append(hit)

refnames: set[ReferenceName] = {
h.refname for h in itertools.chain(left_bwa_result.hits, right_bwa_result.hits)
}

# Build amplicons from hits on the same reference with valid relative orientation
amplicons: list[Span] = []
for refname in refnames:
amplicons.extend(
self._to_amplicons(
positive_hits=left_positive_hits[refname],
negative_hits=right_negative_hits[refname],
max_len=self._max_amplicon_size,
strand=Strand.POSITIVE,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure how to think about strand here. I think it's probably mostly irrelevant and could just always be set to positive?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the PrimerPair code, it's not explicitly enforced by a __post_init__ but is strongly assumed by e.g. calculate_amplicon_span that the left primer is left of the right primer. This implies that the left primer is on the positive strand and the right primer is on the negative strand. All amplicons are coded as on the positive strand (default value for Span, again, not explicit).

When a hit to a primer pair is in the same orientation as the primer pair definition, e.g. left positive and right negative, then it makes sense for the amplicon to also be on the positive strand.

If the hit is in the opposite orientation to the primer pair definition, the amplicon should be on the negative strand.

I think it makes sense to assign strandedness like this:

  • left primer + / right primer - = amplicon +
  • left primer + / left primer - = amplicon +
  • right primer + / left primer - = amplicon -
  • right primer + / right primer - = amplicon -

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed elsewhere, will leave questions of strand and L/L R/R hits to a separate PR.

)
)
amplicons.extend(
self._to_amplicons(
positive_hits=right_positive_hits[refname],
negative_hits=left_negative_hits[refname],
max_len=self._max_amplicon_size,
strand=Strand.NEGATIVE,
)
)

result = OffTargetResult(
primer_pair=primer_pair,
passes=self._min_primer_pair_hits <= len(amplicons) <= self._max_primer_pair_hits,
spans=amplicons if self._keep_spans else [],
left_primer_spans=(
[self._hit_to_span(h) for h in left_bwa_result.hits]
if self._keep_primer_spans
else []
),
right_primer_spans=(
[self._hit_to_span(h) for h in right_bwa_result.hits]
if self._keep_primer_spans
else []
),
)

if self._cache_results:
self._primer_pair_cache[primer_pair] = replace(result, cached=True)

Expand Down Expand Up @@ -410,19 +467,52 @@ def mappings_of(self, primers: list[PrimerType]) -> dict[str, BwaResult]:
return hits_by_primer

@staticmethod
def _to_amplicons(lefts: list[BwaHit], rights: list[BwaHit], max_len: int) -> list[Span]:
def _to_amplicons(
positive_hits: list[BwaHit], negative_hits: list[BwaHit], max_len: int, strand: Strand
) -> list[Span]:
"""Takes a set of hits for one or more left primers and right primers and constructs
amplicon mappings anywhere a left primer hit and a right primer hit align in F/R
orientation up to `maxLen` apart on the same reference. Primers may not overlap.

Args:
positive_hits: List of hits on the positive strand for one of the primers in the pair.
negative_hits: List of hits on the negative strand for the other primer in the pair.
max_len: Maximum length of amplicons to consider.
strand: The strand of the amplicon to generate. Set to Strand.POSITIVE if
`positive_hits` are for the left primer and `negative_hits` are for the right
primer. Set to Strand.NEGATIVE if `positive_hits` are for the right primer and
`negative_hits` are for the left primer.

Raises:
ValueError: If any of the positive hits are not on the positive strand, or any of the
negative hits are not on the negative strand. If hits are present on more than one
reference.
"""
amplicons: list[Span] = []
for h1, h2 in itertools.product(lefts, rights):
if h1.negative == h2.negative or h1.refname != h2.refname: # not F/R orientation
continue
if any(h.negative for h in positive_hits):
raise ValueError("Positive hits must be on the positive strand.")
if any(not h.negative for h in negative_hits):
raise ValueError("Negative hits must be on the negative strand.")

plus, minus = (h2, h1) if h1.negative else (h1, h2)
if minus.start > plus.end and (minus.end - plus.start + 1) <= max_len:
amplicons.append(Span(refname=plus.refname, start=plus.start, end=minus.end))
refnames: set[ReferenceName] = {
h.refname for h in itertools.chain(positive_hits, negative_hits)
}
if len(refnames) > 1:
raise ValueError(f"Hits are present on more than one reference: {refnames}")

amplicons: list[Span] = []
for positive_hit, negative_hit in itertools.product(positive_hits, negative_hits):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is good enough if it meets your performance needs @ameynert. As noted elsewhere, if you had many many hits to one reference, you could speed this up by sorting the positive and negative strand hits, looping over list indices, and avoiding whole swathes of pairs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've tested this separately and the speed improvement by sorting the pos/neg strand hits is significant when there are a large number of hits returned, e.g. in a scenario where the max mismatches parameters are relaxed, one of the primer sequences is getting 100+ hits, and there are multiple highly similar reference contigs in the indexed genome that's being searched.

if (
negative_hit.start > positive_hit.end
and negative_hit.end - positive_hit.start + 1 <= max_len
):
amplicons.append(
Span(
refname=positive_hit.refname,
start=positive_hit.start,
end=negative_hit.end,
strand=strand,
)
)

return amplicons

Expand Down
142 changes: 117 additions & 25 deletions tests/offtarget/test_offtarget.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from prymer.offtarget.bwa import BWA_EXECUTABLE_NAME
from prymer.offtarget.bwa import BwaHit
from prymer.offtarget.bwa import BwaResult
from prymer.offtarget.bwa import Query
from prymer.offtarget.offtarget_detector import OffTargetDetector
from prymer.offtarget.offtarget_detector import OffTargetResult

Expand Down Expand Up @@ -171,68 +172,159 @@ def test_mappings_of(ref_fasta: Path, cache_results: bool) -> None:
assert results_dict[p2.bases].hits[0] == expected_hit2


# Test building an OffTargetResult for a primer pair with left/right hits on different references
# and in different orientations
def test_build_off_target_result(ref_fasta: Path) -> None:
hits_by_primer: dict[str, BwaResult] = {
"A" * 100: BwaResult(
query=Query(
id="left",
bases="A" * 100,
),
hit_count=3,
hits=[
BwaHit.build("chr1", 100, False, "100M", 0),
BwaHit.build("chr1", 400, True, "100M", 0),
BwaHit.build("chr2", 100, False, "100M", 0),
BwaHit.build("chr3", 700, True, "100M", 0),
],
),
"C" * 100: BwaResult(
query=Query(
id="right",
bases="C" * 100,
),
hit_count=2,
hits=[
BwaHit.build("chr1", 800, False, "100M", 0),
BwaHit.build("chr1", 200, True, "100M", 0),
BwaHit.build("chr3", 600, False, "100M", 0),
],
),
}

primer_pair = PrimerPair(
left_primer=Oligo(
tm=50,
penalty=0,
span=Span(refname="chr10", start=100, end=199, strand=Strand.POSITIVE),
bases="A" * 100,
),
right_primer=Oligo(
tm=50,
penalty=0,
span=Span(refname="chr10", start=300, end=399, strand=Strand.NEGATIVE),
bases="C" * 100,
),
amplicon_tm=100,
penalty=0,
)

with _build_detector(
ref_fasta=ref_fasta, max_primer_hits=10, max_primer_pair_hits=10
) as detector:
off_target_result: OffTargetResult = detector._build_off_target_result(
primer_pair=primer_pair,
hits_by_primer=hits_by_primer,
)

assert off_target_result.spans == [
Span(refname="chr1", start=100, end=299, strand=Strand.POSITIVE),
Span(refname="chr3", start=600, end=799, strand=Strand.NEGATIVE),
]


# Test that using the cache (or not) does not affect the results
@pytest.mark.parametrize("cache_results", [True, False])
@pytest.mark.parametrize(
"test_id, left, right, expected",
"test_id, positive, negative, strand, expected",
[
(
"No mappings - different refnames",
BwaHit.build("chr1", 100, False, "100M", 0),
BwaHit.build("chr2", 100, True, "100M", 0),
[],
),
(
"No mappings - FF pair",
BwaHit.build("chr1", 100, True, "100M", 0),
BwaHit.build("chr1", 100, True, "100M", 0),
[],
),
(
"No mappings - RR pair",
BwaHit.build("chr1", 100, False, "100M", 0),
BwaHit.build("chr1", 100, False, "100M", 0),
[],
),
(
"No mappings - overlapping primers (1bp overlap)",
BwaHit.build("chr1", 100, False, "100M", 0),
BwaHit.build("chr1", 199, True, "100M", 0),
Strand.POSITIVE,
[],
),
(
"No mappings - amplicon size too big (1bp too big)",
BwaHit.build("chr1", 100, False, "100M", 0),
BwaHit.build("chr1", 151, True, "100M", 0),
Strand.POSITIVE,
[],
),
(
"Mappings - FR pair (R1 F)",
BwaHit.build("chr1", 100, False, "100M", 0),
BwaHit.build("chr1", 200, True, "100M", 0),
[Span(refname="chr1", start=100, end=299)],
Strand.POSITIVE,
[Span(refname="chr1", start=100, end=299, strand=Strand.POSITIVE)],
),
(
"Mappings - FR pair (R1 R)",
BwaHit.build("chr1", 200, True, "100M", 0),
BwaHit.build("chr1", 100, False, "100M", 0),
[Span(refname="chr1", start=100, end=299)],
BwaHit.build("chr1", 200, True, "100M", 0),
Strand.NEGATIVE,
[Span(refname="chr1", start=100, end=299, strand=Strand.NEGATIVE)],
),
],
)
def test_to_amplicons(
ref_fasta: Path,
test_id: str,
left: BwaHit,
right: BwaHit,
positive: BwaHit,
negative: BwaHit,
strand: Strand,
expected: list[Span],
cache_results: bool,
) -> None:
with _build_detector(ref_fasta=ref_fasta, cache_results=cache_results) as detector:
actual = detector._to_amplicons(lefts=[left], rights=[right], max_len=250)
actual = detector._to_amplicons(
positive_hits=[positive], negative_hits=[negative], max_len=250, strand=strand
)
assert actual == expected, test_id


@pytest.mark.parametrize("cache_results", [True, False])
@pytest.mark.parametrize(
"positive, negative, expected_error",
[
(
# No mappings - different refnames
BwaHit.build("chr1", 100, False, "100M", 0),
BwaHit.build("chr2", 100, True, "100M", 0),
"Hits are present on more than one reference",
),
(
# No mappings - FF pair
BwaHit.build("chr1", 100, True, "100M", 0),
BwaHit.build("chr1", 100, True, "100M", 0),
"Positive hits must be on the positive strand",
),
(
# No mappings - RR pair
BwaHit.build("chr1", 100, False, "100M", 0),
BwaHit.build("chr1", 100, False, "100M", 0),
"Negative hits must be on the negative strand",
),
],
)
def test_to_amplicons_value_error(
ref_fasta: Path,
positive: BwaHit,
negative: BwaHit,
expected_error: str,
cache_results: bool,
) -> None:
with (
_build_detector(ref_fasta=ref_fasta, cache_results=cache_results) as detector,
pytest.raises(ValueError, match=expected_error),
):
detector._to_amplicons(
positive_hits=[positive], negative_hits=[negative], max_len=250, strand=Strand.POSITIVE
)


def test_generic_filter(ref_fasta: Path) -> None:
"""
This test isn't intended to validate any runtime assertions, but is a minimal example for the
Expand Down
Loading