当前位置: 首页>>代码示例>>Python>>正文


Python SVC.fit_generator方法代码示例

本文整理汇总了Python中sklearn.svm.SVC.fit_generator方法的典型用法代码示例。如果您正苦于以下问题:Python SVC.fit_generator方法的具体用法?Python SVC.fit_generator怎么用?Python SVC.fit_generator使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在sklearn.svm.SVC的用法示例。


在下文中一共展示了SVC.fit_generator方法的1个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: main

# 需要导入模块: from sklearn.svm import SVC [as 别名]
# 或者: from sklearn.svm.SVC import fit_generator [as 别名]
def main(hypes_file, data_dir, override):
    """Orchestrate."""
    with open(hypes_file, 'r') as f:
        hypes = json.load(f)

    model_file_path = os.path.abspath('%s.pkl' % hypes['model']['name'])

    color_changes = {0: (0, 0, 0, 0),
                     1: (0, 255, 0, 127),
                     'default': (0, 0, 0, 0)}

    if not os.path.isfile(model_file_path) or override:
        if not os.path.isfile(model_file_path):
            logging.info("Did not find '%s'. Start training...",
                         model_file_path)
        else:
            logging.info("Override '%s'. Start training...",
                         model_file_path)

        # Get data
        # x_files, y_files = inputs(hypes, None, 'train', data_dir)
        x_files, y_files = get_file_list(hypes, 'train')
        x_files, y_files = sklearn.utils.shuffle(x_files,
                                                 y_files,
                                                 random_state=0)

        x_train, y_train = get_traindata_single_file(hypes,
                                                     x_files[0],
                                                     y_files[0])

        nb_features = x_train[0].shape[0]
        logging.info("Input gets %i features", nb_features)

        # Make model
        from sklearn.svm import LinearSVC, SVC
        from sklearn.tree import DecisionTreeClassifier
        model = SVC(probability=False,  # cache_size=200,
                    kernel="linear", C=2.8, gamma=.0073)
        model = LinearSVC(C=2.8)
        model = DecisionTreeClassifier()

        print("Start fitting. This may take a while")

        generator = generate_training_data(hypes, x_files, y_files)
        t0 = time.time()

        if False:
            sep = hypes['solver']['samples_per_epoch']
            model.fit_generator(generator,
                                samples_per_epoch=sep,
                                nb_epoch=hypes['solver']['epochs'],
                                verbose=1,
                                # callbacks=[callb],
                                validation_data=(x_train, y_train))
        else:
            logging.info("Fit with .fit")
            x_train, y_train = inputs(hypes, None, 'train', data_dir)
            print(len(y_train))
            model.fit(x_train, y_train)
        t1 = time.time()
        print("Training Time: %0.4f" % (t1 - t0))

        # save as YAML
        joblib.dump(model, model_file_path)

        # Evaluate
        data = get_file_list(hypes, 'test')
        logging.info("Start segmentation")
        analyze.evaluate(hypes,
                         data,
                         data_dir,
                         model,
                         elements=[0, 1],
                         load_label_seg=load_label_seg,
                         color_changes=color_changes,
                         get_segmentation=get_segmentation)
    else:
        model = joblib.load(model_file_path)
        data = get_file_list(hypes, 'test')
        analyze.evaluate(hypes,
                         data,
                         data_dir,
                         model,
                         elements=[0, 1],
                         load_label_seg=load_label_seg,
                         color_changes=color_changes,
                         get_segmentation=get_segmentation)
开发者ID:TensorVision,项目名称:MediSeg,代码行数:89,代码来源:basic_local_classifier.py


注:本文中的sklearn.svm.SVC.fit_generator方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。