当前位置: 首页>>代码示例>>Python>>正文


Python Data.copy方法代码示例

本文整理汇总了Python中wyrm.types.Data.copy方法的典型用法代码示例。如果您正苦于以下问题:Python Data.copy方法的具体用法?Python Data.copy怎么用?Python Data.copy使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在wyrm.types.Data的用法示例。


在下文中一共展示了Data.copy方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: TestClearMarkers

# 需要导入模块: from wyrm.types import Data [as 别名]
# 或者: from wyrm.types.Data import copy [as 别名]
class TestClearMarkers(unittest.TestCase):

    def setUp(self):
        ones = np.ones((10, 5))
        # cnt with 1, 2, 3
        cnt = np.append(ones, ones*2, axis=0)
        cnt = np.append(cnt, ones*3, axis=0)
        channels = ['ca1', 'ca2', 'cb1', 'cb2', 'cc1']
        time = np.linspace(-1000, 2000, 30, endpoint=False)
        self.good_markers = [[-1000, 'a'], [-999, 'b'], [0, 'c'], [1999.9999999, 'd']]
        bad_markers = [[-1001, 'x'], [2000, 'x']]
        markers = self.good_markers[:]
        markers.extend(bad_markers)
        classes = [0, 1, 2, 1]
        # four cnts: 1s, -1s, and 0s
        data = np.array([cnt * 0, cnt * 1, cnt * 2, cnt * 0])
        self.dat = Data(data, [classes, time, channels], ['class', 'time', 'channel'], ['#', 'ms', '#'])
        self.dat.markers = markers
        self.dat.fs = 10

    def test_clear_markers(self):
        """Clear markers."""
        dat = clear_markers(self.dat)
        self.assertEqual(dat.markers, self.good_markers)

    def test_clear_emtpy_markers(self):
        """Clearing emtpy markers has no effect."""
        dat = self.dat.copy()
        dat.markers = []
        dat2 = clear_markers(dat)
        self.assertEqual(dat, dat2)

    def test_clear_nonexisting_markers(self):
        """Clearing emtpy markers has no effect."""
        dat = self.dat.copy()
        del dat.markers
        dat2 = clear_markers(dat)
        self.assertEqual(dat, dat2)

    def test_clear_markers_w_empty_data(self):
        """Clearing emtpy dat should remove all markers."""
        dat = self.dat.copy()
        dat.data = np.array([])
        dat2 = clear_markers(dat)
        self.assertEqual(dat2.markers, [])

    def test_clear_markes_swapaxes(self):
        """clear_markers must work with nonstandard timeaxis."""
        dat = clear_markers(swapaxes(self.dat, 1, 2), timeaxis=2)
        dat = swapaxes(dat, 1, 2)
        dat2 = clear_markers(self.dat)
        self.assertEqual(dat, dat2)

    def test_clear_markers_copy(self):
        """clear_markers must not modify argument."""
        cpy = self.dat.copy()
        clear_markers(self.dat)
        self.assertEqual(self.dat, cpy)
开发者ID:awakenting,项目名称:wyrm,代码行数:60,代码来源:test_clear_markers.py

示例2: test_copy

# 需要导入模块: from wyrm.types import Data [as 别名]
# 或者: from wyrm.types.Data import copy [as 别名]
 def test_copy(self):
     """Copy must work."""
     d1 = Data(self.data, self.axes, self.names, self.units)
     d2 = d1.copy()
     self.assertEqual(d1, d2)
     # we can't really check of all references to be different in
     # depth recursively, so we only check on the first level
     for k in d1.__dict__:
         self.assertNotEqual(id(getattr(d1, k)), id(getattr(d2, k)))
     d2 = d1.copy(foo='bar')
     self.assertEqual(d2.foo, 'bar')
开发者ID:awakenting,项目名称:wyrm,代码行数:13,代码来源:test_data.py

示例3: TestAppendEpo

# 需要导入模块: from wyrm.types import Data [as 别名]
# 或者: from wyrm.types.Data import copy [as 别名]
class TestAppendEpo(unittest.TestCase):

    def setUp(self):
        ones = np.ones((10, 5))
        # cnt with 1, 2, 3
        cnt = np.append(ones, ones*2, axis=0)
        cnt = np.append(cnt, ones*3, axis=0)
        channels = ['ca1', 'ca2', 'cb1', 'cb2', 'cc1']
        time = np.linspace(0, 3000, 30, endpoint=False)
        classes = [0, 1, 2, 1]
        # four cnts: 1s, -1s, and 0s
        data = np.array([cnt * 0, cnt * 1, cnt * 2, cnt * 0])
        self.dat = Data(data, [classes, time, channels], ['class', 'time', 'channel'], ['#', 'ms', '#'])
        self.dat.class_names = ['zero', 'one', 'two']

    def test_append_epo(self):
        """append_epo."""
        dat = append_epo(self.dat, self.dat)
        self.assertEqual(dat.data.shape[0], 2*self.dat.data.shape[0])
        self.assertEqual(len(dat.axes[0]), 2*len(self.dat.axes[0]))
        np.testing.assert_array_equal(dat.data, np.concatenate([self.dat.data, self.dat.data], axis=0))
        np.testing.assert_array_equal(dat.axes[0], np.concatenate([self.dat.axes[0], self.dat.axes[0]]))
        self.assertEqual(dat.class_names, self.dat.class_names)

    def test_append_epo_with_extra(self):
        """append_epo with extra must work with list and ndarrays."""
        self.dat.a = list(range(10))
        self.dat.b = np.arange(10)
        dat = append_epo(self.dat, self.dat, extra=['a', 'b'])
        self.assertEqual(dat.a, list(range(10)) + list(range(10)))
        np.testing.assert_array_equal(dat.b, np.concatenate([np.arange(10), np.arange(10)]))

    def test_append_epo_with_different_class_names(self):
        """test_append must raise a ValueError if class_names are different."""
        a = self.dat.copy()
        a.class_names = a.class_names[:-1]
        with self.assertRaises(ValueError):
            append_epo(a, self.dat)
            append_epo(self.dat, a)

    def test_append_epo_swapaxes(self):
        """append_epo must work with nonstandard timeaxis."""
        dat = append_epo(swapaxes(self.dat, 0, 2), swapaxes(self.dat, 0, 2), classaxis=2)
        dat = swapaxes(dat, 0, 2)
        dat2 = append_epo(self.dat, self.dat)
        self.assertEqual(dat, dat2)

    def test_append_epo_copy(self):
        """append_epo means must not modify argument."""
        cpy = self.dat.copy()
        append_epo(self.dat, self.dat)
        self.assertEqual(self.dat, cpy)
开发者ID:awakenting,项目名称:wyrm,代码行数:54,代码来源:test_append_epo.py

示例4: TestCorrectForBaseline

# 需要导入模块: from wyrm.types import Data [as 别名]
# 或者: from wyrm.types.Data import copy [as 别名]
class TestCorrectForBaseline(unittest.TestCase):

    def setUp(self):
        ones = np.ones((10, 5))
        classes = [0, 0, 0]
        channels = ['ca1', 'ca2', 'cb1', 'cb2', 'cc1']
        time = np.linspace(-1000, 0, 10, endpoint=False)
        # three cnts: 1s, -1s, and 0s
        data = np.array([ones, ones * -1, ones * 0])
        self.dat = Data(data, [classes, time, channels], ['class', 'time', 'channels'], ['#', 'ms', '#'])

    def test_correct_for_baseline_epo(self):
        """Test baselineing w/ epo like."""
        # normal case
        dat = correct_for_baseline(self.dat, [-500, 0])
        np.testing.assert_array_equal(np.zeros((3, 10, 5)), dat.data)
        # the full dat interval
        dat = correct_for_baseline(self.dat, [dat.axes[-2][0], dat.axes[-2][-1]])
        np.testing.assert_array_equal(np.zeros((3, 10, 5)), dat.data)

    def test_correct_for_baseline_cnt(self):
        """Test baselineing w/ cnt like."""
        data = self.dat.data.reshape(30, 5)
        axes = [np.linspace(-1000, 2000, 30, endpoint=False), self.dat.axes[-1]]
        units = self.dat.units[1:]
        names = self.dat.names[1:]
        dat = self.dat.copy(data=data, axes=axes, names=names, units=units)
        dat2 = correct_for_baseline(dat, [-1000, 0])
        np.testing.assert_array_equal(dat2.data, dat.data - 1)

    def test_ival_checks(self):
        """Test for malformed ival parameter."""
        with self.assertRaises(AssertionError):
            correct_for_baseline(self.dat, [0, -1])
        with self.assertRaises(AssertionError):
            correct_for_baseline(self.dat, [self.dat.axes[-2][0]-1, 0])
        with self.assertRaises(AssertionError):
            correct_for_baseline(self.dat, [0, self.dat.axes[-2][1]+1])

    def test_correct_for_baseline_copy(self):
        """Correct for baseline must not modify dat argument."""
        cpy = self.dat.copy()
        correct_for_baseline(self.dat, [-1000, 0])
        self.assertEqual(cpy, self.dat)

    def test_correct_for_baseline_swapaxes(self):
        """Correct for baseline must work with nonstandard timeaxis."""
        dat = correct_for_baseline(swapaxes(self.dat, 0, 1), [-1000, 0], timeaxis=0)
        dat = swapaxes(dat, 0, 1)
        dat2 = correct_for_baseline(self.dat, [-1000, 0])
        self.assertEqual(dat, dat2)
开发者ID:alistairwalsh,项目名称:wyrm,代码行数:53,代码来源:test_correct_for_baseline.py

示例5: test_eq_and_ne

# 需要导入模块: from wyrm.types import Data [as 别名]
# 或者: from wyrm.types.Data import copy [as 别名]
 def test_eq_and_ne(self):
     """Check if __ne__ is properly implemented."""
     d1 = Data(self.data, self.axes, self.names, self.units)
     d2 = d1.copy()
     # if __eq__ is implemented and __ne__ is not, this evaluates to
     # True!
     self.assertFalse(d1 == d2 and d1 != d2)
开发者ID:awakenting,项目名称:wyrm,代码行数:9,代码来源:test_data.py

示例6: TestSelectChannels

# 需要导入模块: from wyrm.types import Data [as 别名]
# 或者: from wyrm.types.Data import copy [as 别名]
class TestSelectChannels(unittest.TestCase):
    def setUp(self):
        raw = np.arange(20).reshape(4, 5)
        channels = ["ca1", "ca2", "cb1", "cb2", "cc1"]
        time = np.arange(4)
        self.dat = Data(raw, [time, channels], ["time", "channels"], ["ms", "#"])

    def test_select_channels(self):
        """Selecting channels with an array of regexes."""
        channels = self.dat.data.copy()
        self.dat = select_channels(self.dat, ["ca.*", "cc1"])
        np.testing.assert_array_equal(self.dat.axes[-1], np.array(["ca1", "ca2", "cc1"]))
        np.testing.assert_array_equal(self.dat.data, channels[:, np.array([0, 1, -1])])

    def test_select_channels_inverse(self):
        """Removing channels with an array of regexes."""
        channels = self.dat.data.copy()
        self.dat = select_channels(self.dat, ["ca.*", "cc1"], invert=True)
        np.testing.assert_array_equal(self.dat.axes[-1], np.array(["cb1", "cb2"]))
        np.testing.assert_array_equal(self.dat.data, channels[:, np.array([2, 3])])

    def test_select_channels_copy(self):
        """Select channels must not change the original parameter."""
        cpy = self.dat.copy()
        select_channels(self.dat, ["ca.*"])
        self.assertEqual(cpy, self.dat)

    def test_select_channels_swapaxis(self):
        """Select channels works with non default chanaxis."""
        dat1 = select_channels(swapaxes(self.dat, 0, 1), ["ca.*"], chanaxis=0)
        dat1 = swapaxes(dat1, 0, 1)
        dat2 = select_channels(self.dat, ["ca.*"])
        self.assertEqual(dat1, dat2)
开发者ID:usmanayubsh,项目名称:wyrm,代码行数:35,代码来源:test_select_channels.py

示例7: TestSwapaxes

# 需要导入模块: from wyrm.types import Data [as 别名]
# 或者: from wyrm.types.Data import copy [as 别名]
class TestSwapaxes(unittest.TestCase):

    def setUp(self):
        raw = np.arange(2000).reshape(-1, 5)
        channels = ['ca1', 'ca2', 'cb1', 'cb2', 'cc1']
        time = np.linspace(0, 4000, 400, endpoint=False)
        fs = 100
        marker = [[100, 'foo'], [200, 'bar']]
        self.dat = Data(raw, [time, channels], ['time', 'channels'], ['ms', '#'])
        self.dat.fs = fs
        self.dat.markers = marker

    def test_swapaxes(self):
        """Swapping axes."""
        new = swapaxes(self.dat, 0, 1)
        self.assertTrue((new.axes[0] == self.dat.axes[1]).all())
        self.assertTrue((new.axes[1] == self.dat.axes[0]).all())
        self.assertEqual(new.names[0], self.dat.names[1])
        self.assertEqual(new.names[1], self.dat.names[0])
        self.assertEqual(new.units[0], self.dat.units[1])
        self.assertEqual(new.units[1], self.dat.units[0])
        self.assertEqual(new.data.shape[::-1], self.dat.data.shape)
        np.testing.assert_array_equal(new.data.swapaxes(0, 1), self.dat.data)

    def test_swapaxes_copy(self):
        """Swapaxes must not modify argument."""
        cpy = self.dat.copy()
        swapaxes(self.dat, 0, 1)
        self.assertEqual(cpy, self.dat)

    def test_swapaxes_twice(self):
        """Swapping the same axes twice must result in original."""
        dat = swapaxes(self.dat, 0, 1)
        dat = swapaxes(dat, 0, 1)
        self.assertEqual(dat, self.dat)
开发者ID:alistairwalsh,项目名称:wyrm,代码行数:37,代码来源:test_swapaxes.py

示例8: test_equality

# 需要导入模块: from wyrm.types import Data [as 别名]
# 或者: from wyrm.types.Data import copy [as 别名]
 def test_equality(self):
     """Test the various (in)equalities."""
     d1 = Data(self.data, self.axes, self.names, self.units)
     # known extra attributes
     d1.markers = [[123, 'foo'], [234, 'bar']]
     d1.fs = 100
     # unknown extra attribute
     d1.foo = 'bar'
     # so far, so equal
     d2 = d1.copy()
     self.assertEqual(d1, d2)
     # different shape
     d2 = d1.copy()
     d2.data = np.arange(20).reshape(5, 4)
     self.assertNotEqual(d1, d2)
     # different data
     d2 = d1.copy()
     d2.data[0, 0] = 42
     self.assertNotEqual(d1, d2)
     # different axes
     d2 = d1.copy()
     d2.axes[0] = np.arange(100)
     self.assertNotEqual(d1, d2)
     # different names
     d2 = d1.copy()
     d2.names[0] = 'baz'
     self.assertNotEqual(d1, d2)
     # different untis
     d2 = d1.copy()
     d2.units[0] = 'u3'
     self.assertNotEqual(d1, d2)
     # different known extra attribute
     d2 = d1.copy()
     d2.markers[0] = [123, 'baz']
     self.assertNotEqual(d1, d2)
     # different known extra attribute
     d2 = d1.copy()
     d2.fs = 10
     self.assertNotEqual(d1, d2)
     # different unknown extra attribute
     d2 = d1.copy()
     d2.baz = 'baz'
     self.assertNotEqual(d1, d2)
     # different new unknown extra attribute
     d2 = d1.copy()
     d2.bar = 42
     self.assertNotEqual(d1, d2)
开发者ID:awakenting,项目名称:wyrm,代码行数:49,代码来源:test_data.py

示例9: TestSelectEpochs

# 需要导入模块: from wyrm.types import Data [as 别名]
# 或者: from wyrm.types.Data import copy [as 别名]
class TestSelectEpochs(unittest.TestCase):

    def setUp(self):
        ones = np.ones((10, 5))
        channels = ['ca1', 'ca2', 'cb1', 'cb2', 'cc1']
        time = np.linspace(0, 1000, 10, endpoint=False)
        classes = [0, 1, 2, 1]
        class_names = ['zeros', 'ones', 'twoes']
        # four cnts: 0s, 1s, -1s, and 0s
        data = np.array([ones * 0, ones * 1, ones * 2, ones * 0])
        self.dat = Data(data, [classes, time, channels], ['class', 'time', 'channel'], ['#', 'ms', '#'])
        self.dat.class_names = class_names

    def test_select_epochs(self):
        """Selecting Epochs."""
        # normal case
        dat = select_epochs(self.dat, [0])
        self.assertEqual(dat.data.shape[0], 1)
        np.testing.assert_array_equal(dat.data, self.dat.data[[0]])
        np.testing.assert_array_equal(dat.axes[0], self.dat.axes[0][0])
        # normal every second
        dat = select_epochs(self.dat, [0, 2])
        self.assertEqual(dat.data.shape[0], 2)
        np.testing.assert_array_equal(dat.data, self.dat.data[::2])
        np.testing.assert_array_equal(dat.axes[0], self.dat.axes[0][::2])
        # the full epo
        dat = select_epochs(self.dat, list(range(self.dat.data.shape[0])))
        np.testing.assert_array_equal(dat.data, self.dat.data)
        np.testing.assert_array_equal(dat.axes[0], self.dat.axes[0])
        # remove one
        dat = select_epochs(self.dat, [0], invert=True)
        self.assertEqual(dat.data.shape[0], 3)
        np.testing.assert_array_equal(dat.data, self.dat.data[1:])
        np.testing.assert_array_equal(dat.axes[0], self.dat.axes[0][1:])
        # remove every second
        dat = select_epochs(self.dat, [0, 2], invert=True)
        self.assertEqual(dat.data.shape[0], 2)
        np.testing.assert_array_equal(dat.data, self.dat.data[1::2])
        np.testing.assert_array_equal(dat.axes[0], self.dat.axes[0][1::2])

    def test_select_epochs_with_cnt(self):
        """Select epochs must raise an exception if called with cnt argument."""
        del(self.dat.class_names)
        with self.assertRaises(AssertionError):
            select_epochs(self.dat, [0, 1])

    def test_select_epochs_swapaxes(self):
        """Select epochs must work with nonstandard classaxis."""
        dat = select_epochs(swapaxes(self.dat, 0, 2), [0, 1], classaxis=2)
        dat = swapaxes(dat, 0, 2)
        dat2 = select_epochs(self.dat, [0, 1])
        self.assertEqual(dat, dat2)

    def test_select_epochs_copy(self):
        """Select Epochs must not modify argument."""
        cpy = self.dat.copy()
        select_epochs(self.dat, [0, 1])
        self.assertEqual(self.dat, cpy)
开发者ID:awakenting,项目名称:wyrm,代码行数:60,代码来源:test_select_epochs.py

示例10: TestRereference

# 需要导入模块: from wyrm.types import Data [as 别名]
# 或者: from wyrm.types.Data import copy [as 别名]
class TestRereference(unittest.TestCase):

    def setUp(self):
        dat = np.zeros((SAMPLES, CHANS))
        # [-10, -9, ... 20)
        dat[:, 0] = np.arange(SAMPLES) - SAMPLES/2
        channels = ['chan{i}'.format(i=i) for i in range(CHANS)]
        time = np.arange(SAMPLES)
        self.cnt = Data(dat, [time, channels], ['time', 'channels'], ['ms', '#'])
        # construct epo
        epo_dat = np.array([dat + i for i in range(EPOS)])
        classes = ['class{i}'.format(i=i) for i in range(EPOS)]
        self.epo = Data(epo_dat, [classes, time, channels], ['class', 'time', 'channels'], ['#', 'ms', '#'])

    def test_rereference_cnt(self):
        """Rereference channels (cnt)."""
        cnt_r = rereference(self.cnt, 'chan0')
        dat_r = np.linspace(SAMPLES/2, -SAMPLES/2, SAMPLES, endpoint=False)
        dat_r = [dat_r for i in range(CHANS)]
        dat_r = np.array(dat_r).T
        dat_r[:, 0] = 0
        np.testing.assert_array_equal(cnt_r.data, dat_r)

    def test_rereference_epo(self):
        """Rereference channels (epo)."""
        epo_r = rereference(self.epo, 'chan0')
        dat_r = np.linspace(SAMPLES/2, -SAMPLES/2, SAMPLES, endpoint=False)
        dat_r = [dat_r for i in range(CHANS)]
        dat_r = np.array(dat_r).T
        dat_r[:, 0] = 0
        dat_r = np.array([dat_r for i in range(EPOS)])
        np.testing.assert_array_equal(epo_r.data, dat_r)

    def test_raise_value_error(self):
        """Raise ValueError if channel not found."""
        with self.assertRaises(ValueError):
            rereference(self.cnt, 'foo')

    def test_case_insensitivity(self):
        """rereference should not care about case."""
        try:
            rereference(self.cnt, 'ChAN0')
        except ValueError:
            self.fail()

    def test_rereference_copy(self):
        """rereference must not modify arguments."""
        cpy = self.cnt.copy()
        rereference(self.cnt, 'chan0')
        self.assertEqual(self.cnt, cpy)

    def test_rereference_swapaxes(self):
        """rereference must work with nonstandard chanaxis."""
        dat = rereference(swapaxes(self.epo, 1, 2), 'chan0', chanaxis=1)
        dat = swapaxes(dat, 1, 2)
        dat2 = rereference(self.epo, 'chan0')
        self.assertEqual(dat, dat2)
开发者ID:awakenting,项目名称:wyrm,代码行数:59,代码来源:test_rereference.py

示例11: TestVariance

# 需要导入模块: from wyrm.types import Data [as 别名]
# 或者: from wyrm.types.Data import copy [as 别名]
class TestVariance(unittest.TestCase):

    def setUp(self):
        ones = np.ones((10, 5))
        # epo with 0, 1, 2
        data = np.array([0*ones, ones, 2*ones])
        channels = ['ca1', 'ca2', 'cb1', 'cb2', 'cc1']
        time = np.linspace(0, 1000, 10, endpoint=False)
        classes = [0, 1, 2]
        self.dat = Data(data, [classes, time, channels], ['class', 'time', 'channel'], ['#', 'ms', '#'])

    def test_variance(self):
        """Variance."""
        dat = variance(self.dat)
        # test the resulting dat has one axis less (the middle one)
        self.assertEqual(dat.data.shape, self.dat.data.shape[::2])
        # each epoch should have a variance of zero, test if the var of
        # all epochs is 0
        self.assertEqual(dat.data.var(), 0)
        self.assertEqual(len(dat.axes), len(self.dat.axes)-1)

    def test_variance_with_cnt(self):
        """variance must work with cnt argument."""
        data = self.dat.data[1]
        axes = self.dat.axes[1:]
        names = self.dat.names[1:]
        units = self.dat.units[1:]
        dat = self.dat.copy(data=data, axes=axes, names=names, units=units)
        dat = variance(dat)
        self.assertEqual(dat.data.var(), 0)
        self.assertEqual(len(dat.axes), len(self.dat.axes)-2)

    def test_variance_swapaxes(self):
        """variance must work with nonstandard timeaxis."""
        dat = variance(swapaxes(self.dat, 1, 2), timeaxis=2)
        # we don't swap back here as variance removes the timeaxis
        dat2 = variance(self.dat)
        self.assertEqual(dat, dat2)

    def test_variance_copy(self):
        """variance must not modify argument."""
        cpy = self.dat.copy()
        variance(self.dat)
        self.assertEqual(self.dat, cpy)
开发者ID:awakenting,项目名称:wyrm,代码行数:46,代码来源:test_variance.py

示例12: TestRectifytChannels

# 需要导入模块: from wyrm.types import Data [as 别名]
# 或者: from wyrm.types.Data import copy [as 别名]
class TestRectifytChannels(unittest.TestCase):

    def setUp(self):
        raw = np.arange(20).reshape(4, 5)
        channels = ['ca1', 'ca2', 'cb1', 'cb2', 'cc1']
        time = np.arange(4)
        self.dat = Data(raw, [time, channels], ['time', 'channels'], ['ms', '#'])

    def test_rectify_channels(self):
        """Rectify channels of positive and negative data must be equal."""
        dat = rectify_channels(self.dat.copy(data=-self.dat.data))
        dat2 = rectify_channels(self.dat)
        self.assertEqual(dat, dat2)

    def test_rectify_channels_copy(self):
        """Rectify channels must not change the original parameter."""
        cpy = self.dat.copy()
        rectify_channels(self.dat)
        self.assertEqual(cpy, self.dat)
开发者ID:alistairwalsh,项目名称:wyrm,代码行数:21,代码来源:test_rectify_channels.py

示例13: TestSelectIval

# 需要导入模块: from wyrm.types import Data [as 别名]
# 或者: from wyrm.types.Data import copy [as 别名]
class TestSelectIval(unittest.TestCase):

    def setUp(self):
        ones = np.ones((10, 5))
        channels = ['ca1', 'ca2', 'cb1', 'cb2', 'cc1']
        time = np.linspace(-1000, 0, 10, endpoint=False)
        classes = [0, 0, 0]
        # three cnts: 1s, -1s, and 0s
        data = np.array([ones, ones * -1, ones * 0])
        self.dat = Data(data, [classes, time, channels], ['class', 'time', 'channel'], ['#', 'ms', '#'])
        self.dat.fs = 10

    def test_select_ival(self):
        """Selecting Intervals."""
        # normal case
        dat = select_ival(self.dat, [-500, 0])
        self.assertEqual(dat.axes[1][0], -500)
        self.assertEqual(dat.axes[1][-1],-100)
        # the full dat interval
        dat = select_ival(self.dat, [self.dat.axes[1][0], self.dat.axes[1][-1] + 1])
        self.assertEqual(dat.axes[1][0], self.dat.axes[1][0])
        self.assertEqual(dat.axes[1][-1], self.dat.axes[1][-1])
        np.testing.assert_array_equal(dat.data, self.dat.data)

    def test_select_ival_with_markers(self):
        """Selecting Intervals with markers."""
        # normal case
        good_markers = [[-499,99, 'x'], [-500, 'x'], [-0.0001, 'x']]
        bad_markers = [[501, 'y'], [0, 'y'], [1, 'y']]
        self.dat.markers = good_markers[:]
        self.dat.markers.extend(bad_markers)
        dat = select_ival(self.dat, [-500, 0])
        self.assertEqual(dat.markers, good_markers)

    def test_ival_checks(self):
        """Test for malformed ival parameter."""
        with self.assertRaises(AssertionError):
            select_ival(self.dat, [0, -1])
        with self.assertRaises(AssertionError):
            select_ival(self.dat, [self.dat.axes[1][0]-1, 0])
        with self.assertRaises(AssertionError):
            select_ival(self.dat, [0, self.dat.axes[1][-1]+1])

    def test_select_ival_copy(self):
        """Select_ival must not modify the argument."""
        cpy = self.dat.copy()
        select_ival(cpy, [-500, 0])
        self.assertEqual(cpy, self.dat)

    def test_select_ival_swapaxes(self):
        """select_ival must work with nonstandard timeaxis."""
        dat = select_ival(swapaxes(self.dat, 0, 1), [-500, 0], timeaxis=0)
        dat = swapaxes(dat, 0, 1)
        dat2 = select_ival(self.dat, [-500, 0])
        self.assertEqual(dat, dat2)
开发者ID:awakenting,项目名称:wyrm,代码行数:57,代码来源:test_select_ival.py

示例14: TestSpectrum

# 需要导入模块: from wyrm.types import Data [as 别名]
# 或者: from wyrm.types.Data import copy [as 别名]
class TestSpectrum(unittest.TestCase):

    def setUp(self):
        # create some data
        fs = 100
        dt = 5
        self.freqs = [2, 7, 15]
        self.amps = [30, 10, 2]
        t = np.linspace(0, dt, fs*dt)
        data = np.sum([a * np.sin(2*np.pi*t*f) for a, f in zip(self.amps, self.freqs)], axis=0)
        data = data[:, np.newaxis]
        data = np.concatenate([data, data], axis=1)
        channel = np.array(['ch1', 'ch2'])
        self.dat = Data(data, [t, channel], ['time', 'channel'], ['s', '#'])
        self.dat.fs = fs

    def test_spectrum(self):
        """Calculate the spectrum."""
        dat = spectrum(self.dat)
        # check that the amplitudes are almost correct
        for idx, freq in enumerate(self.freqs):
            for chan in range(dat.data.shape[1]):
                self.assertAlmostEqual(dat.data[dat.axes[0] == freq, chan], self.amps[idx], delta=.15)
        # check the amplitudes for the remaining freqs are almost zero
        mask = (dat.axes[0] != self.freqs[0]) & (dat.axes[0] != self.freqs[1]) & (dat.axes[0] != self.freqs[2])
        self.assertFalse((dat.data[mask] > .8).any())
        # check that the max freq is < self.dat.fs / 2, and min freq > 0
        self.assertGreater(min(dat.axes[0]), 0)
        self.assertLess(max(dat.axes[0]), self.dat.fs / 2)

    def test_spectrum_has_no_fs(self):
        """A spectrum has no sampling freq."""
        dat = spectrum(self.dat)
        self.assertFalse(hasattr(dat, 'fs'))

    def test_spectrum_copy(self):
        """spectrum must not modify argument."""
        cpy = self.dat.copy()
        spectrum(self.dat)
        self.assertEqual(cpy, self.dat)

    def test_spectrum_swapaxes(self):
        """spectrum must work with nonstandard timeaxis."""
        dat = spectrum(swapaxes(self.dat, 0, 1), timeaxis=1)
        dat = swapaxes(dat, 0, 1)
        dat2 = spectrum(self.dat)
        self.assertEqual(dat, dat2)
开发者ID:awakenting,项目名称:wyrm,代码行数:49,代码来源:test_spectrum.py

示例15: TestFiltFilt

# 需要导入模块: from wyrm.types import Data [as 别名]
# 或者: from wyrm.types.Data import copy [as 别名]
class TestFiltFilt(unittest.TestCase):

    def setUp(self):
        # create some data
        fs = 100
        dt = 5
        self.freqs = [2, 7, 15]
        amps = [30, 10, 2]
        t = np.linspace(0, dt, fs*dt)
        data = np.sum([a * np.sin(2*np.pi*t*f) for a, f in zip(amps, self.freqs)], axis=0)
        data = data[:, np.newaxis]
        data = np.concatenate([data, data], axis=1)
        channel = np.array(['ch1', 'ch2'])
        self.dat = Data(data, [t, channel], ['time', 'channel'], ['s', '#'])
        self.dat.fs = fs

    def test_bandpass(self):
        """Band pass filtering."""
        # bandpass around the middle frequency
        fn = self.dat.fs / 2
        b, a = butter(4, [6 / fn, 8 / fn], btype='band')
        ans = filtfilt(self.dat, b, a)
        # check if the desired band is not damped
        dat = spectrum(ans)
        mask = dat.axes[0] == 7
        self.assertTrue((dat.data[mask] > 6.5).all())
        # check if the outer freqs are damped close to zero
        mask = (dat.axes[0] <= 6) & (dat.axes[0] > 8)
        self.assertTrue((dat.data[mask] < .5).all())

    def test_filtfilt_copy(self):
        """filtfilt must not modify argument."""
        cpy = self.dat.copy()
        fn = self.dat.fs / 2
        b, a = butter(4, [6 / fn, 8 / fn], btype='band')
        filtfilt(self.dat, b, a)
        self.assertEqual(cpy, self.dat)

    def test_filtfilt_swapaxes(self):
        """filtfilt must work with nonstandard timeaxis."""
        fn = self.dat.fs / 2
        b, a = butter(4, [6 / fn, 8 / fn], btype='band')
        dat = filtfilt(swapaxes(self.dat, 0, 1), b, a, timeaxis=1)
        dat = swapaxes(dat, 0, 1)
        dat2 = filtfilt(self.dat, b, a)
        self.assertEqual(dat, dat2)
开发者ID:awakenting,项目名称:wyrm,代码行数:48,代码来源:test_filtfilt.py


注:本文中的wyrm.types.Data.copy方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。