本文整理汇总了Python中utils.weights_init方法的典型用法代码示例。如果您正苦于以下问题:Python utils.weights_init方法的具体用法?Python utils.weights_init怎么用?Python utils.weights_init使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类utils
的用法示例。
在下文中一共展示了utils.weights_init方法的4个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: setup
# 需要导入模块: import utils [as 别名]
# 或者: from utils import weights_init [as 别名]
def setup(model, opt):
if opt.criterion == "l1":
criterion = nn.L1Loss().cuda()
elif opt.criterion == "mse":
criterion = nn.MSELoss().cuda()
elif opt.criterion == "crossentropy":
criterion = nn.CrossEntropyLoss().cuda()
elif opt.criterion == "hingeEmbedding":
criterion = nn.HingeEmbeddingLoss().cuda()
elif opt.criterion == "tripletmargin":
criterion = nn.TripletMarginLoss(margin = opt.margin, swap = opt.anchorswap).cuda()
parameters = filter(lambda p: p.requires_grad, model.parameters())
if opt.optimType == 'sgd':
optimizer = optim.SGD(parameters, lr = opt.lr, momentum = opt.momentum, nesterov = opt.nesterov, weight_decay = opt.weightDecay)
elif opt.optimType == 'adam':
optimizer = optim.Adam(parameters, lr = opt.maxlr, weight_decay = opt.weightDecay)
if opt.weight_init:
utils.weights_init(model, opt)
return model, criterion, optimizer
示例2: __init__
# 需要导入模块: import utils [as 别名]
# 或者: from utils import weights_init [as 别名]
def __init__(self, num_inputs, action_space):
super(A3C_CONV, self).__init__()
self.conv1 = nn.Conv1d(num_inputs, 32, 3, stride=1, padding=1)
self.lrelu1 = nn.LeakyReLU(0.1)
self.conv2 = nn.Conv1d(32, 32, 3, stride=1, padding=1)
self.lrelu2 = nn.LeakyReLU(0.1)
self.conv3 = nn.Conv1d(32, 64, 2, stride=1, padding=1)
self.lrelu3 = nn.LeakyReLU(0.1)
self.conv4 = nn.Conv1d(64, 64, 1, stride=1)
self.lrelu4 = nn.LeakyReLU(0.1)
self.lstm = nn.LSTMCell(1600, 128)
num_outputs = action_space.shape[0]
self.critic_linear = nn.Linear(128, 1)
self.actor_linear = nn.Linear(128, num_outputs)
self.actor_linear2 = nn.Linear(128, num_outputs)
self.apply(weights_init)
lrelu_gain = nn.init.calculate_gain('leaky_relu')
self.conv1.weight.data.mul_(lrelu_gain)
self.conv2.weight.data.mul_(lrelu_gain)
self.conv3.weight.data.mul_(lrelu_gain)
self.conv4.weight.data.mul_(lrelu_gain)
self.actor_linear.weight.data = norm_col_init(
self.actor_linear.weight.data, 0.01)
self.actor_linear.bias.data.fill_(0)
self.actor_linear2.weight.data = norm_col_init(
self.actor_linear2.weight.data, 0.01)
self.actor_linear2.bias.data.fill_(0)
self.critic_linear.weight.data = norm_col_init(
self.critic_linear.weight.data, 1.0)
self.critic_linear.bias.data.fill_(0)
self.lstm.bias_ih.data.fill_(0)
self.lstm.bias_hh.data.fill_(0)
self.train()
示例3: init_weights
# 需要导入模块: import utils [as 别名]
# 或者: from utils import weights_init [as 别名]
def init_weights(self):
from utils import weights_init, dgmg_message_weight_init
self.graph_embed.apply(weights_init)
self.graph_prop.apply(weights_init)
self.add_node_agent.apply(weights_init)
self.add_edge_agent.apply(weights_init)
self.choose_dest_agent.apply(weights_init)
self.graph_prop.message_funcs.apply(dgmg_message_weight_init)
示例4: __init__
# 需要导入模块: import utils [as 别名]
# 或者: from utils import weights_init [as 别名]
def __init__(self, num_inputs, action_space):
super(A3Clstm, self).__init__()
self.conv1 = nn.Conv2d(num_inputs, 32, 5, stride=1, padding=2)
self.maxp1 = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(32, 32, 5, stride=1, padding=1)
self.maxp2 = nn.MaxPool2d(2, 2)
self.conv3 = nn.Conv2d(32, 64, 4, stride=1, padding=1)
self.maxp3 = nn.MaxPool2d(2, 2)
self.conv4 = nn.Conv2d(64, 64, 3, stride=1, padding=1)
self.maxp4 = nn.MaxPool2d(2, 2)
self.lstm = nn.LSTMCell(1024, 512)
num_outputs = action_space.n
self.critic_linear = nn.Linear(512, 1)
self.actor_linear = nn.Linear(512, num_outputs)
self.apply(weights_init)
relu_gain = nn.init.calculate_gain('relu')
self.conv1.weight.data.mul_(relu_gain)
self.conv2.weight.data.mul_(relu_gain)
self.conv3.weight.data.mul_(relu_gain)
self.conv4.weight.data.mul_(relu_gain)
self.actor_linear.weight.data = norm_col_init(
self.actor_linear.weight.data, 0.01)
self.actor_linear.bias.data.fill_(0)
self.critic_linear.weight.data = norm_col_init(
self.critic_linear.weight.data, 1.0)
self.critic_linear.bias.data.fill_(0)
self.lstm.bias_ih.data.fill_(0)
self.lstm.bias_hh.data.fill_(0)
self.train()