Skip to content

Processors API Reference

The processors module provides data processing pipelines that transform session data into analysis-ready formats. The main component is the SpkmapProcessor, which creates spatial maps of neural activity, behavioral occupancy, and speed.

Overview

The processors module contains:

  • SpkmapProcessor: Main class for processing spike maps from session data
  • Maps: Container class for occupancy, speed, and spike maps
  • Reliability: Container class for reliability measurements
  • SpkmapParams: Configuration parameters for spike map processing

Core Classes

SpkmapProcessor dataclass

Class for processing and caching spike maps from session data

NOTES ON ENGINEERING: I want the variables required for processing spkmaps to be properties (@property) that have hidden attributes for caching. Therefore, we can use the property method to get the attribute and each property method can do whatever processing is needed for that attribute. (Uh, duh). Time to get modern. lol.

Right now I've almost got the register_spkmaps method working again (not tested yet) but now is when the dataclass refactoring comes in. 1. Make it possible to separate the occmap from the spkmap loading. - do so by making the preliminary variables properties with caching 2. Consider how to implement smoothing then correctMap functionality -- it should be possible to do this in a way that allows me to iteratively try different parameterizations without having to go through the whole pipeline again. 3. Consider how / when to implement reliability measures. In PCSS, they're done all right there with get_spkmaps. But it's probably not always necessary and can actually take a bit of time? It would also be nice to save reliability scores for the neurons... but then we'd also need an independent params saving system for them. 4. Re: the point above, I wonder if the one.data loading system is ideal or if I should use a more explicit and dedicated SpkmapProcessor saving / loading system.

Source code in vrAnalysis/processors/spkmaps.py
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
@dataclass
class SpkmapProcessor:
    """Class for processing and caching spike maps from session data

    NOTES ON ENGINEERING:
    I want the variables required for processing spkmaps to be properties (@property)
    that have hidden attributes for caching. Therefore, we can use the property method
    to get the attribute and each property method can do whatever processing is needed
    for that attribute. (Uh, duh). Time to get modern. lol.

    Right now I've almost got the register_spkmaps method working again (not tested yet)
    but now is when the dataclass refactoring comes in.
    1. Make it possible to separate the occmap from the spkmap loading.
       - do so by making the preliminary variables properties with caching
    2. Consider how to implement smoothing then correctMap functionality -- it should
       be possible to do this in a way that allows me to iteratively try different
       parameterizations without having to go through the whole pipeline again.
    3. Consider how / when to implement reliability measures. In PCSS, they're done all
       right there with get_spkmaps. But it's probably not always necessary and can
       actually take a bit of time? It would also be nice to save reliability scores for
       the neurons... but then we'd also need an independent params saving system for them.
    4. Re: the point above, I wonder if the one.data loading system is ideal or if I should
       use a more explicit and dedicated SpkmapProcessor saving / loading system.
    """

    session: Union[SessionData, B2Session, SessionToSpkmapProtocol]
    params: SpkmapParams = field(default_factory=SpkmapParams, repr=False)
    data_cache: dict = field(default_factory=dict, repr=False, init=False)

    def __post_init__(self):
        # Check if the session provided is compatible with SpkmapProcessing
        if not isinstance(self.session, SessionData):
            raise ValueError(f"session must be a SessionData instance, not {type(self.session)}")
        # (Don't check if it's a SessionToSpkmapProtocol because hasattr() will call properties which loads data...)

        # We need to handle the case where params is a dictionary of partial updates to the default params
        self.params = helpers.resolve_dataclass(self.params, SpkmapParams)

    def cached_dependencies(self, data_type: str) -> List[str]:
        """Get the parameter dependencies for a given data type.

        Parameters
        ----------
        data_type : str
            Type of cached data ("raw_maps", "processed_maps", "env_maps", or "reliability").

        Returns
        -------
        list of str
            List of parameter names that affect the cache validity for this data type.
        """
        if data_type == "raw_maps":
            return ["dist_step", "speed_threshold", "speed_max_allowed", "standardize_spks"]
        elif data_type == "processed_maps":
            return ["dist_step", "speed_threshold", "speed_max_allowed", "standardize_spks", "smooth_width"]
        elif data_type == "env_maps":
            return ["dist_step", "speed_threshold", "speed_max_allowed", "standardize_spks", "smooth_width", "full_trial_flexibility"]
        elif data_type == "reliability":
            return [
                "dist_step",
                "speed_threshold",
                "speed_max_allowed",
                "standardize_spks",
                "smooth_width",
                "full_trial_flexibility",
                "reliability_method",
            ]
        # Otherwise just return all params
        return list(self.params.__dict__.keys())

    def show_cache(self, data_type: Optional[str] = None) -> None:
        """Helper function that scrapes the cache directory and shows cached files

        Parameters
        ----------
        data_type: Optional[str] = None
            Indicate a data type to filter which parts of the cache to show

        Notes
        -----
        Prints a formatted table showing cache information including data_type, size,
        parameters, and modification date. If no cache directory exists, prints a message.
        """
        import os
        from datetime import datetime

        # Get the base cache directory
        base_cache_dir = self.cache_directory()

        if not base_cache_dir.exists():
            print(f"No cache directory found at: {base_cache_dir}")
            return

        # Collect information about all cache files
        cache_info = []

        # Define the data types to check
        if data_type is not None:
            data_types_to_check = [data_type]
        else:
            data_types_to_check = ["raw_maps", "processed_maps", "env_maps", "reliability"]

        for dt in data_types_to_check:
            cache_dir = self.cache_directory(dt)
            if not cache_dir.exists():
                continue

            # Find all parameter files (they define what caches exist)
            param_files = list(cache_dir.glob("params_*.npz"))

            for param_file in param_files:
                # Extract the hash from the filename
                params_hash = param_file.stem.replace("params_", "")

                # Load the parameters
                try:
                    cached_params = dict(np.load(param_file))
                    param_str = ", ".join([f"{k}={v}" for k, v in cached_params.items()])
                except Exception as e:
                    param_str = f"Error loading params: {e}"

                # Get file modification time
                mod_time = datetime.fromtimestamp(param_file.stat().st_mtime)
                date_str = mod_time.strftime("%Y-%m-%d %H:%M:%S")

                # Calculate total size of all related cache files
                total_size = param_file.stat().st_size

                if dt in ["raw_maps", "processed_maps"]:
                    # For maps, look for data files for each map type
                    for mapname in ["occmap", "speedmap", "spkmap"]:
                        data_file = cache_dir / f"data_{mapname}_{params_hash}.npy"
                        if data_file.exists():
                            total_size += data_file.stat().st_size

                elif dt == "env_maps":
                    # For env_maps, look for environment file and individual environment data files
                    env_file = cache_dir / f"data_environments_{params_hash}.npy"
                    if env_file.exists():
                        total_size += env_file.stat().st_size
                        # Load environments to find all data files
                        try:
                            environments = np.load(env_file)
                            for env in environments:
                                for mapname in ["occmap", "speedmap", "spkmap"]:
                                    data_file = cache_dir / f"data_{mapname}_{env}_{params_hash}.npy"
                                    if data_file.exists():
                                        total_size += data_file.stat().st_size
                        except Exception:
                            pass  # Continue even if we can't load environments

                elif dt == "reliability":
                    # For reliability, look for environments and reliability data files
                    env_file = cache_dir / f"data_environments_{params_hash}.npy"
                    rel_file = cache_dir / f"data_reliability_{params_hash}.npy"
                    if env_file.exists():
                        total_size += env_file.stat().st_size
                    if rel_file.exists():
                        total_size += rel_file.stat().st_size

                # Convert size to human readable format
                size_str = self._format_file_size(total_size)

                cache_info.append(
                    {
                        "data_type": dt,
                        "size": size_str,
                        "parameters": param_str,
                        "date": date_str,
                        "hash": params_hash[:8],  # Show first 8 chars of hash
                    }
                )

        if not cache_info:
            print("No cache files found.")
            return

        # Format the output as a table
        output_lines = []
        output_lines.append("Cache Files Summary")
        output_lines.append("=" * 80)
        output_lines.append(f"{'Data Type':<15} {'Size':<10} {'Date':<20} {'Hash':<10} {'Parameters'}")
        output_lines.append("-" * 80)

        for info in cache_info:
            output_lines.append(f"{info['data_type']:<15} {info['size']:<10} {info['date']:<20} " f"{info['hash']:<10} {info['parameters']}")

        output_lines.append("-" * 80)
        output_lines.append(f"Total cache entries: {len(cache_info)}")

        result = "\n".join(output_lines)
        print(result)

    def _format_file_size(self, size_bytes: int) -> str:
        """Convert bytes to human-readable format.

        Parameters
        ----------
        size_bytes : int
            Size in bytes.

        Returns
        -------
        str
            Human-readable size string (e.g., "1.5 MB").
        """
        if size_bytes == 0:
            return "0 B"

        size_names = ["B", "KB", "MB", "GB", "TB"]
        import math

        i = int(math.floor(math.log(size_bytes, 1024)))
        p = math.pow(1024, i)
        s = round(size_bytes / p, 2)
        return f"{s} {size_names[i]}"

    def cache_directory(self, data_type: Optional[str] = None) -> Path:
        """Get the cache directory path for a given data type.

        Parameters
        ----------
        data_type : str, optional
            Type of cached data. If None, returns the base cache directory.
            Default is None.

        Returns
        -------
        Path
            Path to the cache directory for the specified data type.
        """
        if data_type is None:
            return self.session.data_path / "spkmaps"
        else:
            folder_name = f"{data_type}_{self.session.spks_type}"
            return self.session.data_path / "spkmaps" / folder_name

    def dependent_params(self, data_type: str) -> dict:
        """Get the dependent parameters for a given data type as a dictionary.

        Parameters
        ----------
        data_type : str
            Type of cached data.

        Returns
        -------
        dict
            Dictionary mapping parameter names to their values for the given data type.
        """
        return {k: getattr(self.params, k) for k in self.cached_dependencies(data_type)}

    def _params_hash(self, data_type: str) -> str:
        """Get the hash of the dependent parameters for a given data type.

        Parameters
        ----------
        data_type : str
            Type of cached data.

        Returns
        -------
        str
            SHA256 hash of the dependent parameters (as hexadecimal string).
        """
        return hashlib.sha256(json.dumps(self.dependent_params(data_type), sort_keys=True).encode()).hexdigest()

    def save_cache(self, data_type: str, data: Union[Maps, Reliability]) -> None:
        """Save the cached parameters and data for a given data type.

        Parameters
        ----------
        data_type : str
            Type of data being cached ("raw_maps", "processed_maps", "env_maps", or "reliability").
        data : Maps or Reliability
            The data object to cache.

        Notes
        -----
        Creates the cache directory if it doesn't exist. Saves parameters as an NPZ file
        and data as NPY files, using a hash of the parameters in the filenames.
        """
        cache_dir = self.cache_directory(data_type)
        params_hash = self._params_hash(data_type)
        cache_param_path = cache_dir / f"params_{params_hash}.npz"
        if not cache_dir.exists():
            cache_dir.mkdir(parents=True, exist_ok=True)
        np.savez(cache_param_path, **self.dependent_params(data_type))
        if data_type == "raw_maps" or data_type == "processed_maps":
            for mapname in Maps.map_types():
                cache_data_path = cache_dir / f"data_{mapname}_{params_hash}.npy"
                np.save(cache_data_path, getattr(data, mapname))
        elif data_type == "env_maps":
            environments = data.environments
            np.save(cache_dir / f"data_environments_{params_hash}.npy", environments)
            for ienv, env in enumerate(environments):
                for mapname in Maps.map_types():
                    cache_data_path = cache_dir / f"data_{mapname}_{env}_{params_hash}.npy"
                    np.save(cache_data_path, getattr(data, mapname)[ienv])
        elif data_type == "reliability":
            values = data.values
            environments = data.environments
            # don't need data.method because it's in params...
            np.save(cache_dir / f"data_environments_{params_hash}.npy", environments)
            np.save(cache_dir / f"data_reliability_{params_hash}.npy", values)
        else:
            raise ValueError(f"Unknown data type: {data_type}")

    def load_from_cache(self, data_type: str) -> Tuple[Union[Maps, Reliability, None], bool]:
        """Load cached parameters and data for a given data type.

        Parameters
        ----------
        data_type : str
            Type of cached data to load.

        Returns
        -------
        tuple
            A tuple containing:
            - The cached data (Maps or Reliability), or None if not found
            - A boolean indicating whether valid cache was found
        """
        cache_dir = self.cache_directory(data_type)
        if cache_dir.exists():
            # If the directory exists, check if there are any cached params that match the expected hash
            params_hash = self._params_hash(data_type)
            cached_params_path = cache_dir / f"params_{params_hash}.npz"
            if cached_params_path.exists():
                cached_params = dict(np.load(cached_params_path))
                # Check if the cached params match the dependent params
                if self.check_params_match(cached_params):
                    return self._load_from_cache(data_type, params_hash, params=cached_params), True
        return None, False

    def check_params_match(self, cached_params: dict) -> bool:
        """Check if the cached params and the current params are the same.

        Parameters
        ----------
        cached_params : dict
            The cached params to check against the current params

        Returns
        -------
        bool
            True if the cached params are nonempty and match the current params, False otherwise
        """
        return cached_params and all(cached_params[k] == getattr(self.params, k) for k in cached_params)

    def _load_from_cache(self, data_type: str, params_hash: str, params: Optional[Dict[str, Any]] | None = None) -> Union[Maps, Reliability]:
        """Load cached data from disk using a parameter hash.

        Parameters
        ----------
        data_type : str
            Type of cached data to load.
        params_hash : str
            Hash string identifying the cached parameters.
        params : dict, optional
            Dictionary of cached parameters. Used for reliability method.
            Default is None.

        Returns
        -------
        Maps or Reliability
            The loaded cached data object.

        Raises
        ------
        ValueError
            If data_type is not recognized.
        """
        cache_dir = self.cache_directory(data_type)
        if data_type == "raw_maps" or data_type == "processed_maps":
            cached_data = {}
            for name in Maps.map_types():
                cached_data[name] = np.load(cache_dir / f"data_{name}_{params_hash}.npy", mmap_mode="r")
            if data_type == "raw_maps":
                return Maps.create_raw_maps(**cached_data)
            elif data_type == "processed_maps":
                return Maps.create_processed_maps(**cached_data)
        elif data_type == "env_maps":
            environments = np.load(cache_dir / f"data_environments_{params_hash}.npy")
            cached_data = dict(environments=environments)
            for name in Maps.map_types():
                cached_data[name] = []
                for env in environments:
                    cached_data[name].append(np.load(cache_dir / f"data_{name}_{env}_{params_hash}.npy", mmap_mode="r"))
            return Maps.create_environment_maps(**cached_data)
        elif data_type == "reliability":
            environments = np.load(cache_dir / f"data_environments_{params_hash}.npy")
            values = np.load(cache_dir / f"data_reliability_{params_hash}.npy")
            method = params["reliability_method"]
            return Reliability(values, environments, method)
        else:
            raise ValueError(f"Unknown data type: {data_type}")

    @manage_one_cache
    def _filter_environments(
        self,
        envnum: Union[int, Iterable[int], None] = None,
        clear_one_cache: bool = True,
    ) -> np.ndarray:
        """Filter the session data to only include trials from certain environments.

        Parameters
        ----------
        envnum : int, iterable of int, or None, optional
            Environment number(s) to filter. If None, returns all trials.
            Default is None.
        clear_one_cache : bool, optional
            Whether to clear the onefile cache after filtering. Default is True.

        Returns
        -------
        np.ndarray
            Boolean array indicating which trials belong to the specified environment(s).

        Notes
        -----
        This assumes that the trials are in order. We might want to use the third
        output of session.positions to get the "real" trial numbers which aren't
        always contiguous and 0 indexed.
        """
        if envnum is None:
            envnum = self.session.environments
        envnum = helpers.check_iterable(envnum)
        return np.isin(self.session.trial_environment, envnum)

    @property
    def dist_edges(self) -> np.ndarray:
        """Distance edges for the position bins.

        Returns
        -------
        np.ndarray
            1D array of position bin edges. Shape is (num_positions + 1,).

        Raises
        ------
        ValueError
            If not all trials have the same environment length.

        Notes
        -----
        The number of position bins is determined by dividing the environment
        length by dist_step. This property caches the environment length
        internally after first access.
        """
        if not hasattr(self, "_env_length"):
            env_length = self.session.env_length
            if hasattr(env_length, "__len__"):
                if np.unique(env_length).size != 1:
                    msg = "SpkmapProcessor (currently) requires all trials to have the same env length!"
                    raise ValueError(msg)
                env_length = env_length[0]
            self._env_length = env_length

        num_positions = int(self._env_length / self.params.dist_step)
        return np.linspace(0, self._env_length, num_positions + 1)

    @property
    def dist_centers(self) -> np.ndarray:
        """Distance centers for the position bins.

        Returns
        -------
        np.ndarray
            1D array of position bin centers. Shape is (num_positions,).
        """
        return helpers.edge2center(self.dist_edges)

    @manage_one_cache
    def _idx_required_position_bins(self, clear_one_cache: bool = True) -> np.ndarray:
        """Get the indices of the position bins that are required for a full trial

        Parameters
        ----------
        clear_one_cache : bool, default=False
            Whether to clear the onefile cache after getting the indices

        Returns
        -------
        np.ndarray
            The indices of the position bins that are required for a trial to be considered full
        """
        num_position_bins = len(self.dist_centers)
        if self.params.full_trial_flexibility is None:
            idx_to_required_bins = np.arange(num_position_bins)
        else:
            start_idx = np.where(self.dist_edges >= self.params.full_trial_flexibility)[0][0]
            end_idx = np.where(self.dist_edges <= self.dist_edges[-1] - self.params.full_trial_flexibility)[0][-1]
            idx_to_required_bins = np.arange(start_idx, end_idx)
        return idx_to_required_bins

    @with_temp_params
    @manage_one_cache
    @cached_processor("raw_maps", disable=False)
    def get_raw_maps(
        self,
        force_recompute: bool = False,
        clear_one_cache: bool = True,
        params: Union[SpkmapParams, Dict[str, Any], None] = None,
    ) -> Maps:
        """Get raw maps (occupancy, speed, spkmap) from session data.

        This method processes session data to create spatial maps representing
        occupancy, speed, and neural activity across position bins. The maps
        are in raw format (not smoothed or normalized by occupancy).

        Parameters
        ----------
        force_recompute : bool, optional
            Whether to force recomputation even if cached data exists. Default is False.
        clear_one_cache : bool, optional
            Whether to clear the onefile cache after processing. Default is True.
        params : SpkmapParams, dict, or None, optional
            Parameters for processing. If None, uses instance parameters.
            If a dict, updates instance parameters temporarily.
            Parameters are restored after method execution. Default is None.

        Returns
        -------
        Maps
            Maps instance containing raw occupancy, speed, and spike maps.
            Shape: (trials, positions) for occmap/speedmap,
            (trials, positions, rois) for spkmap.

        Notes
        -----
        The method:
        1. Bins positions according to dist_step
        2. Filters by speed threshold
        3. Computes occupancy, speed, and spike maps
        4. Sets unvisited position bins to NaN
        5. Optionally standardizes spike data

        Results are cached based on parameter hash for efficient reuse.
        """
        dist_edges = self.dist_edges
        dist_centers = self.dist_centers
        num_positions = len(dist_centers)

        # Get behavioral timestamps and positions
        timestamps, positions, trial_numbers, idx_behave_to_frame = self.session.positions

        # compute behavioral speed on each sample
        within_trial_sample = np.append(np.diff(trial_numbers) == 0, True)
        sample_duration = np.append(np.diff(timestamps), 0)
        speeds = np.append(np.diff(positions) / sample_duration[:-1], 0)
        # do this after division so no /0 errors
        sample_duration = sample_duration * within_trial_sample
        # speed 0 in last sample for each trial (it's undefined)
        speeds = speeds * within_trial_sample
        # Convert positions to position bins
        position_bin = np.digitize(positions, dist_edges) - 1

        # get imaging information
        frame_time_stamps = self.session.timestamps
        sampling_period = np.median(np.diff(frame_time_stamps))
        dist_cutoff = sampling_period / 2
        delay_position_to_imaging = frame_time_stamps[idx_behave_to_frame] - timestamps

        # get spiking information
        spks = self.session.spks
        num_rois = self.session.get_value("numROIs")

        # Do standardization
        if self.params.standardize_spks:
            spks = median_zscore(spks, median_subtract=not self.session.zero_baseline_spks)

        # Get high resolution occupancy and speed maps
        dtype = np.float32
        occmap = np.zeros((self.session.num_trials, num_positions), dtype=dtype)
        counts = np.zeros((self.session.num_trials, num_positions), dtype=dtype)
        speedmap = np.zeros((self.session.num_trials, num_positions), dtype=dtype)
        spkmap = np.zeros((self.session.num_trials, num_positions, num_rois), dtype=dtype)
        extra_counts = np.zeros((self.session.num_trials, num_positions), dtype=dtype)

        # Get maps -- doing this independently for each map allows for more
        # flexibility in which data to load (basically the occmap & speedmap
        # are instantaneous, but the spkmap is a bit slower)
        get_summation_map(
            sample_duration,
            trial_numbers,
            position_bin,
            occmap,
            counts,
            speeds,
            self.params.speed_threshold,
            self.params.speed_max_allowed,
            delay_position_to_imaging,
            dist_cutoff,
            sample_duration,
            scale_by_sample_duration=False,
            use_sample_to_value_idx=False,
            sample_to_value_idx=idx_behave_to_frame,
        )
        get_summation_map(
            speeds,
            trial_numbers,
            position_bin,
            speedmap,
            counts,
            speeds,
            self.params.speed_threshold,
            self.params.speed_max_allowed,
            delay_position_to_imaging,
            dist_cutoff,
            sample_duration,
            scale_by_sample_duration=True,
            use_sample_to_value_idx=False,
            sample_to_value_idx=idx_behave_to_frame,
        )
        get_summation_map(
            spks,
            trial_numbers,
            position_bin,
            spkmap,
            extra_counts,
            speeds,
            self.params.speed_threshold,
            self.params.speed_max_allowed,
            delay_position_to_imaging,
            dist_cutoff,
            sample_duration,
            scale_by_sample_duration=True,
            use_sample_to_value_idx=True,
            sample_to_value_idx=idx_behave_to_frame,
        )

        # Figure out the valid range (outside of this range, set the maps to nan, because their values are not meaningful)
        position_bin_per_trial = [position_bin[trial_numbers == tnum] for tnum in range(self.session.num_trials)]

        # offsetting by 1 because there is a bug in the vrControl software where the first sample is always set
        # to the minimum position (which is 0), but if there is a built-up buffer in the rotary encoder, the position
        # will jump at the second sample. In general this will always work unless the mice have a truly ridiculous
        # speed at the beginning of the trial...
        first_valid_bin = [np.min(bpb[1:] if len(bpb) > 1 else bpb) for bpb in position_bin_per_trial]
        last_valid_bin = [np.max(bpb) for bpb in position_bin_per_trial]

        # set bins to nan when mouse didn't visit them
        occmap = replace_missing_data(occmap, first_valid_bin, last_valid_bin)
        speedmap = replace_missing_data(speedmap, first_valid_bin, last_valid_bin)
        spkmap = replace_missing_data(spkmap, first_valid_bin, last_valid_bin)

        return Maps.create_raw_maps(occmap, speedmap, spkmap)

    @with_temp_params
    @manage_one_cache
    @cached_processor("processed_maps", disable=False)
    def get_processed_maps(
        self,
        force_recompute: bool = False,
        clear_one_cache: bool = True,
        params: Union[SpkmapParams, Dict[str, Any], None] = None,
    ) -> Maps:
        """Get processed maps (smoothed and normalized by occupancy).

        This method creates processed maps by:
        1. Getting raw maps
        2. Optionally smoothing with a Gaussian kernel
        3. Normalizing speedmap and spkmap by occupancy
        4. Reorganizing spkmap to have ROIs as the first dimension

        Parameters
        ----------
        force_recompute : bool, optional
            Whether to force recomputation even if cached data exists. Default is False.
        clear_one_cache : bool, optional
            Whether to clear the onefile cache after processing. Default is True.
        params : SpkmapParams, dict, or None, optional
            Parameters for processing. If None, uses instance parameters.
            If a dict, updates instance parameters temporarily.
            Parameters are restored after method execution. Default is None.

        Returns
        -------
        Maps
            Maps instance containing processed occupancy, speed, and spike maps.
            Shape: (trials, positions) for occmap/speedmap,
            (rois, trials, positions) for spkmap.
        """
        # Get the raw maps first (don't need to specify params because they're already set by this method)
        maps = self.get_raw_maps(
            force_recompute=force_recompute,
            clear_one_cache=clear_one_cache,
        )

        # Process the maps (smooth, divide by occupancy, and change to ROIs first)
        return maps.raw_to_processed(self.dist_centers, self.params.smooth_width)

    @with_temp_params
    @manage_one_cache
    @cached_processor("env_maps", disable=False)
    def get_env_maps(
        self,
        use_session_filters: bool = True,
        force_recompute: bool = False,
        clear_one_cache: bool = True,
        params: Union[SpkmapParams, Dict[str, Any], None] = None,
    ) -> Maps:
        """Get processed maps separated by environment.

        This method creates environment-separated maps by:
        1. Getting processed maps
        2. Filtering to include only full trials (based on full_trial_flexibility)
        3. Filtering ROIs if use_session_filters=True
        4. Grouping maps by environment

        Parameters
        ----------
        use_session_filters : bool, optional
            Whether to filter ROIs using session.idx_rois. Default is True.
        force_recompute : bool, optional
            Whether to force recomputation even if cached data exists. Default is False.
        clear_one_cache : bool, optional
            Whether to clear the onefile cache after processing. Default is True.
        params : SpkmapParams, dict, or None, optional
            Parameters for processing. If None, uses instance parameters.
            If a dict, updates instance parameters temporarily.
            Parameters are restored after method execution. Default is None.

        Returns
        -------
        Maps
            Maps instance with by_environment=True, containing lists of maps
            for each environment. Shape per environment:
            (trials_in_env, positions) for occmap/speedmap,
            (rois, trials_in_env, positions) for spkmap.
        """
        # Make sure it's an iterable -- the output will always be a list
        envnum = helpers.check_iterable(self.session.environments)

        # Get the indices of the trials to each environment
        idx_each_environment = [self._filter_environments(env) for env in envnum]

        # Then get the indices of the position bins that are required for a full trial
        idx_required_position_bins = self._idx_required_position_bins(clear_one_cache)

        # Get the processed maps (don't need to specify params because they're already set by the decorator)
        maps = self.get_processed_maps(
            force_recompute=force_recompute,
            clear_one_cache=clear_one_cache,
        )

        # Add the list of environments to the maps
        maps.environments = envnum

        # Make a list of the maps we are processing
        maps_to_process = Maps.map_types()

        # Filter the maps to only include the ROIs we want
        if use_session_filters:
            idx_rois = np.where(self.session.idx_rois)[0]
        else:
            idx_rois = np.arange(self.session.get_value("numROIs"), dtype=int)

        # Filter the maps to only include the full trials
        full_trials = np.where(np.all(~np.isnan(maps.occmap[:, idx_required_position_bins]), axis=1))[0]

        # Implement trial & ROI filtering here
        for mapname in maps_to_process:
            if mapname == "spkmap":
                maps[mapname] = np.take(np.take(maps[mapname], idx_rois, axis=0), full_trials, axis=1)
            else:
                maps[mapname] = np.take(maps[mapname], full_trials, axis=0)

        # Filter the trial indices to only include full trials
        idx_each_environment = [np.where(np.take(idx, full_trials, axis=0))[0] for idx in idx_each_environment]

        # Then group each one by environment
        # -> this is now (trials_in_env, position_bins, ...(roi if spkmap)...)
        maps.by_environment = True
        for mapname in maps_to_process:
            if mapname == "spkmap":
                maps[mapname] = [np.take(maps[mapname], idx, axis=1) for idx in idx_each_environment]
            else:
                maps[mapname] = [np.take(maps[mapname], idx, axis=0) for idx in idx_each_environment]

        return maps

    @with_temp_params
    @manage_one_cache
    @cached_processor("reliability", disable=False)
    def get_reliability(
        self,
        use_session_filters: bool = True,
        force_recompute: bool = False,
        clear_one_cache: bool = True,
        params: Union[SpkmapParams, Dict[str, Any], None] = None,
    ) -> Reliability:
        """Calculate reliability of spike maps across trials.

        Reliability measures how consistent neural activity is across trials
        within each environment. Multiple methods are supported.

        Parameters
        ----------
        use_session_filters : bool, optional
            Whether to filter ROIs using session.idx_rois. Default is True.
        force_recompute : bool, optional
            Whether to force recomputation even if cached data exists. Default is False.
        clear_one_cache : bool, optional
            Whether to clear the onefile cache after processing. Default is True.
        params : SpkmapParams, dict, or None, optional
            Parameters for processing. If None, uses instance parameters.
            If a dict, updates instance parameters temporarily.
            Parameters are restored after method execution. Default is None.

        Returns
        -------
        Reliability
            Reliability instance containing reliability values for each ROI
            in each environment. Shape: (num_environments, num_rois).

        Notes
        -----
        Supported reliability methods:
        - "leave_one_out": Leave-one-out cross-validation
        - "correlation": Correlation between trial pairs
        - "mse": Mean squared error between trial pairs

        All reliability measures require maps with no NaN positions.
        """
        envnum = helpers.check_iterable(self.session.environments)

        # A list of the requested environments (all if not specified)
        maps = self.get_env_maps(
            use_session_filters=use_session_filters,
            force_recompute=force_recompute,
            clear_one_cache=clear_one_cache,
            params={"autosave": False},  # Prevent saving in the case of a recompute
        )

        # All reliability measures require no NaNs
        maps.pop_nan_positions()

        if self.params.reliability_method == "leave_one_out":
            rel_values = [helpers.reliability_loo(spkmap) for spkmap in maps.spkmap]
        elif self.params.reliability_method == "correlation" or self.params.reliability_method == "mse":
            rel_mse, rel_cor = helpers.named_transpose([helpers.measureReliability(spkmap) for spkmap in maps.spkmap])
            rel_values = rel_mse if self.params.reliability_method == "mse" else rel_cor
        else:
            raise ValueError(f"Method {self.params.reliability_method} not supported")

        return Reliability(
            np.stack(rel_values),
            environments=envnum,
            method=self.params.reliability_method,
        )

    # ------------------- convert between imaging and behavioral time -------------------
    @with_temp_params
    @manage_one_cache
    def get_frame_behavior(
        self,
        clear_one_cache: bool = True,
        params: Union[SpkmapParams, Dict[str, Any], None] = None,
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
        """Get position and environment data for each imaging frame.

        This method aligns behavioral data (position, speed, environment, trial)
        to imaging frame timestamps. Returns NaN for frames where no position
        data is available (e.g., if the closest behavioral sample is further
        away in time than half the sampling period).

        Parameters
        ----------
        clear_one_cache : bool, optional
            Whether to clear the onefile cache after processing. Default is True.
        params : SpkmapParams, dict, or None, optional
            Parameters for processing. If None, uses instance parameters.
            If a dict, updates instance parameters temporarily.
            Parameters are restored after method execution. Default is None.

        Returns
        -------
        tuple
            A tuple containing four arrays (all with shape (num_frames,)):
            - frame_position: Position for each frame (NaN if unavailable)
            - frame_speed: Speed for each frame (NaN if unavailable)
            - frame_environment: Environment number for each frame (NaN if unavailable)
            - frame_trial: Trial number for each frame (NaN if unavailable)
        """
        timestamps = self.session.loadone("positionTracking.times")
        position = self.session.loadone("positionTracking.position")
        idx_behave_to_frame = self.session.loadone("positionTracking.mpci")
        trial_start_index = self.session.loadone("trials.positionTracking")
        num_samples = len(position)
        trial_numbers = np.arange(len(trial_start_index))
        trial_lengths = np.append(np.diff(trial_start_index), num_samples - trial_start_index[-1])
        trial_numbers = np.repeat(trial_numbers, trial_lengths)
        trial_environment = self.session.loadone("trials.environmentIndex")
        trial_environment = np.repeat(trial_environment, trial_lengths)

        within_trial = np.append(np.diff(trial_numbers) == 0, True)
        sample_duration = np.append(np.diff(timestamps), 0)
        speed = np.append(np.diff(position) / sample_duration[:-1], 0)
        sample_duration = sample_duration * within_trial
        speed = speed * within_trial

        frame_timestamps = self.session.loadone("mpci.times")
        difference_timestamps = np.abs(timestamps - frame_timestamps[idx_behave_to_frame])
        sampling_period = np.median(np.diff(frame_timestamps))
        dist_cutoff = sampling_period / 2

        frame_position = np.zeros_like(frame_timestamps)
        count = np.zeros_like(frame_timestamps)
        helpers.get_average_frame_position(position, idx_behave_to_frame, difference_timestamps, dist_cutoff, frame_position, count)
        frame_position[count > 0] /= count[count > 0]
        frame_position[count == 0] = np.nan
        frame_speed = np.diff(frame_position) / np.diff(frame_timestamps)
        frame_speed = np.append(frame_speed, 0)

        # Get a map from frame to behavior time for quick lookup
        idx_frame_to_behave, dist_frame_to_behave = helpers.nearestpoint(frame_timestamps, timestamps)
        idx_get_position = dist_frame_to_behave < dist_cutoff

        frame_environment = np.full(len(frame_timestamps), np.nan)
        frame_environment[idx_get_position] = trial_environment[idx_frame_to_behave[idx_get_position]]
        frame_environment[count == 0] = np.nan

        frame_trial = np.full(len(frame_timestamps), np.nan)
        frame_trial[idx_get_position] = trial_numbers[idx_frame_to_behave[idx_get_position]]
        frame_trial[count == 0] = np.nan

        return frame_position, frame_speed, frame_environment, frame_trial

    @with_temp_params
    @manage_one_cache
    def get_placefield_prediction(
        self,
        use_session_filters: bool = True,
        spks_type: Union[str, None] = None,
        use_speed_threshold: bool = True,
        clear_one_cache: bool = True,
        params: Union[SpkmapParams, Dict[str, Any], None] = None,
    ) -> Tuple[np.ndarray, Dict[str, np.ndarray]]:
        """Predict neural activity from place field maps.

        This method uses averaged environment maps to predict neural activity
        at each imaging frame based on the animal's position and environment.

        Parameters
        ----------
        use_session_filters : bool, optional
            Whether to filter ROIs using session.idx_rois. Default is True.
        spks_type : str or None, optional
            Type of spike data to use. If None, uses session's current spks_type.
            Temporarily changes session.spks_type if provided. Default is None.
        use_speed_threshold : bool, optional
            Whether to only predict for frames where speed exceeds threshold.
            Default is True.
        clear_one_cache : bool, optional
            Whether to clear the onefile cache after processing. Default is True.
        params : SpkmapParams, dict, or None, optional
            Parameters for processing. If None, uses instance parameters.
            If a dict, updates instance parameters temporarily.
            Parameters are restored after method execution. Default is None.

        Returns
        -------
        tuple
            A tuple containing:
            - placefield_prediction: Predicted activity array with shape (frames, rois).
              NaN for frames where prediction is not possible.
            - extras: Dictionary with additional information:
              - frame_position_index: Position bin index for each frame
              - frame_environment_index: Environment index for each frame
              - idx_valid: Boolean array indicating valid predictions

        Notes
        -----
        Predictions are based on averaged trial maps. Frames where the animal
        is not moving (if use_speed_threshold=True) or where position/environment
        data is unavailable will have NaN predictions.
        """
        if spks_type is not None:
            _spks_type = self.session.spks_type
            self.session.params.spks_type = spks_type

        frame_position, frame_speed, frame_environment, _ = self.get_frame_behavior(clear_one_cache, params)
        idx_valid = ~np.isnan(frame_position)
        if use_speed_threshold:
            idx_valid = idx_valid & (frame_speed > self.params.speed_threshold)

        # Convert frame position to bins indices
        frame_position_index = np.searchsorted(self.dist_edges, frame_position, side="right") - 1

        # Get the place field for each neuron
        env_maps = self.get_env_maps(use_session_filters=use_session_filters)
        env_maps.average_trials()

        # Convert frame environment to indices
        env_to_idx = {env: i for i, env in enumerate(env_maps.environments)}
        frame_environment_index = np.array([env_to_idx[env] if not np.isnan(env) else -1000 for env in frame_environment], dtype=int)

        # Get the original spks data
        spks = self.session.spks
        if use_session_filters:
            spks = spks[:, self.session.idx_rois]

        # Use a numba speed up to get the placefield prediction (single pass simple algorithm)
        placefield_prediction = np.full(spks.shape, np.nan)
        spkmaps = np.stack(list(map(lambda x: x.T, env_maps.spkmap)))
        placefield_prediction = placefield_prediction_numba(
            placefield_prediction,
            spkmaps,
            frame_environment_index,
            frame_position_index,
            idx_valid,
        )

        # This will add samples for which a place field was not estimable (at the edges of the environment)
        idx_valid = np.all(~np.isnan(placefield_prediction), axis=1)

        # Reset spks_type
        if spks_type is not None:
            self.session.params.spks_type = _spks_type

        # Include extra details in a dictionary for forward compatibility
        extras = dict(
            frame_position_index=frame_position_index,
            frame_environment_index=frame_environment_index,
            idx_valid=idx_valid,
        )

        return placefield_prediction, extras

    def get_traversals(
        self,
        idx_roi: int,
        idx_env: int,
        width: int = 10,
        placefield_threshold: float = 5.0,
        fill_nan: bool = False,
        spks: np.ndarray = None,
        spks_prediction: np.ndarray = None,
    ) -> Tuple[np.ndarray, np.ndarray]:
        """Extract neural activity around place field peak during traversals.

        This method identifies trials where the animal passes through a neuron's
        place field peak and extracts activity windows around those moments.

        Parameters
        ----------
        idx_roi : int
            Index of the ROI (neuron) to analyze.
        idx_env : int
            Index of the environment to analyze (index into env_maps.environments).
        width : int, optional
            Number of frames on each side of the peak to include. Total window
            size is 2*width + 1. Default is 10.
        placefield_threshold : float, optional
            Maximum distance from place field peak to include a trial (in spatial units).
            Default is 5.0.
        fill_nan : bool, optional
            Whether to fill NaN values with 0. Default is False.
        spks : np.ndarray, optional
            Spike data array. If None, loads from session. Default is None.
        spks_prediction : np.ndarray, optional
            Place field prediction array. If None, computes it. Default is None.

        Returns
        -------
        tuple
            A tuple containing:
            - traversals: Array of shape (num_traversals, 2*width+1) containing
              actual neural activity around each traversal.
            - pred_travs: Array of shape (num_traversals, 2*width+1) containing
              predicted activity around each traversal.

        Notes
        -----
        Only includes trials where the animal passes within placefield_threshold
        of the place field peak. The peak is determined from the averaged spike
        map for the specified ROI and environment.
        """
        frame_position, _, frame_environment, frame_trial = self.get_frame_behavior()
        if spks_prediction is None:
            spks_prediction = self.get_placefield_prediction(use_session_filters=True)[0]
        if spks is None:
            spks = self.session.spks[:, self.session.idx_rois]

        if spks.shape != spks_prediction.shape:
            raise ValueError("spks and spks_prediction must have the same shape")

        env_maps = self.get_env_maps()
        pos_peak = self.dist_centers[np.nanargmax(np.nanmean(env_maps.spkmap[idx_env][idx_roi], axis=0))]
        envnum = env_maps.environments[idx_env]

        env_trials = np.unique(frame_trial[frame_environment == envnum])

        num_trials = len(env_trials)
        idx_traversal = -1 * np.ones(num_trials, dtype=int)
        for itrial, trialnum in enumerate(env_trials):
            idx_trial = frame_trial == trialnum
            idx_closest_pos = np.nanargmin(np.abs(frame_position - pos_peak) + 10000 * ~idx_trial)

            # Only include the trial if the closest position is within placefield threshold of the peak
            if np.abs(frame_position[idx_closest_pos] - pos_peak) < placefield_threshold:
                idx_traversal[itrial] = idx_closest_pos

        # Filter out trials that don't have a traversal
        idx_traversal = idx_traversal[idx_traversal != -1]

        # Get traversals through place field in requested environment
        traversals = np.zeros((len(idx_traversal), width * 2 + 1))
        pred_travs = np.zeros((len(idx_traversal), width * 2 + 1))
        for ii, it in enumerate(idx_traversal):
            istart = it - width
            iend = it + width + 1
            istartoffset = max(0, -istart)
            iendoffset = max(0, iend - spks.shape[0])
            traversals[ii, istartoffset : width * 2 + 1 - iendoffset] = spks[istart + istartoffset : iend - iendoffset, idx_roi]
            pred_travs[ii, istartoffset : width * 2 + 1 - iendoffset] = spks_prediction[istart + istartoffset : iend - iendoffset, idx_roi]

        if fill_nan:
            traversals[np.isnan(traversals)] = 0.0
            pred_travs[np.isnan(pred_travs)] = 0.0

        return traversals, pred_travs

dist_centers property

Distance centers for the position bins.

Returns:

Type Description
ndarray

1D array of position bin centers. Shape is (num_positions,).

dist_edges property

Distance edges for the position bins.

Returns:

Type Description
ndarray

1D array of position bin edges. Shape is (num_positions + 1,).

Raises:

Type Description
ValueError

If not all trials have the same environment length.

Notes

The number of position bins is determined by dividing the environment length by dist_step. This property caches the environment length internally after first access.

cache_directory(data_type=None)

Get the cache directory path for a given data type.

Parameters:

Name Type Description Default
data_type str

Type of cached data. If None, returns the base cache directory. Default is None.

None

Returns:

Type Description
Path

Path to the cache directory for the specified data type.

Source code in vrAnalysis/processors/spkmaps.py
def cache_directory(self, data_type: Optional[str] = None) -> Path:
    """Get the cache directory path for a given data type.

    Parameters
    ----------
    data_type : str, optional
        Type of cached data. If None, returns the base cache directory.
        Default is None.

    Returns
    -------
    Path
        Path to the cache directory for the specified data type.
    """
    if data_type is None:
        return self.session.data_path / "spkmaps"
    else:
        folder_name = f"{data_type}_{self.session.spks_type}"
        return self.session.data_path / "spkmaps" / folder_name

cached_dependencies(data_type)

Get the parameter dependencies for a given data type.

Parameters:

Name Type Description Default
data_type str

Type of cached data ("raw_maps", "processed_maps", "env_maps", or "reliability").

required

Returns:

Type Description
list of str

List of parameter names that affect the cache validity for this data type.

Source code in vrAnalysis/processors/spkmaps.py
def cached_dependencies(self, data_type: str) -> List[str]:
    """Get the parameter dependencies for a given data type.

    Parameters
    ----------
    data_type : str
        Type of cached data ("raw_maps", "processed_maps", "env_maps", or "reliability").

    Returns
    -------
    list of str
        List of parameter names that affect the cache validity for this data type.
    """
    if data_type == "raw_maps":
        return ["dist_step", "speed_threshold", "speed_max_allowed", "standardize_spks"]
    elif data_type == "processed_maps":
        return ["dist_step", "speed_threshold", "speed_max_allowed", "standardize_spks", "smooth_width"]
    elif data_type == "env_maps":
        return ["dist_step", "speed_threshold", "speed_max_allowed", "standardize_spks", "smooth_width", "full_trial_flexibility"]
    elif data_type == "reliability":
        return [
            "dist_step",
            "speed_threshold",
            "speed_max_allowed",
            "standardize_spks",
            "smooth_width",
            "full_trial_flexibility",
            "reliability_method",
        ]
    # Otherwise just return all params
    return list(self.params.__dict__.keys())

check_params_match(cached_params)

Check if the cached params and the current params are the same.

Parameters:

Name Type Description Default
cached_params dict

The cached params to check against the current params

required

Returns:

Type Description
bool

True if the cached params are nonempty and match the current params, False otherwise

Source code in vrAnalysis/processors/spkmaps.py
def check_params_match(self, cached_params: dict) -> bool:
    """Check if the cached params and the current params are the same.

    Parameters
    ----------
    cached_params : dict
        The cached params to check against the current params

    Returns
    -------
    bool
        True if the cached params are nonempty and match the current params, False otherwise
    """
    return cached_params and all(cached_params[k] == getattr(self.params, k) for k in cached_params)

dependent_params(data_type)

Get the dependent parameters for a given data type as a dictionary.

Parameters:

Name Type Description Default
data_type str

Type of cached data.

required

Returns:

Type Description
dict

Dictionary mapping parameter names to their values for the given data type.

Source code in vrAnalysis/processors/spkmaps.py
def dependent_params(self, data_type: str) -> dict:
    """Get the dependent parameters for a given data type as a dictionary.

    Parameters
    ----------
    data_type : str
        Type of cached data.

    Returns
    -------
    dict
        Dictionary mapping parameter names to their values for the given data type.
    """
    return {k: getattr(self.params, k) for k in self.cached_dependencies(data_type)}

get_env_maps(use_session_filters=True, force_recompute=False, clear_one_cache=True, params=None)

Get processed maps separated by environment.

This method creates environment-separated maps by: 1. Getting processed maps 2. Filtering to include only full trials (based on full_trial_flexibility) 3. Filtering ROIs if use_session_filters=True 4. Grouping maps by environment

Parameters:

Name Type Description Default
use_session_filters bool

Whether to filter ROIs using session.idx_rois. Default is True.

True
force_recompute bool

Whether to force recomputation even if cached data exists. Default is False.

False
clear_one_cache bool

Whether to clear the onefile cache after processing. Default is True.

True
params SpkmapParams, dict, or None

Parameters for processing. If None, uses instance parameters. If a dict, updates instance parameters temporarily. Parameters are restored after method execution. Default is None.

None

Returns:

Type Description
Maps

Maps instance with by_environment=True, containing lists of maps for each environment. Shape per environment: (trials_in_env, positions) for occmap/speedmap, (rois, trials_in_env, positions) for spkmap.

Source code in vrAnalysis/processors/spkmaps.py
@with_temp_params
@manage_one_cache
@cached_processor("env_maps", disable=False)
def get_env_maps(
    self,
    use_session_filters: bool = True,
    force_recompute: bool = False,
    clear_one_cache: bool = True,
    params: Union[SpkmapParams, Dict[str, Any], None] = None,
) -> Maps:
    """Get processed maps separated by environment.

    This method creates environment-separated maps by:
    1. Getting processed maps
    2. Filtering to include only full trials (based on full_trial_flexibility)
    3. Filtering ROIs if use_session_filters=True
    4. Grouping maps by environment

    Parameters
    ----------
    use_session_filters : bool, optional
        Whether to filter ROIs using session.idx_rois. Default is True.
    force_recompute : bool, optional
        Whether to force recomputation even if cached data exists. Default is False.
    clear_one_cache : bool, optional
        Whether to clear the onefile cache after processing. Default is True.
    params : SpkmapParams, dict, or None, optional
        Parameters for processing. If None, uses instance parameters.
        If a dict, updates instance parameters temporarily.
        Parameters are restored after method execution. Default is None.

    Returns
    -------
    Maps
        Maps instance with by_environment=True, containing lists of maps
        for each environment. Shape per environment:
        (trials_in_env, positions) for occmap/speedmap,
        (rois, trials_in_env, positions) for spkmap.
    """
    # Make sure it's an iterable -- the output will always be a list
    envnum = helpers.check_iterable(self.session.environments)

    # Get the indices of the trials to each environment
    idx_each_environment = [self._filter_environments(env) for env in envnum]

    # Then get the indices of the position bins that are required for a full trial
    idx_required_position_bins = self._idx_required_position_bins(clear_one_cache)

    # Get the processed maps (don't need to specify params because they're already set by the decorator)
    maps = self.get_processed_maps(
        force_recompute=force_recompute,
        clear_one_cache=clear_one_cache,
    )

    # Add the list of environments to the maps
    maps.environments = envnum

    # Make a list of the maps we are processing
    maps_to_process = Maps.map_types()

    # Filter the maps to only include the ROIs we want
    if use_session_filters:
        idx_rois = np.where(self.session.idx_rois)[0]
    else:
        idx_rois = np.arange(self.session.get_value("numROIs"), dtype=int)

    # Filter the maps to only include the full trials
    full_trials = np.where(np.all(~np.isnan(maps.occmap[:, idx_required_position_bins]), axis=1))[0]

    # Implement trial & ROI filtering here
    for mapname in maps_to_process:
        if mapname == "spkmap":
            maps[mapname] = np.take(np.take(maps[mapname], idx_rois, axis=0), full_trials, axis=1)
        else:
            maps[mapname] = np.take(maps[mapname], full_trials, axis=0)

    # Filter the trial indices to only include full trials
    idx_each_environment = [np.where(np.take(idx, full_trials, axis=0))[0] for idx in idx_each_environment]

    # Then group each one by environment
    # -> this is now (trials_in_env, position_bins, ...(roi if spkmap)...)
    maps.by_environment = True
    for mapname in maps_to_process:
        if mapname == "spkmap":
            maps[mapname] = [np.take(maps[mapname], idx, axis=1) for idx in idx_each_environment]
        else:
            maps[mapname] = [np.take(maps[mapname], idx, axis=0) for idx in idx_each_environment]

    return maps

get_frame_behavior(clear_one_cache=True, params=None)

Get position and environment data for each imaging frame.

This method aligns behavioral data (position, speed, environment, trial) to imaging frame timestamps. Returns NaN for frames where no position data is available (e.g., if the closest behavioral sample is further away in time than half the sampling period).

Parameters:

Name Type Description Default
clear_one_cache bool

Whether to clear the onefile cache after processing. Default is True.

True
params SpkmapParams, dict, or None

Parameters for processing. If None, uses instance parameters. If a dict, updates instance parameters temporarily. Parameters are restored after method execution. Default is None.

None

Returns:

Type Description
tuple

A tuple containing four arrays (all with shape (num_frames,)): - frame_position: Position for each frame (NaN if unavailable) - frame_speed: Speed for each frame (NaN if unavailable) - frame_environment: Environment number for each frame (NaN if unavailable) - frame_trial: Trial number for each frame (NaN if unavailable)

Source code in vrAnalysis/processors/spkmaps.py
@with_temp_params
@manage_one_cache
def get_frame_behavior(
    self,
    clear_one_cache: bool = True,
    params: Union[SpkmapParams, Dict[str, Any], None] = None,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """Get position and environment data for each imaging frame.

    This method aligns behavioral data (position, speed, environment, trial)
    to imaging frame timestamps. Returns NaN for frames where no position
    data is available (e.g., if the closest behavioral sample is further
    away in time than half the sampling period).

    Parameters
    ----------
    clear_one_cache : bool, optional
        Whether to clear the onefile cache after processing. Default is True.
    params : SpkmapParams, dict, or None, optional
        Parameters for processing. If None, uses instance parameters.
        If a dict, updates instance parameters temporarily.
        Parameters are restored after method execution. Default is None.

    Returns
    -------
    tuple
        A tuple containing four arrays (all with shape (num_frames,)):
        - frame_position: Position for each frame (NaN if unavailable)
        - frame_speed: Speed for each frame (NaN if unavailable)
        - frame_environment: Environment number for each frame (NaN if unavailable)
        - frame_trial: Trial number for each frame (NaN if unavailable)
    """
    timestamps = self.session.loadone("positionTracking.times")
    position = self.session.loadone("positionTracking.position")
    idx_behave_to_frame = self.session.loadone("positionTracking.mpci")
    trial_start_index = self.session.loadone("trials.positionTracking")
    num_samples = len(position)
    trial_numbers = np.arange(len(trial_start_index))
    trial_lengths = np.append(np.diff(trial_start_index), num_samples - trial_start_index[-1])
    trial_numbers = np.repeat(trial_numbers, trial_lengths)
    trial_environment = self.session.loadone("trials.environmentIndex")
    trial_environment = np.repeat(trial_environment, trial_lengths)

    within_trial = np.append(np.diff(trial_numbers) == 0, True)
    sample_duration = np.append(np.diff(timestamps), 0)
    speed = np.append(np.diff(position) / sample_duration[:-1], 0)
    sample_duration = sample_duration * within_trial
    speed = speed * within_trial

    frame_timestamps = self.session.loadone("mpci.times")
    difference_timestamps = np.abs(timestamps - frame_timestamps[idx_behave_to_frame])
    sampling_period = np.median(np.diff(frame_timestamps))
    dist_cutoff = sampling_period / 2

    frame_position = np.zeros_like(frame_timestamps)
    count = np.zeros_like(frame_timestamps)
    helpers.get_average_frame_position(position, idx_behave_to_frame, difference_timestamps, dist_cutoff, frame_position, count)
    frame_position[count > 0] /= count[count > 0]
    frame_position[count == 0] = np.nan
    frame_speed = np.diff(frame_position) / np.diff(frame_timestamps)
    frame_speed = np.append(frame_speed, 0)

    # Get a map from frame to behavior time for quick lookup
    idx_frame_to_behave, dist_frame_to_behave = helpers.nearestpoint(frame_timestamps, timestamps)
    idx_get_position = dist_frame_to_behave < dist_cutoff

    frame_environment = np.full(len(frame_timestamps), np.nan)
    frame_environment[idx_get_position] = trial_environment[idx_frame_to_behave[idx_get_position]]
    frame_environment[count == 0] = np.nan

    frame_trial = np.full(len(frame_timestamps), np.nan)
    frame_trial[idx_get_position] = trial_numbers[idx_frame_to_behave[idx_get_position]]
    frame_trial[count == 0] = np.nan

    return frame_position, frame_speed, frame_environment, frame_trial

get_placefield_prediction(use_session_filters=True, spks_type=None, use_speed_threshold=True, clear_one_cache=True, params=None)

Predict neural activity from place field maps.

This method uses averaged environment maps to predict neural activity at each imaging frame based on the animal's position and environment.

Parameters:

Name Type Description Default
use_session_filters bool

Whether to filter ROIs using session.idx_rois. Default is True.

True
spks_type str or None

Type of spike data to use. If None, uses session's current spks_type. Temporarily changes session.spks_type if provided. Default is None.

None
use_speed_threshold bool

Whether to only predict for frames where speed exceeds threshold. Default is True.

True
clear_one_cache bool

Whether to clear the onefile cache after processing. Default is True.

True
params SpkmapParams, dict, or None

Parameters for processing. If None, uses instance parameters. If a dict, updates instance parameters temporarily. Parameters are restored after method execution. Default is None.

None

Returns:

Type Description
tuple

A tuple containing: - placefield_prediction: Predicted activity array with shape (frames, rois). NaN for frames where prediction is not possible. - extras: Dictionary with additional information: - frame_position_index: Position bin index for each frame - frame_environment_index: Environment index for each frame - idx_valid: Boolean array indicating valid predictions

Notes

Predictions are based on averaged trial maps. Frames where the animal is not moving (if use_speed_threshold=True) or where position/environment data is unavailable will have NaN predictions.

Source code in vrAnalysis/processors/spkmaps.py
@with_temp_params
@manage_one_cache
def get_placefield_prediction(
    self,
    use_session_filters: bool = True,
    spks_type: Union[str, None] = None,
    use_speed_threshold: bool = True,
    clear_one_cache: bool = True,
    params: Union[SpkmapParams, Dict[str, Any], None] = None,
) -> Tuple[np.ndarray, Dict[str, np.ndarray]]:
    """Predict neural activity from place field maps.

    This method uses averaged environment maps to predict neural activity
    at each imaging frame based on the animal's position and environment.

    Parameters
    ----------
    use_session_filters : bool, optional
        Whether to filter ROIs using session.idx_rois. Default is True.
    spks_type : str or None, optional
        Type of spike data to use. If None, uses session's current spks_type.
        Temporarily changes session.spks_type if provided. Default is None.
    use_speed_threshold : bool, optional
        Whether to only predict for frames where speed exceeds threshold.
        Default is True.
    clear_one_cache : bool, optional
        Whether to clear the onefile cache after processing. Default is True.
    params : SpkmapParams, dict, or None, optional
        Parameters for processing. If None, uses instance parameters.
        If a dict, updates instance parameters temporarily.
        Parameters are restored after method execution. Default is None.

    Returns
    -------
    tuple
        A tuple containing:
        - placefield_prediction: Predicted activity array with shape (frames, rois).
          NaN for frames where prediction is not possible.
        - extras: Dictionary with additional information:
          - frame_position_index: Position bin index for each frame
          - frame_environment_index: Environment index for each frame
          - idx_valid: Boolean array indicating valid predictions

    Notes
    -----
    Predictions are based on averaged trial maps. Frames where the animal
    is not moving (if use_speed_threshold=True) or where position/environment
    data is unavailable will have NaN predictions.
    """
    if spks_type is not None:
        _spks_type = self.session.spks_type
        self.session.params.spks_type = spks_type

    frame_position, frame_speed, frame_environment, _ = self.get_frame_behavior(clear_one_cache, params)
    idx_valid = ~np.isnan(frame_position)
    if use_speed_threshold:
        idx_valid = idx_valid & (frame_speed > self.params.speed_threshold)

    # Convert frame position to bins indices
    frame_position_index = np.searchsorted(self.dist_edges, frame_position, side="right") - 1

    # Get the place field for each neuron
    env_maps = self.get_env_maps(use_session_filters=use_session_filters)
    env_maps.average_trials()

    # Convert frame environment to indices
    env_to_idx = {env: i for i, env in enumerate(env_maps.environments)}
    frame_environment_index = np.array([env_to_idx[env] if not np.isnan(env) else -1000 for env in frame_environment], dtype=int)

    # Get the original spks data
    spks = self.session.spks
    if use_session_filters:
        spks = spks[:, self.session.idx_rois]

    # Use a numba speed up to get the placefield prediction (single pass simple algorithm)
    placefield_prediction = np.full(spks.shape, np.nan)
    spkmaps = np.stack(list(map(lambda x: x.T, env_maps.spkmap)))
    placefield_prediction = placefield_prediction_numba(
        placefield_prediction,
        spkmaps,
        frame_environment_index,
        frame_position_index,
        idx_valid,
    )

    # This will add samples for which a place field was not estimable (at the edges of the environment)
    idx_valid = np.all(~np.isnan(placefield_prediction), axis=1)

    # Reset spks_type
    if spks_type is not None:
        self.session.params.spks_type = _spks_type

    # Include extra details in a dictionary for forward compatibility
    extras = dict(
        frame_position_index=frame_position_index,
        frame_environment_index=frame_environment_index,
        idx_valid=idx_valid,
    )

    return placefield_prediction, extras

get_processed_maps(force_recompute=False, clear_one_cache=True, params=None)

Get processed maps (smoothed and normalized by occupancy).

This method creates processed maps by: 1. Getting raw maps 2. Optionally smoothing with a Gaussian kernel 3. Normalizing speedmap and spkmap by occupancy 4. Reorganizing spkmap to have ROIs as the first dimension

Parameters:

Name Type Description Default
force_recompute bool

Whether to force recomputation even if cached data exists. Default is False.

False
clear_one_cache bool

Whether to clear the onefile cache after processing. Default is True.

True
params SpkmapParams, dict, or None

Parameters for processing. If None, uses instance parameters. If a dict, updates instance parameters temporarily. Parameters are restored after method execution. Default is None.

None

Returns:

Type Description
Maps

Maps instance containing processed occupancy, speed, and spike maps. Shape: (trials, positions) for occmap/speedmap, (rois, trials, positions) for spkmap.

Source code in vrAnalysis/processors/spkmaps.py
@with_temp_params
@manage_one_cache
@cached_processor("processed_maps", disable=False)
def get_processed_maps(
    self,
    force_recompute: bool = False,
    clear_one_cache: bool = True,
    params: Union[SpkmapParams, Dict[str, Any], None] = None,
) -> Maps:
    """Get processed maps (smoothed and normalized by occupancy).

    This method creates processed maps by:
    1. Getting raw maps
    2. Optionally smoothing with a Gaussian kernel
    3. Normalizing speedmap and spkmap by occupancy
    4. Reorganizing spkmap to have ROIs as the first dimension

    Parameters
    ----------
    force_recompute : bool, optional
        Whether to force recomputation even if cached data exists. Default is False.
    clear_one_cache : bool, optional
        Whether to clear the onefile cache after processing. Default is True.
    params : SpkmapParams, dict, or None, optional
        Parameters for processing. If None, uses instance parameters.
        If a dict, updates instance parameters temporarily.
        Parameters are restored after method execution. Default is None.

    Returns
    -------
    Maps
        Maps instance containing processed occupancy, speed, and spike maps.
        Shape: (trials, positions) for occmap/speedmap,
        (rois, trials, positions) for spkmap.
    """
    # Get the raw maps first (don't need to specify params because they're already set by this method)
    maps = self.get_raw_maps(
        force_recompute=force_recompute,
        clear_one_cache=clear_one_cache,
    )

    # Process the maps (smooth, divide by occupancy, and change to ROIs first)
    return maps.raw_to_processed(self.dist_centers, self.params.smooth_width)

get_raw_maps(force_recompute=False, clear_one_cache=True, params=None)

Get raw maps (occupancy, speed, spkmap) from session data.

This method processes session data to create spatial maps representing occupancy, speed, and neural activity across position bins. The maps are in raw format (not smoothed or normalized by occupancy).

Parameters:

Name Type Description Default
force_recompute bool

Whether to force recomputation even if cached data exists. Default is False.

False
clear_one_cache bool

Whether to clear the onefile cache after processing. Default is True.

True
params SpkmapParams, dict, or None

Parameters for processing. If None, uses instance parameters. If a dict, updates instance parameters temporarily. Parameters are restored after method execution. Default is None.

None

Returns:

Type Description
Maps

Maps instance containing raw occupancy, speed, and spike maps. Shape: (trials, positions) for occmap/speedmap, (trials, positions, rois) for spkmap.

Notes

The method: 1. Bins positions according to dist_step 2. Filters by speed threshold 3. Computes occupancy, speed, and spike maps 4. Sets unvisited position bins to NaN 5. Optionally standardizes spike data

Results are cached based on parameter hash for efficient reuse.

Source code in vrAnalysis/processors/spkmaps.py
@with_temp_params
@manage_one_cache
@cached_processor("raw_maps", disable=False)
def get_raw_maps(
    self,
    force_recompute: bool = False,
    clear_one_cache: bool = True,
    params: Union[SpkmapParams, Dict[str, Any], None] = None,
) -> Maps:
    """Get raw maps (occupancy, speed, spkmap) from session data.

    This method processes session data to create spatial maps representing
    occupancy, speed, and neural activity across position bins. The maps
    are in raw format (not smoothed or normalized by occupancy).

    Parameters
    ----------
    force_recompute : bool, optional
        Whether to force recomputation even if cached data exists. Default is False.
    clear_one_cache : bool, optional
        Whether to clear the onefile cache after processing. Default is True.
    params : SpkmapParams, dict, or None, optional
        Parameters for processing. If None, uses instance parameters.
        If a dict, updates instance parameters temporarily.
        Parameters are restored after method execution. Default is None.

    Returns
    -------
    Maps
        Maps instance containing raw occupancy, speed, and spike maps.
        Shape: (trials, positions) for occmap/speedmap,
        (trials, positions, rois) for spkmap.

    Notes
    -----
    The method:
    1. Bins positions according to dist_step
    2. Filters by speed threshold
    3. Computes occupancy, speed, and spike maps
    4. Sets unvisited position bins to NaN
    5. Optionally standardizes spike data

    Results are cached based on parameter hash for efficient reuse.
    """
    dist_edges = self.dist_edges
    dist_centers = self.dist_centers
    num_positions = len(dist_centers)

    # Get behavioral timestamps and positions
    timestamps, positions, trial_numbers, idx_behave_to_frame = self.session.positions

    # compute behavioral speed on each sample
    within_trial_sample = np.append(np.diff(trial_numbers) == 0, True)
    sample_duration = np.append(np.diff(timestamps), 0)
    speeds = np.append(np.diff(positions) / sample_duration[:-1], 0)
    # do this after division so no /0 errors
    sample_duration = sample_duration * within_trial_sample
    # speed 0 in last sample for each trial (it's undefined)
    speeds = speeds * within_trial_sample
    # Convert positions to position bins
    position_bin = np.digitize(positions, dist_edges) - 1

    # get imaging information
    frame_time_stamps = self.session.timestamps
    sampling_period = np.median(np.diff(frame_time_stamps))
    dist_cutoff = sampling_period / 2
    delay_position_to_imaging = frame_time_stamps[idx_behave_to_frame] - timestamps

    # get spiking information
    spks = self.session.spks
    num_rois = self.session.get_value("numROIs")

    # Do standardization
    if self.params.standardize_spks:
        spks = median_zscore(spks, median_subtract=not self.session.zero_baseline_spks)

    # Get high resolution occupancy and speed maps
    dtype = np.float32
    occmap = np.zeros((self.session.num_trials, num_positions), dtype=dtype)
    counts = np.zeros((self.session.num_trials, num_positions), dtype=dtype)
    speedmap = np.zeros((self.session.num_trials, num_positions), dtype=dtype)
    spkmap = np.zeros((self.session.num_trials, num_positions, num_rois), dtype=dtype)
    extra_counts = np.zeros((self.session.num_trials, num_positions), dtype=dtype)

    # Get maps -- doing this independently for each map allows for more
    # flexibility in which data to load (basically the occmap & speedmap
    # are instantaneous, but the spkmap is a bit slower)
    get_summation_map(
        sample_duration,
        trial_numbers,
        position_bin,
        occmap,
        counts,
        speeds,
        self.params.speed_threshold,
        self.params.speed_max_allowed,
        delay_position_to_imaging,
        dist_cutoff,
        sample_duration,
        scale_by_sample_duration=False,
        use_sample_to_value_idx=False,
        sample_to_value_idx=idx_behave_to_frame,
    )
    get_summation_map(
        speeds,
        trial_numbers,
        position_bin,
        speedmap,
        counts,
        speeds,
        self.params.speed_threshold,
        self.params.speed_max_allowed,
        delay_position_to_imaging,
        dist_cutoff,
        sample_duration,
        scale_by_sample_duration=True,
        use_sample_to_value_idx=False,
        sample_to_value_idx=idx_behave_to_frame,
    )
    get_summation_map(
        spks,
        trial_numbers,
        position_bin,
        spkmap,
        extra_counts,
        speeds,
        self.params.speed_threshold,
        self.params.speed_max_allowed,
        delay_position_to_imaging,
        dist_cutoff,
        sample_duration,
        scale_by_sample_duration=True,
        use_sample_to_value_idx=True,
        sample_to_value_idx=idx_behave_to_frame,
    )

    # Figure out the valid range (outside of this range, set the maps to nan, because their values are not meaningful)
    position_bin_per_trial = [position_bin[trial_numbers == tnum] for tnum in range(self.session.num_trials)]

    # offsetting by 1 because there is a bug in the vrControl software where the first sample is always set
    # to the minimum position (which is 0), but if there is a built-up buffer in the rotary encoder, the position
    # will jump at the second sample. In general this will always work unless the mice have a truly ridiculous
    # speed at the beginning of the trial...
    first_valid_bin = [np.min(bpb[1:] if len(bpb) > 1 else bpb) for bpb in position_bin_per_trial]
    last_valid_bin = [np.max(bpb) for bpb in position_bin_per_trial]

    # set bins to nan when mouse didn't visit them
    occmap = replace_missing_data(occmap, first_valid_bin, last_valid_bin)
    speedmap = replace_missing_data(speedmap, first_valid_bin, last_valid_bin)
    spkmap = replace_missing_data(spkmap, first_valid_bin, last_valid_bin)

    return Maps.create_raw_maps(occmap, speedmap, spkmap)

get_reliability(use_session_filters=True, force_recompute=False, clear_one_cache=True, params=None)

Calculate reliability of spike maps across trials.

Reliability measures how consistent neural activity is across trials within each environment. Multiple methods are supported.

Parameters:

Name Type Description Default
use_session_filters bool

Whether to filter ROIs using session.idx_rois. Default is True.

True
force_recompute bool

Whether to force recomputation even if cached data exists. Default is False.

False
clear_one_cache bool

Whether to clear the onefile cache after processing. Default is True.

True
params SpkmapParams, dict, or None

Parameters for processing. If None, uses instance parameters. If a dict, updates instance parameters temporarily. Parameters are restored after method execution. Default is None.

None

Returns:

Type Description
Reliability

Reliability instance containing reliability values for each ROI in each environment. Shape: (num_environments, num_rois).

Notes

Supported reliability methods: - "leave_one_out": Leave-one-out cross-validation - "correlation": Correlation between trial pairs - "mse": Mean squared error between trial pairs

All reliability measures require maps with no NaN positions.

Source code in vrAnalysis/processors/spkmaps.py
@with_temp_params
@manage_one_cache
@cached_processor("reliability", disable=False)
def get_reliability(
    self,
    use_session_filters: bool = True,
    force_recompute: bool = False,
    clear_one_cache: bool = True,
    params: Union[SpkmapParams, Dict[str, Any], None] = None,
) -> Reliability:
    """Calculate reliability of spike maps across trials.

    Reliability measures how consistent neural activity is across trials
    within each environment. Multiple methods are supported.

    Parameters
    ----------
    use_session_filters : bool, optional
        Whether to filter ROIs using session.idx_rois. Default is True.
    force_recompute : bool, optional
        Whether to force recomputation even if cached data exists. Default is False.
    clear_one_cache : bool, optional
        Whether to clear the onefile cache after processing. Default is True.
    params : SpkmapParams, dict, or None, optional
        Parameters for processing. If None, uses instance parameters.
        If a dict, updates instance parameters temporarily.
        Parameters are restored after method execution. Default is None.

    Returns
    -------
    Reliability
        Reliability instance containing reliability values for each ROI
        in each environment. Shape: (num_environments, num_rois).

    Notes
    -----
    Supported reliability methods:
    - "leave_one_out": Leave-one-out cross-validation
    - "correlation": Correlation between trial pairs
    - "mse": Mean squared error between trial pairs

    All reliability measures require maps with no NaN positions.
    """
    envnum = helpers.check_iterable(self.session.environments)

    # A list of the requested environments (all if not specified)
    maps = self.get_env_maps(
        use_session_filters=use_session_filters,
        force_recompute=force_recompute,
        clear_one_cache=clear_one_cache,
        params={"autosave": False},  # Prevent saving in the case of a recompute
    )

    # All reliability measures require no NaNs
    maps.pop_nan_positions()

    if self.params.reliability_method == "leave_one_out":
        rel_values = [helpers.reliability_loo(spkmap) for spkmap in maps.spkmap]
    elif self.params.reliability_method == "correlation" or self.params.reliability_method == "mse":
        rel_mse, rel_cor = helpers.named_transpose([helpers.measureReliability(spkmap) for spkmap in maps.spkmap])
        rel_values = rel_mse if self.params.reliability_method == "mse" else rel_cor
    else:
        raise ValueError(f"Method {self.params.reliability_method} not supported")

    return Reliability(
        np.stack(rel_values),
        environments=envnum,
        method=self.params.reliability_method,
    )

get_traversals(idx_roi, idx_env, width=10, placefield_threshold=5.0, fill_nan=False, spks=None, spks_prediction=None)

Extract neural activity around place field peak during traversals.

This method identifies trials where the animal passes through a neuron's place field peak and extracts activity windows around those moments.

Parameters:

Name Type Description Default
idx_roi int

Index of the ROI (neuron) to analyze.

required
idx_env int

Index of the environment to analyze (index into env_maps.environments).

required
width int

Number of frames on each side of the peak to include. Total window size is 2*width + 1. Default is 10.

10
placefield_threshold float

Maximum distance from place field peak to include a trial (in spatial units). Default is 5.0.

5.0
fill_nan bool

Whether to fill NaN values with 0. Default is False.

False
spks ndarray

Spike data array. If None, loads from session. Default is None.

None
spks_prediction ndarray

Place field prediction array. If None, computes it. Default is None.

None

Returns:

Type Description
tuple

A tuple containing: - traversals: Array of shape (num_traversals, 2width+1) containing actual neural activity around each traversal. - pred_travs: Array of shape (num_traversals, 2width+1) containing predicted activity around each traversal.

Notes

Only includes trials where the animal passes within placefield_threshold of the place field peak. The peak is determined from the averaged spike map for the specified ROI and environment.

Source code in vrAnalysis/processors/spkmaps.py
def get_traversals(
    self,
    idx_roi: int,
    idx_env: int,
    width: int = 10,
    placefield_threshold: float = 5.0,
    fill_nan: bool = False,
    spks: np.ndarray = None,
    spks_prediction: np.ndarray = None,
) -> Tuple[np.ndarray, np.ndarray]:
    """Extract neural activity around place field peak during traversals.

    This method identifies trials where the animal passes through a neuron's
    place field peak and extracts activity windows around those moments.

    Parameters
    ----------
    idx_roi : int
        Index of the ROI (neuron) to analyze.
    idx_env : int
        Index of the environment to analyze (index into env_maps.environments).
    width : int, optional
        Number of frames on each side of the peak to include. Total window
        size is 2*width + 1. Default is 10.
    placefield_threshold : float, optional
        Maximum distance from place field peak to include a trial (in spatial units).
        Default is 5.0.
    fill_nan : bool, optional
        Whether to fill NaN values with 0. Default is False.
    spks : np.ndarray, optional
        Spike data array. If None, loads from session. Default is None.
    spks_prediction : np.ndarray, optional
        Place field prediction array. If None, computes it. Default is None.

    Returns
    -------
    tuple
        A tuple containing:
        - traversals: Array of shape (num_traversals, 2*width+1) containing
          actual neural activity around each traversal.
        - pred_travs: Array of shape (num_traversals, 2*width+1) containing
          predicted activity around each traversal.

    Notes
    -----
    Only includes trials where the animal passes within placefield_threshold
    of the place field peak. The peak is determined from the averaged spike
    map for the specified ROI and environment.
    """
    frame_position, _, frame_environment, frame_trial = self.get_frame_behavior()
    if spks_prediction is None:
        spks_prediction = self.get_placefield_prediction(use_session_filters=True)[0]
    if spks is None:
        spks = self.session.spks[:, self.session.idx_rois]

    if spks.shape != spks_prediction.shape:
        raise ValueError("spks and spks_prediction must have the same shape")

    env_maps = self.get_env_maps()
    pos_peak = self.dist_centers[np.nanargmax(np.nanmean(env_maps.spkmap[idx_env][idx_roi], axis=0))]
    envnum = env_maps.environments[idx_env]

    env_trials = np.unique(frame_trial[frame_environment == envnum])

    num_trials = len(env_trials)
    idx_traversal = -1 * np.ones(num_trials, dtype=int)
    for itrial, trialnum in enumerate(env_trials):
        idx_trial = frame_trial == trialnum
        idx_closest_pos = np.nanargmin(np.abs(frame_position - pos_peak) + 10000 * ~idx_trial)

        # Only include the trial if the closest position is within placefield threshold of the peak
        if np.abs(frame_position[idx_closest_pos] - pos_peak) < placefield_threshold:
            idx_traversal[itrial] = idx_closest_pos

    # Filter out trials that don't have a traversal
    idx_traversal = idx_traversal[idx_traversal != -1]

    # Get traversals through place field in requested environment
    traversals = np.zeros((len(idx_traversal), width * 2 + 1))
    pred_travs = np.zeros((len(idx_traversal), width * 2 + 1))
    for ii, it in enumerate(idx_traversal):
        istart = it - width
        iend = it + width + 1
        istartoffset = max(0, -istart)
        iendoffset = max(0, iend - spks.shape[0])
        traversals[ii, istartoffset : width * 2 + 1 - iendoffset] = spks[istart + istartoffset : iend - iendoffset, idx_roi]
        pred_travs[ii, istartoffset : width * 2 + 1 - iendoffset] = spks_prediction[istart + istartoffset : iend - iendoffset, idx_roi]

    if fill_nan:
        traversals[np.isnan(traversals)] = 0.0
        pred_travs[np.isnan(pred_travs)] = 0.0

    return traversals, pred_travs

load_from_cache(data_type)

Load cached parameters and data for a given data type.

Parameters:

Name Type Description Default
data_type str

Type of cached data to load.

required

Returns:

Type Description
tuple

A tuple containing: - The cached data (Maps or Reliability), or None if not found - A boolean indicating whether valid cache was found

Source code in vrAnalysis/processors/spkmaps.py
def load_from_cache(self, data_type: str) -> Tuple[Union[Maps, Reliability, None], bool]:
    """Load cached parameters and data for a given data type.

    Parameters
    ----------
    data_type : str
        Type of cached data to load.

    Returns
    -------
    tuple
        A tuple containing:
        - The cached data (Maps or Reliability), or None if not found
        - A boolean indicating whether valid cache was found
    """
    cache_dir = self.cache_directory(data_type)
    if cache_dir.exists():
        # If the directory exists, check if there are any cached params that match the expected hash
        params_hash = self._params_hash(data_type)
        cached_params_path = cache_dir / f"params_{params_hash}.npz"
        if cached_params_path.exists():
            cached_params = dict(np.load(cached_params_path))
            # Check if the cached params match the dependent params
            if self.check_params_match(cached_params):
                return self._load_from_cache(data_type, params_hash, params=cached_params), True
    return None, False

save_cache(data_type, data)

Save the cached parameters and data for a given data type.

Parameters:

Name Type Description Default
data_type str

Type of data being cached ("raw_maps", "processed_maps", "env_maps", or "reliability").

required
data Maps or Reliability

The data object to cache.

required
Notes

Creates the cache directory if it doesn't exist. Saves parameters as an NPZ file and data as NPY files, using a hash of the parameters in the filenames.

Source code in vrAnalysis/processors/spkmaps.py
def save_cache(self, data_type: str, data: Union[Maps, Reliability]) -> None:
    """Save the cached parameters and data for a given data type.

    Parameters
    ----------
    data_type : str
        Type of data being cached ("raw_maps", "processed_maps", "env_maps", or "reliability").
    data : Maps or Reliability
        The data object to cache.

    Notes
    -----
    Creates the cache directory if it doesn't exist. Saves parameters as an NPZ file
    and data as NPY files, using a hash of the parameters in the filenames.
    """
    cache_dir = self.cache_directory(data_type)
    params_hash = self._params_hash(data_type)
    cache_param_path = cache_dir / f"params_{params_hash}.npz"
    if not cache_dir.exists():
        cache_dir.mkdir(parents=True, exist_ok=True)
    np.savez(cache_param_path, **self.dependent_params(data_type))
    if data_type == "raw_maps" or data_type == "processed_maps":
        for mapname in Maps.map_types():
            cache_data_path = cache_dir / f"data_{mapname}_{params_hash}.npy"
            np.save(cache_data_path, getattr(data, mapname))
    elif data_type == "env_maps":
        environments = data.environments
        np.save(cache_dir / f"data_environments_{params_hash}.npy", environments)
        for ienv, env in enumerate(environments):
            for mapname in Maps.map_types():
                cache_data_path = cache_dir / f"data_{mapname}_{env}_{params_hash}.npy"
                np.save(cache_data_path, getattr(data, mapname)[ienv])
    elif data_type == "reliability":
        values = data.values
        environments = data.environments
        # don't need data.method because it's in params...
        np.save(cache_dir / f"data_environments_{params_hash}.npy", environments)
        np.save(cache_dir / f"data_reliability_{params_hash}.npy", values)
    else:
        raise ValueError(f"Unknown data type: {data_type}")

show_cache(data_type=None)

Helper function that scrapes the cache directory and shows cached files

Parameters:

Name Type Description Default
data_type Optional[str]

Indicate a data type to filter which parts of the cache to show

None
Notes

Prints a formatted table showing cache information including data_type, size, parameters, and modification date. If no cache directory exists, prints a message.

Source code in vrAnalysis/processors/spkmaps.py
def show_cache(self, data_type: Optional[str] = None) -> None:
    """Helper function that scrapes the cache directory and shows cached files

    Parameters
    ----------
    data_type: Optional[str] = None
        Indicate a data type to filter which parts of the cache to show

    Notes
    -----
    Prints a formatted table showing cache information including data_type, size,
    parameters, and modification date. If no cache directory exists, prints a message.
    """
    import os
    from datetime import datetime

    # Get the base cache directory
    base_cache_dir = self.cache_directory()

    if not base_cache_dir.exists():
        print(f"No cache directory found at: {base_cache_dir}")
        return

    # Collect information about all cache files
    cache_info = []

    # Define the data types to check
    if data_type is not None:
        data_types_to_check = [data_type]
    else:
        data_types_to_check = ["raw_maps", "processed_maps", "env_maps", "reliability"]

    for dt in data_types_to_check:
        cache_dir = self.cache_directory(dt)
        if not cache_dir.exists():
            continue

        # Find all parameter files (they define what caches exist)
        param_files = list(cache_dir.glob("params_*.npz"))

        for param_file in param_files:
            # Extract the hash from the filename
            params_hash = param_file.stem.replace("params_", "")

            # Load the parameters
            try:
                cached_params = dict(np.load(param_file))
                param_str = ", ".join([f"{k}={v}" for k, v in cached_params.items()])
            except Exception as e:
                param_str = f"Error loading params: {e}"

            # Get file modification time
            mod_time = datetime.fromtimestamp(param_file.stat().st_mtime)
            date_str = mod_time.strftime("%Y-%m-%d %H:%M:%S")

            # Calculate total size of all related cache files
            total_size = param_file.stat().st_size

            if dt in ["raw_maps", "processed_maps"]:
                # For maps, look for data files for each map type
                for mapname in ["occmap", "speedmap", "spkmap"]:
                    data_file = cache_dir / f"data_{mapname}_{params_hash}.npy"
                    if data_file.exists():
                        total_size += data_file.stat().st_size

            elif dt == "env_maps":
                # For env_maps, look for environment file and individual environment data files
                env_file = cache_dir / f"data_environments_{params_hash}.npy"
                if env_file.exists():
                    total_size += env_file.stat().st_size
                    # Load environments to find all data files
                    try:
                        environments = np.load(env_file)
                        for env in environments:
                            for mapname in ["occmap", "speedmap", "spkmap"]:
                                data_file = cache_dir / f"data_{mapname}_{env}_{params_hash}.npy"
                                if data_file.exists():
                                    total_size += data_file.stat().st_size
                    except Exception:
                        pass  # Continue even if we can't load environments

            elif dt == "reliability":
                # For reliability, look for environments and reliability data files
                env_file = cache_dir / f"data_environments_{params_hash}.npy"
                rel_file = cache_dir / f"data_reliability_{params_hash}.npy"
                if env_file.exists():
                    total_size += env_file.stat().st_size
                if rel_file.exists():
                    total_size += rel_file.stat().st_size

            # Convert size to human readable format
            size_str = self._format_file_size(total_size)

            cache_info.append(
                {
                    "data_type": dt,
                    "size": size_str,
                    "parameters": param_str,
                    "date": date_str,
                    "hash": params_hash[:8],  # Show first 8 chars of hash
                }
            )

    if not cache_info:
        print("No cache files found.")
        return

    # Format the output as a table
    output_lines = []
    output_lines.append("Cache Files Summary")
    output_lines.append("=" * 80)
    output_lines.append(f"{'Data Type':<15} {'Size':<10} {'Date':<20} {'Hash':<10} {'Parameters'}")
    output_lines.append("-" * 80)

    for info in cache_info:
        output_lines.append(f"{info['data_type']:<15} {info['size']:<10} {info['date']:<20} " f"{info['hash']:<10} {info['parameters']}")

    output_lines.append("-" * 80)
    output_lines.append(f"Total cache entries: {len(cache_info)}")

    result = "\n".join(output_lines)
    print(result)

options: show_root_heading: true show_root_toc_entry: true heading_level: 3

Maps dataclass

Container for occupancy, speed, and spike maps.

This class holds spatial maps representing neural activity, behavioral occupancy, and speed across position bins. Maps can be organized either as single arrays (all trials combined) or as lists of arrays (separated by environment).

Attributes:

Name Type Description
occmap np.ndarray or list of np.ndarray

Occupancy map(s) representing time spent in each position bin. Shape: (trials, positions) for single array, or list of (trials, positions) arrays when by_environment=True.

speedmap np.ndarray or list of np.ndarray

Speed map(s) representing average speed in each position bin. Shape: (trials, positions) for single array, or list of (trials, positions) arrays when by_environment=True.

spkmap np.ndarray or list of np.ndarray

Spike map(s) representing neural activity in each position bin. Shape depends on rois_first: - If rois_first=True: (rois, trials, positions) or list of (rois, trials, positions) - If rois_first=False: (trials, positions, rois) or list of (trials, positions, rois)

by_environment bool

Whether maps are separated by environment (True) or combined (False).

rois_first bool

Whether ROI dimension is first (True) or last (False) in spkmap arrays.

environments list of int, optional

List of environment numbers when by_environment=True. Default is None.

distcenters (ndarray, optional)

Center positions of distance bins. Default is None.

_averaged bool

Internal flag indicating whether trials have been averaged. Default is False.

Notes

The Maps class supports two organizational modes: 1. Single maps: All trials combined in single arrays (by_environment=False) 2. Environment-separated maps: Maps split by environment (by_environment=True)

The spkmap can have ROIs as the first or last dimension depending on rois_first. This allows flexibility in how data is organized for different processing steps.

Source code in vrAnalysis/processors/spkmaps.py
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
@dataclass
class Maps:
    """Container for occupancy, speed, and spike maps.

    This class holds spatial maps representing neural activity, behavioral occupancy,
    and speed across position bins. Maps can be organized either as single arrays
    (all trials combined) or as lists of arrays (separated by environment).

    Attributes
    ----------
    occmap : np.ndarray or list of np.ndarray
        Occupancy map(s) representing time spent in each position bin.
        Shape: (trials, positions) for single array, or list of (trials, positions)
        arrays when by_environment=True.
    speedmap : np.ndarray or list of np.ndarray
        Speed map(s) representing average speed in each position bin.
        Shape: (trials, positions) for single array, or list of (trials, positions)
        arrays when by_environment=True.
    spkmap : np.ndarray or list of np.ndarray
        Spike map(s) representing neural activity in each position bin.
        Shape depends on rois_first:
        - If rois_first=True: (rois, trials, positions) or list of (rois, trials, positions)
        - If rois_first=False: (trials, positions, rois) or list of (trials, positions, rois)
    by_environment : bool
        Whether maps are separated by environment (True) or combined (False).
    rois_first : bool
        Whether ROI dimension is first (True) or last (False) in spkmap arrays.
    environments : list of int, optional
        List of environment numbers when by_environment=True. Default is None.
    distcenters : np.ndarray, optional
        Center positions of distance bins. Default is None.
    _averaged : bool
        Internal flag indicating whether trials have been averaged. Default is False.

    Notes
    -----
    The Maps class supports two organizational modes:
    1. Single maps: All trials combined in single arrays (by_environment=False)
    2. Environment-separated maps: Maps split by environment (by_environment=True)

    The spkmap can have ROIs as the first or last dimension depending on rois_first.
    This allows flexibility in how data is organized for different processing steps.
    """

    occmap: np.ndarray | list[np.ndarray]
    speedmap: np.ndarray | list[np.ndarray]
    spkmap: np.ndarray | list[np.ndarray]
    by_environment: bool
    rois_first: bool
    environments: list[int] | None = None
    distcenters: np.ndarray | None = None
    _averaged: bool = field(default=False, init=False)

    def __post_init__(self):
        if self.occmap is None or self.speedmap is None or self.spkmap is None:
            raise ValueError("occmap, speedmap, and spkmap must be provided")

        if self.by_environment:
            if self.environments is None:
                raise ValueError("environments must be provided if by_environment is True")
            if not isinstance(self.occmap, list) or not isinstance(self.speedmap, list) or not isinstance(self.spkmap, list):
                raise ValueError("occmap, speedmap, and spkmap must be lists if by_environment is True")
        else:
            if isinstance(self.occmap, list) or isinstance(self.speedmap, list) or isinstance(self.spkmap, list):
                raise ValueError("occmap, speedmap, and spkmap must be single arrays if by_environment is False")

        if not self.by_environment:
            spkmap_shape = self.spkmap.shape[1:] if self.rois_first else self.spkmap.shape[:2]
            if not (self.occmap.shape == self.speedmap.shape == spkmap_shape):
                raise ValueError("occmap, speedmap, and spkmap must have the same shape")
        else:
            if not (len(self.occmap) == len(self.speedmap) == len(self.spkmap) == len(self.environments)):
                raise ValueError("occmap, speedmap, and spkmap must have the same number of environments")
            for i in range(len(self.environments)):
                spkmap_shape = self.spkmap[i].shape[1:] if self.rois_first else self.spkmap[i].shape[:2]
                if not (self.occmap[i].shape == self.speedmap[i].shape == spkmap_shape):
                    raise ValueError("occmap, speedmap, and spkmap must have the same shape for each environment")
            roi_axis = 0 if self.rois_first else -1
            rois_per_env = [spkmap.shape[roi_axis] for spkmap in self.spkmap]
            if not all([rpe == rois_per_env[0] for rpe in rois_per_env]):
                raise ValueError("All environments must have the same number of ROIs")

    def __repr__(self) -> str:
        # Get number of positions
        if self.by_environment:
            num_positions = self.occmap[0].shape[-1]
        else:
            num_positions = self.occmap.shape[-1]
        # Get number of trials
        if self._averaged:
            num_trials = "averaged"
        else:
            if self.by_environment:
                num_trials = [occmap.shape[0] for occmap in self.occmap]
                num_trials = "{" + ", ".join([str(nt) for nt in num_trials]) + "}"
            else:
                num_trials = self.occmap.shape[0]
        # Get number of ROIs
        if self.by_environment:
            num_rois = self.spkmap[0].shape[0] if self.rois_first else self.spkmap[0].shape[1]
        else:
            num_rois = self.spkmap.shape[0] if self.rois_first else self.spkmap.shape[1]
        environments = f", environments={{{', '.join([str(env) for env in self.environments])}}}" if self.by_environment else ""
        return f"Maps(num_trials={num_trials}, num_positions={num_positions}, num_rois={num_rois}{environments}, rois_first={self.rois_first})"

    @classmethod
    def create_raw_maps(cls, occmap: np.ndarray, speedmap: np.ndarray, spkmap: np.ndarray, distcenters: np.ndarray = None) -> "Maps":
        """Create a Maps instance from raw (unprocessed) map data.

        Parameters
        ----------
        occmap : np.ndarray
            Occupancy map with shape (trials, positions).
        speedmap : np.ndarray
            Speed map with shape (trials, positions).
        spkmap : np.ndarray
            Spike map with shape (trials, positions, rois).
        distcenters : np.ndarray, optional
            Center positions of distance bins. Default is None.

        Returns
        -------
        Maps
            Maps instance with by_environment=False and rois_first=False.
        """
        return cls(occmap=occmap, speedmap=speedmap, spkmap=spkmap, distcenters=distcenters, by_environment=False, rois_first=False)

    @classmethod
    def create_processed_maps(cls, occmap: np.ndarray, speedmap: np.ndarray, spkmap: np.ndarray, distcenters: np.ndarray = None) -> "Maps":
        """Create a Maps instance from processed map data.

        Parameters
        ----------
        occmap : np.ndarray
            Occupancy map with shape (trials, positions).
        speedmap : np.ndarray
            Speed map with shape (trials, positions).
        spkmap : np.ndarray
            Spike map with shape (rois, trials, positions).
        distcenters : np.ndarray, optional
            Center positions of distance bins. Default is None.

        Returns
        -------
        Maps
            Maps instance with by_environment=False and rois_first=True.
        """
        return cls(occmap=occmap, speedmap=speedmap, spkmap=spkmap, distcenters=distcenters, by_environment=False, rois_first=True)

    @classmethod
    def create_environment_maps(
        cls,
        occmap: list[np.ndarray],
        speedmap: list[np.ndarray],
        spkmap: list[np.ndarray],
        environments: list[int],
        distcenters: np.ndarray = None,
    ) -> "Maps":
        """Create a Maps instance with maps separated by environment.

        Parameters
        ----------
        occmap : list of np.ndarray
            List of occupancy maps, one per environment. Each with shape (trials, positions).
        speedmap : list of np.ndarray
            List of speed maps, one per environment. Each with shape (trials, positions).
        spkmap : list of np.ndarray
            List of spike maps, one per environment. Each with shape (rois, trials, positions).
        environments : list of int
            List of environment numbers corresponding to each map in the lists.
        distcenters : np.ndarray, optional
            Center positions of distance bins. Default is None.

        Returns
        -------
        Maps
            Maps instance with by_environment=True and rois_first=True.
        """
        return cls(
            occmap=occmap,
            speedmap=speedmap,
            spkmap=spkmap,
            distcenters=distcenters,
            environments=environments,
            by_environment=True,
            rois_first=True,
        )

    @classmethod
    def map_types(cls) -> List[str]:
        """Get the list of map type names.

        Returns
        -------
        list of str
            List containing ["occmap", "speedmap", "spkmap"].
        """
        return ["occmap", "speedmap", "spkmap"]

    def __getitem__(self, key: str) -> np.ndarray | list[np.ndarray]:
        """Get a map by name using dictionary-like access.

        Parameters
        ----------
        key : str
            Name of the map to retrieve ("occmap", "speedmap", or "spkmap").

        Returns
        -------
        np.ndarray or list of np.ndarray
            The requested map array(s).
        """
        return getattr(self, key)

    def __setitem__(self, key: str, value: np.ndarray | list[np.ndarray]) -> None:
        """Set a map by name using dictionary-like access.

        Parameters
        ----------
        key : str
            Name of the map to set ("occmap", "speedmap", or "spkmap").
        value : np.ndarray or list of np.ndarray
            The map array(s) to assign.
        """
        setattr(self, key, value)

    def _get_position_axis(self, mapname: str) -> int:
        """Get the axis index for the position dimension.

        Parameters
        ----------
        mapname : str
            Name of the map ("occmap", "speedmap", or "spkmap").

        Returns
        -------
        int
            Axis index for the position dimension. Typically -1 (last axis),
            except for spkmap when rois_first=False, where it's -2.

        Notes
        -----
        The only time the position axis isn't the last one is for spkmap when
        rois_first is False, where the shape is (trials, positions, rois).
        """
        average_offset = -1 if self._averaged else 0
        if mapname == "spkmap" and not self.rois_first:
            return -2 + average_offset
        else:
            return -1

    def filter_positions(self, idx_positions: np.ndarray) -> None:
        """Filter maps to keep only specified position bins.

        Parameters
        ----------
        idx_positions : np.ndarray
            Indices of position bins to keep. Must be a 1D array of integers.

        Notes
        -----
        This method modifies the maps in-place, keeping only the position bins
        specified by idx_positions. Also updates distcenters if present.
        """
        if self.distcenters is not None:
            self.distcenters = self.distcenters[idx_positions]
        for mapname in self.map_types():
            axis = self._get_position_axis(mapname)
            if self.by_environment:
                self[mapname] = [np.take(x, idx_positions, axis=axis) for x in self[mapname]]
            else:
                self[mapname] = np.take(self[mapname], idx_positions, axis=axis)

    def filter_rois(self, idx_rois: np.ndarray) -> None:
        """Filter spike maps to keep only specified ROIs.

        Parameters
        ----------
        idx_rois : np.ndarray
            Indices of ROIs to keep. Must be a 1D array of integers.

        Notes
        -----
        This method modifies the spkmap in-place, keeping only the ROIs
        specified by idx_rois. Only affects spkmap; occmap and speedmap
        are unchanged.
        """
        axis = 0 if self.rois_first else -1
        if self.by_environment:
            self.spkmap = [np.take(x, idx_rois, axis=axis) for x in self.spkmap]
        else:
            self.spkmap = np.take(self.spkmap, idx_rois, axis=axis)

    def filter_environments(self, environments: list[int]) -> None:
        """Filter maps to keep only specified environments.

        Parameters
        ----------
        environments : list of int
            List of environment numbers to keep.

        Raises
        ------
        ValueError
            If by_environment is False, since environments cannot be filtered
            when maps are not separated by environment.

        Notes
        -----
        This method modifies the maps in-place, keeping only the environments
        specified. Only works when by_environment=True.
        """
        if self.by_environment:
            idx_to_requested_env = [i for i, env in enumerate(self.environments) if env in environments]
            self.occmap = [self.occmap[i] for i in idx_to_requested_env]
            self.speedmap = [self.speedmap[i] for i in idx_to_requested_env]
            self.spkmap = [self.spkmap[i] for i in idx_to_requested_env]
            self.environments = [self.environments[i] for i in idx_to_requested_env]
        else:
            raise ValueError("Cannot filter environments when maps aren't separated by environment!")

    def pop_nan_positions(self) -> None:
        """Remove position bins that contain NaN values in any map.

        Notes
        -----
        This method identifies position bins that have NaN values in any of the
        maps (occmap, speedmap, or spkmap) and removes them from all maps.
        Useful for cleaning data before analysis.
        """
        if self.by_environment:
            idx_valid_positions = np.where(~np.any(np.stack([np.any(np.isnan(occmap), axis=0) for occmap in self.occmap], axis=0), axis=0))[0]
        else:
            idx_valid_positions = np.where(~np.any(np.isnan(self.occmap), axis=0))[0]
        self.filter_positions(idx_valid_positions)

    def smooth_maps(self, positions: np.ndarray, kernel_width: float) -> None:
        """Smooth the maps using a Gaussian kernel.

        Parameters
        ----------
        positions : np.ndarray
            Position values corresponding to the position bins. Used to compute
            the Gaussian kernel.
        kernel_width : float
            Width of the Gaussian smoothing kernel in spatial units.

        Notes
        -----
        This method applies Gaussian smoothing to all maps (occmap, speedmap, spkmap).
        NaN values are temporarily replaced with 0 during smoothing, then restored
        afterward. The smoothing is applied along the position dimension.
        """
        kernel = get_gauss_kernel(positions, kernel_width)

        # Replace nans with 0s
        if self.by_environment:
            idxnan = [np.isnan(occmap) for occmap in self.occmap]
        else:
            idxnan = np.isnan(self.occmap)

        if self.rois_first:
            # Move the rois axis to the last axis
            if self.by_environment:
                self.spkmap = [np.moveaxis(map, 0, -1) for map in self.spkmap]
            else:
                self.spkmap = np.moveaxis(self.spkmap, 0, -1)

        for mapname in self.map_types():
            if self.by_environment:
                for ienv, inanenv in enumerate(idxnan):
                    self[mapname][ienv][inanenv] = 0
            else:
                self[mapname][idxnan] = 0

        for mapname in self.map_types():
            # Since we moved ROIs to the last axis position will be axis=1 for all map types
            if self.by_environment:
                self[mapname] = [convolve_toeplitz(map, kernel, axis=1) for map in self[mapname]]
            else:
                self[mapname] = convolve_toeplitz(self[mapname], kernel, axis=1)

        # Put nans back in place
        for mapname in self.map_types():
            if self.by_environment:
                for ienv, inanenv in enumerate(idxnan):
                    self[mapname][ienv][inanenv] = np.nan
            else:
                self[mapname][idxnan] = np.nan

        # Move the rois axis back to the first axis
        if self.rois_first:
            if self.by_environment:
                self.spkmap = [np.moveaxis(map, -1, 0) for map in self.spkmap]
            else:
                self.spkmap = np.moveaxis(self.spkmap, -1, 0)

    def average_trials(self, keepdims: bool = False) -> None:
        """Average the trials within each environment.

        Parameters
        ----------
        keepdims : bool, optional
            Whether to keep the trial dimension with size 1 after averaging.
            Default is False.

        Notes
        -----
        This method computes the mean across trials for each map. After averaging,
        the _averaged flag is set to True to prevent redundant averaging.
        The trial dimension is removed unless keepdims=True.
        """
        if self._averaged:
            return
        for mapname in self.map_types():
            axis = 1 if mapname == "spkmap" and self.rois_first else 0
            if self.by_environment:
                self[mapname] = [ss.mean(map, axis=axis, keepdims=keepdims) for map in self[mapname]]
            else:
                self[mapname] = ss.mean(self[mapname], axis=axis, keepdims=keepdims)
        self._averaged = True

    def nbytes(self) -> int:
        """Calculate the total memory size of all maps in bytes.

        Returns
        -------
        int
            Total number of bytes used by all map arrays.
        """
        num_bytes = 0
        for name in self.map_types():
            if self.by_environment:
                num_bytes += sum(x.nbytes for x in getattr(self, name))
            else:
                num_bytes += getattr(self, name).nbytes
        return num_bytes

    def raw_to_processed(self, positions: np.ndarray, smooth_width: float | None = None) -> "Maps":
        """Convert raw maps to processed maps.

        Processing steps:
        1. Optionally smooth maps with a Gaussian kernel
        2. Divide speedmap and spkmap by occmap (correct_map)
        3. Reorganize spkmap to have ROIs as the first dimension

        Parameters
        ----------
        positions : np.ndarray
            Position values corresponding to the position bins.
        smooth_width : float, optional
            Width of the Gaussian smoothing kernel. If None, no smoothing is applied.
            Default is None.

        Returns
        -------
        Maps
            Self, with maps now in processed format (rois_first=True).

        Notes
        -----
        This method modifies the maps in-place. After processing, spkmap will
        have shape (rois, trials, positions) instead of (trials, positions, rois).
        """
        if smooth_width is not None:
            self.smooth_maps(positions, smooth_width)

        self.speedmap = correct_map(self.occmap, self.speedmap)
        self.spkmap = correct_map(self.occmap, self.spkmap)

        # Change spkmap to be ROIs first
        self.spkmap = np.moveaxis(self.spkmap, -1, 0)
        self.rois_first = True

        return self

__getitem__(key)

Get a map by name using dictionary-like access.

Parameters:

Name Type Description Default
key str

Name of the map to retrieve ("occmap", "speedmap", or "spkmap").

required

Returns:

Type Description
np.ndarray or list of np.ndarray

The requested map array(s).

Source code in vrAnalysis/processors/spkmaps.py
def __getitem__(self, key: str) -> np.ndarray | list[np.ndarray]:
    """Get a map by name using dictionary-like access.

    Parameters
    ----------
    key : str
        Name of the map to retrieve ("occmap", "speedmap", or "spkmap").

    Returns
    -------
    np.ndarray or list of np.ndarray
        The requested map array(s).
    """
    return getattr(self, key)

__setitem__(key, value)

Set a map by name using dictionary-like access.

Parameters:

Name Type Description Default
key str

Name of the map to set ("occmap", "speedmap", or "spkmap").

required
value np.ndarray or list of np.ndarray

The map array(s) to assign.

required
Source code in vrAnalysis/processors/spkmaps.py
def __setitem__(self, key: str, value: np.ndarray | list[np.ndarray]) -> None:
    """Set a map by name using dictionary-like access.

    Parameters
    ----------
    key : str
        Name of the map to set ("occmap", "speedmap", or "spkmap").
    value : np.ndarray or list of np.ndarray
        The map array(s) to assign.
    """
    setattr(self, key, value)

average_trials(keepdims=False)

Average the trials within each environment.

Parameters:

Name Type Description Default
keepdims bool

Whether to keep the trial dimension with size 1 after averaging. Default is False.

False
Notes

This method computes the mean across trials for each map. After averaging, the _averaged flag is set to True to prevent redundant averaging. The trial dimension is removed unless keepdims=True.

Source code in vrAnalysis/processors/spkmaps.py
def average_trials(self, keepdims: bool = False) -> None:
    """Average the trials within each environment.

    Parameters
    ----------
    keepdims : bool, optional
        Whether to keep the trial dimension with size 1 after averaging.
        Default is False.

    Notes
    -----
    This method computes the mean across trials for each map. After averaging,
    the _averaged flag is set to True to prevent redundant averaging.
    The trial dimension is removed unless keepdims=True.
    """
    if self._averaged:
        return
    for mapname in self.map_types():
        axis = 1 if mapname == "spkmap" and self.rois_first else 0
        if self.by_environment:
            self[mapname] = [ss.mean(map, axis=axis, keepdims=keepdims) for map in self[mapname]]
        else:
            self[mapname] = ss.mean(self[mapname], axis=axis, keepdims=keepdims)
    self._averaged = True

create_environment_maps(occmap, speedmap, spkmap, environments, distcenters=None) classmethod

Create a Maps instance with maps separated by environment.

Parameters:

Name Type Description Default
occmap list of np.ndarray

List of occupancy maps, one per environment. Each with shape (trials, positions).

required
speedmap list of np.ndarray

List of speed maps, one per environment. Each with shape (trials, positions).

required
spkmap list of np.ndarray

List of spike maps, one per environment. Each with shape (rois, trials, positions).

required
environments list of int

List of environment numbers corresponding to each map in the lists.

required
distcenters ndarray

Center positions of distance bins. Default is None.

None

Returns:

Type Description
Maps

Maps instance with by_environment=True and rois_first=True.

Source code in vrAnalysis/processors/spkmaps.py
@classmethod
def create_environment_maps(
    cls,
    occmap: list[np.ndarray],
    speedmap: list[np.ndarray],
    spkmap: list[np.ndarray],
    environments: list[int],
    distcenters: np.ndarray = None,
) -> "Maps":
    """Create a Maps instance with maps separated by environment.

    Parameters
    ----------
    occmap : list of np.ndarray
        List of occupancy maps, one per environment. Each with shape (trials, positions).
    speedmap : list of np.ndarray
        List of speed maps, one per environment. Each with shape (trials, positions).
    spkmap : list of np.ndarray
        List of spike maps, one per environment. Each with shape (rois, trials, positions).
    environments : list of int
        List of environment numbers corresponding to each map in the lists.
    distcenters : np.ndarray, optional
        Center positions of distance bins. Default is None.

    Returns
    -------
    Maps
        Maps instance with by_environment=True and rois_first=True.
    """
    return cls(
        occmap=occmap,
        speedmap=speedmap,
        spkmap=spkmap,
        distcenters=distcenters,
        environments=environments,
        by_environment=True,
        rois_first=True,
    )

create_processed_maps(occmap, speedmap, spkmap, distcenters=None) classmethod

Create a Maps instance from processed map data.

Parameters:

Name Type Description Default
occmap ndarray

Occupancy map with shape (trials, positions).

required
speedmap ndarray

Speed map with shape (trials, positions).

required
spkmap ndarray

Spike map with shape (rois, trials, positions).

required
distcenters ndarray

Center positions of distance bins. Default is None.

None

Returns:

Type Description
Maps

Maps instance with by_environment=False and rois_first=True.

Source code in vrAnalysis/processors/spkmaps.py
@classmethod
def create_processed_maps(cls, occmap: np.ndarray, speedmap: np.ndarray, spkmap: np.ndarray, distcenters: np.ndarray = None) -> "Maps":
    """Create a Maps instance from processed map data.

    Parameters
    ----------
    occmap : np.ndarray
        Occupancy map with shape (trials, positions).
    speedmap : np.ndarray
        Speed map with shape (trials, positions).
    spkmap : np.ndarray
        Spike map with shape (rois, trials, positions).
    distcenters : np.ndarray, optional
        Center positions of distance bins. Default is None.

    Returns
    -------
    Maps
        Maps instance with by_environment=False and rois_first=True.
    """
    return cls(occmap=occmap, speedmap=speedmap, spkmap=spkmap, distcenters=distcenters, by_environment=False, rois_first=True)

create_raw_maps(occmap, speedmap, spkmap, distcenters=None) classmethod

Create a Maps instance from raw (unprocessed) map data.

Parameters:

Name Type Description Default
occmap ndarray

Occupancy map with shape (trials, positions).

required
speedmap ndarray

Speed map with shape (trials, positions).

required
spkmap ndarray

Spike map with shape (trials, positions, rois).

required
distcenters ndarray

Center positions of distance bins. Default is None.

None

Returns:

Type Description
Maps

Maps instance with by_environment=False and rois_first=False.

Source code in vrAnalysis/processors/spkmaps.py
@classmethod
def create_raw_maps(cls, occmap: np.ndarray, speedmap: np.ndarray, spkmap: np.ndarray, distcenters: np.ndarray = None) -> "Maps":
    """Create a Maps instance from raw (unprocessed) map data.

    Parameters
    ----------
    occmap : np.ndarray
        Occupancy map with shape (trials, positions).
    speedmap : np.ndarray
        Speed map with shape (trials, positions).
    spkmap : np.ndarray
        Spike map with shape (trials, positions, rois).
    distcenters : np.ndarray, optional
        Center positions of distance bins. Default is None.

    Returns
    -------
    Maps
        Maps instance with by_environment=False and rois_first=False.
    """
    return cls(occmap=occmap, speedmap=speedmap, spkmap=spkmap, distcenters=distcenters, by_environment=False, rois_first=False)

filter_environments(environments)

Filter maps to keep only specified environments.

Parameters:

Name Type Description Default
environments list of int

List of environment numbers to keep.

required

Raises:

Type Description
ValueError

If by_environment is False, since environments cannot be filtered when maps are not separated by environment.

Notes

This method modifies the maps in-place, keeping only the environments specified. Only works when by_environment=True.

Source code in vrAnalysis/processors/spkmaps.py
def filter_environments(self, environments: list[int]) -> None:
    """Filter maps to keep only specified environments.

    Parameters
    ----------
    environments : list of int
        List of environment numbers to keep.

    Raises
    ------
    ValueError
        If by_environment is False, since environments cannot be filtered
        when maps are not separated by environment.

    Notes
    -----
    This method modifies the maps in-place, keeping only the environments
    specified. Only works when by_environment=True.
    """
    if self.by_environment:
        idx_to_requested_env = [i for i, env in enumerate(self.environments) if env in environments]
        self.occmap = [self.occmap[i] for i in idx_to_requested_env]
        self.speedmap = [self.speedmap[i] for i in idx_to_requested_env]
        self.spkmap = [self.spkmap[i] for i in idx_to_requested_env]
        self.environments = [self.environments[i] for i in idx_to_requested_env]
    else:
        raise ValueError("Cannot filter environments when maps aren't separated by environment!")

filter_positions(idx_positions)

Filter maps to keep only specified position bins.

Parameters:

Name Type Description Default
idx_positions ndarray

Indices of position bins to keep. Must be a 1D array of integers.

required
Notes

This method modifies the maps in-place, keeping only the position bins specified by idx_positions. Also updates distcenters if present.

Source code in vrAnalysis/processors/spkmaps.py
def filter_positions(self, idx_positions: np.ndarray) -> None:
    """Filter maps to keep only specified position bins.

    Parameters
    ----------
    idx_positions : np.ndarray
        Indices of position bins to keep. Must be a 1D array of integers.

    Notes
    -----
    This method modifies the maps in-place, keeping only the position bins
    specified by idx_positions. Also updates distcenters if present.
    """
    if self.distcenters is not None:
        self.distcenters = self.distcenters[idx_positions]
    for mapname in self.map_types():
        axis = self._get_position_axis(mapname)
        if self.by_environment:
            self[mapname] = [np.take(x, idx_positions, axis=axis) for x in self[mapname]]
        else:
            self[mapname] = np.take(self[mapname], idx_positions, axis=axis)

filter_rois(idx_rois)

Filter spike maps to keep only specified ROIs.

Parameters:

Name Type Description Default
idx_rois ndarray

Indices of ROIs to keep. Must be a 1D array of integers.

required
Notes

This method modifies the spkmap in-place, keeping only the ROIs specified by idx_rois. Only affects spkmap; occmap and speedmap are unchanged.

Source code in vrAnalysis/processors/spkmaps.py
def filter_rois(self, idx_rois: np.ndarray) -> None:
    """Filter spike maps to keep only specified ROIs.

    Parameters
    ----------
    idx_rois : np.ndarray
        Indices of ROIs to keep. Must be a 1D array of integers.

    Notes
    -----
    This method modifies the spkmap in-place, keeping only the ROIs
    specified by idx_rois. Only affects spkmap; occmap and speedmap
    are unchanged.
    """
    axis = 0 if self.rois_first else -1
    if self.by_environment:
        self.spkmap = [np.take(x, idx_rois, axis=axis) for x in self.spkmap]
    else:
        self.spkmap = np.take(self.spkmap, idx_rois, axis=axis)

map_types() classmethod

Get the list of map type names.

Returns:

Type Description
list of str

List containing ["occmap", "speedmap", "spkmap"].

Source code in vrAnalysis/processors/spkmaps.py
@classmethod
def map_types(cls) -> List[str]:
    """Get the list of map type names.

    Returns
    -------
    list of str
        List containing ["occmap", "speedmap", "spkmap"].
    """
    return ["occmap", "speedmap", "spkmap"]

nbytes()

Calculate the total memory size of all maps in bytes.

Returns:

Type Description
int

Total number of bytes used by all map arrays.

Source code in vrAnalysis/processors/spkmaps.py
def nbytes(self) -> int:
    """Calculate the total memory size of all maps in bytes.

    Returns
    -------
    int
        Total number of bytes used by all map arrays.
    """
    num_bytes = 0
    for name in self.map_types():
        if self.by_environment:
            num_bytes += sum(x.nbytes for x in getattr(self, name))
        else:
            num_bytes += getattr(self, name).nbytes
    return num_bytes

pop_nan_positions()

Remove position bins that contain NaN values in any map.

Notes

This method identifies position bins that have NaN values in any of the maps (occmap, speedmap, or spkmap) and removes them from all maps. Useful for cleaning data before analysis.

Source code in vrAnalysis/processors/spkmaps.py
def pop_nan_positions(self) -> None:
    """Remove position bins that contain NaN values in any map.

    Notes
    -----
    This method identifies position bins that have NaN values in any of the
    maps (occmap, speedmap, or spkmap) and removes them from all maps.
    Useful for cleaning data before analysis.
    """
    if self.by_environment:
        idx_valid_positions = np.where(~np.any(np.stack([np.any(np.isnan(occmap), axis=0) for occmap in self.occmap], axis=0), axis=0))[0]
    else:
        idx_valid_positions = np.where(~np.any(np.isnan(self.occmap), axis=0))[0]
    self.filter_positions(idx_valid_positions)

raw_to_processed(positions, smooth_width=None)

Convert raw maps to processed maps.

Processing steps: 1. Optionally smooth maps with a Gaussian kernel 2. Divide speedmap and spkmap by occmap (correct_map) 3. Reorganize spkmap to have ROIs as the first dimension

Parameters:

Name Type Description Default
positions ndarray

Position values corresponding to the position bins.

required
smooth_width float

Width of the Gaussian smoothing kernel. If None, no smoothing is applied. Default is None.

None

Returns:

Type Description
Maps

Self, with maps now in processed format (rois_first=True).

Notes

This method modifies the maps in-place. After processing, spkmap will have shape (rois, trials, positions) instead of (trials, positions, rois).

Source code in vrAnalysis/processors/spkmaps.py
def raw_to_processed(self, positions: np.ndarray, smooth_width: float | None = None) -> "Maps":
    """Convert raw maps to processed maps.

    Processing steps:
    1. Optionally smooth maps with a Gaussian kernel
    2. Divide speedmap and spkmap by occmap (correct_map)
    3. Reorganize spkmap to have ROIs as the first dimension

    Parameters
    ----------
    positions : np.ndarray
        Position values corresponding to the position bins.
    smooth_width : float, optional
        Width of the Gaussian smoothing kernel. If None, no smoothing is applied.
        Default is None.

    Returns
    -------
    Maps
        Self, with maps now in processed format (rois_first=True).

    Notes
    -----
    This method modifies the maps in-place. After processing, spkmap will
    have shape (rois, trials, positions) instead of (trials, positions, rois).
    """
    if smooth_width is not None:
        self.smooth_maps(positions, smooth_width)

    self.speedmap = correct_map(self.occmap, self.speedmap)
    self.spkmap = correct_map(self.occmap, self.spkmap)

    # Change spkmap to be ROIs first
    self.spkmap = np.moveaxis(self.spkmap, -1, 0)
    self.rois_first = True

    return self

smooth_maps(positions, kernel_width)

Smooth the maps using a Gaussian kernel.

Parameters:

Name Type Description Default
positions ndarray

Position values corresponding to the position bins. Used to compute the Gaussian kernel.

required
kernel_width float

Width of the Gaussian smoothing kernel in spatial units.

required
Notes

This method applies Gaussian smoothing to all maps (occmap, speedmap, spkmap). NaN values are temporarily replaced with 0 during smoothing, then restored afterward. The smoothing is applied along the position dimension.

Source code in vrAnalysis/processors/spkmaps.py
def smooth_maps(self, positions: np.ndarray, kernel_width: float) -> None:
    """Smooth the maps using a Gaussian kernel.

    Parameters
    ----------
    positions : np.ndarray
        Position values corresponding to the position bins. Used to compute
        the Gaussian kernel.
    kernel_width : float
        Width of the Gaussian smoothing kernel in spatial units.

    Notes
    -----
    This method applies Gaussian smoothing to all maps (occmap, speedmap, spkmap).
    NaN values are temporarily replaced with 0 during smoothing, then restored
    afterward. The smoothing is applied along the position dimension.
    """
    kernel = get_gauss_kernel(positions, kernel_width)

    # Replace nans with 0s
    if self.by_environment:
        idxnan = [np.isnan(occmap) for occmap in self.occmap]
    else:
        idxnan = np.isnan(self.occmap)

    if self.rois_first:
        # Move the rois axis to the last axis
        if self.by_environment:
            self.spkmap = [np.moveaxis(map, 0, -1) for map in self.spkmap]
        else:
            self.spkmap = np.moveaxis(self.spkmap, 0, -1)

    for mapname in self.map_types():
        if self.by_environment:
            for ienv, inanenv in enumerate(idxnan):
                self[mapname][ienv][inanenv] = 0
        else:
            self[mapname][idxnan] = 0

    for mapname in self.map_types():
        # Since we moved ROIs to the last axis position will be axis=1 for all map types
        if self.by_environment:
            self[mapname] = [convolve_toeplitz(map, kernel, axis=1) for map in self[mapname]]
        else:
            self[mapname] = convolve_toeplitz(self[mapname], kernel, axis=1)

    # Put nans back in place
    for mapname in self.map_types():
        if self.by_environment:
            for ienv, inanenv in enumerate(idxnan):
                self[mapname][ienv][inanenv] = np.nan
        else:
            self[mapname][idxnan] = np.nan

    # Move the rois axis back to the first axis
    if self.rois_first:
        if self.by_environment:
            self.spkmap = [np.moveaxis(map, -1, 0) for map in self.spkmap]
        else:
            self.spkmap = np.moveaxis(self.spkmap, -1, 0)

options: show_root_heading: true show_root_toc_entry: true heading_level: 3

Reliability dataclass

Container for reliability values.

Attributes:

Name Type Description
values ndarray

Reliability values for each neuron

environments ndarray

Environments for which the reliability was computed

method str

Method used to compute the reliability

Source code in vrAnalysis/processors/spkmaps.py
@dataclass
class Reliability:
    """Container for reliability values.

    Attributes
    ----------
    values : np.ndarray
        Reliability values for each neuron
    environments : np.ndarray
        Environments for which the reliability was computed
    method : str
        Method used to compute the reliability
    """

    values: np.ndarray
    environments: np.ndarray
    method: str

    def __post_init__(self):
        if self.values.shape[0] != len(self.environments):
            raise ValueError("values and environments must have the same number of environments")

    def __repr__(self) -> str:
        return f"Reliability(num_rois={self.values.shape[1]}, environments={self.environments}, method={self.method})"

    def filter_rois(self, idx_rois: np.ndarray) -> "Reliability":
        """Filter reliability values to keep only specified ROIs.

        Parameters
        ----------
        idx_rois : np.ndarray
            Indices of ROIs to keep. Must be a 1D array of integers.

        Returns
        -------
        Reliability
            New Reliability instance with filtered ROI values.
        """
        return Reliability(self.values[:, idx_rois], self.environments, self.method)

    def filter_environments(self, idx_environments: np.ndarray) -> "Reliability":
        """Filter reliability values to keep only specified environments by index.

        Parameters
        ----------
        idx_environments : np.ndarray
            Indices of environments to keep. Must be a 1D array of integers.

        Returns
        -------
        Reliability
            New Reliability instance with filtered environment values.
        """
        return Reliability(self.values[idx_environments], self.environments[idx_environments], self.method)

    def filter_by_environment(self, environments: list[int]) -> "Reliability":
        """Filter reliability values to keep only specified environments by environment number.

        Parameters
        ----------
        environments : list of int
            List of environment numbers to keep.

        Returns
        -------
        Reliability
            New Reliability instance with filtered environment values.
        """
        idx_to_requested_env = [i for i, env in enumerate(self.environments) if env in environments]
        return Reliability(self.values[idx_to_requested_env], self.environments[idx_to_requested_env], self.method)

filter_by_environment(environments)

Filter reliability values to keep only specified environments by environment number.

Parameters:

Name Type Description Default
environments list of int

List of environment numbers to keep.

required

Returns:

Type Description
Reliability

New Reliability instance with filtered environment values.

Source code in vrAnalysis/processors/spkmaps.py
def filter_by_environment(self, environments: list[int]) -> "Reliability":
    """Filter reliability values to keep only specified environments by environment number.

    Parameters
    ----------
    environments : list of int
        List of environment numbers to keep.

    Returns
    -------
    Reliability
        New Reliability instance with filtered environment values.
    """
    idx_to_requested_env = [i for i, env in enumerate(self.environments) if env in environments]
    return Reliability(self.values[idx_to_requested_env], self.environments[idx_to_requested_env], self.method)

filter_environments(idx_environments)

Filter reliability values to keep only specified environments by index.

Parameters:

Name Type Description Default
idx_environments ndarray

Indices of environments to keep. Must be a 1D array of integers.

required

Returns:

Type Description
Reliability

New Reliability instance with filtered environment values.

Source code in vrAnalysis/processors/spkmaps.py
def filter_environments(self, idx_environments: np.ndarray) -> "Reliability":
    """Filter reliability values to keep only specified environments by index.

    Parameters
    ----------
    idx_environments : np.ndarray
        Indices of environments to keep. Must be a 1D array of integers.

    Returns
    -------
    Reliability
        New Reliability instance with filtered environment values.
    """
    return Reliability(self.values[idx_environments], self.environments[idx_environments], self.method)

filter_rois(idx_rois)

Filter reliability values to keep only specified ROIs.

Parameters:

Name Type Description Default
idx_rois ndarray

Indices of ROIs to keep. Must be a 1D array of integers.

required

Returns:

Type Description
Reliability

New Reliability instance with filtered ROI values.

Source code in vrAnalysis/processors/spkmaps.py
def filter_rois(self, idx_rois: np.ndarray) -> "Reliability":
    """Filter reliability values to keep only specified ROIs.

    Parameters
    ----------
    idx_rois : np.ndarray
        Indices of ROIs to keep. Must be a 1D array of integers.

    Returns
    -------
    Reliability
        New Reliability instance with filtered ROI values.
    """
    return Reliability(self.values[:, idx_rois], self.environments, self.method)

options: show_root_heading: true show_root_toc_entry: true heading_level: 3

SpkmapParams dataclass

Parameters for spike map processing.

Contains configuration settings that control how spike maps are processed, including distance steps, speed thresholds, and standardization options.

Parameters:

Name Type Description Default
dist_step float

Step size for distance calculations in spatial units

1
speed_threshold float

Minimum speed threshold for valid movement periods

1.0
speed_max_allowed float

Maximum speed allowed for valid movement periods (default is no maximum, can be useful when behavioral computer allows jumps in position which are usually due to hardware issues

np.inf
full_trial_flexibility float | None

Flexibility parameter for trial alignment. If None, no flexibility

None
standardize_spks bool

Whether to standardize spike counts by dividing by the standard deviation

True
smooth_width float | None

Width of the Gaussian smoothing kernel to apply to the maps (width in spatial units)

1
reliability_method str

Method to use for calculating reliability

"leave_one_out"
autosave bool

Whether to save the cache automatically

True
Source code in vrAnalysis/processors/spkmaps.py
@dataclass
class SpkmapParams:
    """Parameters for spike map processing.

    Contains configuration settings that control how spike maps are processed,
    including distance steps, speed thresholds, and standardization options.

    Parameters
    ----------
    dist_step : float, default=1
        Step size for distance calculations in spatial units
    speed_threshold : float, default=1.0
        Minimum speed threshold for valid movement periods
    speed_max_allowed : float, default=np.inf
        Maximum speed allowed for valid movement periods (default is no maximum,
        can be useful when behavioral computer allows jumps in position which
        are usually due to hardware issues
    full_trial_flexibility : float | None, default=None
        Flexibility parameter for trial alignment. If None, no flexibility
    standardize_spks : bool, default=True
        Whether to standardize spike counts by dividing by the standard deviation
    smooth_width : float | None, default=1
        Width of the Gaussian smoothing kernel to apply to the maps (width in spatial units)
    reliability_method : str, default="leave_one_out"
        Method to use for calculating reliability
    autosave : bool, default=True
        Whether to save the cache automatically
    """

    dist_step: float = 1.0
    speed_threshold: float = 1.0
    speed_max_allowed: float = np.inf
    full_trial_flexibility: Union[float, None] = 3.0
    standardize_spks: bool = True
    smooth_width: Union[float, None] = 1.0
    reliability_method: str = "leave_one_out"
    autosave: bool = False

    def __repr__(self) -> str:
        class_fields = fields(self)
        lines = []
        for field in class_fields:
            field_name = field.name
            field_value = getattr(self, field_name)
            lines.append(f"{field_name}={repr(field_value)}")

        class_name = self.__class__.__name__
        joined_lines = ",\n    ".join(lines)
        return f"{class_name}(\n    {joined_lines}\n)"

    @classmethod
    def from_dict(cls, params_dict: dict) -> "SpkmapParams":
        """Create a SpkmapParams instance from a dictionary.

        Parameters
        ----------
        params_dict : dict
            Dictionary of parameter names and values. Missing parameters will
            use default values from SpkmapParams.

        Returns
        -------
        SpkmapParams
            New SpkmapParams instance with values from the dictionary.
        """
        return cls(**{k: params_dict[k] for k in params_dict})

    @classmethod
    def from_path(cls, path: Path) -> "SpkmapParams":
        """Create a SpkmapParams instance from a JSON file.

        Parameters
        ----------
        path : Path
            Path to the JSON file containing parameter values.

        Returns
        -------
        SpkmapParams
            New SpkmapParams instance loaded from the JSON file.
        """
        with open(path, "r") as f:
            return cls.from_dict(json.load(f))

    def compare(self, other: "SpkmapParams", filter_keys: Optional[List[str]] = None) -> bool:
        """Compare two SpkmapParams instances.

        Parameters
        ----------
        other : SpkmapParams
            Another SpkmapParams instance to compare against.
        filter_keys : list of str, optional
            If provided, only compare the specified parameter keys.
            If None, compare all parameters. Default is None.

        Returns
        -------
        bool
            True if the parameters match (or specified keys match), False otherwise.
        """
        if filter_keys is None:
            return self == other
        else:
            return all(getattr(self, key) == getattr(other, key) for key in filter_keys)

    def save(self, path: Path) -> None:
        """Save the parameters to a JSON file.

        Parameters
        ----------
        path : Path
            Path where the JSON file will be saved.
        """
        with open(path, "w") as f:
            json.dump(asdict(self), f, sort_keys=True)

    def __post_init__(self):
        if self.dist_step <= 0:
            raise ValueError("dist_step must be positive")
        if self.speed_threshold <= 0:
            raise ValueError("speed_threshold must be positive")
        if self.full_trial_flexibility is not None and self.full_trial_flexibility < 0:
            raise ValueError("If used, full_trial_flexibility must be nonnegative (can also be None)")
        if self.smooth_width is not None and self.smooth_width <= 0:
            raise ValueError("smooth_width must be positive (can also be None)")
        # Convert floats to floats when not None
        self.dist_step = float(self.dist_step)
        self.speed_threshold = float(self.speed_threshold)
        self.speed_max_allowed = float(self.speed_max_allowed)
        self.full_trial_flexibility = float(self.full_trial_flexibility) if self.full_trial_flexibility is not None else None
        self.smooth_width = float(self.smooth_width) if self.smooth_width is not None else None

compare(other, filter_keys=None)

Compare two SpkmapParams instances.

Parameters:

Name Type Description Default
other SpkmapParams

Another SpkmapParams instance to compare against.

required
filter_keys list of str

If provided, only compare the specified parameter keys. If None, compare all parameters. Default is None.

None

Returns:

Type Description
bool

True if the parameters match (or specified keys match), False otherwise.

Source code in vrAnalysis/processors/spkmaps.py
def compare(self, other: "SpkmapParams", filter_keys: Optional[List[str]] = None) -> bool:
    """Compare two SpkmapParams instances.

    Parameters
    ----------
    other : SpkmapParams
        Another SpkmapParams instance to compare against.
    filter_keys : list of str, optional
        If provided, only compare the specified parameter keys.
        If None, compare all parameters. Default is None.

    Returns
    -------
    bool
        True if the parameters match (or specified keys match), False otherwise.
    """
    if filter_keys is None:
        return self == other
    else:
        return all(getattr(self, key) == getattr(other, key) for key in filter_keys)

from_dict(params_dict) classmethod

Create a SpkmapParams instance from a dictionary.

Parameters:

Name Type Description Default
params_dict dict

Dictionary of parameter names and values. Missing parameters will use default values from SpkmapParams.

required

Returns:

Type Description
SpkmapParams

New SpkmapParams instance with values from the dictionary.

Source code in vrAnalysis/processors/spkmaps.py
@classmethod
def from_dict(cls, params_dict: dict) -> "SpkmapParams":
    """Create a SpkmapParams instance from a dictionary.

    Parameters
    ----------
    params_dict : dict
        Dictionary of parameter names and values. Missing parameters will
        use default values from SpkmapParams.

    Returns
    -------
    SpkmapParams
        New SpkmapParams instance with values from the dictionary.
    """
    return cls(**{k: params_dict[k] for k in params_dict})

from_path(path) classmethod

Create a SpkmapParams instance from a JSON file.

Parameters:

Name Type Description Default
path Path

Path to the JSON file containing parameter values.

required

Returns:

Type Description
SpkmapParams

New SpkmapParams instance loaded from the JSON file.

Source code in vrAnalysis/processors/spkmaps.py
@classmethod
def from_path(cls, path: Path) -> "SpkmapParams":
    """Create a SpkmapParams instance from a JSON file.

    Parameters
    ----------
    path : Path
        Path to the JSON file containing parameter values.

    Returns
    -------
    SpkmapParams
        New SpkmapParams instance loaded from the JSON file.
    """
    with open(path, "r") as f:
        return cls.from_dict(json.load(f))

save(path)

Save the parameters to a JSON file.

Parameters:

Name Type Description Default
path Path

Path where the JSON file will be saved.

required
Source code in vrAnalysis/processors/spkmaps.py
def save(self, path: Path) -> None:
    """Save the parameters to a JSON file.

    Parameters
    ----------
    path : Path
        Path where the JSON file will be saved.
    """
    with open(path, "w") as f:
        json.dump(asdict(self), f, sort_keys=True)

options: show_root_heading: true show_root_toc_entry: true heading_level: 3

Usage Examples

Basic Usage

from vrAnalysis.processors import SpkmapProcessor, SpkmapParams
from vrAnalysis.sessions import B2Session

# Create a session
session = B2Session("path/to/session")

# Create a processor with default parameters
processor = SpkmapProcessor(session)

# Get raw maps (unsmoothed, not normalized)
raw_maps = processor.get_raw_maps()

# Get processed maps (smoothed and normalized by occupancy)
processed_maps = processor.get_processed_maps()

# Get environment-separated maps
env_maps = processor.get_env_maps()

# Calculate reliability
reliability = processor.get_reliability()

Custom Parameters

# Create custom parameters
params = SpkmapParams(
    dist_step=2.0,  # 2 cm bins
    speed_threshold=2.0,  # 2 cm/s minimum speed
    smooth_width=5.0,  # 5 cm smoothing
    standardize_spks=True,
)

# Use custom parameters
processor = SpkmapProcessor(session, params=params)

# Or update parameters temporarily
maps = processor.get_processed_maps(params={"smooth_width": 10.0})

Working with Maps

# Get processed maps
maps = processor.get_processed_maps()

# Filter to specific ROIs
maps.filter_rois([0, 1, 2, 3])

# Filter to specific positions
maps.filter_positions(np.arange(10, 50))

# Average across trials
maps.average_trials()

# Check memory usage
print(f"Maps use {maps.nbytes() / 1e6:.2f} MB")

Environment-Separated Maps

# Get maps separated by environment
env_maps = processor.get_env_maps()

# Access maps for a specific environment
env_idx = 0
occmap_env0 = env_maps.occmap[env_idx]
spkmap_env0 = env_maps.spkmap[env_idx]

# Filter to specific environments
env_maps.filter_environments([1, 2])  # Keep only environments 1 and 2

Reliability Analysis

# Calculate reliability with default method (leave_one_out)
reliability = processor.get_reliability()

# Access reliability values
# Shape: (num_environments, num_rois)
reliability_values = reliability.values

# Filter to specific ROIs
reliability_filtered = reliability.filter_rois([0, 1, 2])

# Filter to specific environments
reliability_env = reliability.filter_by_environment([1, 2])

Place Field Predictions

# Get place field predictions for each frame
prediction, extras = processor.get_placefield_prediction()

# prediction shape: (num_frames, num_rois)
# extras contains frame_position_index, frame_environment_index, idx_valid

# Use predictions to analyze neural activity
valid_predictions = prediction[extras["idx_valid"]]

Traversals Analysis

# Extract activity around place field peak
roi_idx = 5  # Neuron index
env_idx = 0  # Environment index

traversals, pred_travs = processor.get_traversals(
    idx_roi=roi_idx,
    idx_env=env_idx,
    width=10,  # 10 frames on each side
    placefield_threshold=5.0,  # 5 cm threshold
)

# traversals shape: (num_traversals, 21)  # 2*width + 1
# pred_travs shape: (num_traversals, 21)

Caching

The SpkmapProcessor uses intelligent caching to avoid recomputing maps when parameters haven't changed:

# Enable autosave to cache results
params = SpkmapParams(autosave=True)
processor = SpkmapProcessor(session, params=params)

# First call computes and caches
maps1 = processor.get_processed_maps()

# Second call loads from cache (much faster)
maps2 = processor.get_processed_maps()

# Force recomputation
maps3 = processor.get_processed_maps(force_recompute=True)

# View cache information
processor.show_cache()
processor.show_cache(data_type="processed_maps")

Protocol Interface

The SpkmapProcessor works with any session class that implements the SessionToSpkmapProtocol. This protocol defines the required properties:

  • spks: Spike data array
  • spks_type: Type of spike data
  • idx_rois: ROI filter mask
  • timestamps: Imaging frame timestamps
  • env_length: Environment length(s)
  • positions: Position data tuple
  • trial_environment: Environment for each trial
  • num_trials: Number of trials
  • zero_baseline_spks: Whether spikes are zero-baselined

See the protocol documentation for details on implementing custom session classes.