当前位置: 首页>>代码示例>>Python>>正文


Python QuickBundles.find_closest方法代码示例

本文整理汇总了Python中dipy.segment.clustering.QuickBundles.find_closest方法的典型用法代码示例。如果您正苦于以下问题:Python QuickBundles.find_closest方法的具体用法?Python QuickBundles.find_closest怎么用?Python QuickBundles.find_closest使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在dipy.segment.clustering.QuickBundles的用法示例。


在下文中一共展示了QuickBundles.find_closest方法的3个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: main

# 需要导入模块: from dipy.segment.clustering import QuickBundles [as 别名]
# 或者: from dipy.segment.clustering.QuickBundles import find_closest [as 别名]
def main():
    parser = buildArgsParser()
    args = parser.parse_args()

    full_tfile = nib.streamlines.load(args.full_tfile)
    model_tfile = nib.streamlines.load(args.model_tfile)
    model_mask = nib.load(args.model_mask)

    # Bring streamlines to voxel space and where coordinate (0,0,0) represents the corner of a voxel.
    model_tfile.tractogram.apply_affine(np.linalg.inv(model_mask.affine))
    model_tfile.streamlines._data += 0.5  # Shift of half a voxel
    full_tfile.tractogram.apply_affine(np.linalg.inv(model_mask.affine))
    full_tfile.streamlines._data += 0.5  # Shift of half a voxel

    assert(model_mask.get_data().sum() == create_binary_map(model_tfile.streamlines, model_mask).sum())

    # Resample streamlines
    full_streamlines = set_number_of_points(full_tfile.streamlines, args.nb_points_resampling)
    model_streamlines = set_number_of_points(model_tfile.streamlines, args.nb_points_resampling)

    # Segment model
    rng = np.random.RandomState(42)
    indices = np.arange(len(model_streamlines))
    rng.shuffle(indices)
    qb = QuickBundles(args.qb_threshold)
    clusters = qb.cluster(model_streamlines, ordering=indices)

    # Try to find optimal assignment threshold
    best_threshold = None
    best_f1_score = -np.inf
    thresholds = np.arange(-2, 10, 0.2) + args.qb_threshold
    for threshold in thresholds:
        indices = qb.find_closest(clusters, full_streamlines, threshold=threshold)
        nb_assignments = np.sum(indices != -1)

        mask = create_binary_map(full_tfile.streamlines[indices != -1], model_mask)

        overlap_per_bundle = _compute_overlap(model_mask.get_data(), mask)
        overreach_per_bundle = _compute_overreach(model_mask.get_data(), mask)
        # overreach_norm_gt_per_bundle = _compute_overreach_normalize_gt(model_mask.get_data(), mask)
        f1_score = _compute_f1_score(overlap_per_bundle, overreach_per_bundle)
        if best_f1_score < f1_score:
            best_threshold = threshold
            best_f1_score = f1_score

        print("{}:\t {}/{} ({:.1%}) {:.1%}/{:.1%} ({:.1%}) {}/{}".format(
            threshold,
            nb_assignments, len(model_streamlines), nb_assignments/len(model_streamlines),
            overlap_per_bundle, overreach_per_bundle, f1_score,
            mask.sum(), model_mask.get_data().sum()))

        if overlap_per_bundle >= 1:
            break


    print("Best threshold: {} with F1-Score of {}".format(best_threshold, best_f1_score))
开发者ID:ppoulin91,项目名称:learn2track,代码行数:58,代码来源:auto_find_qb_threshold.py

示例2: auto_extract

# 需要导入模块: from dipy.segment.clustering import QuickBundles [as 别名]
# 或者: from dipy.segment.clustering.QuickBundles import find_closest [as 别名]
def auto_extract(model_cluster_map, rstreamlines,
                 number_pts_per_str=NB_POINTS_RESAMPLE,
                 close_centroids_thr=20,
                 clean_thr=7.,
                 disp=False, verbose=False,
                 ordering=None):

    if ordering is None:
        ordering = np.arange(len(rstreamlines))

    qb = QuickBundles(threshold=REF_BUNDLES_THRESHOLD, metric=AveragePointwiseEuclideanMetric())
    closest_bundles = qb.find_closest(model_cluster_map, rstreamlines, clean_thr, ordering=ordering)
    return ordering[np.where(closest_bundles >= 0)[0]]
开发者ID:ppoulin91,项目名称:learn2track,代码行数:15,代码来源:learn2track_metrics.py

示例3: _auto_extract_VCs

# 需要导入模块: from dipy.segment.clustering import QuickBundles [as 别名]
# 或者: from dipy.segment.clustering.QuickBundles import find_closest [as 别名]
def _auto_extract_VCs(streamlines, ref_bundles):
    # Streamlines = list of all streamlines

    # TODO check what is neede
    # VC = 0
    VC_idx = set()

    found_vbs_info = {}
    for bundle in ref_bundles:
        found_vbs_info[bundle['name']] = {'nb_streamlines': 0,
                                          'streamlines_indices': set()}

    # TODO probably not needed
    # already_assigned_streamlines_idx = set()

    # Need to bookkeep because we chunk for big datasets
    processed_strl_count = 0
    chunk_size = len(streamlines)
    chunk_it = 0

    # nb_bundles = len(ref_bundles)
    # bundles_found = [False] * nb_bundles
    #bundles_potential_VCWP = [set()] * nb_bundles

    logging.debug("Starting scoring VCs")

    # Start loop here for big datasets
    while processed_strl_count < len(streamlines):
        if processed_strl_count > 0:
            raise NotImplementedError("Not supposed to have more than one chunk!")

        logging.debug("Starting chunk: {0}".format(chunk_it))

        strl_chunk = streamlines[chunk_it * chunk_size: (chunk_it + 1) * chunk_size]

        processed_strl_count += len(strl_chunk)

        # Already resample and run quickbundles on the submission chunk,
        # to avoid doing it at every call of auto_extract
        rstreamlines = set_number_of_points(nib.streamlines.ArraySequence(strl_chunk), NB_POINTS_RESAMPLE)

        # qb.cluster had problem with f8
        # rstreamlines = [s.astype('f4') for s in rstreamlines]

        # chunk_cluster_map = qb.cluster(rstreamlines)
        # chunk_cluster_map.refdata = strl_chunk

        # # Merge clusters
        # all_bundles = ClusterMapCentroid()
        # cluster_id_to_bundle_id = []
        # for bundle_idx, ref_bundle in enumerate(ref_bundles):
        #     clusters = ref_bundle["cluster_map"]
        #     cluster_id_to_bundle_id.extend([bundle_idx] * len(clusters))
        #     all_bundles.add_cluster(*clusters)

        # logging.debug("Starting VC identification through auto_extract")
        # qb = QuickBundles(threshold=10, metric=AveragePointwiseEuclideanMetric())
        # closest_bundles = qb.find_closest(all_bundles, rstreamlines, threshold=7)

        # print("Unassigned streamlines: {}".format(np.sum(closest_bundles == -1)))

        # for cluster_id, bundle_id in enumerate(cluster_id_to_bundle_id):
        #     indices = np.where(closest_bundles == cluster_id)[0]
        #     print("{}/{} ({}) Found {}".format(cluster_id, len(cluster_id_to_bundle_id), ref_bundles[bundle_id]['name'], len(indices)))
        #     if len(indices) == 0:
        #         continue

        #     vb_info = found_vbs_info.get(ref_bundles[bundle_id]['name'])
        #     indices = set(indices)
        #     vb_info['nb_streamlines'] += len(indices)
        #     vb_info['streamlines_indices'] |= indices
        #     VC_idx |= indices

        qb = QuickBundles(threshold=10, metric=AveragePointwiseEuclideanMetric())
        ordering = np.arange(len(rstreamlines))
        logging.debug("Starting VC identification through auto_extract")
        for bundle_idx, ref_bundle in enumerate(ref_bundles):
            print(ref_bundle['name'], ref_bundle['threshold'], len(ref_bundle['cluster_map']))
            # The selected indices are from [0, len(strl_chunk)]
            # selected_streamlines_indices = auto_extract(ref_bundle['cluster_map'],
            #                                             rstreamlines,
            #                                             clean_thr=ref_bundle['threshold'],
            #                                             ordering=ordering)

            closest_bundles = qb.find_closest(ref_bundle['cluster_map'], rstreamlines[ordering], ref_bundle['threshold'])
            selected_streamlines_indices = ordering[closest_bundles >= 0]
            ordering = ordering[closest_bundles == -1]

            # Remove duplicates, when streamlines are assigned to multiple VBs.
            # TODO better handling of this case
            # selected_streamlines_indices = set(selected_streamlines_indices) - cur_chunk_VC_idx
            # cur_chunk_VC_idx |= selected_streamlines_indices

            nb_selected_streamlines = len(selected_streamlines_indices)
            print("{} assigned".format(nb_selected_streamlines))

            if nb_selected_streamlines:
                # bundles_found[bundle_idx] = True
                # VC += nb_selected_streamlines

#.........这里部分代码省略.........
开发者ID:ppoulin91,项目名称:learn2track,代码行数:103,代码来源:learn2track_metrics.py


注:本文中的dipy.segment.clustering.QuickBundles.find_closest方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。