本文整理汇总了Python中solver.Solver.fit方法的典型用法代码示例。如果您正苦于以下问题:Python Solver.fit方法的具体用法?Python Solver.fit怎么用?Python Solver.fit使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类solver.Solver
的用法示例。
在下文中一共展示了Solver.fit方法的4个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: main
# 需要导入模块: from solver import Solver [as 别名]
# 或者: from solver.Solver import fit [as 别名]
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--lr", type=float, default=0.0003)
parser.add_argument("--batch_size", type=int, default=64)
parser.add_argument("--max_epochs", type=int, default=100)
parser.add_argument("--z_dim", type=int, default=100)
parser.add_argument("--print_every", type=int, default=1)
parser.add_argument("--result_dir", type=str, default="./result")
parser.add_argument("--ckpt_dir", type=str, default="./checkpoint")
parser.add_argument("--data_root", type=str, default="./data")
args = parser.parse_args()
solver = Solver(args)
solver.fit()
示例2: main
# 需要导入模块: from solver import Solver [as 别名]
# 或者: from solver.Solver import fit [as 别名]
def main():
fcnxs = symbol_fcnxs.get_fcn32s_symbol(numclass=2, workspace_default=1536)
fcnxs_model_prefix = "model_pascal/FCN32s_VGG16"
num_epoch = 1
learning_rate = 1e-3
if args.model == "fcn16s":
fcnxs = symbol_fcnxs.get_fcn16s_symbol(numclass=2, workspace_default=1536)
fcnxs_model_prefix = "model_pascal/FCN16s_VGG16"
learning_rate = 1e-5
elif args.model == "fcn8s":
fcnxs = symbol_fcnxs.get_fcn8s_symbol(numclass=2, workspace_default=1536)
fcnxs_model_prefix = "model_pascal/FCN8s_VGG16"
num_epoch = 30
learning_rate = 1e-7
arg_names = fcnxs.list_arguments()
_, fcnxs_args, fcnxs_auxs = mx.model.load_checkpoint(args.prefix, args.epoch)
if not args.retrain:
if args.init_type == "vgg16":
fcnxs_args, fcnxs_auxs = init_fcnxs.init_from_vgg16(ctx, fcnxs, fcnxs_args, fcnxs_auxs)
elif args.init_type == "fcnxs":
fcnxs_args, fcnxs_auxs = init_fcnxs.init_from_fcnxs(ctx, fcnxs, fcnxs_args, fcnxs_auxs)
train_dataiter = FileIter(
root_dir = ".",
flist_name = "train.lst",
# cut_off_size = 400,
rgb_mean = (123.68, 116.779, 103.939),
)
# val_dataiter = FileIter(
# root_dir = "/home/zw/dataset/VOC2012Segmentation/VOC2012",
# flist_name = "val.lst",
# rgb_mean = (123.68, 116.779, 103.939),
# )
model = Solver(
ctx = ctx,
symbol = fcnxs,
begin_epoch = 0,
num_epoch = num_epoch,
arg_params = fcnxs_args,
aux_params = fcnxs_auxs,
learning_rate = learning_rate,
momentum = 0.9,
wd = 0.0001)
model.fit(
train_data = train_dataiter,
# eval_data = val_dataiter,
batch_end_callback = mx.callback.Speedometer(1, 10),
epoch_end_callback = mx.callback.do_checkpoint(fcnxs_model_prefix))
示例3: main
# 需要导入模块: from solver import Solver [as 别名]
# 或者: from solver.Solver import fit [as 别名]
def main():
ctx = mx.cpu() if not args.gpu else mx.gpu(args.gpu)
fcnxs = symbol_fcnxs.get_fcn32s_symbol(numclass=21, workspace_default=1536)
fcnxs_model_prefix = "model_pascal/FCN32s_VGG16"
if args.model == "fcn16s":
fcnxs = symbol_fcnxs.get_fcn16s_symbol(numclass=21, workspace_default=1536)
fcnxs_model_prefix = "model_pascal/FCN16s_VGG16"
elif args.model == "fcn8s":
fcnxs = symbol_fcnxs.get_fcn8s_symbol(numclass=21, workspace_default=1536)
fcnxs_model_prefix = "model_pascal/FCN8s_VGG16"
arg_names = fcnxs.list_arguments()
_, fcnxs_args, fcnxs_auxs = mx.model.load_checkpoint(args.prefix, args.epoch)
if not args.retrain:
if args.init_type == "vgg16":
fcnxs_args, fcnxs_auxs = init_fcnxs.init_from_vgg16(ctx, fcnxs, fcnxs_args, fcnxs_auxs)
elif args.init_type == "fcnxs":
fcnxs_args, fcnxs_auxs = init_fcnxs.init_from_fcnxs(ctx, fcnxs, fcnxs_args, fcnxs_auxs)
train_dataiter = FileIter(
root_dir = "./VOC2012",
flist_name = "train.lst",
# cut_off_size = 400,
rgb_mean = (123.68, 116.779, 103.939),
)
val_dataiter = FileIter(
root_dir = "./VOC2012",
flist_name = "val.lst",
rgb_mean = (123.68, 116.779, 103.939),
)
model = Solver(
ctx = ctx,
symbol = fcnxs,
begin_epoch = 0,
num_epoch = 50,
arg_params = fcnxs_args,
aux_params = fcnxs_auxs,
learning_rate = 1e-10,
momentum = 0.99,
wd = 0.0005)
model.fit(
train_data = train_dataiter,
eval_data = val_dataiter,
batch_end_callback = mx.callback.Speedometer(1, 10),
epoch_end_callback = mx.callback.do_checkpoint(fcnxs_model_prefix))
示例4: FileIter
# 需要导入模块: from solver import Solver [as 别名]
# 或者: from solver.Solver import fit [as 别名]
import symbol
net = symbol.get_symbol(output_dim = 30)
from data import FileIter
train = FileIter(
eval_ratio = 0.2,
is_val = False,
data_name = "data",
batch_size = args.batch_size,
label_name = "lr_label"
)
val = FileIter(
eval_ratio = 0.2,
is_val = True,
data_name = "data",
batch_size = 1,
label_name = "lr_label"
)
from solver import Solver
model = Solver(
symbol = net,
num_epoch = args.num_epochs
)
model.fit(train, val,
batch_end_callback = mx.callback.Speedometer(1, 10),
epoch_end_callback = mx.callback.do_checkpoint(args.model_prefix))