当前位置: 首页>>代码示例>>Python>>正文


Python gluon.Trainer方法代码示例

本文整理汇总了Python中mxnet.gluon.Trainer方法的典型用法代码示例。如果您正苦于以下问题:Python gluon.Trainer方法的具体用法?Python gluon.Trainer怎么用?Python gluon.Trainer使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在mxnet.gluon的用法示例。


在下文中一共展示了gluon.Trainer方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: train

# 需要导入模块: from mxnet import gluon [as 别名]
# 或者: from mxnet.gluon import Trainer [as 别名]
def train(net, X_train, y_train, epochs, verbose_epoch, learning_rate,
          weight_decay, batch_size):
    """Trains the model."""
    dataset_train = gluon.data.ArrayDataset(X_train, y_train)
    data_iter_train = gluon.data.DataLoader(dataset_train, batch_size,
                                            shuffle=True)
    trainer = gluon.Trainer(net.collect_params(), 'adam',
                            {'learning_rate': learning_rate,
                             'wd': weight_decay})
    net.initialize(force_reinit=True)
    for epoch in range(epochs):
        for data, label in data_iter_train:
            with autograd.record():
                output = net(data)
                loss = square_loss(output, label)
            loss.backward()
            trainer.step(batch_size)
            avg_loss = get_rmse_log(net, X_train, y_train)
        if epoch > verbose_epoch:
            print("Epoch %d, train loss: %f" % (epoch, avg_loss))
    return avg_loss 
开发者ID:awslabs,项目名称:dynamic-training-with-apache-mxnet-on-aws,代码行数:23,代码来源:kaggle_k_fold_cross_validation.py

示例2: test_sparse_parameter

# 需要导入模块: from mxnet import gluon [as 别名]
# 或者: from mxnet.gluon import Trainer [as 别名]
def test_sparse_parameter():
    p = gluon.Parameter('weight', shape=(10, 10), stype='row_sparse', grad_stype='row_sparse')
    p.initialize(init='xavier', ctx=[mx.cpu(0), mx.cpu(1)])
    row_id = mx.nd.arange(0, 10, ctx=mx.cpu(1))
    assert len(p.list_grad()) == 2
    # getting row_sparse data without trainer throws an exception
    assertRaises(RuntimeError, p.list_row_sparse_data, row_id)
    trainer = mx.gluon.Trainer([p], 'sgd')
    assert len(p.list_row_sparse_data(row_id)) == 2
    weight = p.row_sparse_data(row_id)
    assert weight.context == mx.cpu(1)
    assert weight.shape == (10, 10)
    assert weight.stype == 'row_sparse'
    assert p.var().name == 'weight'
    assert p.var().attr('__storage_type__') == str(_STORAGE_TYPE_STR_TO_ID['row_sparse'])
    assert p.grad(mx.cpu(0)).stype == 'row_sparse'

    p.reset_ctx(ctx=[mx.cpu(1), mx.cpu(2)])
    assert p.list_ctx() == [mx.cpu(1), mx.cpu(2)] 
开发者ID:awslabs,项目名称:dynamic-training-with-apache-mxnet-on-aws,代码行数:21,代码来源:test_gluon.py

示例3: test_parameter_row_sparse_data

# 需要导入模块: from mxnet import gluon [as 别名]
# 或者: from mxnet.gluon import Trainer [as 别名]
def test_parameter_row_sparse_data():
    ctx0 = mx.cpu(1)
    ctx1 = mx.cpu(2)
    dim0 = 4
    x = gluon.Parameter('x', shape=(dim0, 2), stype='row_sparse')
    x.initialize(init='xavier', ctx=[ctx0, ctx1])
    trainer = gluon.Trainer([x], 'sgd')
    x_param = x._data[0].copy()
    assert x_param.stype == 'row_sparse'
    row_id_0 = mx.nd.array([0,1], ctx=ctx0)
    retained_0 = x.row_sparse_data(row_id_0)
    retained_target_0 = mx.nd.sparse.retain(x_param, row_id_0.as_in_context(ctx0))
    mx.test_utils.assert_almost_equal(retained_0.asnumpy(), retained_target_0.asnumpy())
    assert retained_0.context == ctx0
    row_id_1 = mx.nd.arange(0, dim0, ctx=ctx1)
    retained_1 = x.row_sparse_data(row_id_1)
    retained_target_1 = x_param
    mx.test_utils.assert_almost_equal(retained_1.asnumpy(), retained_target_1.asnumpy())
    assert retained_1.context == ctx1
    row_id_2 = mx.nd.array([0,1,2])
    retained_2 = x.list_row_sparse_data(row_id_2)
    retained_target_2 = mx.nd.sparse.retain(x_param, row_id_2.as_in_context(ctx0))
    mx.test_utils.assert_almost_equal(retained_2[0].asnumpy(), retained_target_2.asnumpy()) 
开发者ID:awslabs,项目名称:dynamic-training-with-apache-mxnet-on-aws,代码行数:25,代码来源:test_gluon.py

示例4: test_constant

# 需要导入模块: from mxnet import gluon [as 别名]
# 或者: from mxnet.gluon import Trainer [as 别名]
def test_constant():
    class Test(gluon.HybridBlock):
        def __init__(self, **kwargs):
            super(Test, self).__init__(**kwargs)
            self.value = np.asarray([[1,2], [3,4]])
            self.const = self.params.get_constant('const', self.value)

        def hybrid_forward(self, F, x, const):
            return x + const

    test = Test()
    test.initialize()
    trainer = gluon.Trainer(test.collect_params(), 'sgd',
                            {'learning_rate': 1.0, 'momentum': 0.5})

    with mx.autograd.record():
        x = mx.nd.ones((2,2))
        x.attach_grad()
        y = test(x)
        y.backward()

    trainer.step(1)

    assert (test.const.data().asnumpy() == test.value).all()
    assert (x.grad.asnumpy() == 1).all() 
开发者ID:awslabs,项目名称:dynamic-training-with-apache-mxnet-on-aws,代码行数:27,代码来源:test_gluon.py

示例5: test_trainer_sparse_save_load

# 需要导入模块: from mxnet import gluon [as 别名]
# 或者: from mxnet.gluon import Trainer [as 别名]
def test_trainer_sparse_save_load():
    x = gluon.Parameter('x', shape=(10, 1), lr_mult=1.0, stype='row_sparse')
    x.initialize(ctx=[mx.cpu(0)], init='zeros')
    trainer = gluon.Trainer([x], 'sgd', {'learning_rate': 0.1})
    all_rows = mx.nd.arange(0, 10, ctx=mx.cpu(0))
    with mx.autograd.record():
        for w in x.list_row_sparse_data(all_rows):
            y = w * 1
            y.backward()
    trainer.step(1)
    assert trainer._kvstore._updater.optimizer._get_lr(0) == 0.1
    trainer.save_states('test_trainer_sparse_save_load.states')
    trainer.load_states('test_trainer_sparse_save_load.states')
    x.lr_mult = 2.0
    # check if parameter dict is correctly associated with optimizer after load_state
    assert trainer._kvstore._updater.optimizer._get_lr(0) == 0.2 
开发者ID:awslabs,项目名称:dynamic-training-with-apache-mxnet-on-aws,代码行数:18,代码来源:test_gluon_trainer.py

示例6: train

# 需要导入模块: from mxnet import gluon [as 别名]
# 或者: from mxnet.gluon import Trainer [as 别名]
def train(net, epoch, ctx_list):
    net.collect_params().initialize(mx.init.Xavier(magnitude=2.24), ctx=ctx_list)
    trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.5})
    metric = mx.metric.Accuracy()
    loss = gluon.loss.SoftmaxCrossEntropyLoss()

    for i in range(epoch):
        train_data.reset()
        for batch in train_data:
            datas = gluon.utils.split_and_load(batch.data[0], ctx_list, batch_axis=0)
            labels = gluon.utils.split_and_load(batch.label[0], ctx_list, batch_axis=0)
            outputs = []
            with autograd.record():
                for x, y in zip(datas, labels):
                    z = net(x)
                    L = loss(z, y)
                    L.backward()
                    outputs.append(z)
            trainer.step(batch.data[0].shape[0])
            metric.update(labels, outputs)
        name, acc = metric.get()
        metric.reset()
        print('training acc at epoch %d: %s=%f'%(i, name, acc)) 
开发者ID:awslabs,项目名称:dynamic-training-with-apache-mxnet-on-aws,代码行数:25,代码来源:test_autograd.py

示例7: create_trainer

# 需要导入模块: from mxnet import gluon [as 别名]
# 或者: from mxnet.gluon import Trainer [as 别名]
def create_trainer(self, inference):
        """
        Create trainer
        :param inference: network
        :return: trainer
        """

        if self.args.optim == 'sgd':
            optim_params = {'learning_rate': self.args.lr, 'wd': self.args.wd, 'momentum': self.args.mom}
        elif self.args.optim == 'adam':
            optim_params = {'learning_rate': self.args.lr, 'wd': self.args.wd}
        else:
            raise NotImplementedError

        trainer = Trainer(inference.collect_params(), optimizer=self.args.optim,
                          optimizer_params=optim_params)
        return trainer 
开发者ID:aws-samples,项目名称:d-SNE,代码行数:19,代码来源:training_sda.py

示例8: save_params

# 需要导入模块: from mxnet import gluon [as 别名]
# 或者: from mxnet.gluon import Trainer [as 别名]
def save_params(file_stem,
                net,
                trainer):
    """
    Save current model/trainer parameters.

    Parameters:
    ----------
    file_stem : str
        File stem (with path).
    net : HybridBlock
        Model.
    trainer : Trainer
        Trainer.
    """
    net.save_parameters(file_stem + ".params")
    trainer.save_states(file_stem + ".states") 
开发者ID:osmr,项目名称:imgclsmob,代码行数:19,代码来源:train_gl.py

示例9: train

# 需要导入模块: from mxnet import gluon [as 别名]
# 或者: from mxnet.gluon import Trainer [as 别名]
def train(net,epochs, ctx, train_data,test_data,
            margin_loss, reconstructions_loss, 
            batch_size,scale_factor):
    num_classes = 10
    trainer = gluon.Trainer(
        net.collect_params(),'sgd', {'learning_rate': 0.05, 'wd': 5e-4})

    for epoch in range(epochs):
        train_loss = 0.0
        for batch_idx, (data, label) in tqdm(enumerate(train_data), total=len(train_data), ncols=70, leave=False, unit='b'):
            label = label.as_in_context(ctx)
            data = data.as_in_context(ctx)
            with autograd.record():
                prob, X_l2norm, reconstructions = net(data, label)
                loss1 = margin_loss(data, num_classes,  label, X_l2norm)
                loss2 = reconstructions_loss(reconstructions, data)
                loss = loss1 + scale_factor * loss2
                loss.backward()
            trainer.step(batch_size)
            train_loss += nd.mean(loss).asscalar()
        test_acc = test(test_data, net, ctx)
        print('Epoch:{}, TrainLoss:{:.5f}, TestAcc:{}'.format(epoch,train_loss / len(train_data),test_acc)) 
开发者ID:tonysy,项目名称:CapsuleNet-Gluon,代码行数:24,代码来源:main.py

示例10: gluon_random_data_run

# 需要导入模块: from mxnet import gluon [as 别名]
# 或者: from mxnet.gluon import Trainer [as 别名]
def gluon_random_data_run():
    mlflow.gluon.autolog()

    with mlflow.start_run() as run:
        data = DataLoader(LogsDataset(), batch_size=128, last_batch="discard")
        validation = DataLoader(LogsDataset(), batch_size=128, last_batch="discard")

        model = HybridSequential()
        model.add(Dense(64, activation="relu"))
        model.add(Dense(64, activation="relu"))
        model.add(Dense(10))
        model.initialize()
        model.hybridize()
        trainer = Trainer(model.collect_params(), "adam",
                          optimizer_params={"learning_rate": .001, "epsilon": 1e-07})
        est = estimator.Estimator(net=model, loss=SoftmaxCrossEntropyLoss(),
                                  metrics=Accuracy(), trainer=trainer)

        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            est.fit(data, epochs=3, val_data=validation)
    client = mlflow.tracking.MlflowClient()
    return client.get_run(run.info.run_id) 
开发者ID:mlflow,项目名称:mlflow,代码行数:25,代码来源:test_gluon_autolog.py

示例11: test_autolog_ends_auto_created_run

# 需要导入模块: from mxnet import gluon [as 别名]
# 或者: from mxnet.gluon import Trainer [as 别名]
def test_autolog_ends_auto_created_run():
    mlflow.gluon.autolog()

    data = DataLoader(LogsDataset(), batch_size=128, last_batch="discard")

    model = HybridSequential()
    model.add(Dense(64, activation="relu"))
    model.add(Dense(64, activation="relu"))
    model.add(Dense(10))
    model.initialize()
    model.hybridize()

    trainer = Trainer(model.collect_params(), "adam",
                      optimizer_params={"learning_rate": .001, "epsilon": 1e-07})
    est = estimator.Estimator(net=model, loss=SoftmaxCrossEntropyLoss(),
                              metrics=Accuracy(), trainer=trainer)

    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        est.fit(data, epochs=3)

    assert mlflow.active_run() is None 
开发者ID:mlflow,项目名称:mlflow,代码行数:24,代码来源:test_gluon_autolog.py

示例12: test_autolog_persists_manually_created_run

# 需要导入模块: from mxnet import gluon [as 别名]
# 或者: from mxnet.gluon import Trainer [as 别名]
def test_autolog_persists_manually_created_run():
    mlflow.gluon.autolog()

    data = DataLoader(LogsDataset(), batch_size=128, last_batch="discard")

    with mlflow.start_run() as run:

        model = HybridSequential()
        model.add(Dense(64, activation="relu"))
        model.add(Dense(64, activation="relu"))
        model.add(Dense(10))
        model.initialize()
        model.hybridize()
        trainer = Trainer(model.collect_params(), "adam",
                          optimizer_params={"learning_rate": .001, "epsilon": 1e-07})
        est = estimator.Estimator(net=model, loss=SoftmaxCrossEntropyLoss(),
                                  metrics=Accuracy(), trainer=trainer)

        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            est.fit(data, epochs=3)

        assert mlflow.active_run().info.run_id == run.info.run_id 
开发者ID:mlflow,项目名称:mlflow,代码行数:25,代码来源:test_gluon_autolog.py

示例13: gluon_model

# 需要导入模块: from mxnet import gluon [as 别名]
# 或者: from mxnet.gluon import Trainer [as 别名]
def gluon_model(model_data):
    train_data, train_label, _ = model_data
    train_data_loader = DataLoader(list(zip(train_data, train_label)),
                                   batch_size=128, last_batch="discard")
    model = HybridSequential()
    model.add(Dense(128, activation="relu"))
    model.add(Dense(64, activation="relu"))
    model.add(Dense(10))
    model.initialize()
    model.hybridize()
    trainer = Trainer(model.collect_params(), "adam",
                      optimizer_params={"learning_rate": .001, "epsilon": 1e-07})
    est = estimator.Estimator(net=model, loss=SoftmaxCrossEntropyLoss(),
                              metrics=Accuracy(), trainer=trainer)
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        est.fit(train_data_loader, epochs=3)
    return model 
开发者ID:mlflow,项目名称:mlflow,代码行数:20,代码来源:test_gluon_model_export.py

示例14: set_session

# 需要导入模块: from mxnet import gluon [as 别名]
# 或者: from mxnet.gluon import Trainer [as 别名]
def set_session(self, sess) -> None:
        """
        Initializes the model parameters and creates the model trainer.
        NOTEL Session for mxnet backend must be None.
        :param sess: must be None
        """
        assert sess is None
        # FIXME Add initializer
        self.model.collect_params().initialize(ctx=self._devices)
        # Hybridize model and losses
        self.model.hybridize()
        for l in self.losses:
            l.hybridize()

        # Pass dummy data with correct shape to trigger shape inference and full parameter initialization
        self.model(*self._dummy_model_inputs())

        if self.network_is_trainable:
            self.trainer = gluon.Trainer(
                self.model.collect_params(), optimizer=self.optimizer, update_on_kvstore=False) 
开发者ID:NervanaSystems,项目名称:coach,代码行数:22,代码来源:architecture.py

示例15: train

# 需要导入模块: from mxnet import gluon [as 别名]
# 或者: from mxnet.gluon import Trainer [as 别名]
def train(epoch, ctx):
    if isinstance(ctx, mx.Context):
        ctx = [ctx]
    net.initialize(mx.init.Orthogonal(), ctx=ctx)
    # re-initialize conv4's weight to be Orthogonal
    net.conv4.initialize(mx.init.Orthogonal(scale=1), force_reinit=True, ctx=ctx)
    trainer = gluon.Trainer(net.collect_params(), 'adam', {'learning_rate': opt.lr})
    loss = gluon.loss.L2Loss()

    for i in range(epoch):
        train_data.reset()
        for batch in train_data:
            data = gluon.utils.split_and_load(batch.data[0], ctx_list=ctx, batch_axis=0)
            label = gluon.utils.split_and_load(batch.label[0], ctx_list=ctx, batch_axis=0)
            outputs = []
            with ag.record():
                for x, y in zip(data, label):
                    z = net(x)
                    L = loss(z, y)
                    L.backward()
                    outputs.append(z)
            trainer.step(batch.data[0].shape[0])
            metric.update(label, outputs)

        name, acc = metric.get()
        metric.reset()
        print('training mse at epoch %d: %s=%f'%(i, name, acc))
        test(ctx)

    net.save_parameters('superres.params') 
开发者ID:awslabs,项目名称:dynamic-training-with-apache-mxnet-on-aws,代码行数:32,代码来源:super_resolution.py


注:本文中的mxnet.gluon.Trainer方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。