当前位置: 首页>>代码示例>>Python>>正文


Python Data.set_data方法代码示例

本文整理汇总了Python中idtxl.data.Data.set_data方法的典型用法代码示例。如果您正苦于以下问题:Python Data.set_data方法的具体用法?Python Data.set_data怎么用?Python Data.set_data使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在idtxl.data.Data的用法示例。


在下文中一共展示了Data.set_data方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: test_bivariate_te_one_realisation_per_replication

# 需要导入模块: from idtxl.data import Data [as 别名]
# 或者: from idtxl.data.Data import set_data [as 别名]
def test_bivariate_te_one_realisation_per_replication():
    """Test boundary case of one realisation per replication."""
    # Create a data set where one pattern fits into the time series exactly
    # once, this way, we get one realisation per replication for each variable.
    # This is easyer to assert/verify later. We also test data.get_realisations
    # this way.
    settings = {
        'cmi_estimator': 'JidtKraskovCMI',
        'n_perm_max_stat': 21,
        'max_lag_target': 5,
        'max_lag_sources': 5,
        'min_lag_sources': 4}
    target = 0
    data = Data(normalise=False)
    n_repl = 10
    n_procs = 2
    n_points = n_procs * (settings['max_lag_sources'] + 1) * n_repl
    data.set_data(np.arange(n_points).reshape(
                                        n_procs,
                                        settings['max_lag_sources'] + 1,
                                        n_repl), 'psr')
    nw = BivariateTE()
    nw._initialise(settings, data, 'all', target)
    assert (not nw.selected_vars_full)
    assert (not nw.selected_vars_sources)
    assert (not nw.selected_vars_target)
    assert ((nw._replication_index == np.arange(n_repl)).all())
    assert (nw._current_value == (target, max(
           settings['max_lag_sources'], settings['max_lag_target'])))
    assert (nw._current_value_realisations[:, 0] ==
            data.data[target, -1, :]).all()
开发者ID:SimonStreicher,项目名称:IDTxl,代码行数:33,代码来源:test_bivariate_te.py

示例2: ft2idtxlconverter

# 需要导入模块: from idtxl.data import Data [as 别名]
# 或者: from idtxl.data.Data import set_data [as 别名]
def ft2idtxlconverter(filename, FTstructname, fileversion):
    # TODO: This will need better error handling !
    if fileversion == "v7.3":
#        try:
        print('Creating Python dictionary from FT data structure: ' + FTstructname)
        NPData = _ft_trial_2_numpyarray(filename, FTstructname)
        label = _ft_label_2_list(filename, FTstructname)
        NPfsample = _ft_fsample_2_float(filename, FTstructname)
        NPtime = _ft_time_2_numpyarray(filename, FTstructname)
        # convert data into IDTxl's Data class
        d = Data()
        # fieldtrip had "channel x timesamples" data,
        # but numpy sees the data as stored internally in the hdf5 file as:
        # "timesamples x channel"
        # we collected the replications
        # in the tirhd diemsnion --> dimension are:
        # s(amples) x p(rocesses) x r(eplications) = 'spr'
        d.set_data(NPData, 'spr')
        TXLdata = {"dataset" : d , "label" : label,
                   "time" : NPtime, "fsample" : NPfsample}

#        except(OSError, RuntimeError):
#            print('incorrect file version, the given file was not a MATLAB'
#                  ' m-file version 7.3')
#            return
    else:
        print('At present only m-files in format 7.3 are aupported,'
              'please consider reopening and resaving your m-file in that'
              'version')
    return TXLdata
开发者ID:mwibral,项目名称:IDTxl,代码行数:32,代码来源:ft2idtxl.py

示例3: test_data_type

# 需要导入模块: from idtxl.data import Data [as 别名]
# 或者: from idtxl.data.Data import set_data [as 别名]
def test_data_type():
    """Test if data class always returns the correct data type."""
    # Change data type for the same object instance.
    d_int = np.random.randint(0, 10, size=(3, 50))
    orig_type = type(d_int[0][0])
    data = Data(d_int, dim_order='ps', normalise=False)
    # The concrete type depends on the platform:
    # https://mail.scipy.org/pipermail/numpy-discussion/2011-November/059261.html
    # Hence, compare against the type automatically assigned by Python or
    # against np.integer
    assert data.data_type is orig_type, 'Data type did not change.'
    assert issubclass(type(data.data[0, 0, 0]), np.integer), (
        'Data type is not an int.')
    d_float = np.random.randn(3, 50)
    data.set_data(d_float, dim_order='ps')
    assert data.data_type is np.float64, 'Data type did not change.'
    assert issubclass(type(data.data[0, 0, 0]), np.float), (
        'Data type is not a float.')

    # Check if data returned by the object have the correct type.
    d_int = np.random.randint(0, 10, size=(3, 50, 5))
    data = Data(d_int, dim_order='psr', normalise=False)
    real = data.get_realisations((0, 5), [(1, 1), (1, 3)])[0]
    assert issubclass(type(real[0, 0]), np.integer), (
        'Realisations type is not an int.')
    sl = data._get_data_slice(0)[0]
    assert issubclass(type(sl[0, 0]), np.integer), (
        'Data slice type is not an int.')
    settings = {'perm_type': 'random'}
    sl_perm = data.slice_permute_samples(0, settings)[0]
    assert issubclass(type(sl_perm[0, 0]), np.integer), (
        'Permuted data slice type is not an int.')
    samples = data.permute_samples((0, 5), [(1, 1), (1, 3)], settings)[0]
    assert issubclass(type(samples[0, 0]), np.integer), (
        'Permuted samples type is not an int.')
开发者ID:SimonStreicher,项目名称:IDTxl,代码行数:37,代码来源:test_data.py

示例4: ft2idtxlconverter

# 需要导入模块: from idtxl.data import Data [as 别名]
# 或者: from idtxl.data.Data import set_data [as 别名]
def ft2idtxlconverter(filename, FTstructname, fileversion):
    """Convert FieldTrip-style MATLAB-file into an IDTxl Data object.

    Import a MATLAB structure with fields  "trial" (data), "label" (channel
    labels), "time" (time stamps for data samples), and "fsample" (sampling
    rate). This structure is the standard file format in the MATLAB toolbox
    FieldTrip and commonly use to represent neurophysiological data (see also
    http://www.fieldtriptoolbox.org/). The functions reads a mat-file from
    disc and returns a dictionary containing the information in the mat-file.
    Data is represented as an IDTxl Data object.

    Args:
        filename : string
            full (matlab) filename on disk
        FTstructname : string
            variable name of the MATLAB structure that is in FieldTrip format
            (autodetect will hopefully be possible later ...)
        fileversion : string
            version of the file, e.g. "v7.3" for MATLAB's 7.3 format

    Returns:
        dict
            "dataset": instance of IDTxl Data object; "label": list of channel
            labels; "time": numpy array of time stamps; "fsample": sampling
            rate
    """
    # TODO: This will need better error handling !
    if fileversion == "v7.3":
        print('Creating Python dictionary from FT data structure: ' +
              FTstructname)
        NPData = _ft_trial_2_numpyarray(filename, FTstructname)
        label = _ft_label_2_list(filename, FTstructname)
        NPfsample = _ft_fsample_2_float(filename, FTstructname)
        NPtime = _ft_time_2_numpyarray(filename, FTstructname)
        # convert data into IDTxl's Data class
        d = Data()
        # fieldtrip had "channel x timesamples" data,
        # but numpy sees the data as stored internally in the hdf5 file as:
        # "timesamples x channel"
        # we collected the replications
        # in the tirhd diemsnion --> dimension are:
        # s(amples) x p(rocesses) x r(eplications) = 'spr'
        d.set_data(NPData, 'spr')
        TXLdata = {"dataset": d,
                   "label": label,
                   "time": NPtime,
                   "fsample": NPfsample}

#        except(OSError, RuntimeError):
#            print('incorrect file version, the given file was not a MATLAB'
#                  ' m-file version 7.3')
#            return
    else:
        print('At present only m-files in format 7.3 are aupported,'
              'please consider reopening and resaving your m-file in that'
              'version')
        # TODO we could write a fallback option using numpy's loadmat?
    return TXLdata
开发者ID:finnconor,项目名称:IDTxl,代码行数:60,代码来源:ft2idtxl.py

示例5: test_multivariate_te_initialise

# 需要导入模块: from idtxl.data import Data [as 别名]
# 或者: from idtxl.data.Data import set_data [as 别名]
def test_multivariate_te_initialise():
    """Test if all values are set correctly in _initialise()."""
    # Create a data set where one pattern fits into the time series exactly
    # once, this way, we get one realisation per replication for each variable.
    # This is easyer to assert/verify later. We also test data.get_realisations
    # this way.
    analysis_opts = {'cmi_calc_name': 'jidt_kraskov'}
    max_lag_target = 5
    max_lag_sources = max_lag_target
    min_lag_sources = 4
    target = 1
    dat = Data(normalise=False)
    n_repl = 30
    n_procs = 2
    n_points = n_procs * (max_lag_sources + 1) * n_repl
    dat.set_data(np.arange(n_points).reshape(n_procs, max_lag_sources + 1,
                                             n_repl), 'psr')
    nw_0 = Multivariate_te(max_lag_sources, min_lag_sources, max_lag_target,
                           analysis_opts)
    nw_0._initialise(dat, 'all', target)
    assert (not nw_0.selected_vars_full)
    assert (not nw_0.selected_vars_sources)
    assert (not nw_0.selected_vars_target)
    assert ((nw_0._replication_index == np.arange(n_repl)).all())
    assert (nw_0._current_value == (target, max(max_lag_sources,
                                                max_lag_target)))
    assert ((nw_0._current_value_realisations ==
             np.arange(n_points - n_repl, n_points).reshape(n_repl, 1)).all())

    # Check if the Faes method is working
    analysis_opts['add_conditionals'] = 'faes'
    nw_1 = Multivariate_te(max_lag_sources, min_lag_sources, max_lag_target,
                           analysis_opts)
    dat.generate_mute_data()
    sources = [1, 2, 3]
    target = [0]
    nw_1._initialise(dat, sources, target)
    assert (nw_1._selected_vars_sources ==
            [i for i in it.product(sources, [nw_1.current_value[1]])]), (
                'Did not add correct additional conditioning vars.')

    # Adding a variable that is not in the data set.
    analysis_opts['add_conditionals'] = (8, 0)
    nw_1 = Multivariate_te(max_lag_sources, min_lag_sources, max_lag_target,
                           analysis_opts)
    dat.generate_mute_data()
    sources = [1, 2, 3]
    target = [0]
    with pytest.raises(IndexError):
        nw_1._initialise(dat, sources, target)
开发者ID:finnconor,项目名称:IDTxl,代码行数:52,代码来源:test_multivariate_te.py

示例6: test_data_normalisation

# 需要导入模块: from idtxl.data import Data [as 别名]
# 或者: from idtxl.data.Data import set_data [as 别名]
def test_data_normalisation():
    """Test if data are normalised correctly when stored in a Data instance."""
    a_1 = 100
    a_2 = 1000
    source = np.random.randint(a_1, size=1000)
    target = np.random.randint(a_2, size=1000)

    data = Data(normalise=True)
    data.set_data(np.vstack((source.T, target.T)), 'ps')

    source_std = utils.standardise(source)
    target_std = utils.standardise(target)
    assert (source_std == data.data[0, :, 0]).all(), ('Standardising the '
                                                      'source did not work.')
    assert (target_std == data.data[1, :, 0]).all(), ('Standardising the '
                                                      'target did not work.')
开发者ID:finnconor,项目名称:IDTxl,代码行数:18,代码来源:test_data.py

示例7: test_multivariate_te_lagged_copies

# 需要导入模块: from idtxl.data import Data [as 别名]
# 或者: from idtxl.data.Data import set_data [as 别名]
def test_multivariate_te_lagged_copies():
    """Test multivariate TE estimation on a lagged copy of random data.

    Run the multivariate TE algorithm on two sets of random data, where the
    second set is a lagged copy of the first. This test should find no
    significant conditionals at all (neither in the target's nor in the
    source's past).

    Note:
        This test takes several hours and may take one to two days on some
        machines.
    """
    lag = 3
    d_0 = np.random.rand(1, 1000, 20)
    d_1 = np.hstack((np.random.rand(1, lag, 20), d_0[:, lag:, :]))

    dat = Data()
    dat.set_data(np.vstack((d_0, d_1)), 'psr')
    analysis_opts = {
        'cmi_calc_name': 'jidt_discrete',
        'discretise_method': 'max_ent',
        'n_perm_max_stat': 21,
        'n_perm_min_stat': 21,
        'n_perm_omnibus': 500,
        'n_perm_max_seq': 500,
        }
    random_analysis = Multivariate_te(max_lag_sources=5, options=analysis_opts)
    # Assert that there are no significant conditionals in either direction
    # other than the mandatory single sample in the target's past (which
    # ensures that we calculate a proper TE at any time in the algorithm).
    for target in range(2):
        res = random_analysis.analyse_single_target(dat, target)
        assert (len(res['conditional_full']) == 1), ('Conditional contains '
                                                     'more/less than 1 '
                                                     'variables.')
        assert (not res['conditional_sources']), ('Conditional sources is not '
                                                  'empty.')
        assert (len(res['conditional_target']) == 1), ('Conditional target '
                                                       'contains more/less '
                                                       'than 1 variable.')
        assert (res['cond_sources_pval'] is None), ('Conditional p-value is '
                                                    'not None.')
        assert (res['omnibus_pval'] is None), ('Omnibus p-value is not None.')
        assert (res['omnibus_sign'] is None), ('Omnibus significance is not '
                                               'None.')
        assert (res['conditional_sources_te'] is None), ('Conditional TE '
                                                         'values is not None.')
开发者ID:pwollstadt,项目名称:IDTxl,代码行数:49,代码来源:systemtest_multivariate_te_discrete.py

示例8: test_multivariate_te_lagged_copies

# 需要导入模块: from idtxl.data import Data [as 别名]
# 或者: from idtxl.data.Data import set_data [as 别名]
def test_multivariate_te_lagged_copies():
    """Test multivariate TE estimation on a lagged copy of random data.

    Run the multivariate TE algorithm on two sets of random data, where the
    second set is a lagged copy of the first. This test should find no
    significant conditionals at all (neither in the target's nor in the
    source's past).

    Note:
        This test takes several hours and may take one to two days on some
        machines.
    """
    lag = 3
    d_0 = np.random.rand(1, 1000, 20)
    d_1 = np.hstack((np.random.rand(1, lag, 20), d_0[:, lag:, :]))

    data = Data()
    data.set_data(np.vstack((d_0, d_1)), 'psr')
    settings = {
        'cmi_estimator':  'JidtKraskovCMI',
        'max_lag_sources': 5,
        'n_perm_max_stat': 21,
        'n_perm_min_stat': 21,
        'n_perm_omnibus': 500,
        'n_perm_max_seq': 500,
        }
    random_analysis = MultivariateTE()
    # Assert that there are no significant conditionals in either direction
    # other than the mandatory single sample in the target's past (which
    # ensures that we calculate a proper TE at any time in the algorithm).
    for t in range(2):
        results = random_analysis.analyse_single_target(settings, data, t)
        assert len(results.get_single_target(t, fdr=False).selected_vars_full) == 1, (
                    'Conditional contains more/less than 1 variables.')
        assert not results.get_single_target(t, fdr=False).selected_vars_sources.size, (
                    'Conditional sources is not empty.')
        assert len(results.get_single_target(t, fdr=False).selected_vars_target) == 1, (
            'Conditional target contains more/less than 1 variable.')
        assert results.get_single_target(t, fdr=False).selected_sources_pval is None, (
            'Conditional p-value is not None.')
        assert results.get_single_target(t, fdr=False).omnibus_pval is None, (
            'Omnibus p-value is not None.')
        assert results.get_single_target(t, fdr=False).omnibus_sign is None, (
            'Omnibus significance is not None.')
        assert results.get_single_target(t, fdr=False).selected_sources_te is None, (
            'Conditional TE values is not None.')
开发者ID:SimonStreicher,项目名称:IDTxl,代码行数:48,代码来源:systemtest_multivariate_te.py

示例9: test_multivariate_te_lorenz_2

# 需要导入模块: from idtxl.data import Data [as 别名]
# 或者: from idtxl.data.Data import set_data [as 别名]
def test_multivariate_te_lorenz_2():
    """Test multivariate TE estimation on bivariately couled Lorenz systems.

    Run the multivariate TE algorithm on two Lorenz systems with a coupling
    from first to second system with delay u = 45 samples. Both directions are
    analyzed, the algorithm should not find a coupling from system two to one.

    Note:
        This test takes several hours and may take one to two days on some
        machines.
    """
    # load simulated data from 2 coupled Lorenz systems 1->2, u = 45 ms
    d = np.load(os.path.join(os.path.dirname(__file__),
                'data/lorenz_2_exampledata.npy'))
    data = Data()
    data.set_data(d, 'psr')
    settings = {
        'cmi_estimator':  'JidtKraskovCMI',
        'max_lag_sources': 47,
        'min_lag_sources': 42,
        'max_lag_target': 20,
        'tau_target': 2,
        'n_perm_max_stat': 21,  # 200
        'n_perm_min_stat': 21,  # 200
        'n_perm_omnibus': 21,
        'n_perm_max_seq': 21,  # this should be equal to the min stats b/c we
                               # reuse the surrogate table from the min stats
        }
    lorenz_analysis = MultivariateTE()
    # FOR DEBUGGING: add the whole history for k = 20, tau = 2 to the
    # estimation, this makes things faster, b/c these don't have to be
    # tested again. Note conditionals are specified using lags.
    settings['add_conditionals'] = [(1, 19), (1, 17), (1, 15), (1, 13),
                                    (1, 11), (1, 9), (1, 7), (1, 5), (1, 3),
                                    (1, 1)]

    settings['max_lag_sources'] = 60
    settings['min_lag_sources'] = 31
    settings['tau_sources'] = 2
    settings['max_lag_target'] = 1
    settings['tau_target'] = 1

    # Just analyse the direction of coupling
    results = lorenz_analysis.analyse_single_target(settings, data, target=1)
    print(results._single_target)
    print(results.get_adjacency_matrix('binary'))
开发者ID:SimonStreicher,项目名称:IDTxl,代码行数:48,代码来源:systemtest_multivariate_te.py

示例10: test_data_type

# 需要导入模块: from idtxl.data import Data [as 别名]
# 或者: from idtxl.data.Data import set_data [as 别名]
def test_data_type():
    """Test if stats always returns surrogates with the correct data type."""
    # Change data type for the same object instance.
    d_int = np.random.randint(0, 10, size=(3, 50))
    orig_type = type(d_int[0][0])
    data = Data(d_int, dim_order='ps', normalise=False)
    # The concrete type depends on the platform:
    # https://mail.scipy.org/pipermail/numpy-discussion/2011-November/059261.html
    assert data.data_type is orig_type, 'Data type did not change.'
    assert issubclass(type(data.data[0, 0, 0]), np.integer), (
        'Data type is not an int.')
    settings = {'permute_in_time': True, 'perm_type': 'random'}
    surr = stats._get_surrogates(data=data,
                                 current_value=(0, 5),
                                 idx_list=[(1, 3), (2, 4)],
                                 n_perm=20,
                                 perm_settings=settings)
    assert issubclass(type(surr[0, 0]), np.integer), (
        'Realisations type is not an int.')
    surr = stats._generate_spectral_surrogates(data=data,
                                               scale=1,
                                               n_perm=20,
                                               perm_settings=settings)
    assert issubclass(type(surr[0, 0, 0]), np.integer), (
        'Realisations type is not an int.')

    d_float = np.random.randn(3, 50)
    data.set_data(d_float, dim_order='ps')
    assert data.data_type is np.float64, 'Data type did not change.'
    assert issubclass(type(data.data[0, 0, 0]), np.float), (
        'Data type is not a float.')
    surr = stats._get_surrogates(data=data,
                                 current_value=(0, 5),
                                 idx_list=[(1, 3), (2, 4)],
                                 n_perm=20,
                                 perm_settings=settings)
    assert issubclass(type(surr[0, 0]), np.float), (
        'Realisations type is not a float.')
    surr = stats._generate_spectral_surrogates(data=data,
                                               scale=1,
                                               n_perm=20,
                                               perm_settings=settings)
    assert issubclass(type(surr[0, 0, 0]), np.float), ('Realisations type is '
                                                       'not a float.')
开发者ID:SimonStreicher,项目名称:IDTxl,代码行数:46,代码来源:test_stats.py

示例11: test_multivariate_te_lorenz_2

# 需要导入模块: from idtxl.data import Data [as 别名]
# 或者: from idtxl.data.Data import set_data [as 别名]
def test_multivariate_te_lorenz_2():
    """Test multivariate TE estimation on bivariately couled Lorenz systems.

    Run the multivariate TE algorithm on two Lorenz systems with a coupling
    from first to second system with delay u = 45 samples. Both directions are
    analyzed, the algorithm should not find a coupling from system two to one.

    Note:
        This test takes several hours and may take one to two days on some
        machines.
    """

    d = np.load(os.path.join(os.path.dirname(__file__),
                'data/lorenz_2_exampledata.npy'))
    dat = Data()
    dat.set_data(d, 'psr')
    analysis_opts = {
        'cmi_calc_name': 'jidt_discrete',
        'discretise_method': 'max_ent',
        'n_perm_max_stat': 21,  # 200
        'n_perm_min_stat': 21,  # 200
        'n_perm_omnibus': 21,
        'n_perm_max_seq': 21,  # this should be equal to the min stats b/c we
                               # reuse the surrogate table from the min stats
        }
    lorenz_analysis = Multivariate_te(max_lag_sources=47, min_lag_sources=42,
                                      max_lag_target=20, tau_target=2,
                                      options=analysis_opts)
    # FOR DEBUGGING: add the whole history for k = 20, tau = 2 to the
    # estimation, this makes things faster, b/c these don't have to be
    # tested again.
    analysis_opts['add_conditionals'] = [(1, 44), (1, 42), (1, 40), (1, 38),
                                         (1, 36), (1, 34), (1, 32), (1, 30),
                                         (1, 28)]
    lorenz_analysis = Multivariate_te(max_lag_sources=60, min_lag_sources=31,
                                      tau_sources=2,
                                      max_lag_target=0, tau_target=1,
                                      options=analysis_opts)
    # res = lorenz_analysis.analyse_network(dat)
    # res_0 = lorenz_analysis.analyse_single_target(dat, 0)  # no coupling
    # print(res_0)
    res_1 = lorenz_analysis.analyse_single_target(dat, 1)  # coupling
    print(res_1)
开发者ID:pwollstadt,项目名称:IDTxl,代码行数:45,代码来源:systemtest_multivariate_te_discrete.py

示例12: test_multivariate_te_random

# 需要导入模块: from idtxl.data import Data [as 别名]
# 或者: from idtxl.data.Data import set_data [as 别名]
def test_multivariate_te_random():
    """Test multivariate TE estimation on two random data sets.

    Run the multivariate TE algorithm on two sets of random data with no
    coupling. This test should find no significant conditionals at all (neither
    in the target's nor in the source's past).

    Note:
        This test takes several hours and may take one to two days on some
        machines.
    """
    d = np.random.rand(2, 1000, 20)
    dat = Data()
    dat.set_data(d, 'psr')
    analysis_opts = {
        'cmi_calc_name': 'jidt_kraskov',
        'n_perm_max_stat': 200,
        'n_perm_min_stat': 200,
        'n_perm_omnibus': 500,
        'n_perm_max_seq': 500,
        }
    random_analysis = Multivariate_te(max_lag_sources=5, options=analysis_opts)
    # Assert that there are no significant conditionals in either direction
    # other than the mandatory single sample in the target's past (which
    # ensures that we calculate a proper TE at any time in the algorithm).
    for target in range(2):
        res = random_analysis.analyse_single_target(dat, target)
        assert (len(res['conditional_full']) == 1), ('Conditional contains '
                                                     'more/less than 1 '
                                                     'variables.')
        assert (not res['conditional_sources']), ('Conditional sources is not '
                                                  'empty.')
        assert (len(res['conditional_target']) == 1), ('Conditional target '
                                                       'contains more/less '
                                                       'than 1 variable.')
        assert (res['cond_sources_pval'] is None), ('Conditional p-value is '
                                                    'not None.')
        assert (res['omnibus_pval'] is None), ('Omnibus p-value is not None.')
        assert (res['omnibus_sign'] is None), ('Omnibus significance is not '
                                               'None.')
        assert (res['conditional_sources_te'] is None), ('Conditional TE '
                                                         'values is not None.')
开发者ID:pwollstadt,项目名称:IDTxl,代码行数:44,代码来源:systemtest_multivariate_te.py

示例13: test_set_data

# 需要导入模块: from idtxl.data import Data [as 别名]
# 或者: from idtxl.data.Data import set_data [as 别名]
def test_set_data():
    """Test if data is written correctly into a Data instance."""
    source = np.expand_dims(np.repeat(1, 30), axis=1)
    target = np.expand_dims(np.arange(30), axis=1)

    data = Data(normalise=False)
    data.set_data(np.vstack((source.T, target.T)), 'ps')

    assert (data.data[0, :].T == source.T).all(), ('Class data does not match '
                                                   'input (source).')
    assert (data.data[1, :].T == target.T).all(), ('Class data does not match '
                                                   'input (target).')

    d = Data()
    dat = np.arange(10000).reshape((2, 1000, 5))  # random data with correct
    d = Data(dat, dim_order='psr')               # order od dimensions
    assert (d.data.shape[0] == 2), ('Class data does not match input, number '
                                    'of processes wrong.')
    assert (d.data.shape[1] == 1000), ('Class data does not match input, '
                                       'number of samples wrong.')
    assert (d.data.shape[2] == 5), ('Class data does not match input, number '
                                    'of replications wrong.')
    dat = np.arange(3000).reshape((3, 1000))  # random data with incorrect
    d = Data(dat, dim_order='ps')            # order of dimensions
    assert (d.data.shape[0] == 3), ('Class data does not match input, number '
                                    'of processes wrong.')
    assert (d.data.shape[1] == 1000), ('Class data does not match input, '
                                       'number of samples wrong.')
    assert (d.data.shape[2] == 1), ('Class data does not match input, number '
                                    'of replications wrong.')
    dat = np.arange(5000)
    d.set_data(dat, 's')
    assert (d.data.shape[0] == 1), ('Class data does not match input, number '
                                    'of processes wrong.')
    assert (d.data.shape[1] == 5000), ('Class data does not match input, '
                                       'number of samples wrong.')
    assert (d.data.shape[2] == 1), ('Class data does not match input, number '
                                    'of replications wrong.')
开发者ID:finnconor,项目名称:IDTxl,代码行数:40,代码来源:test_data.py

示例14: test_calculate_cmi_all_links

# 需要导入模块: from idtxl.data import Data [as 别名]
# 或者: from idtxl.data.Data import set_data [as 别名]
def test_calculate_cmi_all_links():
    """Test if the CMI is estimated correctly."""
    data = Data()
    n = 1000
    cov = 0.4
    source = [rn.normalvariate(0, 1) for r in range(n)]  # correlated src
    target = [0] + [sum(pair) for pair in zip(
        [cov * y for y in source[0:n - 1]],
        [(1 - cov) * y for y in
            [rn.normalvariate(0, 1) for r in range(n - 1)]])]
    data.set_data(np.vstack((source, target)), 'ps')
    res_0 = np.load(os.path.join(os.path.dirname(__file__),
                    'data/mute_results_0.p'))
    comp_settings = {
        'cmi_estimator': 'JidtKraskovCMI',
        'n_perm_max_stat': 50,
        'n_perm_min_stat': 50,
        'n_perm_omnibus': 200,
        'n_perm_max_seq': 50,
        'tail': 'two',
        'n_perm_comp': 6,
        'alpha_comp': 0.2,
        'stats_type': 'dependent'
        }
    comp = NetworkComparison()
    comp._initialise(comp_settings)
    comp._create_union(res_0)
    comp.union._single_target[1]['selected_vars_sources'] = [(0, 4)]
    cmi = comp._calculate_cmi_all_links(data)
    corr_expected = cov / (1 * np.sqrt(cov**2 + (1-cov)**2))
    expected_cmi = calculate_mi(corr_expected)
    print('correlated Gaussians: TE result {0:.4f} bits; expected to be '
          '{1:0.4f} bit for the copy'.format(cmi[1][0], expected_cmi))
    np.testing.assert_almost_equal(
                   cmi[1][0], expected_cmi, decimal=1,
                   err_msg='when calculating cmi for correlated Gaussians.')
开发者ID:SimonStreicher,项目名称:IDTxl,代码行数:38,代码来源:test_network_comparison.py

示例15: test_multivariate_te_corr_gaussian

# 需要导入模块: from idtxl.data import Data [as 别名]
# 或者: from idtxl.data.Data import set_data [as 别名]
def test_multivariate_te_corr_gaussian(estimator=None):
    """Test multivariate TE estimation on correlated Gaussians.

    Run the multivariate TE algorithm on two sets of random Gaussian data with
    a given covariance. The second data set is shifted by one sample creating
    a source-target delay of one sample. This example is modeled after the
    JIDT demo 4 for transfer entropy. The resulting TE can be compared to the
    analytical result (but expect some error in the estimate).

    The simulated delay is 1 sample, i.e., the algorithm should find
    significant TE from sample (0, 1), a sample in process 0 with lag/delay 1.
    The final target sample should always be (1, 1), the mandatory sample at
    lat 1, because there is no memory in the process.

    Note:
        This test runs considerably faster than other system tests.
        This produces strange small values for non-coupled sources.  TODO
    """
    if estimator is None:
        estimator = 'jidt_kraskov'

    n = 1000
    cov = 0.4
    source_1 = [rn.normalvariate(0, 1) for r in range(n)]  # correlated src
    # source_2 = [rn.normalvariate(0, 1) for r in range(n)]  # uncorrelated src
    target = [sum(pair) for pair in zip(
        [cov * y for y in source_1],
        [(1 - cov) * y for y in [rn.normalvariate(0, 1) for r in range(n)]])]
    # Cast everything to numpy so the idtxl estimator understands it.
    source_1 = np.expand_dims(np.array(source_1), axis=1)
    # source_2 = np.expand_dims(np.array(source_2), axis=1)
    target = np.expand_dims(np.array(target), axis=1)

    dat = Data(normalise=True)
    dat.set_data(np.vstack((source_1[1:].T, target[:-1].T)), 'ps')
    analysis_opts = {
        'cmi_calc_name': estimator,
        'n_perm_max_stat': 21,
        'n_perm_min_stat': 21,
        'n_perm_omnibus': 21,
        'n_perm_max_seq': 21,
        }
    random_analysis = Multivariate_te(max_lag_sources=5, min_lag_sources=1,
                                      max_lag_target=5, options=analysis_opts)
    # res = random_analysis.analyse_network(dat)  # full network
    # utils.print_dict(res)
    res_1 = random_analysis.analyse_single_target(dat, 1)  # coupled direction
    # Assert that there are significant conditionals from the source for target
    # 1. For 500 repetitions I got mean errors of 0.02097686 and 0.01454073 for
    # examples 1 and 2 respectively. The maximum errors were 0.093841 and
    # 0.05833172 repectively. This inspired the following error boundaries.
    expected_res = np.log(1 / (1 - np.power(cov, 2)))
    diff = np.abs(max(res_1['cond_sources_te']) - expected_res)
    print('Expected source sample: (0, 1)\nExpected target sample: (1, 1)')
    print(('Estimated TE: {0:5.4f}, analytical result: {1:5.4f}, error:'
           '{2:2.2f} % ').format(max(res_1['cond_sources_te']), expected_res,
                                 diff / expected_res))
    assert (diff < 0.1), ('Multivariate TE calculation for correlated '
                          'Gaussians failed (error larger 0.1: {0}, expected: '
                          '{1}, actual: {2}).'.format(diff,
                                                      expected_res,
                                                      res_1['cond_sources_te']
                                                      ))
开发者ID:pwollstadt,项目名称:IDTxl,代码行数:65,代码来源:systemtest_multivariate_te.py


注:本文中的idtxl.data.Data.set_data方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。