Skip to content

Registration API Reference

registration

B2Registration dataclass

Bases: B2Session

Registration class for processing B2 (vrControl) session data.

This class handles the preprocessing and registration of behavioral and imaging data from vrControl experiments. It processes timeline data, behavioral data, imaging data (suite2p outputs), red cell identification, and creates mappings between behavioral and imaging data.

Parameters:

Name Type Description Default
mouse_name str

Name of the mouse.

required
date_string str

Date string in format "YYYY-MM-DD".

required
session_id str

Session identifier.

required
opts B2RegistrationOpts or dict

Registration options. If a dict, will be converted to B2RegistrationOpts. Default is B2RegistrationOpts().

B2RegistrationOpts()

Attributes:

Name Type Description
opts B2RegistrationOpts

Registration options.

tl_file dict

Timeline data loaded from Timeline.mat file.

vr_file dict

Behavioral data loaded from VRBehavior_trial.mat file.

Raises:

Type Description
ValueError

If session directory does not exist.

Source code in vrAnalysis/registration/register.py
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
@dataclass(init=False)
class B2Registration(B2Session):
    """
    Registration class for processing B2 (vrControl) session data.

    This class handles the preprocessing and registration of behavioral and
    imaging data from vrControl experiments. It processes timeline data,
    behavioral data, imaging data (suite2p outputs), red cell identification,
    and creates mappings between behavioral and imaging data.

    Parameters
    ----------
    mouse_name : str
        Name of the mouse.
    date_string : str
        Date string in format "YYYY-MM-DD".
    session_id : str
        Session identifier.
    opts : B2RegistrationOpts or dict, optional
        Registration options. If a dict, will be converted to B2RegistrationOpts.
        Default is B2RegistrationOpts().

    Attributes
    ----------
    opts : B2RegistrationOpts
        Registration options.
    tl_file : dict
        Timeline data loaded from Timeline.mat file.
    vr_file : dict
        Behavioral data loaded from VRBehavior_trial.mat file.

    Raises
    ------
    ValueError
        If session directory does not exist.
    """

    _for_registration: bool = True

    def __init__(
        self,
        mouse_name: str,
        date_string: str,
        session_id: str,
        opts: Union[B2RegistrationOpts, dict] = B2RegistrationOpts(),
    ):
        """
        Initialize B2Registration object.

        Parameters
        ----------
        mouse_name : str
            Name of the mouse.
        date_string : str
            Date string in format "YYYY-MM-DD".
        session_id : str
            Session identifier.
        opts : B2RegistrationOpts or dict, optional
            Registration options. If a dict, will be converted to B2RegistrationOpts.
            Default is B2RegistrationOpts().

        Raises
        ------
        ValueError
            If session directory does not exist.
        """
        super().__init__(mouse_name, date_string, session_id)
        if isinstance(opts, B2RegistrationOpts):
            self.opts = opts
        elif is_dataclass(opts):
            self.opts = B2RegistrationOpts(**asdict(opts))
        elif isinstance(opts, Mapping):
            self.opts = B2RegistrationOpts(**opts)
        else:
            raise TypeError(f"opts must be B2RegistrationOpts, a dataclass, or a mapping; got {type(opts)}")

        if not self.data_path.exists():
            raise ValueError(f"Session directory does not exist for {self.session_print()}")

        if not self.one_path.exists():
            self.one_path.mkdir(parents=True)

    def _additional_loading(self):
        """
        Override to skip loading registered data.

        Registration objects produce registered data rather than load it,
        so we skip the parent's _additional_loading() which tries to load
        from saved JSON files.

        Notes
        -----
        This method intentionally does nothing. Registration objects create
        data rather than loading pre-existing registered data.
        """
        pass

    def register(self):
        """
        Register the session by running all preprocessing steps.

        This is the main entry point for registration. It runs all preprocessing
        steps (timeline, behavior, imaging, red cells, facecam, behavior-to-imaging)
        and saves session parameters.

        See Also
        --------
        do_preprocessing : Run all preprocessing steps.
        save_session_prms : Save session parameters to oneData.
        """
        self.do_preprocessing()
        self.save_session_prms()

    def do_preprocessing(self):
        """
        Run all preprocessing steps for the session.

        Processes timeline data, behavioral data, imaging data, red cell
        identification, facecam data (placeholder), and creates mappings
        between behavioral and imaging data.

        Notes
        -----
        Processing steps are run in order:
        1. Timeline processing
        2. Behavior processing
        3. Imaging processing
        4. Red cell processing (if enabled)
        5. Facecam processing (not yet implemented)
        6. Behavior-to-imaging mapping
        """
        if self.opts.clearOne:
            self.clear_one_data(certainty=True)
        self.process_timeline()
        self.process_behavior()
        self.process_imaging()
        self.process_red_cells()
        self.process_facecam()
        self.process_behavior_to_imaging()

    # --------------------------------------------------------------- preprocessing methods ------------------------------------------------------------
    def process_timeline(self):
        """
        Process timeline data from rigbox.

        Extracts timestamps, rotary encoder position, lick times, reward times,
        and trial start times from the Timeline.mat file. Processes photodiode
        signals to align trial starts with imaging frames.

        Notes
        -----
        This method:
        - Loads timeline structure from Timeline.mat
        - Converts rotary encoder to position
        - Detects licks and rewards
        - Processes photodiode signal to find trial start frames
        - Saves timeline data to oneData
        """
        # load these files for raw behavioral & timeline data
        self.load_timeline_structure()
        self.load_behavior_structure()

        # get time stamps, photodiode, trial start and end times, room position, lick times, trial idx, visual data visible
        mpepStartTimes = []
        for mt, me in zip(self.tl_file["mpepUDPTimes"], self.tl_file["mpepUDPEvents"]):
            if isinstance(me, str):
                if "TrialStart" in me:
                    mpepStartTimes.append(mt)
                elif "StimStart" in me:
                    mpepStartTimes.append(mt)

        mpepStartTimes = np.array(mpepStartTimes)
        timestamps = self.get_timeline_var("timestamps")  # load timestamps

        # Get rotary position -- (load timeline measurement of rotary encoder, which is a circular position counter, use vrExperiment function to convert to a running measurement of position)
        rotaryEncoder = self.get_timeline_var("rotaryEncoder")
        rotaryPosition = self.convert_rotary_encoder_to_position(rotaryEncoder, self.vr_file["rigInfo"])

        # Get Licks (uses an edge counter)
        lickDetector = self.get_timeline_var("lickDetector")  # load lick detector copy
        lickSamples = np.where(helpers.diffsame(lickDetector) == 1)[0].astype(np.uint64)  # timeline samples of lick times

        # Get Reward Commands (measures voltage of output -- assume it's either low or high)
        rewardCommand = self.get_timeline_var("rewardCommand")  # load reward command signal
        rewardCommand = np.round(rewardCommand / np.max(rewardCommand))
        rewardSamples = np.where(helpers.diffsame(rewardCommand) > 0.5)[0].astype(np.uint64)  # timeline samples when reward was delivered

        # Now process photodiode signal
        photodiode = self.get_timeline_var("photoDiode")  # load lick detector copy

        # Remove any slow trends
        pdDetrend = sp.signal.detrend(photodiode)
        pdDetrend = (pdDetrend - pdDetrend.min()) / pdDetrend.ptp()

        # median filter and take smooth derivative
        hfpd = 10
        refreshRate = 30  # hz
        refreshSamples = int(1.0 / refreshRate / np.mean(np.diff(timestamps)))
        pdMedFilt = sp.ndimage.median_filter(pdDetrend, size=refreshSamples)
        pdDerivative, pdIndex = helpers.fivePointDer(pdMedFilt, hfpd, returnIndex=True)
        pdDerivative = sp.stats.zscore(pdDerivative)
        pdDerTime = timestamps[pdIndex]

        # find upward and downward peaks, not perfect but in practice close enough
        locUp = sp.signal.find_peaks(pdDerivative, height=1, distance=refreshSamples / 2)
        locDn = sp.signal.find_peaks(-pdDerivative, height=1, distance=refreshSamples / 2)
        flipTimes = np.concatenate((pdDerTime[locUp[0]], pdDerTime[locDn[0]]))
        flipValue = np.concatenate((np.ones(len(locUp[0])), np.zeros(len(locDn[0]))))
        flipSortIdx = np.argsort(flipTimes)
        flipTimes = flipTimes[flipSortIdx]
        flipValue = flipValue[flipSortIdx]

        # Naive Method (just look for flips before and after trialstart/trialend mpep message:
        # A sophisticated message uses the time of the photodiode ramps, but those are really just for safety and rare manual curation...
        firstFlipIndex = np.array([np.where(flipTimes >= mpepStart)[0][0] for mpepStart in mpepStartTimes])
        startTrialIndex = helpers.nearestpoint(flipTimes[firstFlipIndex], timestamps)[0]  # returns frame index of first photodiode flip in each trial

        # Check that first flip is always down -- all of the vrControl code prepares trials in this way
        if datetime.strptime(self.date, "%Y-%m-%d") >= datetime.strptime("2022-08-30", "%Y-%m-%d"):
            # But it didn't prepare it this way before august 30th :(
            assert np.all(flipValue[firstFlipIndex] == 0), f"In session {self.sessionPrint()}, first flips in trial are not all down!!"

        # Check shapes of timeline arrays
        assert timestamps.ndim == 1, "timelineTimestamps is not a 1-d array!"
        assert timestamps.shape == rotaryPosition.shape, "timeline timestamps and rotary position arrays do not have the same shape!"

        # Save timeline oneData
        self.saveone(timestamps, "wheelPosition.times")
        self.saveone(rotaryPosition, "wheelPosition.position")
        self.saveone(timestamps[lickSamples], "licks.times")
        self.saveone(timestamps[rewardSamples], "rewards.times")
        self.saveone(timestamps[startTrialIndex], "trials.startTimes")
        self.preprocessing.append("timeline")

    def process_behavior(self):
        """
        Process behavioral data from vrControl.

        Processes behavioral data using the appropriate behavior processing
        function based on the vrBehaviorVersion option. Extracts trial-level
        and sample-level behavioral data and aligns timestamps to the timeline.

        See Also
        --------
        register_behavior : Dispatcher function for behavior processing.
        """
        self = register_behavior(self, self.opts.vrBehaviorVersion)

        # Confirm that vrBehavior has been processed
        self.preprocessing.append("vrBehavior")

    def process_imaging(self):
        """
        Process imaging data from suite2p outputs.

        Loads suite2p outputs, identifies available planes and outputs,
        handles frame count mismatches between suite2p and timeline,
        optionally recomputes deconvolution using OASIS, and saves imaging
        data to oneData format.

        Raises
        ------
        ValueError
            If imaging is requested but suite2p directory does not exist, or
            if required suite2p outputs are missing, or if frame count
            mismatches cannot be resolved.

        Notes
        -----
        This method:
        - Identifies planes and available suite2p outputs
        - Checks for required outputs (stat, ops, F, Fneu, iscell, spks)
        - Handles frame count mismatches between suite2p and timeline
        - Optionally recomputes deconvolution using OASIS
        - Saves imaging data to oneData
        """
        if not self.opts.imaging:
            print(f"In session {self.session_print()}, imaging setting set to False in opts['imaging']. Skipping image processing.")
            return None

        if not self.s2p_path.exists():
            raise ValueError(f"In session {self.session_print()}, suite2p processing was requested but suite2p directory does not exist.")

        # identifies which planes were processed through suite2p (assume that those are all available planes)
        # identifies which s2p outputs are available from each plane
        self.set_value("planeNames", [plane.parts[-1] for plane in self.s2p_path.glob("plane*/")])
        self.set_value("planeIDs", [int(planeName[5:]) for planeName in self.get_value("planeNames")])
        npysInPlanes = [[npy.stem for npy in list((self.s2p_path / planeName).glob("*.npy"))] for planeName in self.get_value("planeNames")]
        commonNPYs = list(set.intersection(*[set(npy) for npy in npysInPlanes]))
        unionNPYs = list(set.union(*[set(npy) for npy in npysInPlanes]))
        if set(commonNPYs) < set(unionNPYs):
            print(
                f"The following npy files are present in some but not all plane folders within session {self.session_print()}: {list(set(unionNPYs) - set(commonNPYs))}"
            )
            print(f"Each plane folder contains the following npy files: {commonNPYs}")
        self.set_value("available", commonNPYs)  # a list of npy files available in each plane folder

        # required variables (anything else is either optional or can be computed independently)
        required = ["stat", "ops", "F", "Fneu", "iscell"]
        if not self.opts.oasis:
            # add deconvolved spikes to required variable if we aren't recomputing it here
            required.append("spks")
        for varName in required:
            assert varName in self.get_value("available"), f"{self.session_print()} is missing {varName} in at least one suite2p folder!"
        # get number of ROIs in each plane
        self.set_value("roiPerPlane", [iscell.shape[0] for iscell in self.load_s2p("iscell", concatenate=False)])
        # get number of frames in each plane (might be different!)
        self.set_value("framePerPlane", [F.shape[1] for F in self.load_s2p("F", concatenate=False)])
        assert_msg = f"The frame count in {self.session_print()} varies by more than 1 frame! ({self.get_value('framePerPlane')})"
        assert np.max(self.get_value("framePerPlane")) - np.min(self.get_value("framePerPlane")) <= 1, assert_msg
        self.set_value("numROIs", np.sum(self.get_value("roiPerPlane")))  # number of ROIs in session
        # number of frames to use when retrieving imaging data (might be overwritten to something smaller if timeline handled improperly)
        self.set_value("numFrames", np.min(self.get_value("framePerPlane")))

        # Get timeline sample corresponding to each imaging volume
        timeline_timestamps = self.loadone("wheelPosition.times")
        changeFrames = (
            np.append(
                0,
                np.diff(np.ceil(self.get_timeline_var("neuralFrames") / len(self.get_value("planeIDs")))),
            )
            == 1
        )
        frame_samples = np.where(changeFrames)[0]  # TTLs for each volume (increments by 1 for each plane)
        frame_to_time = timeline_timestamps[frame_samples]  # get timelineTimestamps of each imaging volume

        # Handle mismatch between number of imaging frames saved by scanImage (and propagated through suite2p), and between timeline's measurement of the scanImage frame counter
        if len(frame_to_time) != self.get_value("numFrames"):
            if len(frame_to_time) - 1 == self.get_value("numFrames"):
                # If frame_to_time had one more frame, just trim it and assume everything is fine. This happens when a new volume was started but not finished, so does not required communication to user.
                frame_samples = frame_samples[:-1]
                frame_to_time = frame_to_time[:-1]
            elif len(frame_to_time) - 2 == self.get_value("numFrames"):
                print(
                    "frame_to_time had 2 more than suite2p output. This happens sometimes. I don't like it. I think it's because scanimage sends a TTL before starting the frame"
                )
                frame_samples = frame_samples[:-2]
                frame_to_time = frame_to_time[:-2]
            else:
                # If frameSamples has too few frames, it's possible that the scanImage signal to timeline was broken but scanImage still continued normally.
                numMissing = self.get_value("numFrames") - len(frame_samples)  # measure number of missing frames
                if numMissing < 0:
                    # If frameSamples had many more frames, generate an error -- something went wrong that needs manual inspection
                    print(
                        f"In session {self.session_print()}, frameSamples has {len(frame_samples)} elements, but {self.get_value('numFrames')} frames were reported in suite2p. Cannot resolve."
                    )
                    raise ValueError("Cannot fix mismatches when suite2p data is missing!")
                # It's possible that the scanImage signal to timeline was broken but scanImage still continued normally.
                if numMissing > 1:
                    print(
                        f"In session {self.session_print()}, more than one frameSamples sample was missing. Consider using tiff timelineTimestamps to reproduce accurately."
                    )
                print(
                    (
                        f"In session {self.session_print()}, frameSamples has {len(frame_samples)} elements, but {self.get_value('numFrames')} frames were saved by suite2p. "
                        "Will extend frameSamples using the typical sampling rate and nearestpoint algorithm."
                    )
                )
                # If frame_to_time difference vector is consistent within 1%, then use mean (which is a little more accurate), otherwise use median
                frame_to_time = timeline_timestamps[frame_samples]
                medianFramePeriod = np.median(np.diff(frame_to_time))  # measure median sample period
                consistentFrames = np.all(
                    np.abs(np.log(np.diff(frame_to_time) / medianFramePeriod)) < np.log(1.01)
                )  # True if all frames take within 1% of median frame period
                if consistentFrames:
                    samplePeriod_f2t = np.mean(np.diff(frame_to_time))
                else:
                    samplePeriod_f2t = np.median(np.diff(frame_to_time))
                appendFrames = frame_to_time[-1] + samplePeriod_f2t * (
                    np.arange(numMissing) + 1
                )  # add elements to frame_to_time, assume sampling rate was perfect
                frame_to_time = np.concatenate((frame_to_time, appendFrames))
                frame_samples = helpers.nearestpoint(frame_to_time, timeline_timestamps)[0]

        # average percentage difference between all samples differences and median -- just a useful metric to be saved --
        self.set_value(
            "samplingDeviationMedianPercentError", np.exp(np.mean(np.abs(np.log(np.diff(frame_to_time) / np.median(np.diff(frame_to_time))))))
        )
        self.set_value(
            "samplingDeviationMaximumPercentError", np.exp(np.max(np.abs(np.log(np.diff(frame_to_time) / np.median(np.diff(frame_to_time))))))
        )

        # recompute deconvolution if requested
        spks = self.load_s2p("spks")
        if self.opts.oasis:
            # set parameters for oasis and get corrected fluorescence traces
            g = np.exp(-1 / self.opts.tau / self.opts.fs)
            fcorr = self.loadfcorr(try_from_one=False)
            results = oasis_deconvolution(fcorr, g)
            ospks = np.stack(results)

            # Check that the shape is correct
            msg = f"In session {self.session_print()}, oasis was run and did not produce the same shaped array as suite2p spks..."
            assert ospks.shape == spks.shape, msg

        # save onedata (no assertions needed, loadS2P() handles shape checks and this function already handled any mismatch between frameSamples and suite2p output
        self.saveone(frame_to_time, "mpci.times")
        self.saveone(LoadingRecipe("S2P", "F", transforms=["transpose"]), "mpci.roiActivityF")
        self.saveone(LoadingRecipe("S2P", "Fneu", transforms=["transpose"]), "mpci.roiNeuropilActivityF")
        self.saveone(LoadingRecipe("S2P", "spks", transforms=["transpose"]), "mpci.roiActivityDeconvolved")
        if "redcell" in self.get_value("available"):
            self.saveone(LoadingRecipe("S2P", "redcell", transforms=["idx_column1"]), "mpciROIs.redS2P")
        self.saveone(LoadingRecipe("S2P", "iscell"), "mpciROIs.isCell")
        self.saveone(self.get_roi_position(), "mpciROIs.stackPosition")
        if self.opts.oasis:
            self.saveone(ospks.T, "mpci.roiActivityDeconvolvedOasis")
        self.preprocessing.append("imaging")

    def process_facecam(self):
        """
        Process facecam data.

        Placeholder for facecam preprocessing. Not yet implemented.

        Notes
        -----
        This method currently only prints a message indicating that facecam
        preprocessing has not been implemented yet.
        """
        print("Facecam preprocessing has not been coded yet!")

    def process_behavior_to_imaging(self):
        """
        Create mapping from behavioral frames to imaging frames.

        Computes the nearest imaging frame for each behavioral sample and
        saves the mapping to oneData.

        Notes
        -----
        This method is skipped if imaging is disabled in opts. The mapping
        is saved as "positionTracking.mpci" in oneData.
        """
        if not self.opts.imaging:
            print(f"In session {self.session_print()}, imaging setting set to False in opts['imaging']. Skipping behavior2imaging processing.")
            return None

        # compute translation mapping from behave frames to imaging frames
        idx_behave_to_frame = helpers.nearestpoint(self.loadone("positionTracking.times"), self.loadone("mpci.times"))[0]
        self.saveone(idx_behave_to_frame.astype(int), "positionTracking.mpci")

    def process_red_cells(self):
        """
        Process red cell features for identification.

        Computes red cell features (dot product, Pearson correlation, phase
        correlation) using RedCellProcessing and saves them to oneData.
        Initializes red cell index and manual assignment arrays.

        Notes
        -----
        This method is skipped if imaging or redCellProcessing is disabled in opts,
        or if redcell output is not available in suite2p. The computed features
        are saved to oneData for later use in red cell identification.
        """
        if not (self.opts.imaging) or not (self.opts.redCellProcessing):
            return  # if not requested, skip function
        # if imaging was processed and redCellProcessing was requested, then try to preprocess red cell features
        if "redcell" not in self.get_value("available"):
            print(f"In session {self.session_print()}, 'redcell' is not an available suite2p output, although 'redCellProcessing' was requested.")
            return

        # create RedCellProcessing object
        # b2session_of_self = B2Session.create(self.mouse_name, self.date, self.session_id, for_registration=True)
        red_cell_processing = RedCellProcessing(self)

        # compute red-features
        dot_parameters = {"lowcut": 12, "highcut": 250, "order": 3, "fs": 512}
        corr_parameters = {"width": 20, "lowcut": 12, "highcut": 250, "order": 3, "fs": 512}
        phase_parameters = {"width": 40, "eps": 1e6, "winFunc": "hamming"}

        print(f"Computing red cell features for {self.session_print()}... (usually takes 10-20 seconds)")
        dot_product = red_cell_processing.compute_dot(plane_idx=None, **dot_parameters)
        corr_coeff = red_cell_processing.compute_corr(plane_idx=None, **corr_parameters)
        phase_corr = red_cell_processing.cropped_phase_correlation(plane_idx=None, **phase_parameters)[3]

        # initialize annotations
        self.saveone(np.full(self.get_value("numROIs"), False), "mpciROIs.redCellIdx")
        self.saveone(np.full((2, self.get_value("numROIs")), False), "mpciROIs.redCellManualAssignments")

        # save oneData
        self.saveone(dot_product, "mpciROIs.redDotProduct")
        self.saveone(corr_coeff, "mpciROIs.redPearson")
        self.saveone(phase_corr, "mpciROIs.redPhaseCorrelation")
        self.saveone(np.array(dot_parameters), "parametersRedDotProduct.keyValuePairs")
        self.saveone(np.array(corr_parameters), "parametersRedPearson.keyValuePairs")
        self.saveone(np.array(phase_parameters), "parametersRedPhaseCorrelation.keyValuePairs")

    # -------------------------------------- methods for handling timeline data produced by rigbox ------------------------------------------------------------
    def load_timeline_structure(self):
        """
        Load timeline structure from Timeline.mat file.

        Loads the Timeline.mat file produced by rigbox and stores it in
        self.tl_file. The file is expected to be named
        "{date}_{session_id}_{mouse_name}_Timeline.mat".

        Notes
        -----
        The timeline file contains raw DAQ data, timestamps, and hardware
        input measurements from the experimental rig.
        """
        tl_file_name = self.data_path / f"{self.date}_{self.session_id}_{self.mouse_name}_Timeline.mat"  # timeline.mat file name
        self.tl_file = scio.loadmat(tl_file_name, simplify_cells=True)["Timeline"]  # load matlab structure

    def timeline_inputs(self, ignore_timestamps=False):
        """
        Get list of available timeline input names.

        Parameters
        ----------
        ignore_timestamps : bool, optional
            If True, return only hardware input names. If False, include
            "timestamps" as the first element. Default is False.

        Returns
        -------
        list of str
            List of timeline input names.
        """
        if not hasattr(self, "tl_file"):
            self.load_timeline_structure()
        hw_inputs = [hwInput["name"] for hwInput in self.tl_file["hw"]["inputs"]]
        if ignore_timestamps:
            return hw_inputs
        return ["timestamps", *hw_inputs]

    def get_timeline_var(self, var_name):
        """
        Get a timeline variable by name.

        Parameters
        ----------
        var_name : str
            Name of the timeline variable to retrieve. Can be "timestamps"
            or any hardware input name.

        Returns
        -------
        np.ndarray
            Timeline variable data as a 1D array.

        Raises
        ------
        AssertionError
            If var_name is not a valid timeline variable name.
        """
        if not hasattr(self, "tl_file"):
            self.load_timeline_structure()
        if var_name == "timestamps":
            return self.tl_file["rawDAQTimestamps"]
        else:
            inputNames = self.timeline_inputs(ignore_timestamps=True)
            assert var_name in inputNames, f"{var_name} is not a tl_file in session {self.session_print()}"
            return np.squeeze(self.tl_file["rawDAQData"][:, np.where([inputName == var_name for inputName in inputNames])[0]])

    def convert_rotary_encoder_to_position(self, rotaryEncoder, rigInfo):
        """
        Convert rotary encoder counts to position in centimeters.

        The rotary encoder is a counter with a large range that sometimes
        wraps around. This method handles wrap-around, computes cumulative
        movement, and scales to centimeters.

        Parameters
        ----------
        rotaryEncoder : np.ndarray
            Rotary encoder counts from timeline.
        rigInfo : DefaultRigInfo or similar
            Rig information containing rotary encoder parameters.

        Returns
        -------
        np.ndarray
            Position in centimeters, shape (num_samples,).
        """
        # rotary encoder is a counter with a big range that sometimes flips around it's axis
        # first get changes in encoder position, fix any big jumps in value, take the cumulative movement and scale to centimeters
        rotary_movement = helpers.diffsame(rotaryEncoder)
        idx_high_values = rotary_movement > 2 ** (rigInfo.rotaryRange - 1)
        idx_low_values = rotary_movement < -(2 ** (rigInfo.rotaryRange - 1))
        rotary_movement[idx_high_values] -= 2**rigInfo.rotaryRange
        rotary_movement[idx_low_values] += 2**rigInfo.rotaryRange
        return rigInfo.rotEncSign * np.cumsum(rotary_movement) * (2 * np.pi * rigInfo.wheelRadius) / rigInfo.wheelToVR

    # -------------------------------------- methods for handling vrBehavior data produced by vrControl ------------------------------------------------------------
    def load_behavior_structure(self):
        """
        Load behavioral structure from VRBehavior_trial.mat file.

        Loads the VRBehavior_trial.mat file produced by vrControl and stores
        it in self.vr_file. If rigInfo is missing, uses DefaultRigInfo as
        a fallback.

        Notes
        -----
        The behavior file contains trial-level and sample-level behavioral
        data including timestamps, positions, rewards, and licks. The file
        is expected to be named
        "{date}_{session_id}_{mouse_name}_VRBehavior_trial.mat".
        """
        vr_file_name = self.data_path / f"{self.date}_{self.session_id}_{self.mouse_name}_VRBehavior_trial.mat"  # vrBehavior output file name
        self.vr_file = scio.loadmat(vr_file_name, struct_as_record=False, squeeze_me=True)
        if "rigInfo" not in self.vr_file.keys():
            print(f"Assuming default settings for B2 using `DefaultRigInfo()` in session: {self.session_print()}!!!")
            self.vr_file["rigInfo"] = DefaultRigInfo()
        if not (hasattr(self.vr_file["rigInfo"], "rotaryRange")):
            self.vr_file["rigInfo"].rotaryRange = 32

    def convert_dense(self, data: Union[np.ndarray, sp.sparse.spmatrix]) -> np.ndarray:
        """
        Convert sparse or dense array to dense numpy array.

        Truncates data to numTrials rows and converts sparse matrices to
        dense arrays.

        Parameters
        ----------
        data : np.ndarray or scipy.sparse.spmatrix
            Input data, which may be sparse or dense.

        Returns
        -------
        np.ndarray
            Dense numpy array, truncated to numTrials rows and squeezed.
        """
        data = data[: self.get_value("numTrials")]
        if sp.sparse.issparse(data):
            data = data.toarray().squeeze()
        else:
            data = np.asarray(data).squeeze()
        return data

    def create_index(self, time_stamps):
        """
        Create index arrays for non-zero/non-NaN timestamps per trial.

        Parameters
        ----------
        time_stamps : np.ndarray
            Timestamps as (numTrials x numSamples) dense numpy array.

        Returns
        -------
        list of np.ndarray
            List of index arrays, one per trial, indicating which samples
            have valid data (non-NaN or non-zero).
        """
        # requires timestamps as (numTrials x numSamples) dense numpy array
        if np.any(np.isnan(time_stamps)):
            return [np.where(~np.isnan(t))[0] for t in time_stamps]  # in case we have dense timestamps with nans where no data
        else:
            return [np.nonzero(t)[0] for t in time_stamps]  # in case we have sparse timestamps with 0s where no data

    def get_vr_data(self, data, nzindex):
        """
        Extract valid data samples using index arrays.

        Parameters
        ----------
        data : np.ndarray
            Data array, shape (numTrials, numSamples).
        nzindex : list of np.ndarray
            List of index arrays, one per trial, indicating valid samples.

        Returns
        -------
        list of np.ndarray
            List of data arrays, one per trial, containing only valid samples.
        """
        return [d[nz] for (d, nz) in zip(data, nzindex)]

__init__(mouse_name, date_string, session_id, opts=B2RegistrationOpts())

Initialize B2Registration object.

Parameters:

Name Type Description Default
mouse_name str

Name of the mouse.

required
date_string str

Date string in format "YYYY-MM-DD".

required
session_id str

Session identifier.

required
opts B2RegistrationOpts or dict

Registration options. If a dict, will be converted to B2RegistrationOpts. Default is B2RegistrationOpts().

B2RegistrationOpts()

Raises:

Type Description
ValueError

If session directory does not exist.

Source code in vrAnalysis/registration/register.py
def __init__(
    self,
    mouse_name: str,
    date_string: str,
    session_id: str,
    opts: Union[B2RegistrationOpts, dict] = B2RegistrationOpts(),
):
    """
    Initialize B2Registration object.

    Parameters
    ----------
    mouse_name : str
        Name of the mouse.
    date_string : str
        Date string in format "YYYY-MM-DD".
    session_id : str
        Session identifier.
    opts : B2RegistrationOpts or dict, optional
        Registration options. If a dict, will be converted to B2RegistrationOpts.
        Default is B2RegistrationOpts().

    Raises
    ------
    ValueError
        If session directory does not exist.
    """
    super().__init__(mouse_name, date_string, session_id)
    if isinstance(opts, B2RegistrationOpts):
        self.opts = opts
    elif is_dataclass(opts):
        self.opts = B2RegistrationOpts(**asdict(opts))
    elif isinstance(opts, Mapping):
        self.opts = B2RegistrationOpts(**opts)
    else:
        raise TypeError(f"opts must be B2RegistrationOpts, a dataclass, or a mapping; got {type(opts)}")

    if not self.data_path.exists():
        raise ValueError(f"Session directory does not exist for {self.session_print()}")

    if not self.one_path.exists():
        self.one_path.mkdir(parents=True)

convert_dense(data)

Convert sparse or dense array to dense numpy array.

Truncates data to numTrials rows and converts sparse matrices to dense arrays.

Parameters:

Name Type Description Default
data ndarray or spmatrix

Input data, which may be sparse or dense.

required

Returns:

Type Description
ndarray

Dense numpy array, truncated to numTrials rows and squeezed.

Source code in vrAnalysis/registration/register.py
def convert_dense(self, data: Union[np.ndarray, sp.sparse.spmatrix]) -> np.ndarray:
    """
    Convert sparse or dense array to dense numpy array.

    Truncates data to numTrials rows and converts sparse matrices to
    dense arrays.

    Parameters
    ----------
    data : np.ndarray or scipy.sparse.spmatrix
        Input data, which may be sparse or dense.

    Returns
    -------
    np.ndarray
        Dense numpy array, truncated to numTrials rows and squeezed.
    """
    data = data[: self.get_value("numTrials")]
    if sp.sparse.issparse(data):
        data = data.toarray().squeeze()
    else:
        data = np.asarray(data).squeeze()
    return data

convert_rotary_encoder_to_position(rotaryEncoder, rigInfo)

Convert rotary encoder counts to position in centimeters.

The rotary encoder is a counter with a large range that sometimes wraps around. This method handles wrap-around, computes cumulative movement, and scales to centimeters.

Parameters:

Name Type Description Default
rotaryEncoder ndarray

Rotary encoder counts from timeline.

required
rigInfo DefaultRigInfo or similar

Rig information containing rotary encoder parameters.

required

Returns:

Type Description
ndarray

Position in centimeters, shape (num_samples,).

Source code in vrAnalysis/registration/register.py
def convert_rotary_encoder_to_position(self, rotaryEncoder, rigInfo):
    """
    Convert rotary encoder counts to position in centimeters.

    The rotary encoder is a counter with a large range that sometimes
    wraps around. This method handles wrap-around, computes cumulative
    movement, and scales to centimeters.

    Parameters
    ----------
    rotaryEncoder : np.ndarray
        Rotary encoder counts from timeline.
    rigInfo : DefaultRigInfo or similar
        Rig information containing rotary encoder parameters.

    Returns
    -------
    np.ndarray
        Position in centimeters, shape (num_samples,).
    """
    # rotary encoder is a counter with a big range that sometimes flips around it's axis
    # first get changes in encoder position, fix any big jumps in value, take the cumulative movement and scale to centimeters
    rotary_movement = helpers.diffsame(rotaryEncoder)
    idx_high_values = rotary_movement > 2 ** (rigInfo.rotaryRange - 1)
    idx_low_values = rotary_movement < -(2 ** (rigInfo.rotaryRange - 1))
    rotary_movement[idx_high_values] -= 2**rigInfo.rotaryRange
    rotary_movement[idx_low_values] += 2**rigInfo.rotaryRange
    return rigInfo.rotEncSign * np.cumsum(rotary_movement) * (2 * np.pi * rigInfo.wheelRadius) / rigInfo.wheelToVR

create_index(time_stamps)

Create index arrays for non-zero/non-NaN timestamps per trial.

Parameters:

Name Type Description Default
time_stamps ndarray

Timestamps as (numTrials x numSamples) dense numpy array.

required

Returns:

Type Description
list of np.ndarray

List of index arrays, one per trial, indicating which samples have valid data (non-NaN or non-zero).

Source code in vrAnalysis/registration/register.py
def create_index(self, time_stamps):
    """
    Create index arrays for non-zero/non-NaN timestamps per trial.

    Parameters
    ----------
    time_stamps : np.ndarray
        Timestamps as (numTrials x numSamples) dense numpy array.

    Returns
    -------
    list of np.ndarray
        List of index arrays, one per trial, indicating which samples
        have valid data (non-NaN or non-zero).
    """
    # requires timestamps as (numTrials x numSamples) dense numpy array
    if np.any(np.isnan(time_stamps)):
        return [np.where(~np.isnan(t))[0] for t in time_stamps]  # in case we have dense timestamps with nans where no data
    else:
        return [np.nonzero(t)[0] for t in time_stamps]  # in case we have sparse timestamps with 0s where no data

do_preprocessing()

Run all preprocessing steps for the session.

Processes timeline data, behavioral data, imaging data, red cell identification, facecam data (placeholder), and creates mappings between behavioral and imaging data.

Notes

Processing steps are run in order: 1. Timeline processing 2. Behavior processing 3. Imaging processing 4. Red cell processing (if enabled) 5. Facecam processing (not yet implemented) 6. Behavior-to-imaging mapping

Source code in vrAnalysis/registration/register.py
def do_preprocessing(self):
    """
    Run all preprocessing steps for the session.

    Processes timeline data, behavioral data, imaging data, red cell
    identification, facecam data (placeholder), and creates mappings
    between behavioral and imaging data.

    Notes
    -----
    Processing steps are run in order:
    1. Timeline processing
    2. Behavior processing
    3. Imaging processing
    4. Red cell processing (if enabled)
    5. Facecam processing (not yet implemented)
    6. Behavior-to-imaging mapping
    """
    if self.opts.clearOne:
        self.clear_one_data(certainty=True)
    self.process_timeline()
    self.process_behavior()
    self.process_imaging()
    self.process_red_cells()
    self.process_facecam()
    self.process_behavior_to_imaging()

get_timeline_var(var_name)

Get a timeline variable by name.

Parameters:

Name Type Description Default
var_name str

Name of the timeline variable to retrieve. Can be "timestamps" or any hardware input name.

required

Returns:

Type Description
ndarray

Timeline variable data as a 1D array.

Raises:

Type Description
AssertionError

If var_name is not a valid timeline variable name.

Source code in vrAnalysis/registration/register.py
def get_timeline_var(self, var_name):
    """
    Get a timeline variable by name.

    Parameters
    ----------
    var_name : str
        Name of the timeline variable to retrieve. Can be "timestamps"
        or any hardware input name.

    Returns
    -------
    np.ndarray
        Timeline variable data as a 1D array.

    Raises
    ------
    AssertionError
        If var_name is not a valid timeline variable name.
    """
    if not hasattr(self, "tl_file"):
        self.load_timeline_structure()
    if var_name == "timestamps":
        return self.tl_file["rawDAQTimestamps"]
    else:
        inputNames = self.timeline_inputs(ignore_timestamps=True)
        assert var_name in inputNames, f"{var_name} is not a tl_file in session {self.session_print()}"
        return np.squeeze(self.tl_file["rawDAQData"][:, np.where([inputName == var_name for inputName in inputNames])[0]])

get_vr_data(data, nzindex)

Extract valid data samples using index arrays.

Parameters:

Name Type Description Default
data ndarray

Data array, shape (numTrials, numSamples).

required
nzindex list of np.ndarray

List of index arrays, one per trial, indicating valid samples.

required

Returns:

Type Description
list of np.ndarray

List of data arrays, one per trial, containing only valid samples.

Source code in vrAnalysis/registration/register.py
def get_vr_data(self, data, nzindex):
    """
    Extract valid data samples using index arrays.

    Parameters
    ----------
    data : np.ndarray
        Data array, shape (numTrials, numSamples).
    nzindex : list of np.ndarray
        List of index arrays, one per trial, indicating valid samples.

    Returns
    -------
    list of np.ndarray
        List of data arrays, one per trial, containing only valid samples.
    """
    return [d[nz] for (d, nz) in zip(data, nzindex)]

load_behavior_structure()

Load behavioral structure from VRBehavior_trial.mat file.

Loads the VRBehavior_trial.mat file produced by vrControl and stores it in self.vr_file. If rigInfo is missing, uses DefaultRigInfo as a fallback.

Notes

The behavior file contains trial-level and sample-level behavioral data including timestamps, positions, rewards, and licks. The file is expected to be named "{date}{session_id}{mouse_name}_VRBehavior_trial.mat".

Source code in vrAnalysis/registration/register.py
def load_behavior_structure(self):
    """
    Load behavioral structure from VRBehavior_trial.mat file.

    Loads the VRBehavior_trial.mat file produced by vrControl and stores
    it in self.vr_file. If rigInfo is missing, uses DefaultRigInfo as
    a fallback.

    Notes
    -----
    The behavior file contains trial-level and sample-level behavioral
    data including timestamps, positions, rewards, and licks. The file
    is expected to be named
    "{date}_{session_id}_{mouse_name}_VRBehavior_trial.mat".
    """
    vr_file_name = self.data_path / f"{self.date}_{self.session_id}_{self.mouse_name}_VRBehavior_trial.mat"  # vrBehavior output file name
    self.vr_file = scio.loadmat(vr_file_name, struct_as_record=False, squeeze_me=True)
    if "rigInfo" not in self.vr_file.keys():
        print(f"Assuming default settings for B2 using `DefaultRigInfo()` in session: {self.session_print()}!!!")
        self.vr_file["rigInfo"] = DefaultRigInfo()
    if not (hasattr(self.vr_file["rigInfo"], "rotaryRange")):
        self.vr_file["rigInfo"].rotaryRange = 32

load_timeline_structure()

Load timeline structure from Timeline.mat file.

Loads the Timeline.mat file produced by rigbox and stores it in self.tl_file. The file is expected to be named "{date}{session_id}{mouse_name}_Timeline.mat".

Notes

The timeline file contains raw DAQ data, timestamps, and hardware input measurements from the experimental rig.

Source code in vrAnalysis/registration/register.py
def load_timeline_structure(self):
    """
    Load timeline structure from Timeline.mat file.

    Loads the Timeline.mat file produced by rigbox and stores it in
    self.tl_file. The file is expected to be named
    "{date}_{session_id}_{mouse_name}_Timeline.mat".

    Notes
    -----
    The timeline file contains raw DAQ data, timestamps, and hardware
    input measurements from the experimental rig.
    """
    tl_file_name = self.data_path / f"{self.date}_{self.session_id}_{self.mouse_name}_Timeline.mat"  # timeline.mat file name
    self.tl_file = scio.loadmat(tl_file_name, simplify_cells=True)["Timeline"]  # load matlab structure

process_behavior()

Process behavioral data from vrControl.

Processes behavioral data using the appropriate behavior processing function based on the vrBehaviorVersion option. Extracts trial-level and sample-level behavioral data and aligns timestamps to the timeline.

See Also

register_behavior : Dispatcher function for behavior processing.

Source code in vrAnalysis/registration/register.py
def process_behavior(self):
    """
    Process behavioral data from vrControl.

    Processes behavioral data using the appropriate behavior processing
    function based on the vrBehaviorVersion option. Extracts trial-level
    and sample-level behavioral data and aligns timestamps to the timeline.

    See Also
    --------
    register_behavior : Dispatcher function for behavior processing.
    """
    self = register_behavior(self, self.opts.vrBehaviorVersion)

    # Confirm that vrBehavior has been processed
    self.preprocessing.append("vrBehavior")

process_behavior_to_imaging()

Create mapping from behavioral frames to imaging frames.

Computes the nearest imaging frame for each behavioral sample and saves the mapping to oneData.

Notes

This method is skipped if imaging is disabled in opts. The mapping is saved as "positionTracking.mpci" in oneData.

Source code in vrAnalysis/registration/register.py
def process_behavior_to_imaging(self):
    """
    Create mapping from behavioral frames to imaging frames.

    Computes the nearest imaging frame for each behavioral sample and
    saves the mapping to oneData.

    Notes
    -----
    This method is skipped if imaging is disabled in opts. The mapping
    is saved as "positionTracking.mpci" in oneData.
    """
    if not self.opts.imaging:
        print(f"In session {self.session_print()}, imaging setting set to False in opts['imaging']. Skipping behavior2imaging processing.")
        return None

    # compute translation mapping from behave frames to imaging frames
    idx_behave_to_frame = helpers.nearestpoint(self.loadone("positionTracking.times"), self.loadone("mpci.times"))[0]
    self.saveone(idx_behave_to_frame.astype(int), "positionTracking.mpci")

process_facecam()

Process facecam data.

Placeholder for facecam preprocessing. Not yet implemented.

Notes

This method currently only prints a message indicating that facecam preprocessing has not been implemented yet.

Source code in vrAnalysis/registration/register.py
def process_facecam(self):
    """
    Process facecam data.

    Placeholder for facecam preprocessing. Not yet implemented.

    Notes
    -----
    This method currently only prints a message indicating that facecam
    preprocessing has not been implemented yet.
    """
    print("Facecam preprocessing has not been coded yet!")

process_imaging()

Process imaging data from suite2p outputs.

Loads suite2p outputs, identifies available planes and outputs, handles frame count mismatches between suite2p and timeline, optionally recomputes deconvolution using OASIS, and saves imaging data to oneData format.

Raises:

Type Description
ValueError

If imaging is requested but suite2p directory does not exist, or if required suite2p outputs are missing, or if frame count mismatches cannot be resolved.

Notes

This method: - Identifies planes and available suite2p outputs - Checks for required outputs (stat, ops, F, Fneu, iscell, spks) - Handles frame count mismatches between suite2p and timeline - Optionally recomputes deconvolution using OASIS - Saves imaging data to oneData

Source code in vrAnalysis/registration/register.py
def process_imaging(self):
    """
    Process imaging data from suite2p outputs.

    Loads suite2p outputs, identifies available planes and outputs,
    handles frame count mismatches between suite2p and timeline,
    optionally recomputes deconvolution using OASIS, and saves imaging
    data to oneData format.

    Raises
    ------
    ValueError
        If imaging is requested but suite2p directory does not exist, or
        if required suite2p outputs are missing, or if frame count
        mismatches cannot be resolved.

    Notes
    -----
    This method:
    - Identifies planes and available suite2p outputs
    - Checks for required outputs (stat, ops, F, Fneu, iscell, spks)
    - Handles frame count mismatches between suite2p and timeline
    - Optionally recomputes deconvolution using OASIS
    - Saves imaging data to oneData
    """
    if not self.opts.imaging:
        print(f"In session {self.session_print()}, imaging setting set to False in opts['imaging']. Skipping image processing.")
        return None

    if not self.s2p_path.exists():
        raise ValueError(f"In session {self.session_print()}, suite2p processing was requested but suite2p directory does not exist.")

    # identifies which planes were processed through suite2p (assume that those are all available planes)
    # identifies which s2p outputs are available from each plane
    self.set_value("planeNames", [plane.parts[-1] for plane in self.s2p_path.glob("plane*/")])
    self.set_value("planeIDs", [int(planeName[5:]) for planeName in self.get_value("planeNames")])
    npysInPlanes = [[npy.stem for npy in list((self.s2p_path / planeName).glob("*.npy"))] for planeName in self.get_value("planeNames")]
    commonNPYs = list(set.intersection(*[set(npy) for npy in npysInPlanes]))
    unionNPYs = list(set.union(*[set(npy) for npy in npysInPlanes]))
    if set(commonNPYs) < set(unionNPYs):
        print(
            f"The following npy files are present in some but not all plane folders within session {self.session_print()}: {list(set(unionNPYs) - set(commonNPYs))}"
        )
        print(f"Each plane folder contains the following npy files: {commonNPYs}")
    self.set_value("available", commonNPYs)  # a list of npy files available in each plane folder

    # required variables (anything else is either optional or can be computed independently)
    required = ["stat", "ops", "F", "Fneu", "iscell"]
    if not self.opts.oasis:
        # add deconvolved spikes to required variable if we aren't recomputing it here
        required.append("spks")
    for varName in required:
        assert varName in self.get_value("available"), f"{self.session_print()} is missing {varName} in at least one suite2p folder!"
    # get number of ROIs in each plane
    self.set_value("roiPerPlane", [iscell.shape[0] for iscell in self.load_s2p("iscell", concatenate=False)])
    # get number of frames in each plane (might be different!)
    self.set_value("framePerPlane", [F.shape[1] for F in self.load_s2p("F", concatenate=False)])
    assert_msg = f"The frame count in {self.session_print()} varies by more than 1 frame! ({self.get_value('framePerPlane')})"
    assert np.max(self.get_value("framePerPlane")) - np.min(self.get_value("framePerPlane")) <= 1, assert_msg
    self.set_value("numROIs", np.sum(self.get_value("roiPerPlane")))  # number of ROIs in session
    # number of frames to use when retrieving imaging data (might be overwritten to something smaller if timeline handled improperly)
    self.set_value("numFrames", np.min(self.get_value("framePerPlane")))

    # Get timeline sample corresponding to each imaging volume
    timeline_timestamps = self.loadone("wheelPosition.times")
    changeFrames = (
        np.append(
            0,
            np.diff(np.ceil(self.get_timeline_var("neuralFrames") / len(self.get_value("planeIDs")))),
        )
        == 1
    )
    frame_samples = np.where(changeFrames)[0]  # TTLs for each volume (increments by 1 for each plane)
    frame_to_time = timeline_timestamps[frame_samples]  # get timelineTimestamps of each imaging volume

    # Handle mismatch between number of imaging frames saved by scanImage (and propagated through suite2p), and between timeline's measurement of the scanImage frame counter
    if len(frame_to_time) != self.get_value("numFrames"):
        if len(frame_to_time) - 1 == self.get_value("numFrames"):
            # If frame_to_time had one more frame, just trim it and assume everything is fine. This happens when a new volume was started but not finished, so does not required communication to user.
            frame_samples = frame_samples[:-1]
            frame_to_time = frame_to_time[:-1]
        elif len(frame_to_time) - 2 == self.get_value("numFrames"):
            print(
                "frame_to_time had 2 more than suite2p output. This happens sometimes. I don't like it. I think it's because scanimage sends a TTL before starting the frame"
            )
            frame_samples = frame_samples[:-2]
            frame_to_time = frame_to_time[:-2]
        else:
            # If frameSamples has too few frames, it's possible that the scanImage signal to timeline was broken but scanImage still continued normally.
            numMissing = self.get_value("numFrames") - len(frame_samples)  # measure number of missing frames
            if numMissing < 0:
                # If frameSamples had many more frames, generate an error -- something went wrong that needs manual inspection
                print(
                    f"In session {self.session_print()}, frameSamples has {len(frame_samples)} elements, but {self.get_value('numFrames')} frames were reported in suite2p. Cannot resolve."
                )
                raise ValueError("Cannot fix mismatches when suite2p data is missing!")
            # It's possible that the scanImage signal to timeline was broken but scanImage still continued normally.
            if numMissing > 1:
                print(
                    f"In session {self.session_print()}, more than one frameSamples sample was missing. Consider using tiff timelineTimestamps to reproduce accurately."
                )
            print(
                (
                    f"In session {self.session_print()}, frameSamples has {len(frame_samples)} elements, but {self.get_value('numFrames')} frames were saved by suite2p. "
                    "Will extend frameSamples using the typical sampling rate and nearestpoint algorithm."
                )
            )
            # If frame_to_time difference vector is consistent within 1%, then use mean (which is a little more accurate), otherwise use median
            frame_to_time = timeline_timestamps[frame_samples]
            medianFramePeriod = np.median(np.diff(frame_to_time))  # measure median sample period
            consistentFrames = np.all(
                np.abs(np.log(np.diff(frame_to_time) / medianFramePeriod)) < np.log(1.01)
            )  # True if all frames take within 1% of median frame period
            if consistentFrames:
                samplePeriod_f2t = np.mean(np.diff(frame_to_time))
            else:
                samplePeriod_f2t = np.median(np.diff(frame_to_time))
            appendFrames = frame_to_time[-1] + samplePeriod_f2t * (
                np.arange(numMissing) + 1
            )  # add elements to frame_to_time, assume sampling rate was perfect
            frame_to_time = np.concatenate((frame_to_time, appendFrames))
            frame_samples = helpers.nearestpoint(frame_to_time, timeline_timestamps)[0]

    # average percentage difference between all samples differences and median -- just a useful metric to be saved --
    self.set_value(
        "samplingDeviationMedianPercentError", np.exp(np.mean(np.abs(np.log(np.diff(frame_to_time) / np.median(np.diff(frame_to_time))))))
    )
    self.set_value(
        "samplingDeviationMaximumPercentError", np.exp(np.max(np.abs(np.log(np.diff(frame_to_time) / np.median(np.diff(frame_to_time))))))
    )

    # recompute deconvolution if requested
    spks = self.load_s2p("spks")
    if self.opts.oasis:
        # set parameters for oasis and get corrected fluorescence traces
        g = np.exp(-1 / self.opts.tau / self.opts.fs)
        fcorr = self.loadfcorr(try_from_one=False)
        results = oasis_deconvolution(fcorr, g)
        ospks = np.stack(results)

        # Check that the shape is correct
        msg = f"In session {self.session_print()}, oasis was run and did not produce the same shaped array as suite2p spks..."
        assert ospks.shape == spks.shape, msg

    # save onedata (no assertions needed, loadS2P() handles shape checks and this function already handled any mismatch between frameSamples and suite2p output
    self.saveone(frame_to_time, "mpci.times")
    self.saveone(LoadingRecipe("S2P", "F", transforms=["transpose"]), "mpci.roiActivityF")
    self.saveone(LoadingRecipe("S2P", "Fneu", transforms=["transpose"]), "mpci.roiNeuropilActivityF")
    self.saveone(LoadingRecipe("S2P", "spks", transforms=["transpose"]), "mpci.roiActivityDeconvolved")
    if "redcell" in self.get_value("available"):
        self.saveone(LoadingRecipe("S2P", "redcell", transforms=["idx_column1"]), "mpciROIs.redS2P")
    self.saveone(LoadingRecipe("S2P", "iscell"), "mpciROIs.isCell")
    self.saveone(self.get_roi_position(), "mpciROIs.stackPosition")
    if self.opts.oasis:
        self.saveone(ospks.T, "mpci.roiActivityDeconvolvedOasis")
    self.preprocessing.append("imaging")

process_red_cells()

Process red cell features for identification.

Computes red cell features (dot product, Pearson correlation, phase correlation) using RedCellProcessing and saves them to oneData. Initializes red cell index and manual assignment arrays.

Notes

This method is skipped if imaging or redCellProcessing is disabled in opts, or if redcell output is not available in suite2p. The computed features are saved to oneData for later use in red cell identification.

Source code in vrAnalysis/registration/register.py
def process_red_cells(self):
    """
    Process red cell features for identification.

    Computes red cell features (dot product, Pearson correlation, phase
    correlation) using RedCellProcessing and saves them to oneData.
    Initializes red cell index and manual assignment arrays.

    Notes
    -----
    This method is skipped if imaging or redCellProcessing is disabled in opts,
    or if redcell output is not available in suite2p. The computed features
    are saved to oneData for later use in red cell identification.
    """
    if not (self.opts.imaging) or not (self.opts.redCellProcessing):
        return  # if not requested, skip function
    # if imaging was processed and redCellProcessing was requested, then try to preprocess red cell features
    if "redcell" not in self.get_value("available"):
        print(f"In session {self.session_print()}, 'redcell' is not an available suite2p output, although 'redCellProcessing' was requested.")
        return

    # create RedCellProcessing object
    # b2session_of_self = B2Session.create(self.mouse_name, self.date, self.session_id, for_registration=True)
    red_cell_processing = RedCellProcessing(self)

    # compute red-features
    dot_parameters = {"lowcut": 12, "highcut": 250, "order": 3, "fs": 512}
    corr_parameters = {"width": 20, "lowcut": 12, "highcut": 250, "order": 3, "fs": 512}
    phase_parameters = {"width": 40, "eps": 1e6, "winFunc": "hamming"}

    print(f"Computing red cell features for {self.session_print()}... (usually takes 10-20 seconds)")
    dot_product = red_cell_processing.compute_dot(plane_idx=None, **dot_parameters)
    corr_coeff = red_cell_processing.compute_corr(plane_idx=None, **corr_parameters)
    phase_corr = red_cell_processing.cropped_phase_correlation(plane_idx=None, **phase_parameters)[3]

    # initialize annotations
    self.saveone(np.full(self.get_value("numROIs"), False), "mpciROIs.redCellIdx")
    self.saveone(np.full((2, self.get_value("numROIs")), False), "mpciROIs.redCellManualAssignments")

    # save oneData
    self.saveone(dot_product, "mpciROIs.redDotProduct")
    self.saveone(corr_coeff, "mpciROIs.redPearson")
    self.saveone(phase_corr, "mpciROIs.redPhaseCorrelation")
    self.saveone(np.array(dot_parameters), "parametersRedDotProduct.keyValuePairs")
    self.saveone(np.array(corr_parameters), "parametersRedPearson.keyValuePairs")
    self.saveone(np.array(phase_parameters), "parametersRedPhaseCorrelation.keyValuePairs")

process_timeline()

Process timeline data from rigbox.

Extracts timestamps, rotary encoder position, lick times, reward times, and trial start times from the Timeline.mat file. Processes photodiode signals to align trial starts with imaging frames.

Notes

This method: - Loads timeline structure from Timeline.mat - Converts rotary encoder to position - Detects licks and rewards - Processes photodiode signal to find trial start frames - Saves timeline data to oneData

Source code in vrAnalysis/registration/register.py
def process_timeline(self):
    """
    Process timeline data from rigbox.

    Extracts timestamps, rotary encoder position, lick times, reward times,
    and trial start times from the Timeline.mat file. Processes photodiode
    signals to align trial starts with imaging frames.

    Notes
    -----
    This method:
    - Loads timeline structure from Timeline.mat
    - Converts rotary encoder to position
    - Detects licks and rewards
    - Processes photodiode signal to find trial start frames
    - Saves timeline data to oneData
    """
    # load these files for raw behavioral & timeline data
    self.load_timeline_structure()
    self.load_behavior_structure()

    # get time stamps, photodiode, trial start and end times, room position, lick times, trial idx, visual data visible
    mpepStartTimes = []
    for mt, me in zip(self.tl_file["mpepUDPTimes"], self.tl_file["mpepUDPEvents"]):
        if isinstance(me, str):
            if "TrialStart" in me:
                mpepStartTimes.append(mt)
            elif "StimStart" in me:
                mpepStartTimes.append(mt)

    mpepStartTimes = np.array(mpepStartTimes)
    timestamps = self.get_timeline_var("timestamps")  # load timestamps

    # Get rotary position -- (load timeline measurement of rotary encoder, which is a circular position counter, use vrExperiment function to convert to a running measurement of position)
    rotaryEncoder = self.get_timeline_var("rotaryEncoder")
    rotaryPosition = self.convert_rotary_encoder_to_position(rotaryEncoder, self.vr_file["rigInfo"])

    # Get Licks (uses an edge counter)
    lickDetector = self.get_timeline_var("lickDetector")  # load lick detector copy
    lickSamples = np.where(helpers.diffsame(lickDetector) == 1)[0].astype(np.uint64)  # timeline samples of lick times

    # Get Reward Commands (measures voltage of output -- assume it's either low or high)
    rewardCommand = self.get_timeline_var("rewardCommand")  # load reward command signal
    rewardCommand = np.round(rewardCommand / np.max(rewardCommand))
    rewardSamples = np.where(helpers.diffsame(rewardCommand) > 0.5)[0].astype(np.uint64)  # timeline samples when reward was delivered

    # Now process photodiode signal
    photodiode = self.get_timeline_var("photoDiode")  # load lick detector copy

    # Remove any slow trends
    pdDetrend = sp.signal.detrend(photodiode)
    pdDetrend = (pdDetrend - pdDetrend.min()) / pdDetrend.ptp()

    # median filter and take smooth derivative
    hfpd = 10
    refreshRate = 30  # hz
    refreshSamples = int(1.0 / refreshRate / np.mean(np.diff(timestamps)))
    pdMedFilt = sp.ndimage.median_filter(pdDetrend, size=refreshSamples)
    pdDerivative, pdIndex = helpers.fivePointDer(pdMedFilt, hfpd, returnIndex=True)
    pdDerivative = sp.stats.zscore(pdDerivative)
    pdDerTime = timestamps[pdIndex]

    # find upward and downward peaks, not perfect but in practice close enough
    locUp = sp.signal.find_peaks(pdDerivative, height=1, distance=refreshSamples / 2)
    locDn = sp.signal.find_peaks(-pdDerivative, height=1, distance=refreshSamples / 2)
    flipTimes = np.concatenate((pdDerTime[locUp[0]], pdDerTime[locDn[0]]))
    flipValue = np.concatenate((np.ones(len(locUp[0])), np.zeros(len(locDn[0]))))
    flipSortIdx = np.argsort(flipTimes)
    flipTimes = flipTimes[flipSortIdx]
    flipValue = flipValue[flipSortIdx]

    # Naive Method (just look for flips before and after trialstart/trialend mpep message:
    # A sophisticated message uses the time of the photodiode ramps, but those are really just for safety and rare manual curation...
    firstFlipIndex = np.array([np.where(flipTimes >= mpepStart)[0][0] for mpepStart in mpepStartTimes])
    startTrialIndex = helpers.nearestpoint(flipTimes[firstFlipIndex], timestamps)[0]  # returns frame index of first photodiode flip in each trial

    # Check that first flip is always down -- all of the vrControl code prepares trials in this way
    if datetime.strptime(self.date, "%Y-%m-%d") >= datetime.strptime("2022-08-30", "%Y-%m-%d"):
        # But it didn't prepare it this way before august 30th :(
        assert np.all(flipValue[firstFlipIndex] == 0), f"In session {self.sessionPrint()}, first flips in trial are not all down!!"

    # Check shapes of timeline arrays
    assert timestamps.ndim == 1, "timelineTimestamps is not a 1-d array!"
    assert timestamps.shape == rotaryPosition.shape, "timeline timestamps and rotary position arrays do not have the same shape!"

    # Save timeline oneData
    self.saveone(timestamps, "wheelPosition.times")
    self.saveone(rotaryPosition, "wheelPosition.position")
    self.saveone(timestamps[lickSamples], "licks.times")
    self.saveone(timestamps[rewardSamples], "rewards.times")
    self.saveone(timestamps[startTrialIndex], "trials.startTimes")
    self.preprocessing.append("timeline")

register()

Register the session by running all preprocessing steps.

This is the main entry point for registration. It runs all preprocessing steps (timeline, behavior, imaging, red cells, facecam, behavior-to-imaging) and saves session parameters.

See Also

do_preprocessing : Run all preprocessing steps. save_session_prms : Save session parameters to oneData.

Source code in vrAnalysis/registration/register.py
def register(self):
    """
    Register the session by running all preprocessing steps.

    This is the main entry point for registration. It runs all preprocessing
    steps (timeline, behavior, imaging, red cells, facecam, behavior-to-imaging)
    and saves session parameters.

    See Also
    --------
    do_preprocessing : Run all preprocessing steps.
    save_session_prms : Save session parameters to oneData.
    """
    self.do_preprocessing()
    self.save_session_prms()

timeline_inputs(ignore_timestamps=False)

Get list of available timeline input names.

Parameters:

Name Type Description Default
ignore_timestamps bool

If True, return only hardware input names. If False, include "timestamps" as the first element. Default is False.

False

Returns:

Type Description
list of str

List of timeline input names.

Source code in vrAnalysis/registration/register.py
def timeline_inputs(self, ignore_timestamps=False):
    """
    Get list of available timeline input names.

    Parameters
    ----------
    ignore_timestamps : bool, optional
        If True, return only hardware input names. If False, include
        "timestamps" as the first element. Default is False.

    Returns
    -------
    list of str
        List of timeline input names.
    """
    if not hasattr(self, "tl_file"):
        self.load_timeline_structure()
    hw_inputs = [hwInput["name"] for hwInput in self.tl_file["hw"]["inputs"]]
    if ignore_timestamps:
        return hw_inputs
    return ["timestamps", *hw_inputs]

RedCellProcessing

Handle red cell processing for B2Registration sessions.

This class processes red cell data from suite2p outputs, computes features for identifying red cells (S2P, dot product, Pearson correlation, phase correlation), and provides methods for updating red cell indices based on cutoff criteria.

Parameters:

Name Type Description Default
b2session B2Session

The B2Session object containing the session data.

required
um_per_pixel float

Micrometers per pixel for spatial measurements. Default is 1.3.

1.3
autoload bool

If True, automatically load reference images and masks on initialization. Default is True.

True

Attributes:

Name Type Description
b2session B2Session

The B2Session object containing the session data.

feature_names list of str

Standard names of features used to determine red cell criterion.

num_planes int

Number of imaging planes in the session.

um_per_pixel float

Micrometers per pixel for spatial measurements.

data_loaded bool

Whether reference images and masks have been loaded.

reference list of np.ndarray

Reference images for each plane (loaded when data_loaded is True).

lx, ly int

Dimensions of reference images.

lam list of np.ndarray

Weights of each pixel in ROI masks.

ypix, xpix list of np.ndarray

Pixel indices for each ROI mask.

roi_plane_idx ndarray

Plane index for each ROI.

red_s2p ndarray

Suite2p red cell values for each ROI.

Source code in vrAnalysis/registration/redcell.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
class RedCellProcessing:
    """
    Handle red cell processing for B2Registration sessions.

    This class processes red cell data from suite2p outputs, computes features
    for identifying red cells (S2P, dot product, Pearson correlation, phase
    correlation), and provides methods for updating red cell indices based on
    cutoff criteria.

    Parameters
    ----------
    b2session : B2Session
        The B2Session object containing the session data.
    um_per_pixel : float, optional
        Micrometers per pixel for spatial measurements. Default is 1.3.
    autoload : bool, optional
        If True, automatically load reference images and masks on initialization.
        Default is True.

    Attributes
    ----------
    b2session : B2Session
        The B2Session object containing the session data.
    feature_names : list of str
        Standard names of features used to determine red cell criterion.
    num_planes : int
        Number of imaging planes in the session.
    um_per_pixel : float
        Micrometers per pixel for spatial measurements.
    data_loaded : bool
        Whether reference images and masks have been loaded.
    reference : list of np.ndarray
        Reference images for each plane (loaded when data_loaded is True).
    lx, ly : int
        Dimensions of reference images.
    lam : list of np.ndarray
        Weights of each pixel in ROI masks.
    ypix, xpix : list of np.ndarray
        Pixel indices for each ROI mask.
    roi_plane_idx : np.ndarray
        Plane index for each ROI.
    red_s2p : np.ndarray
        Suite2p red cell values for each ROI.
    """

    def __init__(
        self,
        b2session: "B2Session",
        um_per_pixel: float = 1.3,
        autoload: bool = True,
    ):
        """
        Initialize RedCellProcessing object.

        Parameters
        ----------
        b2session : B2Session
            The B2Session object containing the session data.
        um_per_pixel : float, optional
            Micrometers per pixel for spatial measurements. Default is 1.3.
        autoload : bool, optional
            If True, automatically load reference images and masks on initialization.
            Default is True.

        Raises
        ------
        AssertionError
            If redcell is not available in suite2p outputs.
        """

        # Make sure redcell is available...
        msg = "redcell is not an available suite2p output, so you can't do redCellProcessing."
        assert "redcell" in b2session.get_value("available"), msg

        self.b2session = b2session

        # standard names of the features used to determine red cell criterion
        self.feature_names = ["S2P", "dotProduct", "pearson", "phaseCorrelation"]

        # load some critical values for easy readable access
        self.num_planes = len(self.b2session.get_value("planeNames"))
        self.um_per_pixel = um_per_pixel  # store this for generating correct axes and measuring distances

        self.data_loaded = False  # initialize to false in case data isn't loaded
        if autoload:
            self.load_reference_and_masks()  # prepare reference images and ROI mask data

    # ------------------------------
    # -- initialization functions --
    # ------------------------------
    def load_reference_and_masks(self):
        """
        Load reference images and ROI masks from suite2p outputs.

        Loads the mean image for channel 2 (red channel) for each plane, along
        with ROI mask data (lam, ypix, xpix) and ROI plane indices. Also loads
        suite2p red cell values and creates supporting variables for spatial
        measurements.

        Raises
        ------
        AssertionError
            If reference images do not all have the same shape.
        """
        # load reference images
        ops = self.b2session.load_s2p("ops")
        self.reference = [op["meanImg_chan2"] for op in ops]
        self.lx, self.ly = self.reference[0].shape
        for ref in self.reference:
            msg = "reference images do not all have the same shape"
            assert (self.lx, self.ly) == ref.shape, msg

        # load masks (lam=weight of each pixel, xpix & ypix=index of each pixel in ROI mask)
        stat = self.b2session.load_s2p("stat")
        self.lam = [s["lam"] for s in stat]
        self.ypix = [s["ypix"] for s in stat]
        self.xpix = [s["xpix"] for s in stat]
        self.roi_plane_idx = self.b2session.loadone("mpciROIs.stackPosition")[:, 2]

        # load S2P red cell value
        self.red_s2p = self.b2session.loadone("mpciROIs.redS2P")  # (preloaded, will never change in this function)

        # create supporting variables for mapping locations and axes
        self.y_base_ref = np.arange(self.ly)
        self.x_base_ref = np.arange(self.lx)
        self.y_dist_ref = self.create_centered_axis(self.ly, self.um_per_pixel)
        self.x_dist_ref = self.create_centered_axis(self.lx, self.um_per_pixel)

        # update data_loaded field
        self.data_loaded = True

    # ---------------------------------
    # -- updating one data functions --
    # ---------------------------------
    def one_name_feature_cutoffs(self, name):
        """
        Generate oneData name for feature cutoff parameters.

        Parameters
        ----------
        name : str
            Feature name (e.g., "S2P", "dotProduct", "pearson", "phaseCorrelation").

        Returns
        -------
        str
            OneData name for the feature cutoff parameter, formatted as
            "parametersRed{Name}.minMaxCutoff" where {Name} is the capitalized
            feature name.
        """
        return "parameters" + "Red" + name[0].upper() + name[1:] + ".minMaxCutoff"

    def update_red_idx(self, s2p_cutoff=None, dot_product_cutoff=None, corr_coef_cutoff=None, phase_corr_cutoff=None):
        """
        Update red cell index based on feature cutoff values.

        Updates the red cell index by applying minimum and maximum cutoffs to
        each feature (S2P, dot product, Pearson correlation, phase correlation).
        Only features with non-NaN cutoff values are applied. The red cell index
        is updated to include only ROIs that meet all specified criteria.

        Parameters
        ----------
        s2p_cutoff : array-like of float, length 2, optional
            [min, max] cutoff values for suite2p red cell feature. NaN values
            indicate the cutoff should not be applied. Default is None.
        dot_product_cutoff : array-like of float, length 2, optional
            [min, max] cutoff values for dot product feature. Default is None.
        corr_coef_cutoff : array-like of float, length 2, optional
            [min, max] cutoff values for Pearson correlation feature.
            Default is None.
        phase_corr_cutoff : array-like of float, length 2, optional
            [min, max] cutoff values for phase correlation feature.
            Default is None.

        Raises
        ------
        ValueError
            If any cutoff is not a numpy array or list, or if any cutoff does
            not have exactly 2 elements.

        Notes
        -----
        Cutoff values are saved to oneData for future reference. The red cell
        index is updated in place and saved to oneData.
        """
        # create initial all true red cell idx
        red_cell_idx = np.full(self.b2session.loadone("mpciROIs.redCellIdx").shape, True)

        # load feature values for each ROI
        red_s2p = self.b2session.loadone("mpciROIs.redS2P")
        dot_product = self.b2session.loadone("mpciROIs.redDotProduct")
        corr_coef = self.b2session.loadone("mpciROIs.redPearson")
        phase_corr = self.b2session.loadone("mpciROIs.redPhaseCorrelation")

        # create lists for zipping through each feature/cutoff combination
        features = [red_s2p, dot_product, corr_coef, phase_corr]
        cutoffs = [s2p_cutoff, dot_product_cutoff, corr_coef_cutoff, phase_corr_cutoff]
        usecutoff = [[False, False] for _ in range(len(cutoffs))]

        # check validity of each cutoff and identify whether it should be used
        for name, use, cutoff in zip(self.feature_names, usecutoff, cutoffs):
            if not isinstance(cutoff, np.ndarray) and not isinstance(cutoff, list):
                raise ValueError(f"Expecting a numpy array or a list for {name} cutoff, got {type(cutoff)}")
            assert len(cutoff) == 2, f"{name} cutoff does not have 2 elements"
            if not (np.isnan(cutoff[0])):
                use[0] = True
            if not (np.isnan(cutoff[1])):
                use[1] = True

        # add feature cutoffs to redCellIdx (sets any to False that don't meet the cutoff)
        for feature, use, cutoff in zip(features, usecutoff, cutoffs):
            if use[0]:
                red_cell_idx &= feature >= cutoff[0]
            if use[1]:
                red_cell_idx &= feature <= cutoff[1]

        # save new red cell index to one data
        self.b2session.saveone(red_cell_idx, "mpciROIs.redCellIdx")

        # save feature cutoffs to one data
        for idx, name in enumerate(self.feature_names):
            self.b2session.saveone(cutoffs[idx], self.one_name_feature_cutoffs(name))
        print(f"Red Cell curation choices are saved for session {self.b2session.session_print()}")

    def update_from_session(self, red_cell, force_update=False):
        """
        Update red cell cutoffs from another session.

        Copies red cell cutoff parameters from another RedCellProcessing object
        and applies them to this session.

        Parameters
        ----------
        red_cell : RedCellProcessing
            Another RedCellProcessing object to copy cutoffs from.
        force_update : bool, optional
            If False, only allows copying from sessions with the same mouse name.
            If True, allows copying from any session. Default is False.

        Raises
        ------
        AssertionError
            If red_cell is not a RedCellProcessing object, or if force_update is
            False and the mouse names don't match.
        """
        assert isinstance(red_cell, RedCellProcessing), "red_cell is not a RedCellProcessing object"
        if not (force_update):
            assert (
                red_cell.b2session.mouse_name == self.b2session.mouse_name
            ), "session to copy from is from a different mouse, this isn't allowed without the force_update=True input"
        cutoffs = [red_cell.b2session.loadone(red_cell.one_name_feature_cutoffs(name)) for name in self.feature_names]
        self.update_red_idx(s2p_cutoff=cutoffs[0], dot_product_cutoff=cutoffs[1], corr_coef_cutoff=cutoffs[2], phase_corr_cutoff=cutoffs[3])

    def cropped_phase_correlation(self, plane_idx=None, width=40, eps=1e6, winFunc=lambda x: np.hamming(x.shape[-1])):
        """
        Compute phase correlation of cropped masks with cropped reference images.

        Returns the phase correlation of each ROI mask (cropped around the ROI
        centroid) with the corresponding cropped reference image. This is used
        as a feature for identifying red cells.

        Parameters
        ----------
        plane_idx : int or array-like of int, optional
            Plane indices to process. If None, processes all planes. Default is None.
        width : float, optional
            Width in micrometers of the cropped region around each ROI centroid.
            Default is 40.
        eps : float, optional
            Small value added to avoid division by zero in phase correlation.
            Default is 1e6.
        winFunc : callable or str, optional
            Window function to apply before computing phase correlation. If "hamming",
            uses Hamming window. Otherwise should be a callable that takes an array
            and returns a windowed array. Default is Hamming window.

        Returns
        -------
        refStack : np.ndarray
            Stack of cropped reference images, shape (num_rois, height, width).
        maskStack : np.ndarray
            Stack of cropped ROI masks, shape (num_rois, height, width).
        pxcStack : np.ndarray
            Stack of phase correlation maps, shape (num_rois, height, width).
        phase_corr_values : np.ndarray
            Phase correlation values at the center of each correlation map,
            shape (num_rois,). This is the feature value used for red cell identification.

        Notes
        -----
        The default parameters (width=40um, eps=1e6, and a Hamming window function)
        were tested on a few sessions and are subjective. Manual curation and
        parameter adjustment may be necessary for optimal results.
        """
        if not (self.data_loaded):
            self.load_reference_and_masks()
        if winFunc == "hamming":
            winFunc = lambda x: np.hamming(x.shape[-1])
        refStack = self.centered_reference_stack(plane_idx=plane_idx, width=width)  # get stack of reference image centered on each ROI
        maskStack = self.centered_mask_stack(plane_idx=plane_idx, width=width)  # get stack of mask value centered on each ROI
        window = winFunc(refStack)  # create a window function
        pxcStack = np.stack(
            [helpers.phaseCorrelation(ref, mask, eps=eps, window=window) for (ref, mask) in zip(refStack, maskStack)]
        )  # measure phase correlation
        pxcCenterPixel = int((pxcStack.shape[2] - 1) / 2)
        return refStack, maskStack, pxcStack, pxcStack[:, pxcCenterPixel, pxcCenterPixel]

    def compute_dot(self, plane_idx=None, lowcut=12, highcut=250, order=3, fs=512):
        """
        Compute normalized dot product between filtered reference and ROI masks.

        Computes the dot product between each ROI mask and a Butterworth-filtered
        reference image. This is used as a feature for identifying red cells.

        Parameters
        ----------
        plane_idx : int or array-like of int, optional
            Plane indices to process. If None, processes all planes. Default is None.
        lowcut : float, optional
            Low cutoff frequency for Butterworth bandpass filter in Hz.
            Default is 12.
        highcut : float, optional
            High cutoff frequency for Butterworth bandpass filter in Hz.
            Default is 250.
        order : int, optional
            Order of the Butterworth filter. Default is 3.
        fs : float, optional
            Sampling frequency for the filter in Hz. Default is 512.

        Returns
        -------
        np.ndarray
            Normalized dot product values for each ROI, shape (num_rois,).
        """
        if plane_idx is None:
            plane_idx = np.arange(self.num_planes)
        if isinstance(plane_idx, (int, np.integer)):
            plane_idx = (plane_idx,)  # make plane_idx iterable
        if not (self.data_loaded):
            self.load_reference_and_masks()

        dot_prod = []
        for plane in plane_idx:
            t = time.time()
            c_roi_idx = np.where(self.roi_plane_idx == plane)[0]  # index of ROIs in this plane
            bwReference = helpers.butterworthbpf(self.reference[plane], lowcut, highcut, order=order, fs=fs)  # filtered reference image
            bwReference /= np.linalg.norm(bwReference)  # adjust to norm for straightforward cosine angle
            # compute normalized dot product for each ROI
            dot_prod.append(
                np.array([bwReference[self.ypix[roi], self.xpix[roi]] @ self.lam[roi] / np.linalg.norm(self.lam[roi]) for roi in c_roi_idx])
            )

        return np.concatenate(dot_prod)

    def compute_corr(self, plane_idx=None, width=20, lowcut=12, highcut=250, order=3, fs=512):
        """
        Compute Pearson correlation between filtered reference and ROI masks.

        Computes the Pearson correlation coefficient between each ROI mask and
        a Butterworth-filtered reference image within a cropped region around
        each ROI. This is used as a feature for identifying red cells.

        Parameters
        ----------
        plane_idx : int or array-like of int, optional
            Plane indices to process. If None, processes all planes. Default is None.
        width : float, optional
            Width in micrometers of the cropped region around each ROI centroid.
            Default is 20.
        lowcut : float, optional
            Low cutoff frequency for Butterworth bandpass filter in Hz.
            Default is 12.
        highcut : float, optional
            High cutoff frequency for Butterworth bandpass filter in Hz.
            Default is 250.
        order : int, optional
            Order of the Butterworth filter. Default is 3.
        fs : float, optional
            Sampling frequency for the filter in Hz. Default is 512.

        Returns
        -------
        np.ndarray
            Pearson correlation coefficients for each ROI, shape (num_rois,).
        """
        if plane_idx is None:
            plane_idx = np.arange(self.num_planes)
        if isinstance(plane_idx, (int, np.integer)):
            plane_idx = (plane_idx,)  # make plane_idx iterable
        if not (self.data_loaded):
            self.load_reference_and_masks()

        corr_coef = []
        for plane in plane_idx:
            num_rois = self.b2session.get_value("roiPerPlane")[plane]
            c_ref_stack = np.reshape(
                self.centered_reference_stack(plane_idx=plane, width=width, fill=np.nan, filtPrms=(lowcut, highcut, order, fs)),
                (num_rois, -1),
            )
            c_mask_stack = np.reshape(self.centered_mask_stack(plane_idx=plane, width=width, fill=0), (num_rois, -1))
            c_mask_stack[np.isnan(c_ref_stack)] = np.nan

            # Measure mean and standard deviation (and number of non-nan datapoints)
            u_ref = np.nanmean(c_ref_stack, axis=1, keepdims=True)
            u_mask = np.nanmean(c_mask_stack, axis=1, keepdims=True)
            s_ref = np.nanstd(c_ref_stack, axis=1)
            s_mask = np.nanstd(c_mask_stack, axis=1)
            N = np.sum(~np.isnan(c_ref_stack), axis=1)

            # compute correlation coefficient and add to storage variable
            corr_coef.append(np.nansum((c_ref_stack - u_ref) * (c_mask_stack - u_mask), axis=1) / N / s_ref / s_mask)

        return np.concatenate(corr_coef)

    # --------------------------
    # -- supporting functions --
    # --------------------------
    def create_centered_axis(self, numElements, scale=1):
        """
        Create a centered axis array.

        Parameters
        ----------
        numElements : int
            Number of elements in the axis.
        scale : float, optional
            Scaling factor for the axis. Default is 1.

        Returns
        -------
        np.ndarray
            Centered axis array, shape (numElements,), with values ranging from
            -scale*(numElements-1)/2 to scale*(numElements-1)/2.
        """
        return scale * (np.arange(numElements) - (numElements - 1) / 2)

    def getyref(self, yCenter):
        """
        Get y-axis reference coordinates relative to a center point.

        Parameters
        ----------
        yCenter : float
            Y-coordinate of the center point in pixels.

        Returns
        -------
        np.ndarray
            Y-axis coordinates in micrometers relative to yCenter, shape (ly,).
        """
        if not (self.data_loaded):
            self.load_reference_and_masks()
        return self.um_per_pixel * (self.y_base_ref - yCenter)

    def getxref(self, xCenter):
        """
        Get x-axis reference coordinates relative to a center point.

        Parameters
        ----------
        xCenter : float
            X-coordinate of the center point in pixels.

        Returns
        -------
        np.ndarray
            X-axis coordinates in micrometers relative to xCenter, shape (lx,).
        """
        if not (self.data_loaded):
            self.load_reference_and_masks()
        return self.um_per_pixel * (self.x_base_ref - xCenter)

    def get_roi_centroid(self, idx, mode="weightedmean"):
        """
        Get the centroid of an ROI.

        Parameters
        ----------
        idx : int
            Index of the ROI.
        mode : str, optional
            Method for computing centroid. "weightedmean" uses pixel weights (lam),
            "median" uses median pixel coordinates. Default is "weightedmean".

        Returns
        -------
        yc : float
            Y-coordinate of the centroid in pixels.
        xc : float
            X-coordinate of the centroid in pixels.
        """
        if not (self.data_loaded):
            self.load_reference_and_masks()

        if mode == "weightedmean":
            yc = np.sum(self.lam[idx] * self.ypix[idx]) / np.sum(self.lam[idx])
            xc = np.sum(self.lam[idx] * self.xpix[idx]) / np.sum(self.lam[idx])
        elif mode == "median":
            yc = int(np.median(self.ypix[idx]))
            xc = int(np.median(self.xpix[idx]))

        return yc, xc

    def get_roi_range(self, idx):
        """
        Get the range (peak-to-peak) of x and y pixels for an ROI.

        Parameters
        ----------
        idx : int
            Index of the ROI.

        Returns
        -------
        yr : int
            Range of y-pixels (peak-to-peak).
        xr : int
            Range of x-pixels (peak-to-peak).
        """
        if not (self.data_loaded):
            self.load_reference_and_masks()
        # get range of x and y pixels for a particular ROI
        yr = np.ptp(self.ypix[idx])
        xr = np.ptp(self.xpix[idx])
        return yr, xr

    def get_roi_in_plane_idx(self, idx):
        """
        Get the index of an ROI within its own plane.

        Parameters
        ----------
        idx : int
            Global ROI index.

        Returns
        -------
        int
            Index of the ROI within its plane (0-indexed within that plane).
        """
        if not (self.data_loaded):
            self.load_reference_and_masks()
        # return index of ROI within it's own plane
        plane_idx = self.roi_plane_idx[idx]
        return idx - np.sum(self.roi_plane_idx < plane_idx)

    def centered_reference_stack(self, plane_idx=None, width=15, fill=0.0, filtPrms=None):
        """
        Create a stack of reference images centered on each ROI.

        Returns a stack of reference images cropped around each ROI centroid
        within a specified width. Optionally applies a Butterworth filter to
        the reference images before cropping.

        Parameters
        ----------
        plane_idx : int or array-like of int, optional
            Plane indices to process. If None, processes all planes. Default is None.
        width : float, optional
            Width in micrometers of the cropped region around each ROI centroid.
            Default is 15.
        fill : float, optional
            Value to use for background pixels outside the image bounds.
            Should be 0.0 or np.nan. Default is 0.0.
        filtPrms : tuple of 4 floats, optional
            Parameters for Butterworth filter: (lowcut, highcut, order, fs).
            If None, no filtering is applied. Default is None.

        Returns
        -------
        np.ndarray
            Stack of centered reference images, shape (num_rois, height, width),
            where height = width = 2 * round(width / um_per_pixel) + 1.
        """
        if plane_idx is None:
            plane_idx = np.arange(self.num_planes)
        if isinstance(plane_idx, (int, np.integer)):
            plane_idx = (plane_idx,)  # make plane_idx iterable
        if not (self.data_loaded):
            self.load_reference_and_masks()
        num_pixels = int(np.round(width / self.um_per_pixel))  # numPixels to each side around the centroid
        ref_stack = []
        for plane in plane_idx:
            c_reference = self.reference[plane]
            if filtPrms is not None:
                # filtered reference image
                c_reference = helpers.butterworthbpf(c_reference, filtPrms[0], filtPrms[1], order=filtPrms[2], fs=filtPrms[3])
            idx_roi_in_plane = np.where(self.roi_plane_idx == plane)[0]
            ref_stack.append(np.full((len(idx_roi_in_plane), 2 * num_pixels + 1, 2 * num_pixels + 1), fill))
            # fill the reference stack with the reference image
            for idx, idx_roi in enumerate(idx_roi_in_plane):
                yc, xc = self.get_roi_centroid(idx_roi, mode="median")
                yUse = (np.maximum(yc - num_pixels, 0), np.minimum(yc + num_pixels + 1, self.ly))
                xUse = (np.maximum(xc - num_pixels, 0), np.minimum(xc + num_pixels + 1, self.lx))
                yMissing = (
                    -np.minimum(yc - num_pixels, 0),
                    -np.minimum(self.ly - (yc + num_pixels + 1), 0),
                )
                xMissing = (
                    -np.minimum(xc - num_pixels, 0),
                    -np.minimum(self.lx - (xc + num_pixels + 1), 0),
                )
                ref_stack[-1][
                    idx,
                    yMissing[0] : 2 * num_pixels + 1 - yMissing[1],
                    xMissing[0] : 2 * num_pixels + 1 - xMissing[1],
                ] = c_reference[yUse[0] : yUse[1], xUse[0] : xUse[1]]
        return np.concatenate(ref_stack, axis=0).astype(np.float32)

    def centered_mask_stack(self, plane_idx=None, width=15, fill=0.0):
        """
        Create a stack of ROI masks centered on each ROI.

        Returns a stack of ROI masks cropped around each ROI centroid within
        a specified width. Mask values (lam) are placed at the appropriate
        positions in the centered stack.

        Parameters
        ----------
        plane_idx : int or array-like of int, optional
            Plane indices to process. If None, processes all planes. Default is None.
        width : float, optional
            Width in micrometers of the cropped region around each ROI centroid.
            Default is 15.
        fill : float, optional
            Value to use for background pixels outside the ROI mask.
            Should be 0.0 or np.nan. Default is 0.0.

        Returns
        -------
        np.ndarray
            Stack of centered ROI masks, shape (num_rois, height, width),
            where height = width = 2 * round(width / um_per_pixel) + 1.
        """
        if plane_idx is None:
            plane_idx = np.arange(self.num_planes)
        if isinstance(plane_idx, (int, np.integer)):
            plane_idx = (plane_idx,)  # make plane_idx iterable
        if not (self.data_loaded):
            self.load_reference_and_masks()
        num_pixels = int(np.round(width / self.um_per_pixel))  # numPixels to each side around the centroid
        mask_stack = []
        for plane in plane_idx:
            idx_roi_in_plane = np.where(self.roi_plane_idx == plane)[0]
            mask_stack.append(np.full((len(idx_roi_in_plane), 2 * num_pixels + 1, 2 * num_pixels + 1), fill))
            for idx, idx_roi in enumerate(idx_roi_in_plane):
                yc, xc = self.get_roi_centroid(idx_roi, mode="median")
                # centered y&x pixels of ROI
                cyidx = self.ypix[idx_roi] - yc + num_pixels
                cxidx = self.xpix[idx_roi] - xc + num_pixels
                # index of pixels still within width of stack
                idx_use_points = (cyidx >= 0) & (cyidx < 2 * num_pixels + 1) & (cxidx >= 0) & (cxidx < 2 * num_pixels + 1)
                mask_stack[-1][idx, cyidx[idx_use_points], cxidx[idx_use_points]] = self.lam[idx_roi][idx_use_points]
        return np.concatenate(mask_stack, axis=0).astype(np.float32)

    def compute_volume(self, plane_idx=None):
        """
        Compute full-volume ROI masks for specified planes.

        Creates a 3D array where each ROI mask is placed at its original
        position in the full image plane. This is useful for visualization
        or volume-based operations.

        Parameters
        ----------
        plane_idx : int or array-like of int, optional
            Plane indices to process. If None, processes all planes. Default is None.

        Returns
        -------
        np.ndarray
            Volume array of ROI masks, shape (num_rois, ly, lx), where ly and lx
            are the dimensions of the reference images.

        Raises
        ------
        AssertionError
            If any plane index is out of range.
        """
        if plane_idx is None:
            plane_idx = np.arange(self.num_planes)
        if isinstance(plane_idx, (int, np.integer)):
            plane_idx = (plane_idx,)  # make plane_idx iterable
        msg = f"in session: {self.b2session.session_print()}, there are only {self.num_planes} planes!"
        assert all([0 <= plane < self.num_planes for plane in plane_idx]), msg
        if not (self.data_loaded):
            self.load_reference_and_masks()
        roi_mask_volume = []
        for plane in plane_idx:
            roi_mask_volume.append(np.zeros((self.b2session.get_value("roiPerPlane")[plane], self.ly, self.lx)))
            idx_roi_in_plane = np.where(self.roi_plane_idx == plane)[0]
            for roi in range(self.b2session.get_value("roiPerPlane")[plane]):
                c_roi_idx = idx_roi_in_plane[roi]
                roi_mask_volume[-1][roi, self.ypix[c_roi_idx], self.xpix[c_roi_idx]] = self.lam[c_roi_idx]
        return np.concatenate(roi_mask_volume, axis=0)

__init__(b2session, um_per_pixel=1.3, autoload=True)

Initialize RedCellProcessing object.

Parameters:

Name Type Description Default
b2session B2Session

The B2Session object containing the session data.

required
um_per_pixel float

Micrometers per pixel for spatial measurements. Default is 1.3.

1.3
autoload bool

If True, automatically load reference images and masks on initialization. Default is True.

True

Raises:

Type Description
AssertionError

If redcell is not available in suite2p outputs.

Source code in vrAnalysis/registration/redcell.py
def __init__(
    self,
    b2session: "B2Session",
    um_per_pixel: float = 1.3,
    autoload: bool = True,
):
    """
    Initialize RedCellProcessing object.

    Parameters
    ----------
    b2session : B2Session
        The B2Session object containing the session data.
    um_per_pixel : float, optional
        Micrometers per pixel for spatial measurements. Default is 1.3.
    autoload : bool, optional
        If True, automatically load reference images and masks on initialization.
        Default is True.

    Raises
    ------
    AssertionError
        If redcell is not available in suite2p outputs.
    """

    # Make sure redcell is available...
    msg = "redcell is not an available suite2p output, so you can't do redCellProcessing."
    assert "redcell" in b2session.get_value("available"), msg

    self.b2session = b2session

    # standard names of the features used to determine red cell criterion
    self.feature_names = ["S2P", "dotProduct", "pearson", "phaseCorrelation"]

    # load some critical values for easy readable access
    self.num_planes = len(self.b2session.get_value("planeNames"))
    self.um_per_pixel = um_per_pixel  # store this for generating correct axes and measuring distances

    self.data_loaded = False  # initialize to false in case data isn't loaded
    if autoload:
        self.load_reference_and_masks()  # prepare reference images and ROI mask data

centered_mask_stack(plane_idx=None, width=15, fill=0.0)

Create a stack of ROI masks centered on each ROI.

Returns a stack of ROI masks cropped around each ROI centroid within a specified width. Mask values (lam) are placed at the appropriate positions in the centered stack.

Parameters:

Name Type Description Default
plane_idx int or array-like of int

Plane indices to process. If None, processes all planes. Default is None.

None
width float

Width in micrometers of the cropped region around each ROI centroid. Default is 15.

15
fill float

Value to use for background pixels outside the ROI mask. Should be 0.0 or np.nan. Default is 0.0.

0.0

Returns:

Type Description
ndarray

Stack of centered ROI masks, shape (num_rois, height, width), where height = width = 2 * round(width / um_per_pixel) + 1.

Source code in vrAnalysis/registration/redcell.py
def centered_mask_stack(self, plane_idx=None, width=15, fill=0.0):
    """
    Create a stack of ROI masks centered on each ROI.

    Returns a stack of ROI masks cropped around each ROI centroid within
    a specified width. Mask values (lam) are placed at the appropriate
    positions in the centered stack.

    Parameters
    ----------
    plane_idx : int or array-like of int, optional
        Plane indices to process. If None, processes all planes. Default is None.
    width : float, optional
        Width in micrometers of the cropped region around each ROI centroid.
        Default is 15.
    fill : float, optional
        Value to use for background pixels outside the ROI mask.
        Should be 0.0 or np.nan. Default is 0.0.

    Returns
    -------
    np.ndarray
        Stack of centered ROI masks, shape (num_rois, height, width),
        where height = width = 2 * round(width / um_per_pixel) + 1.
    """
    if plane_idx is None:
        plane_idx = np.arange(self.num_planes)
    if isinstance(plane_idx, (int, np.integer)):
        plane_idx = (plane_idx,)  # make plane_idx iterable
    if not (self.data_loaded):
        self.load_reference_and_masks()
    num_pixels = int(np.round(width / self.um_per_pixel))  # numPixels to each side around the centroid
    mask_stack = []
    for plane in plane_idx:
        idx_roi_in_plane = np.where(self.roi_plane_idx == plane)[0]
        mask_stack.append(np.full((len(idx_roi_in_plane), 2 * num_pixels + 1, 2 * num_pixels + 1), fill))
        for idx, idx_roi in enumerate(idx_roi_in_plane):
            yc, xc = self.get_roi_centroid(idx_roi, mode="median")
            # centered y&x pixels of ROI
            cyidx = self.ypix[idx_roi] - yc + num_pixels
            cxidx = self.xpix[idx_roi] - xc + num_pixels
            # index of pixels still within width of stack
            idx_use_points = (cyidx >= 0) & (cyidx < 2 * num_pixels + 1) & (cxidx >= 0) & (cxidx < 2 * num_pixels + 1)
            mask_stack[-1][idx, cyidx[idx_use_points], cxidx[idx_use_points]] = self.lam[idx_roi][idx_use_points]
    return np.concatenate(mask_stack, axis=0).astype(np.float32)

centered_reference_stack(plane_idx=None, width=15, fill=0.0, filtPrms=None)

Create a stack of reference images centered on each ROI.

Returns a stack of reference images cropped around each ROI centroid within a specified width. Optionally applies a Butterworth filter to the reference images before cropping.

Parameters:

Name Type Description Default
plane_idx int or array-like of int

Plane indices to process. If None, processes all planes. Default is None.

None
width float

Width in micrometers of the cropped region around each ROI centroid. Default is 15.

15
fill float

Value to use for background pixels outside the image bounds. Should be 0.0 or np.nan. Default is 0.0.

0.0
filtPrms tuple of 4 floats

Parameters for Butterworth filter: (lowcut, highcut, order, fs). If None, no filtering is applied. Default is None.

None

Returns:

Type Description
ndarray

Stack of centered reference images, shape (num_rois, height, width), where height = width = 2 * round(width / um_per_pixel) + 1.

Source code in vrAnalysis/registration/redcell.py
def centered_reference_stack(self, plane_idx=None, width=15, fill=0.0, filtPrms=None):
    """
    Create a stack of reference images centered on each ROI.

    Returns a stack of reference images cropped around each ROI centroid
    within a specified width. Optionally applies a Butterworth filter to
    the reference images before cropping.

    Parameters
    ----------
    plane_idx : int or array-like of int, optional
        Plane indices to process. If None, processes all planes. Default is None.
    width : float, optional
        Width in micrometers of the cropped region around each ROI centroid.
        Default is 15.
    fill : float, optional
        Value to use for background pixels outside the image bounds.
        Should be 0.0 or np.nan. Default is 0.0.
    filtPrms : tuple of 4 floats, optional
        Parameters for Butterworth filter: (lowcut, highcut, order, fs).
        If None, no filtering is applied. Default is None.

    Returns
    -------
    np.ndarray
        Stack of centered reference images, shape (num_rois, height, width),
        where height = width = 2 * round(width / um_per_pixel) + 1.
    """
    if plane_idx is None:
        plane_idx = np.arange(self.num_planes)
    if isinstance(plane_idx, (int, np.integer)):
        plane_idx = (plane_idx,)  # make plane_idx iterable
    if not (self.data_loaded):
        self.load_reference_and_masks()
    num_pixels = int(np.round(width / self.um_per_pixel))  # numPixels to each side around the centroid
    ref_stack = []
    for plane in plane_idx:
        c_reference = self.reference[plane]
        if filtPrms is not None:
            # filtered reference image
            c_reference = helpers.butterworthbpf(c_reference, filtPrms[0], filtPrms[1], order=filtPrms[2], fs=filtPrms[3])
        idx_roi_in_plane = np.where(self.roi_plane_idx == plane)[0]
        ref_stack.append(np.full((len(idx_roi_in_plane), 2 * num_pixels + 1, 2 * num_pixels + 1), fill))
        # fill the reference stack with the reference image
        for idx, idx_roi in enumerate(idx_roi_in_plane):
            yc, xc = self.get_roi_centroid(idx_roi, mode="median")
            yUse = (np.maximum(yc - num_pixels, 0), np.minimum(yc + num_pixels + 1, self.ly))
            xUse = (np.maximum(xc - num_pixels, 0), np.minimum(xc + num_pixels + 1, self.lx))
            yMissing = (
                -np.minimum(yc - num_pixels, 0),
                -np.minimum(self.ly - (yc + num_pixels + 1), 0),
            )
            xMissing = (
                -np.minimum(xc - num_pixels, 0),
                -np.minimum(self.lx - (xc + num_pixels + 1), 0),
            )
            ref_stack[-1][
                idx,
                yMissing[0] : 2 * num_pixels + 1 - yMissing[1],
                xMissing[0] : 2 * num_pixels + 1 - xMissing[1],
            ] = c_reference[yUse[0] : yUse[1], xUse[0] : xUse[1]]
    return np.concatenate(ref_stack, axis=0).astype(np.float32)

compute_corr(plane_idx=None, width=20, lowcut=12, highcut=250, order=3, fs=512)

Compute Pearson correlation between filtered reference and ROI masks.

Computes the Pearson correlation coefficient between each ROI mask and a Butterworth-filtered reference image within a cropped region around each ROI. This is used as a feature for identifying red cells.

Parameters:

Name Type Description Default
plane_idx int or array-like of int

Plane indices to process. If None, processes all planes. Default is None.

None
width float

Width in micrometers of the cropped region around each ROI centroid. Default is 20.

20
lowcut float

Low cutoff frequency for Butterworth bandpass filter in Hz. Default is 12.

12
highcut float

High cutoff frequency for Butterworth bandpass filter in Hz. Default is 250.

250
order int

Order of the Butterworth filter. Default is 3.

3
fs float

Sampling frequency for the filter in Hz. Default is 512.

512

Returns:

Type Description
ndarray

Pearson correlation coefficients for each ROI, shape (num_rois,).

Source code in vrAnalysis/registration/redcell.py
def compute_corr(self, plane_idx=None, width=20, lowcut=12, highcut=250, order=3, fs=512):
    """
    Compute Pearson correlation between filtered reference and ROI masks.

    Computes the Pearson correlation coefficient between each ROI mask and
    a Butterworth-filtered reference image within a cropped region around
    each ROI. This is used as a feature for identifying red cells.

    Parameters
    ----------
    plane_idx : int or array-like of int, optional
        Plane indices to process. If None, processes all planes. Default is None.
    width : float, optional
        Width in micrometers of the cropped region around each ROI centroid.
        Default is 20.
    lowcut : float, optional
        Low cutoff frequency for Butterworth bandpass filter in Hz.
        Default is 12.
    highcut : float, optional
        High cutoff frequency for Butterworth bandpass filter in Hz.
        Default is 250.
    order : int, optional
        Order of the Butterworth filter. Default is 3.
    fs : float, optional
        Sampling frequency for the filter in Hz. Default is 512.

    Returns
    -------
    np.ndarray
        Pearson correlation coefficients for each ROI, shape (num_rois,).
    """
    if plane_idx is None:
        plane_idx = np.arange(self.num_planes)
    if isinstance(plane_idx, (int, np.integer)):
        plane_idx = (plane_idx,)  # make plane_idx iterable
    if not (self.data_loaded):
        self.load_reference_and_masks()

    corr_coef = []
    for plane in plane_idx:
        num_rois = self.b2session.get_value("roiPerPlane")[plane]
        c_ref_stack = np.reshape(
            self.centered_reference_stack(plane_idx=plane, width=width, fill=np.nan, filtPrms=(lowcut, highcut, order, fs)),
            (num_rois, -1),
        )
        c_mask_stack = np.reshape(self.centered_mask_stack(plane_idx=plane, width=width, fill=0), (num_rois, -1))
        c_mask_stack[np.isnan(c_ref_stack)] = np.nan

        # Measure mean and standard deviation (and number of non-nan datapoints)
        u_ref = np.nanmean(c_ref_stack, axis=1, keepdims=True)
        u_mask = np.nanmean(c_mask_stack, axis=1, keepdims=True)
        s_ref = np.nanstd(c_ref_stack, axis=1)
        s_mask = np.nanstd(c_mask_stack, axis=1)
        N = np.sum(~np.isnan(c_ref_stack), axis=1)

        # compute correlation coefficient and add to storage variable
        corr_coef.append(np.nansum((c_ref_stack - u_ref) * (c_mask_stack - u_mask), axis=1) / N / s_ref / s_mask)

    return np.concatenate(corr_coef)

compute_dot(plane_idx=None, lowcut=12, highcut=250, order=3, fs=512)

Compute normalized dot product between filtered reference and ROI masks.

Computes the dot product between each ROI mask and a Butterworth-filtered reference image. This is used as a feature for identifying red cells.

Parameters:

Name Type Description Default
plane_idx int or array-like of int

Plane indices to process. If None, processes all planes. Default is None.

None
lowcut float

Low cutoff frequency for Butterworth bandpass filter in Hz. Default is 12.

12
highcut float

High cutoff frequency for Butterworth bandpass filter in Hz. Default is 250.

250
order int

Order of the Butterworth filter. Default is 3.

3
fs float

Sampling frequency for the filter in Hz. Default is 512.

512

Returns:

Type Description
ndarray

Normalized dot product values for each ROI, shape (num_rois,).

Source code in vrAnalysis/registration/redcell.py
def compute_dot(self, plane_idx=None, lowcut=12, highcut=250, order=3, fs=512):
    """
    Compute normalized dot product between filtered reference and ROI masks.

    Computes the dot product between each ROI mask and a Butterworth-filtered
    reference image. This is used as a feature for identifying red cells.

    Parameters
    ----------
    plane_idx : int or array-like of int, optional
        Plane indices to process. If None, processes all planes. Default is None.
    lowcut : float, optional
        Low cutoff frequency for Butterworth bandpass filter in Hz.
        Default is 12.
    highcut : float, optional
        High cutoff frequency for Butterworth bandpass filter in Hz.
        Default is 250.
    order : int, optional
        Order of the Butterworth filter. Default is 3.
    fs : float, optional
        Sampling frequency for the filter in Hz. Default is 512.

    Returns
    -------
    np.ndarray
        Normalized dot product values for each ROI, shape (num_rois,).
    """
    if plane_idx is None:
        plane_idx = np.arange(self.num_planes)
    if isinstance(plane_idx, (int, np.integer)):
        plane_idx = (plane_idx,)  # make plane_idx iterable
    if not (self.data_loaded):
        self.load_reference_and_masks()

    dot_prod = []
    for plane in plane_idx:
        t = time.time()
        c_roi_idx = np.where(self.roi_plane_idx == plane)[0]  # index of ROIs in this plane
        bwReference = helpers.butterworthbpf(self.reference[plane], lowcut, highcut, order=order, fs=fs)  # filtered reference image
        bwReference /= np.linalg.norm(bwReference)  # adjust to norm for straightforward cosine angle
        # compute normalized dot product for each ROI
        dot_prod.append(
            np.array([bwReference[self.ypix[roi], self.xpix[roi]] @ self.lam[roi] / np.linalg.norm(self.lam[roi]) for roi in c_roi_idx])
        )

    return np.concatenate(dot_prod)

compute_volume(plane_idx=None)

Compute full-volume ROI masks for specified planes.

Creates a 3D array where each ROI mask is placed at its original position in the full image plane. This is useful for visualization or volume-based operations.

Parameters:

Name Type Description Default
plane_idx int or array-like of int

Plane indices to process. If None, processes all planes. Default is None.

None

Returns:

Type Description
ndarray

Volume array of ROI masks, shape (num_rois, ly, lx), where ly and lx are the dimensions of the reference images.

Raises:

Type Description
AssertionError

If any plane index is out of range.

Source code in vrAnalysis/registration/redcell.py
def compute_volume(self, plane_idx=None):
    """
    Compute full-volume ROI masks for specified planes.

    Creates a 3D array where each ROI mask is placed at its original
    position in the full image plane. This is useful for visualization
    or volume-based operations.

    Parameters
    ----------
    plane_idx : int or array-like of int, optional
        Plane indices to process. If None, processes all planes. Default is None.

    Returns
    -------
    np.ndarray
        Volume array of ROI masks, shape (num_rois, ly, lx), where ly and lx
        are the dimensions of the reference images.

    Raises
    ------
    AssertionError
        If any plane index is out of range.
    """
    if plane_idx is None:
        plane_idx = np.arange(self.num_planes)
    if isinstance(plane_idx, (int, np.integer)):
        plane_idx = (plane_idx,)  # make plane_idx iterable
    msg = f"in session: {self.b2session.session_print()}, there are only {self.num_planes} planes!"
    assert all([0 <= plane < self.num_planes for plane in plane_idx]), msg
    if not (self.data_loaded):
        self.load_reference_and_masks()
    roi_mask_volume = []
    for plane in plane_idx:
        roi_mask_volume.append(np.zeros((self.b2session.get_value("roiPerPlane")[plane], self.ly, self.lx)))
        idx_roi_in_plane = np.where(self.roi_plane_idx == plane)[0]
        for roi in range(self.b2session.get_value("roiPerPlane")[plane]):
            c_roi_idx = idx_roi_in_plane[roi]
            roi_mask_volume[-1][roi, self.ypix[c_roi_idx], self.xpix[c_roi_idx]] = self.lam[c_roi_idx]
    return np.concatenate(roi_mask_volume, axis=0)

create_centered_axis(numElements, scale=1)

Create a centered axis array.

Parameters:

Name Type Description Default
numElements int

Number of elements in the axis.

required
scale float

Scaling factor for the axis. Default is 1.

1

Returns:

Type Description
ndarray

Centered axis array, shape (numElements,), with values ranging from -scale(numElements-1)/2 to scale(numElements-1)/2.

Source code in vrAnalysis/registration/redcell.py
def create_centered_axis(self, numElements, scale=1):
    """
    Create a centered axis array.

    Parameters
    ----------
    numElements : int
        Number of elements in the axis.
    scale : float, optional
        Scaling factor for the axis. Default is 1.

    Returns
    -------
    np.ndarray
        Centered axis array, shape (numElements,), with values ranging from
        -scale*(numElements-1)/2 to scale*(numElements-1)/2.
    """
    return scale * (np.arange(numElements) - (numElements - 1) / 2)

cropped_phase_correlation(plane_idx=None, width=40, eps=1000000.0, winFunc=lambda x: np.hamming(x.shape[-1]))

Compute phase correlation of cropped masks with cropped reference images.

Returns the phase correlation of each ROI mask (cropped around the ROI centroid) with the corresponding cropped reference image. This is used as a feature for identifying red cells.

Parameters:

Name Type Description Default
plane_idx int or array-like of int

Plane indices to process. If None, processes all planes. Default is None.

None
width float

Width in micrometers of the cropped region around each ROI centroid. Default is 40.

40
eps float

Small value added to avoid division by zero in phase correlation. Default is 1e6.

1000000.0
winFunc callable or str

Window function to apply before computing phase correlation. If "hamming", uses Hamming window. Otherwise should be a callable that takes an array and returns a windowed array. Default is Hamming window.

lambda x: hamming(shape[-1])

Returns:

Name Type Description
refStack ndarray

Stack of cropped reference images, shape (num_rois, height, width).

maskStack ndarray

Stack of cropped ROI masks, shape (num_rois, height, width).

pxcStack ndarray

Stack of phase correlation maps, shape (num_rois, height, width).

phase_corr_values ndarray

Phase correlation values at the center of each correlation map, shape (num_rois,). This is the feature value used for red cell identification.

Notes

The default parameters (width=40um, eps=1e6, and a Hamming window function) were tested on a few sessions and are subjective. Manual curation and parameter adjustment may be necessary for optimal results.

Source code in vrAnalysis/registration/redcell.py
def cropped_phase_correlation(self, plane_idx=None, width=40, eps=1e6, winFunc=lambda x: np.hamming(x.shape[-1])):
    """
    Compute phase correlation of cropped masks with cropped reference images.

    Returns the phase correlation of each ROI mask (cropped around the ROI
    centroid) with the corresponding cropped reference image. This is used
    as a feature for identifying red cells.

    Parameters
    ----------
    plane_idx : int or array-like of int, optional
        Plane indices to process. If None, processes all planes. Default is None.
    width : float, optional
        Width in micrometers of the cropped region around each ROI centroid.
        Default is 40.
    eps : float, optional
        Small value added to avoid division by zero in phase correlation.
        Default is 1e6.
    winFunc : callable or str, optional
        Window function to apply before computing phase correlation. If "hamming",
        uses Hamming window. Otherwise should be a callable that takes an array
        and returns a windowed array. Default is Hamming window.

    Returns
    -------
    refStack : np.ndarray
        Stack of cropped reference images, shape (num_rois, height, width).
    maskStack : np.ndarray
        Stack of cropped ROI masks, shape (num_rois, height, width).
    pxcStack : np.ndarray
        Stack of phase correlation maps, shape (num_rois, height, width).
    phase_corr_values : np.ndarray
        Phase correlation values at the center of each correlation map,
        shape (num_rois,). This is the feature value used for red cell identification.

    Notes
    -----
    The default parameters (width=40um, eps=1e6, and a Hamming window function)
    were tested on a few sessions and are subjective. Manual curation and
    parameter adjustment may be necessary for optimal results.
    """
    if not (self.data_loaded):
        self.load_reference_and_masks()
    if winFunc == "hamming":
        winFunc = lambda x: np.hamming(x.shape[-1])
    refStack = self.centered_reference_stack(plane_idx=plane_idx, width=width)  # get stack of reference image centered on each ROI
    maskStack = self.centered_mask_stack(plane_idx=plane_idx, width=width)  # get stack of mask value centered on each ROI
    window = winFunc(refStack)  # create a window function
    pxcStack = np.stack(
        [helpers.phaseCorrelation(ref, mask, eps=eps, window=window) for (ref, mask) in zip(refStack, maskStack)]
    )  # measure phase correlation
    pxcCenterPixel = int((pxcStack.shape[2] - 1) / 2)
    return refStack, maskStack, pxcStack, pxcStack[:, pxcCenterPixel, pxcCenterPixel]

get_roi_centroid(idx, mode='weightedmean')

Get the centroid of an ROI.

Parameters:

Name Type Description Default
idx int

Index of the ROI.

required
mode str

Method for computing centroid. "weightedmean" uses pixel weights (lam), "median" uses median pixel coordinates. Default is "weightedmean".

'weightedmean'

Returns:

Name Type Description
yc float

Y-coordinate of the centroid in pixels.

xc float

X-coordinate of the centroid in pixels.

Source code in vrAnalysis/registration/redcell.py
def get_roi_centroid(self, idx, mode="weightedmean"):
    """
    Get the centroid of an ROI.

    Parameters
    ----------
    idx : int
        Index of the ROI.
    mode : str, optional
        Method for computing centroid. "weightedmean" uses pixel weights (lam),
        "median" uses median pixel coordinates. Default is "weightedmean".

    Returns
    -------
    yc : float
        Y-coordinate of the centroid in pixels.
    xc : float
        X-coordinate of the centroid in pixels.
    """
    if not (self.data_loaded):
        self.load_reference_and_masks()

    if mode == "weightedmean":
        yc = np.sum(self.lam[idx] * self.ypix[idx]) / np.sum(self.lam[idx])
        xc = np.sum(self.lam[idx] * self.xpix[idx]) / np.sum(self.lam[idx])
    elif mode == "median":
        yc = int(np.median(self.ypix[idx]))
        xc = int(np.median(self.xpix[idx]))

    return yc, xc

get_roi_in_plane_idx(idx)

Get the index of an ROI within its own plane.

Parameters:

Name Type Description Default
idx int

Global ROI index.

required

Returns:

Type Description
int

Index of the ROI within its plane (0-indexed within that plane).

Source code in vrAnalysis/registration/redcell.py
def get_roi_in_plane_idx(self, idx):
    """
    Get the index of an ROI within its own plane.

    Parameters
    ----------
    idx : int
        Global ROI index.

    Returns
    -------
    int
        Index of the ROI within its plane (0-indexed within that plane).
    """
    if not (self.data_loaded):
        self.load_reference_and_masks()
    # return index of ROI within it's own plane
    plane_idx = self.roi_plane_idx[idx]
    return idx - np.sum(self.roi_plane_idx < plane_idx)

get_roi_range(idx)

Get the range (peak-to-peak) of x and y pixels for an ROI.

Parameters:

Name Type Description Default
idx int

Index of the ROI.

required

Returns:

Name Type Description
yr int

Range of y-pixels (peak-to-peak).

xr int

Range of x-pixels (peak-to-peak).

Source code in vrAnalysis/registration/redcell.py
def get_roi_range(self, idx):
    """
    Get the range (peak-to-peak) of x and y pixels for an ROI.

    Parameters
    ----------
    idx : int
        Index of the ROI.

    Returns
    -------
    yr : int
        Range of y-pixels (peak-to-peak).
    xr : int
        Range of x-pixels (peak-to-peak).
    """
    if not (self.data_loaded):
        self.load_reference_and_masks()
    # get range of x and y pixels for a particular ROI
    yr = np.ptp(self.ypix[idx])
    xr = np.ptp(self.xpix[idx])
    return yr, xr

getxref(xCenter)

Get x-axis reference coordinates relative to a center point.

Parameters:

Name Type Description Default
xCenter float

X-coordinate of the center point in pixels.

required

Returns:

Type Description
ndarray

X-axis coordinates in micrometers relative to xCenter, shape (lx,).

Source code in vrAnalysis/registration/redcell.py
def getxref(self, xCenter):
    """
    Get x-axis reference coordinates relative to a center point.

    Parameters
    ----------
    xCenter : float
        X-coordinate of the center point in pixels.

    Returns
    -------
    np.ndarray
        X-axis coordinates in micrometers relative to xCenter, shape (lx,).
    """
    if not (self.data_loaded):
        self.load_reference_and_masks()
    return self.um_per_pixel * (self.x_base_ref - xCenter)

getyref(yCenter)

Get y-axis reference coordinates relative to a center point.

Parameters:

Name Type Description Default
yCenter float

Y-coordinate of the center point in pixels.

required

Returns:

Type Description
ndarray

Y-axis coordinates in micrometers relative to yCenter, shape (ly,).

Source code in vrAnalysis/registration/redcell.py
def getyref(self, yCenter):
    """
    Get y-axis reference coordinates relative to a center point.

    Parameters
    ----------
    yCenter : float
        Y-coordinate of the center point in pixels.

    Returns
    -------
    np.ndarray
        Y-axis coordinates in micrometers relative to yCenter, shape (ly,).
    """
    if not (self.data_loaded):
        self.load_reference_and_masks()
    return self.um_per_pixel * (self.y_base_ref - yCenter)

load_reference_and_masks()

Load reference images and ROI masks from suite2p outputs.

Loads the mean image for channel 2 (red channel) for each plane, along with ROI mask data (lam, ypix, xpix) and ROI plane indices. Also loads suite2p red cell values and creates supporting variables for spatial measurements.

Raises:

Type Description
AssertionError

If reference images do not all have the same shape.

Source code in vrAnalysis/registration/redcell.py
def load_reference_and_masks(self):
    """
    Load reference images and ROI masks from suite2p outputs.

    Loads the mean image for channel 2 (red channel) for each plane, along
    with ROI mask data (lam, ypix, xpix) and ROI plane indices. Also loads
    suite2p red cell values and creates supporting variables for spatial
    measurements.

    Raises
    ------
    AssertionError
        If reference images do not all have the same shape.
    """
    # load reference images
    ops = self.b2session.load_s2p("ops")
    self.reference = [op["meanImg_chan2"] for op in ops]
    self.lx, self.ly = self.reference[0].shape
    for ref in self.reference:
        msg = "reference images do not all have the same shape"
        assert (self.lx, self.ly) == ref.shape, msg

    # load masks (lam=weight of each pixel, xpix & ypix=index of each pixel in ROI mask)
    stat = self.b2session.load_s2p("stat")
    self.lam = [s["lam"] for s in stat]
    self.ypix = [s["ypix"] for s in stat]
    self.xpix = [s["xpix"] for s in stat]
    self.roi_plane_idx = self.b2session.loadone("mpciROIs.stackPosition")[:, 2]

    # load S2P red cell value
    self.red_s2p = self.b2session.loadone("mpciROIs.redS2P")  # (preloaded, will never change in this function)

    # create supporting variables for mapping locations and axes
    self.y_base_ref = np.arange(self.ly)
    self.x_base_ref = np.arange(self.lx)
    self.y_dist_ref = self.create_centered_axis(self.ly, self.um_per_pixel)
    self.x_dist_ref = self.create_centered_axis(self.lx, self.um_per_pixel)

    # update data_loaded field
    self.data_loaded = True

one_name_feature_cutoffs(name)

Generate oneData name for feature cutoff parameters.

Parameters:

Name Type Description Default
name str

Feature name (e.g., "S2P", "dotProduct", "pearson", "phaseCorrelation").

required

Returns:

Type Description
str

OneData name for the feature cutoff parameter, formatted as "parametersRed{Name}.minMaxCutoff" where {Name} is the capitalized feature name.

Source code in vrAnalysis/registration/redcell.py
def one_name_feature_cutoffs(self, name):
    """
    Generate oneData name for feature cutoff parameters.

    Parameters
    ----------
    name : str
        Feature name (e.g., "S2P", "dotProduct", "pearson", "phaseCorrelation").

    Returns
    -------
    str
        OneData name for the feature cutoff parameter, formatted as
        "parametersRed{Name}.minMaxCutoff" where {Name} is the capitalized
        feature name.
    """
    return "parameters" + "Red" + name[0].upper() + name[1:] + ".minMaxCutoff"

update_from_session(red_cell, force_update=False)

Update red cell cutoffs from another session.

Copies red cell cutoff parameters from another RedCellProcessing object and applies them to this session.

Parameters:

Name Type Description Default
red_cell RedCellProcessing

Another RedCellProcessing object to copy cutoffs from.

required
force_update bool

If False, only allows copying from sessions with the same mouse name. If True, allows copying from any session. Default is False.

False

Raises:

Type Description
AssertionError

If red_cell is not a RedCellProcessing object, or if force_update is False and the mouse names don't match.

Source code in vrAnalysis/registration/redcell.py
def update_from_session(self, red_cell, force_update=False):
    """
    Update red cell cutoffs from another session.

    Copies red cell cutoff parameters from another RedCellProcessing object
    and applies them to this session.

    Parameters
    ----------
    red_cell : RedCellProcessing
        Another RedCellProcessing object to copy cutoffs from.
    force_update : bool, optional
        If False, only allows copying from sessions with the same mouse name.
        If True, allows copying from any session. Default is False.

    Raises
    ------
    AssertionError
        If red_cell is not a RedCellProcessing object, or if force_update is
        False and the mouse names don't match.
    """
    assert isinstance(red_cell, RedCellProcessing), "red_cell is not a RedCellProcessing object"
    if not (force_update):
        assert (
            red_cell.b2session.mouse_name == self.b2session.mouse_name
        ), "session to copy from is from a different mouse, this isn't allowed without the force_update=True input"
    cutoffs = [red_cell.b2session.loadone(red_cell.one_name_feature_cutoffs(name)) for name in self.feature_names]
    self.update_red_idx(s2p_cutoff=cutoffs[0], dot_product_cutoff=cutoffs[1], corr_coef_cutoff=cutoffs[2], phase_corr_cutoff=cutoffs[3])

update_red_idx(s2p_cutoff=None, dot_product_cutoff=None, corr_coef_cutoff=None, phase_corr_cutoff=None)

Update red cell index based on feature cutoff values.

Updates the red cell index by applying minimum and maximum cutoffs to each feature (S2P, dot product, Pearson correlation, phase correlation). Only features with non-NaN cutoff values are applied. The red cell index is updated to include only ROIs that meet all specified criteria.

Parameters:

Name Type Description Default
s2p_cutoff array-like of float, length 2

[min, max] cutoff values for suite2p red cell feature. NaN values indicate the cutoff should not be applied. Default is None.

None
dot_product_cutoff array-like of float, length 2

[min, max] cutoff values for dot product feature. Default is None.

None
corr_coef_cutoff array-like of float, length 2

[min, max] cutoff values for Pearson correlation feature. Default is None.

None
phase_corr_cutoff array-like of float, length 2

[min, max] cutoff values for phase correlation feature. Default is None.

None

Raises:

Type Description
ValueError

If any cutoff is not a numpy array or list, or if any cutoff does not have exactly 2 elements.

Notes

Cutoff values are saved to oneData for future reference. The red cell index is updated in place and saved to oneData.

Source code in vrAnalysis/registration/redcell.py
def update_red_idx(self, s2p_cutoff=None, dot_product_cutoff=None, corr_coef_cutoff=None, phase_corr_cutoff=None):
    """
    Update red cell index based on feature cutoff values.

    Updates the red cell index by applying minimum and maximum cutoffs to
    each feature (S2P, dot product, Pearson correlation, phase correlation).
    Only features with non-NaN cutoff values are applied. The red cell index
    is updated to include only ROIs that meet all specified criteria.

    Parameters
    ----------
    s2p_cutoff : array-like of float, length 2, optional
        [min, max] cutoff values for suite2p red cell feature. NaN values
        indicate the cutoff should not be applied. Default is None.
    dot_product_cutoff : array-like of float, length 2, optional
        [min, max] cutoff values for dot product feature. Default is None.
    corr_coef_cutoff : array-like of float, length 2, optional
        [min, max] cutoff values for Pearson correlation feature.
        Default is None.
    phase_corr_cutoff : array-like of float, length 2, optional
        [min, max] cutoff values for phase correlation feature.
        Default is None.

    Raises
    ------
    ValueError
        If any cutoff is not a numpy array or list, or if any cutoff does
        not have exactly 2 elements.

    Notes
    -----
    Cutoff values are saved to oneData for future reference. The red cell
    index is updated in place and saved to oneData.
    """
    # create initial all true red cell idx
    red_cell_idx = np.full(self.b2session.loadone("mpciROIs.redCellIdx").shape, True)

    # load feature values for each ROI
    red_s2p = self.b2session.loadone("mpciROIs.redS2P")
    dot_product = self.b2session.loadone("mpciROIs.redDotProduct")
    corr_coef = self.b2session.loadone("mpciROIs.redPearson")
    phase_corr = self.b2session.loadone("mpciROIs.redPhaseCorrelation")

    # create lists for zipping through each feature/cutoff combination
    features = [red_s2p, dot_product, corr_coef, phase_corr]
    cutoffs = [s2p_cutoff, dot_product_cutoff, corr_coef_cutoff, phase_corr_cutoff]
    usecutoff = [[False, False] for _ in range(len(cutoffs))]

    # check validity of each cutoff and identify whether it should be used
    for name, use, cutoff in zip(self.feature_names, usecutoff, cutoffs):
        if not isinstance(cutoff, np.ndarray) and not isinstance(cutoff, list):
            raise ValueError(f"Expecting a numpy array or a list for {name} cutoff, got {type(cutoff)}")
        assert len(cutoff) == 2, f"{name} cutoff does not have 2 elements"
        if not (np.isnan(cutoff[0])):
            use[0] = True
        if not (np.isnan(cutoff[1])):
            use[1] = True

    # add feature cutoffs to redCellIdx (sets any to False that don't meet the cutoff)
    for feature, use, cutoff in zip(features, usecutoff, cutoffs):
        if use[0]:
            red_cell_idx &= feature >= cutoff[0]
        if use[1]:
            red_cell_idx &= feature <= cutoff[1]

    # save new red cell index to one data
    self.b2session.saveone(red_cell_idx, "mpciROIs.redCellIdx")

    # save feature cutoffs to one data
    for idx, name in enumerate(self.feature_names):
        self.b2session.saveone(cutoffs[idx], self.one_name_feature_cutoffs(name))
    print(f"Red Cell curation choices are saved for session {self.b2session.session_print()}")

behavior

Behavior processing functions for B2Registration.

This module contains functions for processing behavioral data from different versions of the vrControl software. Each function processes behavior data to achieve the same results structure regardless of the data collection method.

BEHAVIOR_PROCESSING = {1: standard_behavior, 2: cr_hippocannula_behavior} module-attribute

Dictionary of behavior processing functions.

These reflect the different versions of the vrControl software that was used to collect the behavior data. Because the behavioral data was collected in different ways, we need to process it differently to achieve the same results structure.

Keys:

  • 1: Standard behavior processing function.
  • 2: CR hippocannula behavior processing function.

cr_hippocannula_behavior(b2registration)

Process behavior data from CR hippocannula version of vrControl.

Extracts behavioral data from TRIAL and EXP structures, processes timestamps, positions, rewards, and licks, and saves them to oneData format. Aligns behavioral timestamps to the timeline using photodiode flips.

Parameters:

Name Type Description Default
b2registration B2Registration

The B2Registration object containing the session data to process.

required

Returns:

Type Description
B2Registration

The B2Registration object with behavior data processed and saved.

Notes

This function processes behavior data from the CR hippocannula version of vrControl. The data structure differs from the standard version, requiring different field names and processing steps.

Source code in vrAnalysis/registration/behavior.py
def cr_hippocannula_behavior(b2registration: "B2Registration") -> "B2Registration":
    """
    Process behavior data from CR hippocannula version of vrControl.

    Extracts behavioral data from TRIAL and EXP structures, processes
    timestamps, positions, rewards, and licks, and saves them to oneData format.
    Aligns behavioral timestamps to the timeline using photodiode flips.

    Parameters
    ----------
    b2registration : B2Registration
        The B2Registration object containing the session data to process.

    Returns
    -------
    B2Registration
        The B2Registration object with behavior data processed and saved.

    Notes
    -----
    This function processes behavior data from the CR hippocannula version of
    vrControl. The data structure differs from the standard version, requiring
    different field names and processing steps.
    """
    trialInfo = b2registration.vr_file["TRIAL"]
    expInfo = b2registration.vr_file["EXP"]

    numTrials = trialInfo.info.no
    nonNanSamples = np.sum(~np.isnan(trialInfo.time[:, 0]))
    assert numTrials == nonNanSamples, f"# trials {trialInfo.info.no} isn't equal to non-nan first time samples {nonNanSamples}"
    b2registration.set_value("numTrials", numTrials)

    # trialInfo contains sparse matrices of size (maxTrials, maxSamples), where numTrials<maxTrials and numSamples<maxSamples
    nzindex = b2registration.create_index(b2registration.convert_dense(trialInfo.time))
    timeStamps = b2registration.get_vr_data(b2registration.convert_dense(trialInfo.time), nzindex)
    roomPosition = b2registration.get_vr_data(b2registration.convert_dense(trialInfo.roomPosition), nzindex)

    # oneData with behave prefix is a (numBehavioralSamples, ) shaped array conveying information about the state of VR
    numTimeStamps = np.array([len(t) for t in timeStamps])  # list of number of behavioral timestamps in each trial
    behaveTimeStamps = np.concatenate(timeStamps)  # time stamp associated with each behavioral sample
    behavePosition = np.concatenate(roomPosition)  # virtual position associated with each behavioral sample
    b2registration.set_value("numBehaveTimestamps", len(behaveTimeStamps))

    # Check shapes and sizes
    assert behaveTimeStamps.ndim == 1, "behaveTimeStamps is not a 1-d array!"
    assert behaveTimeStamps.shape == behavePosition.shape, "behave oneData arrays do not have the same shape!"

    # oneData with trial prefix is a (numTrials,) shaped array conveying information about the state on each trial
    trialStartFrame = np.array([0, *np.cumsum(numTimeStamps)[:-1]]).astype(np.int64)
    trialEnvironmentIndex = (
        b2registration.convert_dense(trialInfo.vrEnvIdx).astype(np.int16)
        if "vrEnvIdx" in trialInfo._fieldnames
        else -1 * np.ones(b2registration.get_value("numTrials"), dtype=np.int16)
    )
    trialRoomLength = np.ones(b2registration.get_value("numTrials")) * expInfo.roomLength
    trialMovementGain = np.ones(b2registration.get_value("numTrials"))  # mvmt gain always one
    trialRewardPosition = b2registration.convert_dense(trialInfo.trialRewPos)
    trialRewardTolerance = b2registration.convert_dense(expInfo.rewPosTolerance * np.ones(b2registration.get_value("numTrials")))
    trialRewardAvailability = b2registration.convert_dense(trialInfo.trialRewAvailable).astype(np.bool_)
    rewardDelivery = b2registration.convert_dense(trialInfo.trialRewDelivery)
    rewardDelivery[np.isnan(rewardDelivery)] = 0  # about to be (-1), indicating no reward delivered
    rewardDelivery = rewardDelivery.astype(np.int64) - 1  # get reward delivery frame (frame within trial) first (will be -1 if no reward delivered)

    # adjust frame count to behave arrays
    trialRewardDelivery = np.array(
        [
            (rewardDelivery + np.sum(numTimeStamps[:trialIdx]) if rewardDelivery >= 0 else rewardDelivery)
            for (trialIdx, rewardDelivery) in enumerate(rewardDelivery)
        ]
    )
    trialActiveLicking = b2registration.convert_dense(trialInfo.trialActiveLicking).astype(np.bool_)
    trialActiveStopping = b2registration.convert_dense(trialInfo.trialActiveStopping).astype(np.bool_)

    # Check shapes and sizes
    assert trialEnvironmentIndex.ndim == 1 and len(trialEnvironmentIndex) == b2registration.get_value(
        "numTrials"
    ), "trialEnvironmentIndex is not a (numTrials,) shaped array!"
    assert (
        trialStartFrame.shape
        == trialEnvironmentIndex.shape
        == trialRoomLength.shape
        == trialMovementGain.shape
        == trialRewardPosition.shape
        == trialRewardTolerance.shape
        == trialRewardAvailability.shape
        == trialRewardDelivery.shape
        == trialActiveLicking.shape
        == trialActiveStopping.shape
    ), "trial oneData arrays do not have the same shape!"

    # oneData with lick prefix is a (numLicks,) shaped array containing information about each lick during VR behavior
    licks = [vrd.astype(np.int16) for vrd in b2registration.get_vr_data(b2registration.convert_dense(trialInfo.lick), nzindex)]
    lickFrames = [np.nonzero(licks)[0] for licks in licks]
    lickCounts = np.concatenate([licks[lickFrames] for (licks, lickFrames) in zip(licks, lickFrames)])
    lickTrials = np.concatenate([tidx * np.ones_like(lickFrames) for (tidx, lickFrames) in enumerate(lickFrames)])
    lickFrames = np.concatenate(lickFrames)
    if np.sum(lickCounts) > 0:
        lickFramesRepeat = np.concatenate([lf * np.ones(lc, dtype=np.uint8) for (lf, lc) in zip(lickFrames, lickCounts)])
        lickTrialsRepeat = np.concatenate([lt * np.ones(lc, dtype=np.uint8) for (lt, lc) in zip(lickTrials, lickCounts)])
        lickCountsRepeat = np.concatenate([lc * np.ones(lc, dtype=np.uint8) for (lc, lc) in zip(lickCounts, lickCounts)])
        lickBehaveSample = lickFramesRepeat + np.array([np.sum(numTimeStamps[:trialIdx]) for trialIdx in lickTrialsRepeat])

        assert len(lickBehaveSample) == np.sum(
            lickCounts
        ), "the number of licks counted by vrBehavior is not equal to the length of the lickBehaveSample vector!"
        assert lickBehaveSample.ndim == 1, "lickBehaveIndex is not a 1-d array!"
        assert (
            0 <= np.max(lickBehaveSample) <= len(behaveTimeStamps)
        ), "lickBehaveSample contains index outside range of possible indices for behaveTimeStamps"
    else:
        # No licks found -- create empty array
        lickBehaveSample = np.array([], dtype=np.uint8)

    # Align behavioral timestamp data to timeline -- shift each trials timestamps so that they start at the time of the first photodiode flip (which is reliably detected)
    trialStartOffsets = behaveTimeStamps[trialStartFrame] - b2registration.loadone("trials.startTimes")  # get offset
    behaveTimeStamps = np.concatenate(
        [bts - trialStartOffsets[tidx] for (tidx, bts) in enumerate(b2registration.group_behave_by_trial(behaveTimeStamps, trialStartFrame))]
    )

    # Save behave onedata
    b2registration.saveone(behaveTimeStamps, "positionTracking.times")
    b2registration.saveone(behavePosition, "positionTracking.position")

    # Save trial onedata
    b2registration.saveone(trialStartFrame, "trials.positionTracking")
    b2registration.saveone(trialEnvironmentIndex, "trials.environmentIndex")
    b2registration.saveone(trialRoomLength, "trials.roomlength")
    b2registration.saveone(trialMovementGain, "trials.movementGain")
    b2registration.saveone(trialRewardPosition, "trials.rewardPosition")
    b2registration.saveone(trialRewardTolerance, "trials.rewardZoneHalfwidth")
    b2registration.saveone(trialRewardAvailability, "trials.rewardAvailability")
    b2registration.saveone(trialRewardDelivery, "trials.rewardPositionTracking")
    b2registration.saveone(trialActiveLicking, "trials.activeLicking")
    b2registration.saveone(trialActiveStopping, "trials.activeStopping")

    # Save lick onedata
    b2registration.saveone(lickBehaveSample, "licksTracking.positionTracking")

    return b2registration

register_behavior(b2registration, behavior_type)

Register behavior for a given behavior type.

This is a dispatcher function that calls the appropriate behavior processing function based on the behavior type.

Parameters:

Name Type Description Default
b2registration B2Registration

The B2Registration object containing the session data to process.

required
behavior_type int

The behavior type to register. Must be a key in BEHAVIOR_PROCESSING.

required

Returns:

Type Description
B2Registration

The B2Registration object with behavior registered.

Raises:

Type Description
ValueError

If behavior_type is not supported.

See Also

BEHAVIOR_PROCESSING : Dictionary mapping behavior types to processing functions.

Source code in vrAnalysis/registration/behavior.py
def register_behavior(b2registration: "B2Registration", behavior_type: int) -> "B2Registration":
    """
    Register behavior for a given behavior type.

    This is a dispatcher function that calls the appropriate behavior processing
    function based on the behavior type.

    Parameters
    ----------
    b2registration : B2Registration
        The B2Registration object containing the session data to process.
    behavior_type : int
        The behavior type to register. Must be a key in BEHAVIOR_PROCESSING.

    Returns
    -------
    B2Registration
        The B2Registration object with behavior registered.

    Raises
    ------
    ValueError
        If behavior_type is not supported.

    See Also
    --------
    BEHAVIOR_PROCESSING : Dictionary mapping behavior types to processing functions.
    """
    if behavior_type not in BEHAVIOR_PROCESSING.keys():
        raise ValueError(f"Behavior type {behavior_type} not supported. Supported types are: {list(BEHAVIOR_PROCESSING.keys())}.")
    return BEHAVIOR_PROCESSING[behavior_type](b2registration)

standard_behavior(b2registration)

Process standard behavior data from vrControl.

Extracts behavioral data from trialInfo and expInfo structures, processes timestamps, positions, rewards, and licks, and saves them to oneData format. Aligns behavioral timestamps to the timeline using photodiode flips.

Parameters:

Name Type Description Default
b2registration B2Registration

The B2Registration object containing the session data to process.

required

Returns:

Type Description
B2Registration

The B2Registration object with behavior data processed and saved.

Notes

This function processes behavior data from the standard vrControl format. It extracts trial-level and sample-level behavioral data and aligns timestamps to the imaging timeline.

Source code in vrAnalysis/registration/behavior.py
def standard_behavior(b2registration: "B2Registration") -> "B2Registration":
    """
    Process standard behavior data from vrControl.

    Extracts behavioral data from trialInfo and expInfo structures, processes
    timestamps, positions, rewards, and licks, and saves them to oneData format.
    Aligns behavioral timestamps to the timeline using photodiode flips.

    Parameters
    ----------
    b2registration : B2Registration
        The B2Registration object containing the session data to process.

    Returns
    -------
    B2Registration
        The B2Registration object with behavior data processed and saved.

    Notes
    -----
    This function processes behavior data from the standard vrControl format.
    It extracts trial-level and sample-level behavioral data and aligns timestamps
    to the imaging timeline.
    """
    expInfo = b2registration.vr_file["expInfo"]
    trialInfo = b2registration.vr_file["trialInfo"]
    num_values_per_trial = np.diff(trialInfo.time.tocsr().indptr)
    valid_trials = np.where(num_values_per_trial > 0)[0]
    numTrials = len(valid_trials)
    assert np.array_equal(valid_trials, np.arange(numTrials)), "valid_trials is not a range from 0 to numTrials"
    b2registration.set_value("numTrials", numTrials)

    # trialInfo contains sparse matrices of size (maxTrials, maxSamples), where numTrials<maxTrials and numSamples<maxSamples
    nzindex = b2registration.create_index(b2registration.convert_dense(trialInfo.time))
    timeStamps = b2registration.get_vr_data(b2registration.convert_dense(trialInfo.time), nzindex)
    roomPosition = b2registration.get_vr_data(b2registration.convert_dense(trialInfo.roomPosition), nzindex)

    # oneData with behave prefix is a (numBehavioralSamples, ) shaped array conveying information about the state of VR
    numTimeStamps = np.array([len(t) for t in timeStamps])  # list of number of behavioral timestamps in each trial
    behaveTimeStamps = np.concatenate(timeStamps)  # time stamp associated with each behavioral sample
    behavePosition = np.concatenate(roomPosition)  # virtual position associated with each behavioral sample
    b2registration.set_value("numBehaveTimestamps", len(behaveTimeStamps))

    # Check shapes and sizes
    assert behaveTimeStamps.ndim == 1, "behaveTimeStamps is not a 1-d array!"
    assert behaveTimeStamps.shape == behavePosition.shape, "behave oneData arrays do not have the same shape!"

    # oneData with trial prefix is a (numTrials,) shaped array conveying information about the state on each trial
    trialStartFrame = np.array([0, *np.cumsum(numTimeStamps)[:-1]]).astype(np.int64)
    trialEnvironmentIndex = (
        b2registration.convert_dense(trialInfo.vrEnvIdx).astype(np.int16)
        if "vrEnvIdx" in trialInfo._fieldnames
        else -1 * np.ones(b2registration.get_value("numTrials"), dtype=np.int16)
    )
    trialRoomLength = expInfo.roomLength[: b2registration.get_value("numTrials")]
    trialMovementGain = expInfo.mvmtGain[: b2registration.get_value("numTrials")]
    trialRewardPosition = b2registration.convert_dense(trialInfo.rewardPosition)
    trialRewardTolerance = b2registration.convert_dense(trialInfo.rewardTolerance)
    trialRewardAvailability = b2registration.convert_dense(trialInfo.rewardAvailable).astype(np.bool_)
    rewardDelivery = (
        b2registration.convert_dense(trialInfo.rewardDeliveryFrame).astype(np.int64) - 1
    )  # get reward delivery frame (frame within trial) first (will be -1 if no reward delivered)
    # adjust frame count to behave arrays
    trialRewardDelivery = np.array(
        [
            (rewardDelivery + np.sum(numTimeStamps[:trialIdx]) if rewardDelivery >= 0 else rewardDelivery)
            for (trialIdx, rewardDelivery) in enumerate(rewardDelivery)
        ]
    )
    trialActiveLicking = b2registration.convert_dense(trialInfo.activeLicking).astype(np.bool_)
    trialActiveStopping = b2registration.convert_dense(trialInfo.activeStopping).astype(np.bool_)

    # Check shapes and sizes
    assert trialEnvironmentIndex.ndim == 1 and len(trialEnvironmentIndex) == b2registration.get_value(
        "numTrials"
    ), "trialEnvironmentIndex is not a (numTrials,) shaped array!"
    assert (
        trialStartFrame.shape
        == trialEnvironmentIndex.shape
        == trialRoomLength.shape
        == trialMovementGain.shape
        == trialRewardPosition.shape
        == trialRewardTolerance.shape
        == trialRewardAvailability.shape
        == trialRewardDelivery.shape
        == trialActiveLicking.shape
        == trialActiveStopping.shape
    ), "trial oneData arrays do not have the same shape!"

    # oneData with lick prefix is a (numLicks,) shaped array containing information about each lick during VR behavior
    licks = b2registration.get_vr_data(b2registration.convert_dense(trialInfo.lick), nzindex)
    lickFrames = [np.nonzero(licks)[0] for licks in licks]
    lickCounts = np.concatenate([licks[lickFrames] for (licks, lickFrames) in zip(licks, lickFrames)])
    lickTrials = np.concatenate([tidx * np.ones_like(lickFrames) for (tidx, lickFrames) in enumerate(lickFrames)])
    lickFrames = np.concatenate(lickFrames)
    if np.sum(lickCounts) > 0:
        lickFramesRepeat = np.concatenate([lf * np.ones(lc, dtype=np.uint8) for (lf, lc) in zip(lickFrames, lickCounts)])
        lickTrialsRepeat = np.concatenate([lt * np.ones(lc, dtype=np.uint8) for (lt, lc) in zip(lickTrials, lickCounts)])
        lickCountsRepeat = np.concatenate([lc * np.ones(lc, dtype=np.uint8) for (lc, lc) in zip(lickCounts, lickCounts)])
        lickBehaveSample = lickFramesRepeat + np.array([np.sum(numTimeStamps[:trialIdx]) for trialIdx in lickTrialsRepeat])

        assert len(lickBehaveSample) == np.sum(
            lickCounts
        ), "the number of licks counted by vrBehavior is not equal to the length of the lickBehaveSample vector!"
        assert lickBehaveSample.ndim == 1, "lickBehaveIndex is not a 1-d array!"
        assert (
            0 <= np.max(lickBehaveSample) <= len(behaveTimeStamps)
        ), "lickBehaveSample contains index outside range of possible indices for behaveTimeStamps"
    else:
        # No licks found -- create empty array
        lickBehaveSample = np.array([], dtype=np.uint8)

    # Align behavioral timestamp data to timeline -- shift each trials timestamps so that they start at the time of the first photodiode flip (which is reliably detected)
    trialStartOffsets = behaveTimeStamps[trialStartFrame] - b2registration.loadone("trials.startTimes")  # get offset
    behaveTimeStamps = np.concatenate(
        [bts - trialStartOffsets[tidx] for (tidx, bts) in enumerate(b2registration.group_behave_by_trial(behaveTimeStamps, trialStartFrame))]
    )

    # Save behave onedata
    b2registration.saveone(behaveTimeStamps, "positionTracking.times")
    b2registration.saveone(behavePosition, "positionTracking.position")

    # Save trial onedata
    b2registration.saveone(trialStartFrame, "trials.positionTracking")
    b2registration.saveone(trialEnvironmentIndex, "trials.environmentIndex")
    b2registration.saveone(trialRoomLength, "trials.roomlength")
    b2registration.saveone(trialMovementGain, "trials.movementGain")
    b2registration.saveone(trialRewardPosition, "trials.rewardPosition")
    b2registration.saveone(trialRewardTolerance, "trials.rewardZoneHalfwidth")
    b2registration.saveone(trialRewardAvailability, "trials.rewardAvailability")
    b2registration.saveone(trialRewardDelivery, "trials.rewardPositionTracking")
    b2registration.saveone(trialActiveLicking, "trials.activeLicking")
    b2registration.saveone(trialActiveStopping, "trials.activeStopping")

    # Save lick onedata
    b2registration.saveone(lickBehaveSample, "licksTracking.positionTracking")

    return b2registration

oasis

oasis_deconvolution(fcorr, g, num_processes=cpu_count() - 1)

Perform oasis deconvolution on a batch of fluorescence traces.

Processes fluorescence traces in parallel using multiple processes to compute deconvolved spike estimates using the OASIS algorithm.

Parameters:

Name Type Description Default
fcorr ndarray

The fluorescence traces to process, shape (num_rois, num_frames).

required
g float

The g parameter for oasis deconvolution (decay constant).

required
num_processes int

The number of processes to use for parallel processing. Default is cpu_count() - 1.

cpu_count() - 1

Returns:

Type Description
list of np.ndarray

List of deconvolved traces, one per ROI. Each trace has negative values clipped to zero.

Raises:

Type Description
ValueError

If fcorr is not a 2D array or if num_processes is less than 1.

ImportError

If the oasis package cannot be imported.

Source code in vrAnalysis/registration/oasis.py
def oasis_deconvolution(fcorr: np.ndarray, g: float, num_processes: int = cpu_count() - 1) -> list[np.ndarray]:
    """
    Perform oasis deconvolution on a batch of fluorescence traces.

    Processes fluorescence traces in parallel using multiple processes to
    compute deconvolved spike estimates using the OASIS algorithm.

    Parameters
    ----------
    fcorr : np.ndarray
        The fluorescence traces to process, shape (num_rois, num_frames).
    g : float
        The g parameter for oasis deconvolution (decay constant).
    num_processes : int, optional
        The number of processes to use for parallel processing.
        Default is cpu_count() - 1.

    Returns
    -------
    list of np.ndarray
        List of deconvolved traces, one per ROI. Each trace has negative
        values clipped to zero.

    Raises
    ------
    ValueError
        If fcorr is not a 2D array or if num_processes is less than 1.
    ImportError
        If the oasis package cannot be imported.
    """
    if fcorr.ndim != 2:
        raise ValueError("fcorr must be a 2D numpy array.")
    if num_processes < 1:
        raise ValueError("num_processes must be at least 1.")

    # Lazy import of deconvolve method from oasis to not break registration
    # if oasis_deconvolution isn't used.
    try:
        from oasis.functions import deconvolve
    except ImportError as error:
        print("Failed to import deconvolve from oasis.")
        raise error

    # Create partial function with fixed parameters
    process_func = partial(_process_fc, g=g, deconvolve=deconvolve)

    with Pool(num_processes) as pool:
        results = tqdm(pool.imap(process_func, fcorr), total=len(fcorr))
        return list(results)