当前位置: 首页>>代码示例>>Python>>正文


Python Solver.fit方法代码示例

本文整理汇总了Python中solver.Solver.fit方法的典型用法代码示例。如果您正苦于以下问题:Python Solver.fit方法的具体用法?Python Solver.fit怎么用?Python Solver.fit使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在solver.Solver的用法示例。


在下文中一共展示了Solver.fit方法的4个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: main

# 需要导入模块: from solver import Solver [as 别名]
# 或者: from solver.Solver import fit [as 别名]
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--lr", type=float, default=0.0003)
    parser.add_argument("--batch_size", type=int, default=64)
    parser.add_argument("--max_epochs", type=int, default=100)
    parser.add_argument("--z_dim", type=int, default=100)
    
    parser.add_argument("--print_every", type=int, default=1)
    parser.add_argument("--result_dir", type=str, default="./result")
    parser.add_argument("--ckpt_dir", type=str, default="./checkpoint")
    parser.add_argument("--data_root", type=str, default="./data")

    args = parser.parse_args()
    solver = Solver(args)
    solver.fit()
开发者ID:muncok,项目名称:pytorch-exercise,代码行数:17,代码来源:train.py

示例2: main

# 需要导入模块: from solver import Solver [as 别名]
# 或者: from solver.Solver import fit [as 别名]
def main():
    fcnxs = symbol_fcnxs.get_fcn32s_symbol(numclass=2, workspace_default=1536)
    fcnxs_model_prefix = "model_pascal/FCN32s_VGG16"
    num_epoch = 1
    learning_rate = 1e-3
    if args.model == "fcn16s":
        fcnxs = symbol_fcnxs.get_fcn16s_symbol(numclass=2, workspace_default=1536)
        fcnxs_model_prefix = "model_pascal/FCN16s_VGG16"
        learning_rate = 1e-5
    elif args.model == "fcn8s":
        fcnxs = symbol_fcnxs.get_fcn8s_symbol(numclass=2, workspace_default=1536)
        fcnxs_model_prefix = "model_pascal/FCN8s_VGG16"
        num_epoch = 30
        learning_rate = 1e-7
    arg_names = fcnxs.list_arguments()
    _, fcnxs_args, fcnxs_auxs = mx.model.load_checkpoint(args.prefix, args.epoch)
    if not args.retrain:
        if args.init_type == "vgg16":
            fcnxs_args, fcnxs_auxs = init_fcnxs.init_from_vgg16(ctx, fcnxs, fcnxs_args, fcnxs_auxs)
        elif args.init_type == "fcnxs":
            fcnxs_args, fcnxs_auxs = init_fcnxs.init_from_fcnxs(ctx, fcnxs, fcnxs_args, fcnxs_auxs)
    train_dataiter = FileIter(
        root_dir             = ".",
        flist_name           = "train.lst",
        # cut_off_size         = 400,
        rgb_mean             = (123.68, 116.779, 103.939),
        )
#     val_dataiter = FileIter(
#         root_dir             = "/home/zw/dataset/VOC2012Segmentation/VOC2012",
#         flist_name           = "val.lst",
#         rgb_mean             = (123.68, 116.779, 103.939),
#         )
    model = Solver(
        ctx                 = ctx,
        symbol              = fcnxs,
        begin_epoch         = 0,
        num_epoch           = num_epoch,
        arg_params          = fcnxs_args,
        aux_params          = fcnxs_auxs,
        learning_rate       = learning_rate,
        momentum            = 0.9,
        wd                  = 0.0001)
    model.fit(
        train_data          = train_dataiter,
#         eval_data           = val_dataiter,
        batch_end_callback  = mx.callback.Speedometer(1, 10),
        epoch_end_callback  = mx.callback.do_checkpoint(fcnxs_model_prefix))
开发者ID:zhaw,项目名称:wine_private,代码行数:49,代码来源:fcn_xs.py

示例3: main

# 需要导入模块: from solver import Solver [as 别名]
# 或者: from solver.Solver import fit [as 别名]
def main():
    ctx = mx.cpu() if not args.gpu else mx.gpu(args.gpu)
    fcnxs = symbol_fcnxs.get_fcn32s_symbol(numclass=21, workspace_default=1536)
    fcnxs_model_prefix = "model_pascal/FCN32s_VGG16"
    if args.model == "fcn16s":
        fcnxs = symbol_fcnxs.get_fcn16s_symbol(numclass=21, workspace_default=1536)
        fcnxs_model_prefix = "model_pascal/FCN16s_VGG16"
    elif args.model == "fcn8s":
        fcnxs = symbol_fcnxs.get_fcn8s_symbol(numclass=21, workspace_default=1536)
        fcnxs_model_prefix = "model_pascal/FCN8s_VGG16"
    arg_names = fcnxs.list_arguments()
    _, fcnxs_args, fcnxs_auxs = mx.model.load_checkpoint(args.prefix, args.epoch)
    if not args.retrain:
        if args.init_type == "vgg16":
            fcnxs_args, fcnxs_auxs = init_fcnxs.init_from_vgg16(ctx, fcnxs, fcnxs_args, fcnxs_auxs)
        elif args.init_type == "fcnxs":
            fcnxs_args, fcnxs_auxs = init_fcnxs.init_from_fcnxs(ctx, fcnxs, fcnxs_args, fcnxs_auxs)
    train_dataiter = FileIter(
        root_dir             = "./VOC2012",
        flist_name           = "train.lst",
        # cut_off_size         = 400,
        rgb_mean             = (123.68, 116.779, 103.939),
        )
    val_dataiter = FileIter(
        root_dir             = "./VOC2012",
        flist_name           = "val.lst",
        rgb_mean             = (123.68, 116.779, 103.939),
        )
    model = Solver(
        ctx                 = ctx,
        symbol              = fcnxs,
        begin_epoch         = 0,
        num_epoch           = 50,
        arg_params          = fcnxs_args,
        aux_params          = fcnxs_auxs,
        learning_rate       = 1e-10,
        momentum            = 0.99,
        wd                  = 0.0005)
    model.fit(
        train_data          = train_dataiter,
        eval_data           = val_dataiter,
        batch_end_callback  = mx.callback.Speedometer(1, 10),
        epoch_end_callback  = mx.callback.do_checkpoint(fcnxs_model_prefix))
开发者ID:luobao-intel,项目名称:incubator-mxnet,代码行数:45,代码来源:fcn_xs.py

示例4: FileIter

# 需要导入模块: from solver import Solver [as 别名]
# 或者: from solver.Solver import fit [as 别名]
import symbol
net = symbol.get_symbol(output_dim = 30)

from data import FileIter

train = FileIter(
         eval_ratio = 0.2, 
         is_val = False,
         data_name = "data",
         batch_size = args.batch_size,
         label_name = "lr_label"
        )

val = FileIter(
     eval_ratio = 0.2, 
     is_val = True,
     data_name = "data",
     batch_size = 1,
     label_name = "lr_label"
    )


from solver import Solver
model = Solver(
    symbol = net,
    num_epoch = args.num_epochs
    )
model.fit(train, val, 
    batch_end_callback = mx.callback.Speedometer(1, 10),
    epoch_end_callback = mx.callback.do_checkpoint(args.model_prefix))
开发者ID:Rescape,项目名称:kaggle_Facial_Keypoints_Detection_mxnet,代码行数:32,代码来源:train.py


注:本文中的solver.Solver.fit方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。