Skip to content

Commit

Permalink
Merge pull request #56 from YeoLab/pivot_on_edits_1219
Browse files Browse the repository at this point in the history
Fix concatenation of coverage adatas
  • Loading branch information
ekofman authored Dec 25, 2024
2 parents 88727dd + 5b0d874 commit 32120ef
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 27 deletions.
119 changes: 94 additions & 25 deletions src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1017,35 +1017,106 @@ def merge_files_by_chromosome(args):



def concatenate_coverage_adatas(adata_dict):
# Step 1: Get the union of all variable names (columns)
all_vars = set()
def create_zeros_df(indices, columns):
return pd.DataFrame.sparse.from_spmatrix(
sp.csr_matrix((len(indices), len(columns))),
index=indices,
columns=columns
)


def sum_adata_cellwise(adatas):
"""
Sums the values of multiple AnnData objects cell-wise (across observations).
Assumes all AnnData objects have the same `obs` and `var` indices.
Parameters:
adatas (list of AnnData): List of AnnData objects to sum.
Returns:
AnnData: A new AnnData object with summed values.
"""
# Check that all AnnData objects have the same indices
obs_names = adatas[0].obs_names
var_names = adatas[0].var_names
for adata in adatas:
assert np.array_equal(adata.obs_names, obs_names), "All obs indices must match."
assert np.array_equal(adata.var_names, var_names), "All var indices must match."

# Sum the X matrices across all AnnData objects
summed_X = sum(adata.X for adata in adatas)

# If the result is not sparse, convert it to a sparse matrix
if not sp.issparse(summed_X):
summed_X = sp.csr_matrix(summed_X)

# Create a new AnnData object with the summed matrix
summed_adata = ad.AnnData(X=summed_X, obs=adatas[0].obs, var=adatas[0].var)

return summed_adata


def combine_coverage_adatas(adata_dict):
"""
Combines multiple AnnData objects, aligning them to the union of all barcodes (obs_names)
and positions (var_names). Missing values are filled with zeros.
"""
# Step 1: Get the union of all positions (var_names) and barcodes (obs_names)
all_pos = set()
all_obs = set()

for adata in adata_dict.values():
all_vars.update(adata.var_names)
all_pos.update(adata.var.index)
all_obs.update(adata.obs.index)

# Convert to sorted lists for consistent ordering
all_pos = sorted(list(all_pos))
all_obs = sorted(list(all_obs))

# Convert to a sorted list for consistent ordering
all_vars = sorted(all_vars)

# Step 2: Align each AnnData object to the union of variables
aligned_adatas = []
adatas_processed = 0
for adata in adata_dict.values():
# Find missing variables
missing_vars = [var for var in all_vars if var not in adata.var_names]
# Create a sparse matrix of zeros for missing variables
zeros_matrix = sp.csr_matrix((adata.n_obs, len(missing_vars)))
# Concatenate the existing matrix with the zeros matrix
existing_matrix = adata[:, [var for var in adata.var_names if var in all_vars]].X
aligned_matrix = sp.hstack([existing_matrix, zeros_matrix], format='csr')
adatas_processed += 1
print("\t{}/{} adatas concatenated...".format(adatas_processed, len(adata_dict)))

dense_adata = pd.DataFrame(adata.X.todense(), index=adata.obs.index, columns=adata.var.index)
missing_barcodes = [b for b in all_obs if b not in dense_adata.index]
missing_positions = [p for p in all_pos if p not in dense_adata.columns]

# First add set of empty rows for missing barcodes, with existing columns as columns
missing_barcode_fix = create_zeros_df(missing_barcodes, list(dense_adata.columns))

dense_adata = (dense_adata.T).join(missing_barcode_fix.T, how='inner').T
assert(len(dense_adata) == len(all_obs))

# Now, add a set of empty columns for the missing positions, with all barcodes as indices
# Create a sparse DataFrame with zeros
missing_position_fix = create_zeros_df(all_obs, missing_positions)

dense_adata = dense_adata.join(missing_position_fix, how='inner')
assert(len(dense_adata.columns) == len(all_pos))

dense_adata.columns = sorted(dense_adata.columns)
dense_adata = dense_adata.sort_index()

# Create a new DataFrame for `var` to include all_vars
new_var = pd.DataFrame(index=all_vars)
aligned_adata = ad.AnnData(X=aligned_matrix, obs=adata.obs.copy(), var=new_var)
aligned_adata.var_names = all_vars
aligned_adatas.append(aligned_adata)
# Convert the dense matrix to a sparse CSR matrix
sparse_X = sp.csr_matrix(dense_adata)
# Create a new AnnData object with the sparse matrix

# Correctly assign obs and var as DataFrames
obs_df = pd.DataFrame(index=all_obs)
var_df = pd.DataFrame(index=all_pos)

# Step 3: Concatenate the aligned AnnData objects along the observation axis
combined_adata = ad.concat(aligned_adatas, axis=0, join='outer')
sparse_adata = ad.AnnData(X=sparse_X, obs=obs_df, var=var_df)

aligned_adatas.append(sparse_adata)

# Step 3: Concatenate the aligned AnnData objects along the observation axis
combined_adata = sum_adata_cellwise(aligned_adatas)

assert(combined_adata.shape == (len(all_obs), len(all_pos)))
print("Shape of combined adata is {}".format(combined_adata.shape))
return combined_adata


Expand Down Expand Up @@ -1090,19 +1161,17 @@ def prepare_matrix_files_multiprocess(output_matrix_folder,
# Move the per-contig coverage matrices h5ad files into a subfolder to keep the output area clean
os.makedirs(f"{output_folder}/per_contig_coverage_matrices", exist_ok=True)

h5_filepaths = glob(f'{output_folder}/*comprehensive_coverage_matrix.h5ad')
h5_filepaths = glob(f'{output_folder}/*_comprehensive_coverage_matrix.h5ad')
adata_dict = {}
for h5_filepath in h5_filepaths:
print(f"\t{h5_filepath}")
adata_dict[h5_filepath.split('/')[-1].split('_comprehensive')[0]] = ad.read_h5ad(h5_filepath)

shutil.move(h5_filepath, f"{output_folder}/per_contig_coverage_matrices/{h5_filepath.split('/')[-1]}")

combined_adata = concatenate_coverage_adatas(adata_dict)
combined_adata = combine_coverage_adatas(adata_dict)
combined_adata.write_h5ad(f"{output_folder}/comprehensive_coverage_matrix.h5ad")





def merge_depth_files(output_folder, output_suffix=''):
Expand Down
43 changes: 41 additions & 2 deletions tests/integration_tests_auto_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,24 +472,63 @@ def get_all_edited_positions_and_barcodes(test_folder):
print("\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n")

coverage_adata, ct_edits_adata, ag_edits_adata, gc_edits_adata = get_all_edited_positions_and_barcodes(test_folder)

print('gc_edits_adata: {}'.format(len(gc_edits_adata)))
print('ag_edits_adata: {}'.format(len(ag_edits_adata)))
print('ct_edits_adata: {}'.format(len(ct_edits_adata)))

try:
print('ct_edits_adata', pd.DataFrame(ct_edits_adata.X.todense(),
index=ct_edits_adata.obs.index,
columns=ct_edits_adata.var.index), '\n')

print('ag_edits_adata', pd.DataFrame(ag_edits_adata.X.todense(),
index=ag_edits_adata.obs.index,
columns=ag_edits_adata.var.index), '\n')

print('gc_edits_adata', pd.DataFrame(gc_edits_adata.X.todense(),
index=gc_edits_adata.obs.index,
columns=gc_edits_adata.var.index), '\n')

print('coverage_adata', pd.DataFrame(coverage_adata.X.todense(),
index=coverage_adata.obs.index,
columns=coverage_adata.var.index)[['9:3000508', '9:3000527', '9:3000528']], '\n')

print('\t', ct_edits_adata['GGGACCTTCGAGCCAC-1','9:3000528'].X.todense())
print('\t', coverage_adata['GGGACCTTCGAGCCAC-1','9:3000528'].X.todense())

assert(ct_edits_adata['GGGACCTTCGAGCCAC-1','9:3000528'].X.todense() == 1)
assert(coverage_adata['GGGACCTTCGAGCCAC-1','9:3000528'].X.todense() == 12)
print("\t\t9:3000528 passed...")

print('\t', ct_edits_adata['GATCCCTCAGTAACGG-1','9:3000508'].X.todense())
print('\t', coverage_adata['GATCCCTCAGTAACGG-1','9:3000508'].X.todense())

assert(ct_edits_adata['GATCCCTCAGTAACGG-1','9:3000508'].X.todense() == 1)
assert(coverage_adata['GATCCCTCAGTAACGG-1','9:3000508'].X.todense() == 3)
print("\t\t9:3000508 passed...")

print('\t', ag_edits_adata['GGGACCTTCGAGCCAC-1','9:3000527'].X.todense())
print('\t', coverage_adata['GGGACCTTCGAGCCAC-1','9:3000527'].X.todense())

assert(ag_edits_adata['GGGACCTTCGAGCCAC-1','9:3000527'].X.todense() == 10)
assert(coverage_adata['GGGACCTTCGAGCCAC-1','9:3000527'].X.todense() == 12)
print("\t\t9:3000527 passed...")

assert(gc_edits_adata['GATCCCTCAGTAACGG-1','9:3000525'].X.todense() == 1)
print('\t', gc_edits_adata['GATCCCTCAGTAACGG-1','9:3000525'].X.todense())
print('\t', coverage_adata['GATCCCTCAGTAACGG-1','9:3000525'].X.todense())

assert(gc_edits_adata['GATCCCTCAGTAACGG-1','9:3000525'].X.todense()== 1)
assert(coverage_adata['GATCCCTCAGTAACGG-1','9:3000525'].X.todense() == 4)
print("\t\t9:3000525 passed...")

print("\n\t >>> coverage matrix and edit matrix values confirmation passed! <<<\n")

except Exception as e:
print(e)
print('Error', e)
print("Exception: Expected edit and coverage values not found in sparse matrices")
failures += 1
raise(e)


print("There were {} failures".format(failures))
Expand Down

0 comments on commit 32120ef

Please sign in to comment.