Skip to content

Commit 60701a7

Browse files
committed
debug for scprint-1
1 parent f20f131 commit 60701a7

1 file changed

Lines changed: 16 additions & 1 deletion

File tree

scdataloader/datamodule.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,7 @@ def __init__(
204204
genedf=genedf,
205205
n_bins=n_bins,
206206
)
207+
self.gene_subset = gene_subset
207208
self.n_bins = n_bins
208209
self.validation_split = validation_split
209210
self.test_split = test_split
@@ -294,7 +295,10 @@ def genes(self) -> list:
294295
Returns:
295296
list
296297
"""
297-
return self.dataset.genedf.index.tolist()
298+
if self.gene_subset is not None:
299+
return self.gene_subset
300+
else:
301+
return self.dataset.genedf.index.tolist()
298302

299303
@property
300304
def genes_dict(self):
@@ -932,6 +936,17 @@ def _build_class_indices_parallel(self, labels, chunk_size, n_workers=None):
932936
print(f"Processing {n:,} elements in {n_chunks} chunks...")
933937

934938
# Process in chunks to limit memory usage
939+
if n_workers == 1:
940+
# Process sequentially without multiprocessing
941+
for i in tqdm(
942+
range(n_chunks), total=n_chunks, desc="Processing chunks sequentially"
943+
):
944+
start_idx = i * chunk_size
945+
end_idx = min((i + 1) * chunk_size, n)
946+
results.append(
947+
self._process_chunk_with_slice((start_idx, end_idx, labels))
948+
)
949+
return self._merge_chunk_results(results)
935950
with ProcessPoolExecutor(
936951
max_workers=n_workers, mp_context=mp.get_context("spawn")
937952
) as executor:

0 commit comments

Comments
 (0)