diff --git a/pyproject.toml b/pyproject.toml index 3adbdaf..d5f1eaa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,7 +44,8 @@ dependencies = [ "tqdm>=4.64.0,<5", "shapely>=1.8.0,<2.1", "sksparse-minimal>=0.2", - "geopandas>=0.14.1,<1" + "geopandas>=0.14.1,<1", + "spatialdata-io>=0.0.9,<1" ] [project.optional-dependencies] diff --git a/src/bayestme/data.py b/src/bayestme/data.py index f921888..b006ce2 100644 --- a/src/bayestme/data.py +++ b/src/bayestme/data.py @@ -9,6 +9,7 @@ import pandas as pd import scipy.io as io import scipy.sparse.csc +import spatialdata_io from scipy.sparse import csr_matrix from bayestme import utils @@ -243,61 +244,18 @@ def read_spaceranger(cls, data_path): 3) /spatial for position list :return: SpatialExpressionDataset """ - raw_count_path = os.path.join(data_path, "raw_feature_bc_matrix/matrix.mtx.gz") - filtered_count_path = os.path.join( - data_path, "filtered_feature_bc_matrix/matrix.mtx.gz" - ) - features_path = os.path.join(data_path, "raw_feature_bc_matrix/features.tsv.gz") - barcodes_path = os.path.join(data_path, "raw_feature_bc_matrix/barcodes.tsv.gz") - - tissue_positions_lists = glob.glob( - os.path.join(data_path, "spatial/tissue_positions*.*") - ) - - positions_path = [fn for fn in tissue_positions_lists if is_csv_tsv(fn)][0] - - if is_tsv(positions_path): - positions_df = pd.read_csv( - positions_path, sep="\t", header=0, index_col=0, names=None - ) - elif is_csv(positions_path): - positions_df = pd.read_csv( - positions_path, header=0, index_col=0, names=None - ) - else: - raise RuntimeError("No positions list found in spaceranger directory") - - raw_count = np.array(io.mmread(raw_count_path).todense()) - filtered_count = np.array(io.mmread(filtered_count_path).todense()) - features = np.array(pd.read_csv(features_path, header=None, sep="\t"))[ - :, 1 - ].astype(str) - barcodes = pd.read_csv(barcodes_path, header=None, sep="\t") - n_spots = raw_count.shape[1] - n_genes = raw_count.shape[0] - logger.info("detected {} spots, {} genes".format(n_spots, n_genes)) - positions = positions_df.loc[barcodes[0]][ - [POSITIONS_X_COLUMN, POSITIONS_Y_COLUMN] - ].to_numpy() - tissue_mask = positions_df[IN_TISSUE_ATTR].to_numpy().astype(bool) - n_spot_in = tissue_mask.sum() - logger.info("\t {} spots in tissue sample".format(n_spot_in)) - all_counts = raw_count.sum() - tissue_counts = filtered_count.sum() - logger.info( - "\t {:.3f}% UMI counts bleeds out".format( - (1 - tissue_counts / all_counts) * 100 - ) - ) + sd = spatialdata_io.visium(data_path, dataset_id="visium") + ad = sd.table + tissue_mask = ad.obs.in_tissue.astype(bool).values return cls.from_arrays( - raw_counts=raw_count.T, - positions=positions, + raw_counts=ad.X, + positions=ad.obs[["array_row", "array_col"]].values, tissue_mask=tissue_mask, - gene_names=features, + gene_names=ad.var_names.values, layout=Layout.HEX, - edges=utils.get_edges(positions[tissue_mask], Layout.HEX), - barcodes=barcodes[0].to_numpy(), + edges=utils.get_edges(ad.obsm["spatial"][tissue_mask], Layout.HEX), + barcodes=ad.obs_names.values, ) @classmethod