Skip to content

Commit

Permalink
Support A Single Anndata File Input for scRNA Prior (#204)
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffquinn-msk authored Nov 21, 2024
1 parent 48dd08f commit 849b112
Show file tree
Hide file tree
Showing 19 changed files with 229 additions and 116 deletions.
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM pytorch/pytorch:latest
FROM bitnami/pytorch:latest

USER root

Expand Down
45 changes: 0 additions & 45 deletions docs/fine_mapping_workflow.rst

This file was deleted.

2 changes: 1 addition & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ Contents
nextflow
data_format
example_workflow
fine_mapping_workflow
reference_scrna
command_line_interface
api
Github <https://github.com/tansey-lab/bayestme>
Expand Down
23 changes: 23 additions & 0 deletions docs/reference_scrna.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
Reference scRNA
===============

BayesTME can take advantage of reference scRNA data in anndata h5ad format.

This will set a prior on the expression profiles of cell types in the spatial data, it will
also consequently set the number of celltypes.

Provide your anndata archive to ``deconvolve`` like so:

.. code::
deconvolve --expression-truth companion_scRNA.h5ad \
--reference-scrna-celltype-column celltype \
--reference-scrna-sample-column sample
``deconvolve`` will consider all the samples jointly to determine
a prior on celltype expression profiles.

We assume the values of ``--reference-scrna-celltype-column`` and
``--reference-scrna-sample-column`` are attributes of the ``obs`` table in your
anndata object.
4 changes: 3 additions & 1 deletion main.nf
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ workflow {
BAYESTME_BASIC_VISIUM_ANALYSIS(
Channel.fromList( [tuple([id: "sample", single_end: false],
file(params.input),
params.n_cell_types) ])
params.n_cell_types,
params.reference_scrna ? [] : file(params.reference_scrna, checkIfExists: true)
) ])
)
}
8 changes: 8 additions & 0 deletions nextflow.config
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,14 @@ params {
input = null
n_cell_types = null
outdir = null
n_deconvolution_genes = 1000
// Filter genes occuring in more than this percentage of spots
spot_threshold = 0.9
// filter ribosomal genes
filter_ribosomal_genes = true

reference_scrna_sample_column = 'sample'
reference_scrna_celltype_column = 'celltype'
}

process {
Expand Down
4 changes: 4 additions & 0 deletions nextflow/modules/bayestme/bayestme_deconvolution/main.nf
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ process BAYESTME_DECONVOLUTION {
deconvolve --adata ${adata} \
--adata-output "${prefix}/dataset_deconvolved.h5ad" \
--output "${prefix}/deconvolution_samples.h5" \
--reference-scrna-celltype-column ${params.reference_scrna_sample_column} \
--reference-scrna-sample-column ${params.reference_scrna_gene_column} \
${spatial_smoothing_parameter_flag} \
${n_components_flag} \
${expression_truth_flag} \
Expand All @@ -50,6 +52,8 @@ process BAYESTME_DECONVOLUTION {
mkdir plots
plot_deconvolution --adata "${prefix}/dataset_deconvolved_marker_genes.h5ad" \
--output-dir "${prefix}/plots" \
--reference-scrna-celltype-column ${params.reference_scrna_sample_column} \
--reference-scrna-sample-column ${params.reference_scrna_gene_column} \
${expression_truth_flag} \
${args3}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ include { BAYESTME_SPATIAL_TRANSCRIPTIONAL_PROGRAMS } from '../../../modules/bay
workflow BAYESTME_BASIC_VISIUM_ANALYSIS {

take:
ch_input // channel: [ val(meta), path(spaceranger_dir), val(n_cell_types) ]
ch_input // channel: [ val(meta), path(spaceranger_dir), val(n_cell_types), path(reference_scrna) ]

main:

Expand All @@ -16,11 +16,10 @@ workflow BAYESTME_BASIC_VISIUM_ANALYSIS {
filter_genes_input = BAYESTME_LOAD_SPACERANGER.out.adata.map { tuple(
it[0],
it[1],
true,
1000,
0.9,
[])
}
params.filter_ribosomal_genes.toBoolean(),
params.n_deconvolution_genes,
params.spot_threshold)
}.join( ch_input.map { tuple(it[0], it[3] ? it[3] : []) } )

BAYESTME_FILTER_GENES( filter_genes_input )

Expand Down
17 changes: 13 additions & 4 deletions src/bayestme/cli/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,20 @@ def add_deconvolution_arguments(parser: argparse.ArgumentParser):
)
parser.add_argument(
"--expression-truth",
help="Use expression ground truth from one or matched samples that have been processed "
"with the seurat companion scRNA fine mapping workflow. This flag can be provided multiple times"
" for multiple matched samples.",
help="Matched scRNA data in h5ad format, will be used to enforce a prior on celltypes and expression.",
type=str,
default=None,
)
parser.add_argument(
"--reference-scrna-celltype-column",
help="The name of the column with celltype id in the matched scRNA anndata.",
type=str,
default=None,
)
parser.add_argument(
"--reference-scrna-sample-column",
help="The name of the column with sample id in the matched scRNA anndata.",
type=str,
action="append",
default=None,
)
parser.add_argument(
Expand Down
23 changes: 9 additions & 14 deletions src/bayestme/cli/deconvolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import logging
import os

import anndata

import bayestme
import bayestme.cli.common
import bayestme.data
Expand Down Expand Up @@ -57,22 +59,15 @@ def main():
rng = create_rng(args.seed)

if args.expression_truth:
expression_truth_samples = []
for fn in args.expression_truth:
expression_truth_samples.append(
bayestme.expression_truth.load_expression_truth(dataset, fn)
)
n_components = expression_truth_samples[0].shape[0]

if not len(set([x.shape for x in expression_truth_samples])) == 1:
raise RuntimeError(
"Multiple expression truth arrays were provided, and they have different dimensions. "
"Please ensure --expression-truth arguments are correct."
expression_truth = (
bayestme.expression_truth.calculate_celltype_profile_prior_from_adata(
args.expression_truth,
dataset.gene_names,
celltype_column=args.reference_scrna_celltype_column,
sample_column=args.reference_scrna_sample_column,
)

expression_truth = bayestme.expression_truth.combine_multiple_expression_truth(
expression_truth_samples
)
n_components = expression_truth.shape[0]
else:
expression_truth = None
n_components = args.n_components
Expand Down
2 changes: 1 addition & 1 deletion src/bayestme/cli/deconvolve_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def test_deconvolve_with_expression_truth():
"bayestme.deconvolution.sample_from_posterior"
) as deconvolve_mock:
with mock.patch(
"bayestme.expression_truth.load_expression_truth"
"bayestme.expression_truth.calculate_celltype_profile_prior_from_adata"
) as load_expression_truth_mock:
expression_truth = np.zeros((9, 10))
load_expression_truth_mock.return_value = expression_truth
Expand Down
2 changes: 1 addition & 1 deletion src/bayestme/cli/filter_genes.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def get_parser():
)
parser.add_argument(
"--expression-truth",
help="Filter out genes not found in all expression truth datasets.",
help="Anndata h5ad file with reference scRNA data.",
type=str,
action="append",
default=None,
Expand Down
33 changes: 20 additions & 13 deletions src/bayestme/cli/plot_deconvolution.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import argparse
import logging

import anndata
import pandas
import bayestme.log_config
import bayestme.plot.deconvolution
Expand All @@ -25,11 +27,20 @@ def get_parser():
)
parser.add_argument(
"--expression-truth",
help="Use expression ground truth from one or matched samples that have been processed "
"with the seurat companion scRNA fine mapping workflow. This flag can be provided multiple times"
" for multiple matched samples.",
help="Use expression ground truth from one or matched scRNA datasets.",
type=str,
default=None,
)
parser.add_argument(
"--reference-scrna-celltype-column",
help="The name of the column with celltype id in the matched scRNA anndata.",
type=str,
default=None,
)
parser.add_argument(
"--reference-scrna-sample-column",
help="The name of the column with sample id in the matched scRNA anndata.",
type=str,
action="append",
default=None,
)
bayestme.log_config.add_logging_args(parser)
Expand All @@ -49,15 +60,11 @@ def main():
cell_type_names = None

if args.expression_truth is not None:
cell_type_names = pandas.read_csv(
args.expression_truth[0], index_col=0
).columns.tolist()

# pad cell type names up to length stdata.n_cell_types
i = 1
while len(cell_type_names) < stdata.n_cell_types:
cell_type_names.append(f"unknown_{i}")
i += 1
cell_type_names = sorted(
anndata.read_h5ad(args.expression_truth)
.obs[args.reference_scrna_celltype_column]
.unique()
)

bayestme.plot.deconvolution.plot_deconvolution(
stdata=stdata, output_dir=args.output_dir, cell_type_names=cell_type_names
Expand Down
19 changes: 15 additions & 4 deletions src/bayestme/cli/plot_deconvolution_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import tempfile
from unittest import mock

import anndata
import numpy as np
import pandas

Expand Down Expand Up @@ -153,14 +154,20 @@ def test_plot_deconvolution_with_cell_type_names_from_exp_truth():
stdata=dataset, result=deconvolve_results
)

fake_expression_truth = pandas.DataFrame(np.random.poisson(10, size=(50, 5)))
fake_expression_truth.index = np.array(["gene{}".format(x) for x in range(n_genes)])
fake_expression_truth.columns = ["type 1", "type 2", "type 3", "type 4", "type 5"]
fake_expression_truth = anndata.AnnData(
X=np.random.poisson(10, size=(50, 5)),
obs=pandas.DataFrame(
{
"sample": ["sample"] * 50,
"celltype": ["type 1", "type 2", "type 3", "type 4", "type 5"] * 10,
}
),
)

tmpdir = tempfile.mkdtemp()

fake_expression_truth_fn = os.path.join(tmpdir, "expression_truth.csv")
fake_expression_truth.to_csv(fake_expression_truth_fn)
fake_expression_truth.write_h5ad(fake_expression_truth_fn)

stdata_fn = os.path.join(tmpdir, "data.h5")
deconvolve_results_fn = os.path.join(tmpdir, "deconvolve.h5")
Expand All @@ -175,6 +182,10 @@ def test_plot_deconvolution_with_cell_type_names_from_exp_truth():
tmpdir,
"--expression-truth",
fake_expression_truth_fn,
"--reference-scrna-celltype-column",
"celltype",
"--reference-scrna-sample-column",
"sample",
]

with mock.patch(
Expand Down
43 changes: 42 additions & 1 deletion src/bayestme/expression_truth.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from typing import List
from typing import List, Optional

import numpy as np
import pandas
import pyro
import scanpy
import torch
import anndata
from anndata import AnnData
from pyro import distributions as dist
from pyro.infer import MCMC, NUTS

Expand Down Expand Up @@ -93,3 +96,41 @@ def load_expression_truth(stdata: data.SpatialExpressionDataset, seurat_output:
expression_truth = phi_k_truth_normalized.T

return expression_truth


def calculate_celltype_profile_prior_from_adata(
fn, gene_names, celltype_column: str, sample_column: Optional[str] = None
):
ad = anndata.read_h5ad(fn)
ad = ad[ad.obs[celltype_column].notnull()].copy()
ad = ad[:, gene_names].copy()
if sample_column is not None:
ad = ad[ad.obs[sample_column].notnull()].copy()

results = []
for sample_id in ad.obs[sample_column].unique():
ad_sample = ad[ad.obs[sample_column] == sample_id]
mean_expression = scanpy.get.aggregate(ad_sample, celltype_column, "sum")
# sort ad on obs names
order = np.argsort(mean_expression.obs_names)
mean_expression = mean_expression.layers["sum"][order]

mean_expression = (mean_expression + 1) / (
mean_expression.sum(axis=1)[:, None] + mean_expression.shape[1]
)

mean_expression = np.clip(mean_expression, 1e-10, None)

results.append(mean_expression)

return combine_multiple_expression_truth(results)
else:
mean_expression = scanpy.get.aggregate(ad, celltype_column, "sum").layers["sum"]

mean_expression = (mean_expression + 1) / (
mean_expression.sum(axis=1)[:, None] + mean_expression.shape[1]
)

mean_expression = np.clip(mean_expression, 1e-10, None)

return mean_expression
Loading

0 comments on commit 849b112

Please sign in to comment.