当前位置: 首页>>代码示例>>Python>>正文


Python testing.assert_equal方法代码示例

本文整理汇总了Python中numpy.testing.assert_equal方法的典型用法代码示例。如果您正苦于以下问题:Python testing.assert_equal方法的具体用法?Python testing.assert_equal怎么用?Python testing.assert_equal使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在numpy.testing的用法示例。


在下文中一共展示了testing.assert_equal方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: test_class

# 需要导入模块: from numpy import testing [as 别名]
# 或者: from numpy.testing import assert_equal [as 别名]
def test_class(self):
        """Tests container behavior."""
        model = kproxy_supercell.TDProxy(self.model_krhf, "hf", [self.k, 1, 1], density_fitting_hf)
        model.nroots = self.td_model_krhf.nroots
        assert not model.fast
        model.kernel()
        testing.assert_allclose(model.e, self.td_model_krhf.e, atol=1e-5)
        # Test real
        testing.assert_allclose(model.e.imag, 0, atol=1e-8)

        nocc = nvirt = 4
        testing.assert_equal(model.xy.shape, (len(model.e), 2, self.k, self.k, nocc, nvirt))

        # Test only non-degenerate roots
        d = abs(model.e[1:] - model.e[:-1]) < 1e-8
        d = numpy.logical_or(numpy.concatenate(([False], d)), numpy.concatenate((d, [False])))
        d = numpy.logical_not(d)
        assert_vectors_close(self.td_model_krhf.xy[d], model.xy[d], atol=1e-5) 
开发者ID:pyscf,项目名称:pyscf,代码行数:20,代码来源:test_kproxy_supercell_hf.py

示例2: assert_warns_msg

# 需要导入模块: from numpy import testing [as 别名]
# 或者: from numpy.testing import assert_equal [as 别名]
def assert_warns_msg(expected_warning, func, msg, *args, **kwargs):
    """Assert a call leads to a warning with a specific message

    Test whether a function call leads to a warning of type
    ``expected_warning`` with a message that contains the string ``msg``.

    Parameters
    ----------
    expected_warning : warning class
        The class of warning to be checked; e.g., DeprecationWarning
    func : object
        The class, method, property, or function to be called as\
        func(\*args, \*\*kwargs)
    msg : str
        The message or a substring of the message to test for.
    \*args : positional arguments to ``func``
    \*\*kwargs: keyword arguments to ``func``

    """
    with pytest.warns(expected_warning) as record:
        func(*args, **kwargs)
    npt.assert_equal(len(record), 1)
    if msg is not None:
        npt.assert_equal(msg in record[0].message.args[0], True) 
开发者ID:pulse2percept,项目名称:pulse2percept,代码行数:26,代码来源:testing.py

示例3: test_conv

# 需要导入模块: from numpy import testing [as 别名]
# 或者: from numpy.testing import assert_equal [as 别名]
def test_conv(mode, method):
    reload(convolution)
    # time vector for stimulus (long)
    stim_dur = 0.5  # seconds
    tsample = 0.001 / 1000
    t = np.arange(0, stim_dur, tsample)

    # stimulus (10 Hz anondic and cathodic pulse train)
    stim = np.zeros_like(t)
    stim[::1000] = 1
    stim[100::1000] = -1

    # kernel
    _, gg = gamma(1, 0.005, tsample)

    # make sure conv returns the same result as np.convolve for all modes:
    npconv = np.convolve(stim, gg, mode=mode)
    conv = convolution.conv(stim, gg, mode=mode, method=method)
    npt.assert_equal(conv.shape, npconv.shape)
    npt.assert_almost_equal(conv, npconv)

    with pytest.raises(ValueError):
        convolution.conv(gg, stim, mode="invalid")
    with pytest.raises(ValueError):
        convolution.conv(gg, stim, method="invalid") 
开发者ID:pulse2percept,项目名称:pulse2percept,代码行数:27,代码来源:test_convolution.py

示例4: test_gamma

# 需要导入模块: from numpy import testing [as 别名]
# 或者: from numpy.testing import assert_equal [as 别名]
def test_gamma():
    tsample = 0.005 / 1000

    with pytest.raises(ValueError):
        t, g = gamma(0, 0.1, tsample)
    with pytest.raises(ValueError):
        t, g = gamma(2, -0.1, tsample)
    with pytest.raises(ValueError):
        t, g = gamma(2, 0.1, -tsample)

    for tau in [0.001, 0.01, 0.1]:
        for n in [1, 2, 5]:
            t, g = gamma(n, tau, tsample)
            npt.assert_equal(np.arange(0, t[-1] + tsample / 2.0, tsample), t)
            if n > 1:
                npt.assert_equal(g[0], 0.0)

            # Make sure area under the curve is normalized
            npt.assert_almost_equal(np.trapz(np.abs(g), dx=tsample), 1.0,
                                    decimal=2)

            # Make sure peak sits correctly
            npt.assert_almost_equal(g.argmax() * tsample, tau * (n - 1)) 
开发者ID:pulse2percept,项目名称:pulse2percept,代码行数:25,代码来源:test_base.py

示例5: test_Grid2D

# 需要导入模块: from numpy import testing [as 别名]
# 或者: from numpy.testing import assert_equal [as 别名]
def test_Grid2D(x_range, y_range):
    grid = Grid2D(x_range, y_range, step=1, grid_type='rectangular')
    npt.assert_equal(grid.x_range, x_range)
    npt.assert_equal(grid.y_range, y_range)
    npt.assert_equal(grid.step, 1)
    npt.assert_equal(grid.type, 'rectangular')

    # Grid is created with indexing='xy', so check coordinates:
    npt.assert_equal(grid.x.shape,
                     (np.abs(np.diff(y_range)) + 1,
                      np.abs(np.diff(x_range)) + 1))
    npt.assert_equal(grid.x.shape, grid.y.shape)
    npt.assert_equal(grid.x.shape, grid.shape)
    npt.assert_almost_equal(grid.x[0, 0], x_range[0])
    npt.assert_almost_equal(grid.x[0, -1], x_range[1])
    npt.assert_almost_equal(grid.x[-1, 0], x_range[0])
    npt.assert_almost_equal(grid.x[-1, -1], x_range[1])
    npt.assert_almost_equal(grid.y[0, 0], y_range[0])
    npt.assert_almost_equal(grid.y[0, -1], y_range[0])
    npt.assert_almost_equal(grid.y[-1, 0], y_range[1])
    npt.assert_almost_equal(grid.y[-1, -1], y_range[1]) 
开发者ID:pulse2percept,项目名称:pulse2percept,代码行数:23,代码来源:test_geometry.py

示例6: test_BaseModel

# 需要导入模块: from numpy import testing [as 别名]
# 或者: from numpy.testing import assert_equal [as 别名]
def test_BaseModel():
    # Test PrettyPrint:
    model = ValidBaseModel()
    npt.assert_equal(str(model), 'ValidBaseModel(a=1, b=2)')

    # Can overwrite default values:
    model = ValidBaseModel(b=3)
    npt.assert_almost_equal(model.b, 3)

    # Cannot add more attributes:
    with pytest.raises(FreezeError):
        model.c = 3

    # Check the build switch:
    npt.assert_equal(model.is_built, False)
    model.build(a=3)
    npt.assert_almost_equal(model.a, 3)
    npt.assert_equal(model.is_built, True)

    # Attributes must be in `get_default_params`:
    with pytest.raises(AttributeError):
        ValidBaseModel(c=3)
    with pytest.raises(AttributeError):
        ValidBaseModel().is_built = True 
开发者ID:pulse2percept,项目名称:pulse2percept,代码行数:26,代码来源:test_base.py

示例7: test_Model_set_params

# 需要导入模块: from numpy import testing [as 别名]
# 或者: from numpy.testing import assert_equal [as 别名]
def test_Model_set_params():
    # SpatialModel, but no TemporalModel:
    model = Model(spatial=ValidSpatialModel())
    model.set_params({'xystep': 2.33})
    npt.assert_almost_equal(model.xystep, 2.33)
    npt.assert_almost_equal(model.spatial.xystep, 2.33)

    # TemporalModel, but no SpatialModel:
    model = Model(temporal=ValidTemporalModel())
    model.set_params({'dt': 2.33})
    npt.assert_almost_equal(model.dt, 2.33)
    npt.assert_almost_equal(model.temporal.dt, 2.33)

    # SpatialModel and TemporalModel:
    model = Model(spatial=ValidSpatialModel(), temporal=ValidTemporalModel())
    # Setting both using the convenience function:
    model.set_params({'xystep': 5, 'dt': 2.33})
    npt.assert_almost_equal(model.xystep, 5)
    npt.assert_almost_equal(model.spatial.xystep, 5)
    npt.assert_equal(hasattr(model.temporal, 'xystep'), False)
    npt.assert_almost_equal(model.dt, 2.33)
    npt.assert_almost_equal(model.temporal.dt, 2.33)
    npt.assert_equal(hasattr(model.spatial, 'dt'), False) 
开发者ID:pulse2percept,项目名称:pulse2percept,代码行数:25,代码来源:test_base.py

示例8: test_fetch_url

# 需要导入模块: from numpy import testing [as 别名]
# 或者: from numpy.testing import assert_equal [as 别名]
def test_fetch_url(tmp_data_dir):
    url1 = 'https://www.nature.com/articles/s41598-019-45416-4.pdf'
    file_path1 = os.path.join(tmp_data_dir, 'paper1.pdf')
    paper_checksum1 = 'e8a2db25916cdd15a4b7be75081ef3e57328fa5f335fb4664d1fb7090dcd6842'
    fetch_url(url1, file_path1, remote_checksum=paper_checksum1)
    npt.assert_equal(os.path.exists(file_path1), True)

    url2 = 'https://bionicvisionlab.org/publication/2019-optimal-surgical-placement/2019-optimal-surgical-placement.pdf'
    file_path2 = os.path.join(tmp_data_dir, 'paper2.pdf')
    paper_checksum2 = 'e2d0cbecc9c2826f66f60576b44fe18ad6a635d394ae02c3f528b89cffcd9450'
    # Use wrong checksum:
    with pytest.raises(IOError):
        fetch_url(url2, file_path2, remote_checksum=paper_checksum1)
    # Use correct checksum:
    fetch_url(url2, file_path2, remote_checksum=paper_checksum2)
    npt.assert_equal(os.path.exists(file_path2), True) 
开发者ID:pulse2percept,项目名称:pulse2percept,代码行数:18,代码来源:test_base.py

示例9: test_BVA24_stim

# 需要导入模块: from numpy import testing [as 别名]
# 或者: from numpy.testing import assert_equal [as 别名]
def test_BVA24_stim():
    # Assign a stimulus:
    implant = BVA24()
    implant.stim = {'1': 1}
    npt.assert_equal(implant.stim.electrodes, ['1'])
    npt.assert_equal(implant.stim.time, None)
    npt.assert_equal(implant.stim.data, [[1]])

    # You can also assign the stimulus in the constructor:
    BVA24(stim={'1': 1})
    npt.assert_equal(implant.stim.electrodes, ['1'])
    npt.assert_equal(implant.stim.time, None)
    npt.assert_equal(implant.stim.data, [[1]])

    # Set a stimulus via array:
    implant = BVA24(stim=np.ones(35))
    npt.assert_equal(implant.stim.shape, (35, 1))
    npt.assert_almost_equal(implant.stim.data, 1) 
开发者ID:pulse2percept,项目名称:pulse2percept,代码行数:20,代码来源:test_bva.py

示例10: test_SquareElectrode

# 需要导入模块: from numpy import testing [as 别名]
# 或者: from numpy.testing import assert_equal [as 别名]
def test_SquareElectrode():
    with pytest.raises(TypeError):
        SquareElectrode(0, 0, 0, [1, 2])
    with pytest.raises(TypeError):
        SquareElectrode(0, np.array([0, 1]), 0, 1)
    # Invalid radius:
    with pytest.raises(ValueError):
        SquareElectrode(0, 0, 0, -5)
    # Check params:
    electrode = SquareElectrode(0, 1, 2, 100)
    npt.assert_almost_equal(electrode.x, 0)
    npt.assert_almost_equal(electrode.y, 1)
    npt.assert_almost_equal(electrode.z, 2)
    npt.assert_almost_equal(electrode.a, 100)
    # Slots:
    npt.assert_equal(hasattr(electrode, '__slots__'), True)
    npt.assert_equal(hasattr(electrode, '__dict__'), False)
    # Plots:
    ax = electrode.plot()
    npt.assert_equal(len(ax.texts), 0)
    npt.assert_equal(len(ax.patches), 1)
    npt.assert_equal(isinstance(ax.patches[0], Rectangle), True) 
开发者ID:pulse2percept,项目名称:pulse2percept,代码行数:24,代码来源:test_electrodes.py

示例11: test_HexElectrode

# 需要导入模块: from numpy import testing [as 别名]
# 或者: from numpy.testing import assert_equal [as 别名]
def test_HexElectrode():
    with pytest.raises(TypeError):
        HexElectrode(0, 0, 0, [1, 2])
    with pytest.raises(TypeError):
        HexElectrode(0, np.array([0, 1]), 0, 1)
    # Invalid radius:
    with pytest.raises(ValueError):
        HexElectrode(0, 0, 0, -5)
    # Check params:
    electrode = HexElectrode(0, 1, 2, 100)
    npt.assert_almost_equal(electrode.x, 0)
    npt.assert_almost_equal(electrode.y, 1)
    npt.assert_almost_equal(electrode.z, 2)
    npt.assert_almost_equal(electrode.a, 100)
    # Slots:
    npt.assert_equal(hasattr(electrode, '__slots__'), True)
    npt.assert_equal(hasattr(electrode, '__dict__'), False)
    # Plots:
    ax = electrode.plot()
    npt.assert_equal(len(ax.texts), 0)
    npt.assert_equal(len(ax.patches), 1)
    npt.assert_equal(isinstance(ax.patches[0], RegularPolygon), True) 
开发者ID:pulse2percept,项目名称:pulse2percept,代码行数:24,代码来源:test_electrodes.py

示例12: test_PhotovoltaicPixel

# 需要导入模块: from numpy import testing [as 别名]
# 或者: from numpy.testing import assert_equal [as 别名]
def test_PhotovoltaicPixel():
    electrode = PhotovoltaicPixel(0, 1, 2, 3, 4)
    npt.assert_almost_equal(electrode.x, 0)
    npt.assert_almost_equal(electrode.y, 1)
    npt.assert_almost_equal(electrode.z, 2)
    npt.assert_almost_equal(electrode.r, 3)
    npt.assert_almost_equal(electrode.a, 4)
    # Slots:
    npt.assert_equal(hasattr(electrode, '__slots__'), True)
    npt.assert_equal(hasattr(electrode, '__dict__'), False)
    # Plots:
    ax = electrode.plot()
    npt.assert_equal(len(ax.texts), 0)
    npt.assert_equal(len(ax.patches), 2)
    npt.assert_equal(isinstance(ax.patches[0], RegularPolygon), True)
    npt.assert_equal(isinstance(ax.patches[1], Circle), True)
    PhotovoltaicPixel(0, 1, 2, 3, 4) 
开发者ID:pulse2percept,项目名称:pulse2percept,代码行数:19,代码来源:test_prima.py

示例13: test_safe_binop

# 需要导入模块: from numpy import testing [as 别名]
# 或者: from numpy.testing import assert_equal [as 别名]
def test_safe_binop():
    # Test checked arithmetic routines

    ops = [
        (operator.add, 1),
        (operator.sub, 2),
        (operator.mul, 3)
    ]

    with exc_iter(ops, INT64_VALUES, INT64_VALUES) as it:
        for xop, a, b in it:
            pyop, op = xop
            c = pyop(a, b)

            if not (INT64_MIN <= c <= INT64_MAX):
                assert_raises(OverflowError, mt.extint_safe_binop, a, b, op)
            else:
                d = mt.extint_safe_binop(a, b, op)
                if c != d:
                    # assert_equal is slow
                    assert_equal(d, c) 
开发者ID:Frank-qlu,项目名称:recruit,代码行数:23,代码来源:test_extint128.py

示例14: test_cross_phase_1d

# 需要导入模块: from numpy import testing [as 别名]
# 或者: from numpy.testing import assert_equal [as 别名]
def test_cross_phase_1d(self, dask):
        N = 32
        x = np.linspace(0, 1, num=N, endpoint=False)
        f = 6
        phase_offset = np.pi/2
        signal1 = np.cos(2*np.pi*f*x)  # frequency = 1/(2*pi)
        signal2 = np.cos(2*np.pi*f*x - phase_offset)
        da1 = xr.DataArray(data=signal1, name='a', dims=['x'], coords={'x': x})
        da2 = xr.DataArray(data=signal2, name='b', dims=['x'], coords={'x': x})

        if dask:
            da1 = da1.chunk({'x': 32})
            da2 = da2.chunk({'x': 32})
        cp = xrft.cross_phase(da1, da2, dim=['x'])

        actual_phase_offset = cp.sel(freq_x=f).values
        npt.assert_almost_equal(actual_phase_offset, phase_offset)
        assert cp.name == 'a_b_phase'

        xrt.assert_equal(xrft.cross_phase(da1, da2), cp)

        with pytest.raises(ValueError):
            xrft.cross_phase(da1, da2.isel(x=0).drop('x'))

        with pytest.raises(ValueError):
            xrft.cross_phase(da1, da2.rename({'x':'y'})) 
开发者ID:xgcm,项目名称:xrft,代码行数:28,代码来源:test_xrft.py

示例15: test_isotropic_ps

# 需要导入模块: from numpy import testing [as 别名]
# 或者: from numpy.testing import assert_equal [as 别名]
def test_isotropic_ps():
    """Test data with extra coordinates"""
    da = xr.DataArray(np.random.rand(2,5,16,32),
                  dims=['time','z','y','x'],
                  coords={'time': np.array(['2019-04-18', '2019-04-19'],
                                          dtype='datetime64'),
                         'zz': ('z',np.arange(5)), 'z': np.arange(5),
                         'y': np.arange(16), 'x': np.arange(32)})
    with pytest.raises(ValueError):
        xrft.isotropic_power_spectrum(da, dim=['z','y','x'])
    iso_ps = xrft.isotropic_power_spectrum(da, dim=['y','x'])
    npt.assert_equal(
            np.ma.masked_invalid(iso_ps.isel(freq_r=slice(1,None))).mask.sum(),
            0.) 
开发者ID:xgcm,项目名称:xrft,代码行数:16,代码来源:test_xrft.py


注:本文中的numpy.testing.assert_equal方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。