Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle Anndata Possibly Containing Sparse Arrays #139

Merged
merged 2 commits into from
Jan 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions src/bayestme/bleeding_correction.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,7 +669,7 @@ def plot_before_after_cleanup(
fig=fig,
ax=ax2,
coords=after_correction.positions_tissue,
values=after_correction.reads[:, gene_idx_after],
values=after_correction.counts[:, gene_idx_after],
layout=after_correction.layout,
colormap=cmap,
plotting_coordinates=after_correction.positions,
Expand Down Expand Up @@ -698,19 +698,19 @@ def plot_bleeding(

# calculate bleeding ratio
all_counts = before_correction.raw_counts.sum()
tissue_counts = after_correction.reads.sum()
tissue_counts = after_correction.counts.sum()
bleed_ratio = 1 - tissue_counts / all_counts
logger.info("\t {:.3f}% bleeds out".format(bleed_ratio * 100))

# plot
plot_intissue = np.ones_like(raw_count) * np.nan
plot_intissue[before_correction.tissue_mask] = after_correction.reads[:, gene_idx]
plot_intissue[before_correction.tissue_mask] = after_correction.counts[:, gene_idx]
plot_outside = raw_count.copy().astype(float)
plot_outside[before_correction.tissue_mask] = np.nan

plot_data = [
before_correction.reads[:, gene_idx],
after_correction.reads[:, gene_idx],
before_correction.counts[:, gene_idx],
after_correction.counts[:, gene_idx],
before_correction.raw_counts[:, gene_idx][~before_correction.tissue_mask],
]
coords = [
Expand Down Expand Up @@ -810,7 +810,7 @@ def create_top_n_gene_bleeding_plots(
n_genes: int = 10,
):
top_gene_names = utils.get_top_gene_names_by_stddev(
reads=corrected_dataset.reads,
reads=corrected_dataset.counts,
gene_names=corrected_dataset.gene_names,
n_genes=n_genes,
)
Expand Down
19 changes: 13 additions & 6 deletions src/bayestme/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import pandas as pd
import scipy.io as io
import scipy.sparse.csc
from scipy.sparse import issparse
import spatialdata_io
from scipy.sparse import csr_matrix

Expand Down Expand Up @@ -96,10 +97,6 @@ def __init__(self, adata: anndata.AnnData):
"""
self.adata: anndata.AnnData = adata

@property
def reads(self) -> ArrayType:
return self.adata[self.adata.obs[IN_TISSUE_ATTR]].X

@property
def positions_tissue(self) -> ArrayType:
return self.adata[self.adata.obs[IN_TISSUE_ATTR]].obsm[SPATIAL_ATTR]
Expand All @@ -118,11 +115,21 @@ def n_gene(self) -> int:

@property
def raw_counts(self) -> ArrayType:
return self.adata.X
X = self.adata.X

if issparse(X):
return X.todense()
else:
return X

@property
def counts(self) -> ArrayType:
return self.adata[self.adata.obs[IN_TISSUE_ATTR]].X
X = self.adata[self.adata.obs[IN_TISSUE_ATTR]].X

if issparse(X):
return X.todense()
else:
return X

@property
def positions(self) -> ArrayType:
Expand Down
4 changes: 2 additions & 2 deletions src/bayestme/data_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def test_properties_work_without_obs_names():

dataset = data.SpatialExpressionDataset(adata)

np.testing.assert_array_equal(dataset.reads, bleed_counts[tissue_mask])
np.testing.assert_array_equal(dataset.counts, bleed_counts[tissue_mask])
np.testing.assert_array_equal(dataset.positions_tissue, locations[tissue_mask])
np.testing.assert_array_equal(dataset.n_spot_in, tissue_mask.sum())
np.testing.assert_array_equal(dataset.raw_counts, bleed_counts)
Expand Down Expand Up @@ -235,7 +235,7 @@ def test_properties_work_with_obs_names():

dataset = data.SpatialExpressionDataset(adata)

np.testing.assert_array_equal(dataset.reads, bleed_counts[tissue_mask])
np.testing.assert_array_equal(dataset.counts, bleed_counts[tissue_mask])
np.testing.assert_array_equal(dataset.positions_tissue, locations[tissue_mask])
np.testing.assert_array_equal(dataset.n_spot_in, tissue_mask.sum())
np.testing.assert_array_equal(dataset.raw_counts, bleed_counts)
Expand Down
2 changes: 1 addition & 1 deletion src/bayestme/deconvolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def sample_from_posterior(
) -> data.DeconvolutionResult:
if inference_type == InferenceType.MCMC:
return bayestme.mcmc.deconvolution.deconvolve(
reads=data.reads,
reads=data.counts,
edges=data.edges,
n_samples=n_samples,
n_burnin=mcmc_n_burn,
Expand Down
6 changes: 3 additions & 3 deletions src/bayestme/gene_filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def select_top_genes_by_standard_deviation(
dataset: data.SpatialExpressionDataset, n_gene: int
) -> data.SpatialExpressionDataset:
# order genes by the standard deviation across spots
ordering = utils.get_stddev_ordering(dataset.reads)
ordering = utils.get_stddev_ordering(dataset.counts)

n_top_genes = min(n_gene, dataset.n_gene)

Expand All @@ -38,9 +38,9 @@ def select_top_genes_by_standard_deviation(
def filter_genes_by_spot_threshold(
dataset: data.SpatialExpressionDataset, spot_threshold: float
):
n_spots = dataset.reads.shape[0]
n_spots = dataset.counts.shape[0]

keep = (dataset.reads > 0).sum(axis=0) <= int(n_spots * spot_threshold)
keep = (dataset.counts > 0).sum(axis=0) <= int(n_spots * spot_threshold)

input_adata = dataset.adata
return data.SpatialExpressionDataset(input_adata[:, keep])
Expand Down
21 changes: 13 additions & 8 deletions src/bayestme/gene_filtering_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,14 @@ def test_select_top_genes_by_standard_deviation():
dataset, n_genes_filter
)

(n_spots_in, n_genes) = result.reads.shape
(n_spots_in, n_genes) = result.counts.shape

# Then: the 2 high variation genes are selected
assert n_genes == n_genes_filter
assert n_spots_in == 3
np.testing.assert_equal(
result.reads, np.array([[199, 200], [10000, 10001], [1, 3]], dtype=np.int64)
result.counts,
np.array([[199, 200], [10000, 10001], [1, 3]], dtype=np.int64),
)
np.testing.assert_equal(result.gene_names, np.array(["keep_me1", "keep_me2"]))

Expand Down Expand Up @@ -94,11 +95,13 @@ def test_filter_genes_by_spot_threshold():
dataset, spot_threshold=0.5
)

(n_spots_in, n_genes) = result.reads.shape
(n_spots_in, n_genes) = result.counts.shape

# Then: only the gene which appears in 0% of spots is kept
assert n_genes == 1
np.testing.assert_equal(result.reads, np.array([[0], [0], [0]], dtype=np.int64))
np.testing.assert_equal(
result.counts, np.array([[0], [0], [0]], dtype=np.int64)
)
np.testing.assert_equal(result.gene_names, np.array(["keep_me"]))


Expand Down Expand Up @@ -135,12 +138,12 @@ def test_filter_ribosome_genes():
# When: filter_ribosome_genes is called
result = gene_filtering.filter_ribosome_genes(dataset)

(n_spots_in, n_genes) = result.reads.shape
(n_spots_in, n_genes) = result.counts.shape

# Then: only the two genes with non matching names are kept
assert n_genes == 2
np.testing.assert_equal(
result.reads, np.array([[1, 2], [2, 3], [3, 4]], dtype=np.int64)
result.counts, np.array([[1, 2], [2, 3], [3, 4]], dtype=np.int64)
)
np.testing.assert_equal(result.gene_names, np.array(["other", "other2"]))

Expand Down Expand Up @@ -177,11 +180,13 @@ def test_filter_list_of_genes():
# When: filter_list_of_genes is called
result = gene_filtering.filter_list_of_genes(dataset, ["other", "other2"])

(n_spots_in, n_genes) = result.reads.shape
(n_spots_in, n_genes) = result.counts.shape

# Then: only the gene not on the list remains
assert n_genes == 1
np.testing.assert_equal(result.reads, np.array([[7], [8], [9]], dtype=np.int64))
np.testing.assert_equal(
result.counts, np.array([[7], [8], [9]], dtype=np.int64)
)
np.testing.assert_equal(result.gene_names, np.array(["RPL333"]))


Expand Down
2 changes: 1 addition & 1 deletion src/bayestme/semi_synthetic_spatial_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@ def test_semi_synthetic_spatial():
ad, "cluster", pos_ss, n_genes=n_genes, canvas_size=(36, 36), n_spatial_gene=5
)

assert stdata.reads.shape == (pos_ss.shape[0], n_genes)
assert stdata.counts.shape == (pos_ss.shape[0], n_genes)
Loading