Skip to content

Commit 09ba600

Browse files
committed
Merge branch 'fix-nonadditive' into mypy
2 parents e99003a + d5dd802 commit 09ba600

14 files changed

+736
-577
lines changed

pgscatalog.core/mypy.ini

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
../mypy.ini

pgscatalog.core/src/pgscatalog/core/__init__.py

-8
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,7 @@
1111
ScoringFiles,
1212
ScoringFile,
1313
NormalisedScoringFile,
14-
ScoreVariant,
1514
GenomeBuild,
16-
TargetVariants,
17-
TargetVariant,
18-
TargetType,
1915
RelabelArgs,
2016
relabel,
2117
relabel_write,
@@ -32,15 +28,11 @@
3228
__all__ = [
3329
"ScoringFiles",
3430
"ScoringFile",
35-
"ScoreVariant",
3631
"Config",
3732
"GenomeBuild",
3833
"CatalogQuery",
3934
"ScoreQueryResult",
4035
"CatalogCategory",
41-
"TargetVariant",
42-
"TargetVariants",
43-
"TargetType",
4436
"NormalisedScoringFile",
4537
"RelabelArgs",
4638
"relabel",

pgscatalog.core/src/pgscatalog/core/cli/_combine.py

-8
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
standardised format, combining them, and calculating some statistics. Only really
33
useful for the CLI, not good for importing elsewhere."""
44

5-
import collections
65
import csv
76
import functools
87
import gzip
@@ -13,13 +12,6 @@
1312
logger = logging.getLogger(__name__)
1413

1514

16-
def get_variant_log(batch):
17-
# these statistics can only be generated while iterating through variants
18-
n_variants = collections.Counter("n_variants" for item in batch)
19-
hm_source = collections.Counter(getattr(item, "hm_source") for item in batch)
20-
return n_variants + hm_source
21-
22-
2315
class DataWriter:
2416
def __init__(self, filename):
2517
self.filename = filename

pgscatalog.core/src/pgscatalog/core/cli/combine_cli.py

+36-14
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
import argparse
2-
import json
32
import logging
43
import pathlib
54
import sys
65
import textwrap
6+
from typing import Optional
77

88
from tqdm import tqdm
9-
from ..lib import GenomeBuild, ScoringFile, ScoreVariant, EffectTypeError
109

11-
from ._combine import get_variant_log, TextFileWriter
10+
from ..lib.models import ScoreLog, ScoreLogs, ScoreVariant
11+
from ..lib import GenomeBuild, ScoringFile, EffectTypeError
12+
13+
from ._combine import TextFileWriter
1214

1315
logger = logging.getLogger(__name__)
1416

@@ -40,6 +42,12 @@ def run():
4042
target_build = GenomeBuild.from_string(args.target_build)
4143

4244
for x in scoring_files:
45+
if x.genome_build is None and target_build is not None:
46+
raise ValueError(
47+
f"Can't combine files with missing build in "
48+
f"header when requesting {target_build=}"
49+
)
50+
4351
if x.genome_build != target_build and not args.liftover:
4452
raise ValueError(
4553
f"Can't combine scoring file with genome build {x.genome_build!r} when {target_build=} without --liftover"
@@ -61,10 +69,10 @@ def run():
6169
liftover_kwargs = {"liftover": False}
6270

6371
n_finished = 0
64-
good_scores = []
65-
6672
for scorefile in tqdm(scoring_files, total=len(scoring_files)):
6773
logger.info(f"Processing {scorefile.pgs_id}")
74+
normalised_score: Optional[list[ScoreVariant]] = None
75+
is_compatible = True
6876
try:
6977
normalised_score = list(
7078
scorefile.normalise(
@@ -77,20 +85,38 @@ def run():
7785
logger.warning(
7886
f"Unsupported non-additive effect types in {scorefile=}, skipping"
7987
)
80-
continue
88+
is_compatible = False
8189
else:
8290
# TODO: go back to parallel execution + write to multiple files
8391
writer = TextFileWriter(compress=compress_output, filename=out_path)
8492

8593
# model_dump returns a dict with a subset of keys
8694
dumped_variants = (
87-
x.model_dump(include=ScoreVariant.output_fields)
95+
x.model_dump(include=set(ScoreVariant.output_fields))
8896
for x in normalised_score
8997
)
9098
writer.write(dumped_variants)
91-
variant_log.append(get_variant_log(normalised_score))
9299
n_finished += 1
93-
good_scores.append(scorefile)
100+
finally:
101+
# grab essential information only for the score log
102+
if normalised_score is not None:
103+
log_variants = (
104+
x.model_dump(include={"accession", "row_nr", "hm_source"})
105+
for x in normalised_score
106+
)
107+
else:
108+
log_variants = None
109+
110+
log = ScoreLog(
111+
header=scorefile.header,
112+
variants=log_variants,
113+
compatible_effect_type=is_compatible,
114+
)
115+
if log.variants_are_missing:
116+
logger.warning(
117+
f"{log.variant_count_difference} fewer variants in output compared to original file"
118+
)
119+
variant_log.append(log)
94120

95121
if n_finished == 0:
96122
raise ValueError(
@@ -100,14 +126,10 @@ def run():
100126
if n_finished != len(scoring_files):
101127
logger.warning(f"{len(scoring_files) - n_finished} scoring files were skipped")
102128

103-
score_log = []
104-
for sf, log in zip(good_scores, variant_log, strict=True):
105-
score_log.append(sf.get_log(variant_log=log))
106-
107129
log_out_path = pathlib.Path(args.outfile).parent / args.logfile
108130
with open(log_out_path, "w") as f:
109131
logger.info(f"Writing log to {f.name}")
110-
json.dump(score_log, f, indent=4)
132+
f.write(ScoreLogs(variant_log).model_dump_json())
111133

112134
logger.info("Combining complete")
113135

pgscatalog.core/src/pgscatalog/core/lib/__init__.py

-6
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
from .genomebuild import GenomeBuild
44
from .catalogapi import ScoreQueryResult, CatalogQuery, CatalogCategory
55
from .scorefiles import ScoringFiles, ScoringFile, NormalisedScoringFile
6-
from .scorevariant import ScoreVariant
7-
from .targetvariants import TargetVariants, TargetVariant, TargetType
86
from ._relabel import RelabelArgs, relabel, relabel_write
97
from ._sortpaths import effect_type_keyfunc, chrom_keyfunc
108
from .pgsexceptions import (
@@ -51,15 +49,11 @@
5149
"SamplesheetFormatError",
5250
"ScoringFiles",
5351
"ScoringFile",
54-
"ScoreVariant",
5552
"Config",
5653
"GenomeBuild",
5754
"CatalogQuery",
5855
"ScoreQueryResult",
5956
"CatalogCategory",
60-
"TargetVariant",
61-
"TargetVariants",
62-
"TargetType",
6357
"NormalisedScoringFile",
6458
"RelabelArgs",
6559
"relabel",

pgscatalog.core/src/pgscatalog/core/lib/_normalise.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,15 @@ def normalise(
3232
if liftover:
3333
variants = lift(
3434
scoring_file=scoring_file,
35-
harmonised=scoring_file.harmonised,
35+
harmonised=scoring_file.is_harmonised,
3636
current_build=scoring_file.genome_build,
3737
target_build=target_build,
3838
chain_dir=chain_dir,
3939
)
4040
else:
4141
variants = scoring_file.variants
4242

43-
variants = remap_harmonised(variants, scoring_file.harmonised, target_build)
43+
variants = remap_harmonised(variants, scoring_file.is_harmonised, target_build)
4444

4545
if drop_missing:
4646
variants = drop_hla(variants)
@@ -100,7 +100,7 @@ def check_duplicates(variants):
100100
def drop_hla(variants):
101101
"""Drop HLA alleles from a list of ScoreVariants
102102
103-
>>> from .scorevariant import ScoreVariant
103+
>>> from .models import ScoreVariant
104104
>>> variant = ScoreVariant(**{"effect_allele": "A", "effect_weight": 5, "accession": "test", "row_nr": 0, "chr_name": "1", "chr_position": 1})
105105
>>> list(drop_hla([variant])) # doctest: +ELLIPSIS
106106
[ScoreVariant(..., effect_allele=Allele(allele='A', is_snp=True), ...
@@ -127,7 +127,7 @@ def drop_hla(variants):
127127
def assign_other_allele(variants):
128128
"""Check if there's more than one possible other allele, remove if true
129129
130-
>>> from .scorevariant import ScoreVariant
130+
>>> from .models import ScoreVariant
131131
>>> variant = ScoreVariant(**{"chr_position": 1, "rsID": None, "chr_name": "1", "effect_allele": "A", "effect_weight": 5, "other_allele": "A", "row_nr": 0, "accession": "test"})
132132
>>> list(assign_other_allele([variant]))[0] # doctest: +ELLIPSIS
133133
ScoreVariant(..., effect_allele=Allele(allele='A', is_snp=True), other_allele=Allele(allele='A', is_snp=True), ...)
@@ -154,7 +154,7 @@ def remap_harmonised(variants, harmonised, target_build):
154154
In this case chr_name, chr_position, and other allele are missing.
155155
Perhaps authors submitted rsID and effect allele originally:
156156
157-
>>> from .scorevariant import ScoreVariant
157+
>>> from .models import ScoreVariant
158158
>>> variant = ScoreVariant(**{"chr_position": 1, "rsID": None, "chr_name": "2", "effect_allele": "A", "effect_weight": 5, "accession": "test", "hm_chr": "1", "hm_pos": 100, "hm_rsID": "testrsid", "hm_inferOtherAllele": "A", "row_nr": 0})
159159
>>> variant
160160
ScoreVariant(..., effect_allele=Allele(allele='A', is_snp=True), other_allele=None, ...
@@ -184,7 +184,7 @@ def check_effect_allele(variants, drop_missing=False):
184184
"""
185185
Odd effect allele:
186186
187-
>>> from .scorevariant import ScoreVariant
187+
>>> from .models import ScoreVariant
188188
>>> variant = ScoreVariant(**{"effect_allele": "Z", "effect_weight": 5, "accession": "test", "row_nr": 0, "chr_name": "1", "chr_position": 1})
189189
>>> list(check_effect_allele([variant], drop_missing=True)) # doctest: +ELLIPSIS
190190
[]

pgscatalog.core/src/pgscatalog/core/lib/_read.py

+18-18
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from xopen import xopen
99

10-
from .scorevariant import ScoreVariant
10+
from .models import ScoreVariant
1111

1212
logger = logging.getLogger(__name__)
1313

@@ -43,23 +43,6 @@ def read_rows_lazy(
4343
row_nr += 1
4444

4545

46-
def generate_header_lines(f):
47-
"""Header lines in a PGS Catalog scoring file are structured like:
48-
49-
#pgs_id=PGS000348
50-
#pgs_name=PRS_PrCa
51-
52-
Files can be big, so we want to only read header lines and stop immediately
53-
"""
54-
for line in f:
55-
if line.startswith("#"):
56-
if "=" in line:
57-
yield line.strip()
58-
else:
59-
# stop reading lines
60-
break
61-
62-
6346
def get_columns(path):
6447
"""Grab column labels from a PGS Catalog scoring file. line_no is useful to skip the header"""
6548
with xopen(path, mode="rt") as f:
@@ -87,6 +70,23 @@ def detect_wide(cols: list[str]) -> bool:
8770
return False
8871

8972

73+
def generate_header_lines(f):
74+
"""Header lines in a PGS Catalog scoring file are structured like:
75+
76+
#pgs_id=PGS000348
77+
#pgs_name=PRS_PrCa
78+
79+
Files can be big, so we want to only read header lines and stop immediately
80+
"""
81+
for line in f:
82+
if line.startswith("#"):
83+
if "=" in line:
84+
yield line.strip()
85+
else:
86+
# stop reading lines
87+
break
88+
89+
9090
def read_header(path: pathlib.Path) -> dict:
9191
"""Parses the header of a PGS Catalog format scoring file into a dictionary"""
9292
header = {}

0 commit comments

Comments
 (0)