本文整理汇总了Python中model.ActorCritic方法的典型用法代码示例。如果您正苦于以下问题:Python model.ActorCritic方法的具体用法?Python model.ActorCritic怎么用?Python model.ActorCritic使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类model
的用法示例。
在下文中一共展示了model.ActorCritic方法的3个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: main
# 需要导入模块: import model [as 别名]
# 或者: from model import ActorCritic [as 别名]
def main():
env = gym.make(args.env_name)
env.seed(500)
torch.manual_seed(500)
num_inputs = env.observation_space.shape[0]
num_actions = env.action_space.n
print('state size:', num_inputs)
print('action size:', num_actions)
net = ActorCritic(num_inputs, num_actions)
net.load_state_dict(torch.load(args.save_path + 'model.pth'))
net.to(device)
net.eval()
running_score = 0
steps = 0
for e in range(5):
done = False
score = 0
state = env.reset()
state = torch.Tensor(state).to(device)
state = state.unsqueeze(0)
while not done:
env.render()
steps += 1
policy, value = net(state)
action = get_action(policy, num_actions)
next_state, reward, done, _ = env.step(action)
next_state = torch.Tensor(next_state).to(device)
next_state = next_state.unsqueeze(0)
score += reward
state = next_state
print('{} episode | score: {:.2f}'.format(e, score))
示例2: test
# 需要导入模块: import model [as 别名]
# 或者: from model import ActorCritic [as 别名]
def test(rank, args, shared_model, counter):
torch.manual_seed(args.seed + rank)
env = create_atari_env(args.env_name)
env.seed(args.seed + rank)
model = ActorCritic(env.observation_space.shape[0], env.action_space)
model.eval()
state = env.reset()
state = torch.from_numpy(state)
reward_sum = 0
done = True
start_time = time.time()
# a quick hack to prevent the agent from stucking
actions = deque(maxlen=100)
episode_length = 0
while True:
episode_length += 1
# Sync with the shared model
if done:
model.load_state_dict(shared_model.state_dict())
cx = torch.zeros(1, 256)
hx = torch.zeros(1, 256)
else:
cx = cx.detach()
hx = hx.detach()
with torch.no_grad():
value, logit, (hx, cx) = model((state.unsqueeze(0), (hx, cx)))
prob = F.softmax(logit, dim=-1)
action = prob.max(1, keepdim=True)[1].numpy()
state, reward, done, _ = env.step(action[0, 0])
done = done or episode_length >= args.max_episode_length
reward_sum += reward
# a quick hack to prevent the agent from stucking
actions.append(action[0, 0])
if actions.count(actions[0]) == actions.maxlen:
done = True
if done:
print("Time {}, num steps {}, FPS {:.0f}, episode reward {}, episode length {}".format(
time.strftime("%Hh %Mm %Ss",
time.gmtime(time.time() - start_time)),
counter.value, counter.value / (time.time() - start_time),
reward_sum, episode_length))
reward_sum = 0
episode_length = 0
actions.clear()
state = env.reset()
time.sleep(60)
state = torch.from_numpy(state)
示例3: main
# 需要导入模块: import model [as 别名]
# 或者: from model import ActorCritic [as 别名]
def main():
env = gym.make(args.env_name)
env.seed(500)
torch.manual_seed(500)
num_inputs = env.observation_space.shape[0]
num_actions = env.action_space.n
print('state size:', num_inputs)
print('action size:', num_actions)
net = ActorCritic(num_inputs, num_actions)
optimizer = optim.Adam(net.parameters(), lr=0.001)
writer = SummaryWriter('logs')
if not os.path.isdir(args.save_path):
os.makedirs(args.save_path)
net.to(device)
net.train()
running_score = 0
for e in range(3000):
done = False
score = 0
state = env.reset()
state = torch.Tensor(state).to(device)
state = state.unsqueeze(0)
while not done:
if args.render:
env.render()
policy, value = net(state)
action = get_action(policy, num_actions)
next_state, reward, done, _ = env.step(action)
next_state = torch.Tensor(next_state).to(device)
next_state = next_state.unsqueeze(0)
mask = 0 if done else 1
reward = reward if not done or score == 499 else -1
transition = [state, next_state, action, reward, mask]
train_model(net, optimizer, transition, policy, value)
score += reward
state = next_state
score = score if score == 500.0 else score + 1
running_score = 0.99 * running_score + 0.01 * score
if e % args.log_interval == 0:
print('{} episode | score: {:.2f}'.format(e, running_score))
writer.add_scalar('log/score', float(score), running_score)
if running_score > args.goal_score:
ckpt_path = args.save_path + 'model.pth'
torch.save(net.state_dict(), ckpt_path)
print('running score exceeds 400 so end')
break