-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathmedian_subtraction.py
More file actions
71 lines (53 loc) · 2.8 KB
/
median_subtraction.py
File metadata and controls
71 lines (53 loc) · 2.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
# -*- coding: utf-8 -*-
"""
Created on Fri Mar 22 18:53:04 2019
@author: svc_ccg
"""
import numpy as np
from matplotlib import pyplot as plt
import time
def medianSubtraction(datFile, minReferenceChannel=0, maxReferenceChannel=150, chunksize = 30000*10, sampleRate=30000, channelNumber=384, readMode = 'r+', plot=True):
'''subtract channel offsets and apply a common median referencing
inputs:
datFile: path to binary data file
minReferenceChannel/maxReferenceChannel: first and last channels bracketing the part of the probe that will used for median calculation
chunksize: number of sample points to process at a time
sampleRate: in Hz
channelNumber: total probe channel count
readMode: if 'r+' changes will be written in place on disk; if 'c', data will be copied and written to separate file
plot: if True, plots standard deviation for each channel before and after referencing
'''
starttime = time.clock()
#create memmap of file to process and reshape to [time, channel]
d = np.memmap(datFile, dtype = 'int16', mode = readMode)
d = np.reshape(d, (int(d.size/channelNumber), channelNumber))
# get channel offsets (we eventually want to subtract off any offsets individual channels might have to center them on zero)
offsets = np.median(d[:chunksize], axis=0).astype('int16')
# plot pre filter standard deviation
if plot:
fig, ax = plt.subplots()
ax.plot(np.std(d[:chunksize], axis=0))
# main loop: loop through data chunks and subtract channel offsets and median across channels
median_values = np.full(chunksize, 0, dtype='int16')
for ind in np.arange(0, d.shape[0], chunksize):
start = ind
end = ind + chunksize if ind + chunksize <= d.shape[0] else d.shape[0]
#subtract offsets calculated above for each individual channel
d[start:end, :] = d[start:end, :] - offsets[None,:]
#subtract median across channels for every time point in chunk
median_values = np.median(d[start:end, minReferenceChannel:maxReferenceChannel], axis = 1)
d[start:end, :] = d[start:end, :] - median_values[:, None]
# plot post filter standard deviation
if plot:
ax.plot(np.std(d[:chunksize], axis=0))
ax.set_xlabel('channel')
ax.set_ylabel('standard deviation')
# if opened as copy, save median subtracted data to file
if readMode == 'c':
outputDir, outputFile = os.path.split(datFile)
outputFile, ext = os.path.splitext(outputFile)
outputFile = outputFile + '_medianSubtracted' + ext
d.astype('int16').tofile(os.path.join(outputDir, outputFile))
del(d)
elapsed = time.clock() - starttime
print('Time elapsed (s): ' + str(elapsed))