本文整理汇总了Python中hrlproject.misc.HRLutils.set_seed方法的典型用法代码示例。如果您正苦于以下问题:Python HRLutils.set_seed方法的具体用法?Python HRLutils.set_seed怎么用?Python HRLutils.set_seed使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类hrlproject.misc.HRLutils
的用法示例。
在下文中一共展示了HRLutils.set_seed方法的8个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: run_gridworld
# 需要导入模块: from hrlproject.misc import HRLutils [as 别名]
# 或者: from hrlproject.misc.HRLutils import set_seed [as 别名]
def run_gridworld(args, seed=None):
if seed is not None:
HRLutils.set_seed(seed)
seed = HRLutils.SEED
net = nef.Network("run_gridworld")
stateN = 400
stateD = 2
actions = [("up", [0, 1]), ("right", [1, 0]),
("down", [0, -1]), ("left", [-1, 0])]
agent = smdpagent.SMDPAgent(stateN, stateD, actions, stateradius=3,
**args)
net.add(agent)
env = gridworldenvironment.GridWorldEnvironment(
stateD, actions, HRLutils.datafile("smallgrid.txt"), cartesian=True,
delay=(0.6, 0.9), datacollection=False)
net.add(env)
net.connect(env.getOrigin("state"), agent.getTermination("state_input"))
net.connect(env.getOrigin("reward"), agent.getTermination("reward"))
net.connect(env.getOrigin("reset"), agent.getTermination("reset"))
net.connect(env.getOrigin("learn"), agent.getTermination("learn"))
net.connect(env.getOrigin("reset"), agent.getTermination("save_state"))
net.connect(env.getOrigin("reset"), agent.getTermination("save_action"))
net.connect(agent.getOrigin("action_output"), env.getTermination("action"))
net.connect(agent.getOrigin("Qs"), env.getTermination("Qs"))
net.add_to_nengo()
view = timeview.View(net.network, update_frequency=5)
view.add_watch(gridworldwatch.GridWorldWatch())
view.restore()
示例2: gen_evalpoints
# 需要导入模块: from hrlproject.misc import HRLutils [as 别名]
# 或者: from hrlproject.misc.HRLutils import set_seed [as 别名]
def gen_evalpoints(filename, seed=None):
"""Runs an environment for some length of time and records state values,
to be used as eval points for agent initialization.
:param filename: name of file in which to save eval points
:param seed: random seed
"""
if seed is not None:
HRLutils.set_seed(seed)
seed = HRLutils.SEED
net = nef.Network("gen_evalpoints")
contextD = 2
actions = [("up", [0, 1]), ("right", [1, 0]),
("down", [0, -1]), ("left", [-1, 0])]
rewards = {"a": 1, "b": 1}
env = contextenvironment.ContextEnvironment(
actions, HRLutils.datafile("contextmap.bmp"), contextD, rewards,
imgsize=(5, 5), dx=0.001, placedev=0.5,
colormap={-16777216: "wall", -1: "floor", -256: "a", -2088896: "b"})
net.add(env)
stateD = len(env.placecells) + contextD
actions = env.actions
actionD = len(actions)
class EvalRecorder(nef.SimpleNode):
def __init__(self, evalfile):
self.action = actions[0]
self.evalpoints = []
self.evalfile = evalfile
nef.SimpleNode.__init__(self, "EvalRecorder")
def tick(self):
if self.t % 0.1 < 0.001:
self.evalpoints += [self.state]
if self.t % 10.0 < 0.001:
if len(self.evalpoints) > 10000:
self.evalpoints = self.evalpoints[len(self.evalpoints) -
10000:]
with open(self.evalfile, "w") as f:
f.write("\n".join([" ".join([str(x) for x in e])
for e in self.evalpoints]))
def termination_state(self, x, dimensions=stateD):
self.state = x
def termination_action_in(self, x, dimensions=actionD):
self.action = actions[x.index(max(x))]
def origin_action_out(self):
return self.action[1]
em = EvalRecorder(HRLutils.datafile("%s_%s.txt" % (filename, seed)))
net.add(em)
net.connect(em.getOrigin("action_out"), env.getTermination("action"))
net.connect(env.getOrigin("optimal_move"), em.getTermination("action_in"))
net.connect(env.getOrigin("placewcontext"), em.getTermination("state"))
# net.add_to_nengo()
net.run(10)
示例3: run_badreenvironment
# 需要导入模块: from hrlproject.misc import HRLutils [as 别名]
# 或者: from hrlproject.misc.HRLutils import set_seed [as 别名]
def run_badreenvironment(nav_args, ctrl_args, bias=0.0, seed=None, flat=False,
label="tmp"):
"""Runs the model on the Badre et al. (2010) task."""
if seed is not None:
HRLutils.set_seed(seed)
seed = HRLutils.SEED
net = nef.Network("run_badreenvironment")
env = badreenvironment.BadreEnvironment(flat=flat)
net.add(env)
# ##NAV AGENT
stateN = 500
max_state_input = 3
enc = env.gen_encoders(stateN, 0, 0.0)
# generate evaluation points
orientations = MU.I(env.num_orientations)
shapes = MU.I(env.num_shapes)
colours = MU.I(env.num_colours)
evals = (list(MU.diag([3 for _ in range(env.stateD)])) +
[o + s + c
for o in orientations for s in shapes for c in colours])
# create lower level
nav_agent = smdpagent.SMDPAgent(stateN, env.stateD, env.actions,
name="NavAgent",
stateradius=max_state_input,
state_encoders=enc, state_evals=evals,
discount=0.5, **nav_args)
net.add(nav_agent)
print "agent neurons:", nav_agent.countNeurons()
# actions terminate on fixed schedule (aligned with environment)
nav_term_node = terminationnode.TerminationNode(
{terminationnode.Timer((0.6, 0.6)): None}, env, name="NavTermNode",
state_delay=0.1, reset_delay=0.05, reset_interval=0.1)
net.add(nav_term_node)
net.connect(nav_term_node.getOrigin("reset"),
nav_agent.getTermination("reset"))
net.connect(nav_term_node.getOrigin("learn"),
nav_agent.getTermination("learn"))
net.connect(nav_term_node.getOrigin("reset"),
nav_agent.getTermination("save_state"))
net.connect(nav_term_node.getOrigin("reset"),
nav_agent.getTermination("save_action"))
net.connect(nav_agent.getOrigin("action_output"),
env.getTermination("action"))
# ##CTRL AGENT
stateN = 500
enc = RandomHypersphereVG().genVectors(stateN, env.stateD)
actions = [("shape", [0, 1]), ("orientation", [1, 0]), ("null", [0, 0])]
ctrl_agent = smdpagent.SMDPAgent(stateN, env.stateD, actions,
name="CtrlAgent", state_encoders=enc,
stateradius=max_state_input,
state_evals=evals, discount=0.4,
**ctrl_args)
net.add(ctrl_agent)
print "agent neurons:", ctrl_agent.countNeurons()
net.connect(env.getOrigin("state"),
ctrl_agent.getTermination("state_input"))
ctrl_term_node = terminationnode.TerminationNode(
{terminationnode.Timer((0.6, 0.6)): None}, env, name="CtrlTermNode",
state_delay=0.1, reset_delay=0.05, reset_interval=0.1)
net.add(ctrl_term_node)
net.connect(ctrl_term_node.getOrigin("reset"),
ctrl_agent.getTermination("reset"))
net.connect(ctrl_term_node.getOrigin("learn"),
ctrl_agent.getTermination("learn"))
net.connect(ctrl_term_node.getOrigin("reset"),
ctrl_agent.getTermination("save_state"))
net.connect(ctrl_term_node.getOrigin("reset"),
ctrl_agent.getTermination("save_action"))
# ctrl gets a slight bonus if it selects a rule (as opposed to null), to
# encourage it to not just pick null all the time
reward_relay = net.make("reward_relay", 1, 3, mode="direct")
reward_relay.fixMode()
net.connect(env.getOrigin("reward"), reward_relay,
transform=[[1], [0], [0]])
net.connect(ctrl_agent.getOrigin("action_output"), reward_relay,
transform=[[0, 0], [1, 0], [0, 1]])
net.connect(reward_relay, ctrl_agent.getTermination("reward"),
func=lambda x: ((x[0] + bias * abs(x[0]))
if x[1] + x[2] > 0.5 else x[0]),
origin_name="ctrl_reward")
# ideal reward function (for testing)
# def ctrl_reward_func(x):
#.........这里部分代码省略.........
示例4: run_flat_delivery
# 需要导入模块: from hrlproject.misc import HRLutils [as 别名]
# 或者: from hrlproject.misc.HRLutils import set_seed [as 别名]
def run_flat_delivery(args, seed=None):
"""Runs the model on the delivery task with only one hierarchical level."""
if seed is not None:
HRLutils.set_seed(seed)
seed = HRLutils.SEED
net = nef.Network("run_flat_delivery")
if "load_weights" in args and args["load_weights"] is not None:
args["load_weights"] += "_%s" % seed
stateN = 1200
contextD = 2
context_scale = 1.0
max_state_input = 2
actions = [("up", [0, 1]), ("right", [1, 0]),
("down", [0, -1]), ("left", [-1, 0])]
# ##ENVIRONMENT
env = deliveryenvironment.DeliveryEnvironment(
actions, HRLutils.datafile("contextmap.bmp"),
colormap={-16777216: "wall", -1: "floor", -256: "a", -2088896: "b"},
imgsize=(5, 5), dx=0.001, placedev=0.5)
net.add(env)
print "generated", len(env.placecells), "placecells"
# ##NAV AGENT
enc = env.gen_encoders(stateN, contextD, context_scale)
enc = MU.prod(enc, 1.0 / max_state_input)
with open(HRLutils.datafile("contextbmp_evalpoints_%s.txt" % seed)) as f:
evals = [[float(x) for x in l.split(" ")] for l in f.readlines()]
nav_agent = smdpagent.SMDPAgent(stateN, len(env.placecells) + contextD,
actions, name="NavAgent",
state_encoders=enc, state_evals=evals,
state_threshold=0.8, **args)
net.add(nav_agent)
print "agent neurons:", nav_agent.countNeurons()
net.connect(nav_agent.getOrigin("action_output"),
env.getTermination("action"))
net.connect(env.getOrigin("placewcontext"),
nav_agent.getTermination("state_input"))
nav_term_node = terminationnode.TerminationNode(
{terminationnode.Timer((0.6, 0.9)): None}, env, name="NavTermNode",
contextD=2)
net.add(nav_term_node)
net.connect(env.getOrigin("context"),
nav_term_node.getTermination("context"))
net.connect(nav_term_node.getOrigin("reset"),
nav_agent.getTermination("reset"))
net.connect(nav_term_node.getOrigin("learn"),
nav_agent.getTermination("learn"))
net.connect(nav_term_node.getOrigin("reset"),
nav_agent.getTermination("save_state"))
net.connect(nav_term_node.getOrigin("reset"),
nav_agent.getTermination("save_action"))
reward_relay = net.make("reward_relay", 1, 1, mode="direct")
reward_relay.fixMode()
net.connect(env.getOrigin("reward"), reward_relay)
net.connect(nav_term_node.getOrigin("pseudoreward"), reward_relay)
net.connect(reward_relay, nav_agent.getTermination("reward"))
# period to save weights (realtime, not simulation time)
weight_save = 600.0
HRLutils.WeightSaveThread(nav_agent.getNode("QNetwork").saveParams,
os.path.join("weights", "%s_%s" %
(nav_agent.name, seed)),
weight_save).start()
# data collection node
data = datanode.DataNode(period=5,
filename=HRLutils.datafile("dataoutput_%s.txt" %
seed))
net.add(data)
q_net = nav_agent.getNode("QNetwork")
data.record_avg(env.getOrigin("reward"))
data.record_avg(q_net.getNode("actionvals").getOrigin("X"))
data.record_sparsity(q_net.getNode("state_pop").getOrigin("AXON"))
data.record_avg(q_net.getNode("valdiff").getOrigin("X"))
data.record_avg(nav_agent.getNode("ErrorNetwork").getOrigin("error"))
# net.add_to_nengo()
# net.run(10000)
net.view()
示例5: run_deliveryenvironment
# 需要导入模块: from hrlproject.misc import HRLutils [as 别名]
# 或者: from hrlproject.misc.HRLutils import set_seed [as 别名]
def run_deliveryenvironment(navargs, ctrlargs, tag=None, seed=None):
"""Runs the model on the delivery task.
:param navargs: kwargs for the nav_agent (see SMDPAgent.__init__)
:param ctrlargs: kwargs for the ctrl_agent (see SMDPAgent.__init__)
:param tag: string appended to datafiles associated with this run
:param seed: random seed used for this run
"""
if seed is not None:
HRLutils.set_seed(seed)
seed = HRLutils.SEED
if tag is None:
tag = str(seed)
net = nef.Network("runDeliveryEnvironment", seed=seed)
stateN = 1200 # number of neurons to use in state population
contextD = 2 # dimension of context vector
context_scale = 1.0 # relative scale of context vector vs state vector
max_state_input = 2 # maximum length of input vector to state population
# labels and vectors corresponding to basic actions available to the system
actions = [("up", [0, 1]), ("right", [1, 0]),
("down", [0, -1]), ("left", [-1, 0])]
if "load_weights" in navargs and navargs["load_weights"] is not None:
navargs["load_weights"] += "_%s" % tag
if "load_weights" in ctrlargs and ctrlargs["load_weights"] is not None:
ctrlargs["load_weights"] += "_%s" % tag
# ##ENVIRONMENT
env = deliveryenvironment.DeliveryEnvironment(
actions, HRLutils.datafile("contextmap.bmp"),
colormap={-16777216: "wall", -1: "floor", -256: "a", -2088896: "b"},
imgsize=(5, 5), dx=0.001, placedev=0.5)
net.add(env)
print "generated", len(env.placecells), "placecells"
# ##NAV AGENT
# generate encoders and divide them by max_state_input (so that inputs
# will be scaled down to radius 1)
enc = env.gen_encoders(stateN, contextD, context_scale)
enc = MU.prod(enc, 1.0 / max_state_input)
# read in eval points from file
with open(HRLutils.datafile("contextbmp_evalpoints_%s.txt" % tag)) as f:
evals = [[float(x) for x in l.split(" ")] for l in f.readlines()]
nav_agent = smdpagent.SMDPAgent(stateN, len(env.placecells) + contextD,
actions, name="NavAgent",
state_encoders=enc, state_evals=evals,
state_threshold=0.8,
**navargs)
net.add(nav_agent)
print "agent neurons:", nav_agent.countNeurons()
# output of nav_agent is what goes to the environment
net.connect(nav_agent.getOrigin("action_output"),
env.getTermination("action"))
# termination node for nav_agent (just a timer that goes off regularly)
nav_term_node = terminationnode.TerminationNode(
{terminationnode.Timer((0.6, 0.9)): None}, env, contextD=2,
name="NavTermNode")
net.add(nav_term_node)
net.connect(nav_term_node.getOrigin("reset"),
nav_agent.getTermination("reset"))
net.connect(nav_term_node.getOrigin("learn"),
nav_agent.getTermination("learn"))
net.connect(nav_term_node.getOrigin("reset"),
nav_agent.getTermination("save_state"))
net.connect(nav_term_node.getOrigin("reset"),
nav_agent.getTermination("save_action"))
# ##CTRL AGENT
# actions corresponding to "go to A" or "go to B"
actions = [("a", [0, 1]), ("b", [1, 0])]
ctrl_agent = smdpagent.SMDPAgent(stateN, len(env.placecells) + contextD,
actions, name="CtrlAgent",
state_encoders=enc, state_evals=evals,
state_threshold=0.8, **ctrlargs)
net.add(ctrl_agent)
print "agent neurons:", ctrl_agent.countNeurons()
# ctrl_agent gets environmental state and reward
net.connect(env.getOrigin("placewcontext"),
ctrl_agent.getTermination("state_input"))
net.connect(env.getOrigin("reward"),
ctrl_agent.getTermination("reward"))
# termination node for ctrl_agent (terminates whenever the agent is in the
#.........这里部分代码省略.........
示例6: run_contextenvironment
# 需要导入模块: from hrlproject.misc import HRLutils [as 别名]
# 或者: from hrlproject.misc.HRLutils import set_seed [as 别名]
def run_contextenvironment(args, seed=None):
"""Runs the model on the context task.
:param args: kwargs for the agent
:param seed: random seed
"""
if seed is not None:
HRLutils.set_seed(seed)
seed = HRLutils.SEED
net = nef.Network("runContextEnvironment")
if "load_weights" in args and args["load_weights"] is not None:
args["load_weights"] += "_%s" % seed
stateN = 1200 # number of neurons to use in state population
contextD = 2 # dimension of context vector
context_scale = 1.0 # scale of context representation
max_state_input = 2 # max length of input vector for state population
# actions (label and vector) available to the system
actions = [("up", [0, 1]), ("right", [1, 0]),
("down", [0, -1]), ("left", [-1, 0])]
# context labels and rewards for achieving those context goals
rewards = {"a": 1.5, "b": 1.5}
env = contextenvironment.ContextEnvironment(
actions, HRLutils.datafile("contextmap.bmp"), contextD, rewards,
colormap={-16777216: "wall", -1: "floor", -256: "a", -2088896: "b"},
imgsize=(5, 5), dx=0.001, placedev=0.5)
net.add(env)
print "generated", len(env.placecells), "placecells"
# termination node for agent (just goes off on some regular interval)
term_node = terminationnode.TerminationNode(
{terminationnode.Timer((0.6, 0.9)): 0.0}, env)
net.add(term_node)
# generate encoders and divide by max_state_input (so that all inputs
# will end up being radius 1)
enc = env.gen_encoders(stateN, contextD, context_scale)
enc = MU.prod(enc, 1.0 / max_state_input)
# load eval points from file
with open(HRLutils.datafile("contextbmp_evalpoints_%s.txt" % seed)) as f:
print "loading contextbmp_evalpoints_%s.txt" % seed
evals = [[float(x) for x in l.split(" ")] for l in f.readlines()]
agent = smdpagent.SMDPAgent(stateN, len(env.placecells) + contextD,
actions, state_encoders=enc, state_evals=evals,
state_threshold=0.8, **args)
net.add(agent)
print "agent neurons:", agent.countNeurons()
# period to save weights (realtime, not simulation time)
weight_save = 600.0
t = HRLutils.WeightSaveThread(agent.getNode("QNetwork").saveParams,
os.path.join("weights", "%s_%s" %
(agent.name, seed)),
weight_save)
t.start()
# data collection node
data = datanode.DataNode(period=5,
filename=HRLutils.datafile("dataoutput_%s.txt" %
seed))
net.add(data)
q_net = agent.getNode("QNetwork")
data.record(env.getOrigin("reward"))
data.record(q_net.getNode("actionvals").getOrigin("X"), func=max)
data.record(q_net.getNode("actionvals").getOrigin("X"), func=min)
data.record_sparsity(q_net.getNode("state_pop").getOrigin("AXON"))
data.record_avg(q_net.getNode("valdiff").getOrigin("X"))
data.record_avg(env.getOrigin("state"))
net.connect(env.getOrigin("placewcontext"),
agent.getTermination("state_input"))
net.connect(env.getOrigin("reward"), agent.getTermination("reward"))
net.connect(term_node.getOrigin("reset"), agent.getTermination("reset"))
net.connect(term_node.getOrigin("learn"), agent.getTermination("learn"))
net.connect(term_node.getOrigin("reset"),
agent.getTermination("save_state"))
net.connect(term_node.getOrigin("reset"),
agent.getTermination("save_action"))
net.connect(agent.getOrigin("action_output"), env.getTermination("action"))
# net.add_to_nengo()
# net.run(2000)
net.view()
t.stop()
示例7: run_badreenvironment
# 需要导入模块: from hrlproject.misc import HRLutils [as 别名]
# 或者: from hrlproject.misc.HRLutils import set_seed [as 别名]
def run_badreenvironment(nav_args, ctrl_args, seed=None, flat=False):
if seed is not None:
HRLutils.set_seed(seed)
seed = HRLutils.SEED
net = nef.Network("run_badreenvironment")
env = badreenvironment.BadreEnvironment(flat=flat)
net.add(env)
###NAV AGENT
stateN = 500
max_state_input = 2
enc = env.gen_encoders(stateN, 0, 1.0)
enc = MU.prod(enc, 1.0 / max_state_input)
# with open(HRLutils.datafile("badre_evalpoints.txt")) as f:
# evals = [[float(x) for x in l.split(" ")] for l in f.readlines()]
orientations = MU.I(env.num_orientations)
shapes = MU.I(env.num_shapes)
colours = MU.I(env.num_colours)
evals = list(MU.I(env.stateD)) + \
[o+s+c for o in orientations for s in shapes for c in colours]
nav_agent = smdpagent.SMDPAgent(stateN, env.stateD,
env.actions, name="NavAgent",
load_weights=None,
state_encoders=enc, state_evals=evals,
discount=0.4, **nav_args)
net.add(nav_agent)
print "agent neurons:", nav_agent.countNeurons()
nav_term_node = terminationnode.TerminationNode({terminationnode.Timer((0.6, 0.6)):None}, env,
name="NavTermNode", state_delay=0.1)
net.add(nav_term_node)
net.connect(nav_term_node.getOrigin("reset"), nav_agent.getTermination("reset"))
net.connect(nav_term_node.getOrigin("learn"), nav_agent.getTermination("learn"))
net.connect(nav_term_node.getOrigin("reset"), nav_agent.getTermination("save_state"))
net.connect(nav_term_node.getOrigin("reset"), nav_agent.getTermination("save_action"))
net.connect(nav_agent.getOrigin("action_output"), env.getTermination("action"))
###CTRL AGENT
enc = env.gen_encoders(stateN, 0, 0)
enc = MU.prod(enc, 1.0 / max_state_input)
actions = [("shape", [0, 1]), ("orientation", [1, 0]), ("null", [0, 0])]
ctrl_agent = smdpagent.SMDPAgent(stateN, env.stateD, actions, name="CtrlAgent",
load_weights=None, state_encoders=enc,
state_evals=evals, discount=0.4, **ctrl_args)
net.add(ctrl_agent)
print "agent neurons:", ctrl_agent.countNeurons()
net.connect(env.getOrigin("state"), ctrl_agent.getTermination("state_input"))
ctrl_term_node = terminationnode.TerminationNode({terminationnode.Timer((0.6, 0.6)):None},
env, name="CtrlTermNode",
state_delay=0.1)
net.add(ctrl_term_node)
net.connect(ctrl_term_node.getOrigin("reset"), ctrl_agent.getTermination("reset"))
net.connect(ctrl_term_node.getOrigin("learn"), ctrl_agent.getTermination("learn"))
net.connect(ctrl_term_node.getOrigin("reset"), ctrl_agent.getTermination("save_state"))
net.connect(ctrl_term_node.getOrigin("reset"), ctrl_agent.getTermination("save_action"))
## reward for nav/ctrl
reward_relay = net.make("reward_relay", 1, 2, mode="direct")
reward_relay.fixMode()
net.connect(env.getOrigin("reward"), reward_relay, transform=[[1], [0]])
net.connect(ctrl_agent.getOrigin("action_output"), reward_relay, transform=[[0, 0], [1, 1]])
# nav reward is just environment
net.connect(reward_relay, nav_agent.getTermination("reward"),
func=lambda x: x[0], origin_name="nav_reward")
# ctrl gets a slight bonus if it selects a rule (as opposed to null), to encourage it not
# to just pick null all the time
net.connect(reward_relay, ctrl_agent.getTermination("reward"),
func=lambda x: x[0]+0.25*abs(x[0]) if x[1] > 0.5 else x[0], origin_name="ctrl_reward")
## state for navagent controlled by ctrlagent
# ctrl_output_relay = net.make("ctrl_output_relay", 1, env.stateD+2, mode="direct")
# ctrl_output_relay.fixMode()
ctrl_output_relay = net.make_array("ctrl_output_relay", 50, env.stateD,
radius=2, mode=HRLutils.SIMULATION_MODE)
ctrl_output_relay.fixMode([SimulationMode.DEFAULT, SimulationMode.RATE])
inhib_matrix = [[0,-5]]*50*env.num_orientations + \
[[-5,0]]*50*env.num_shapes + \
[[-5,-5]]*50*env.num_colours
# ctrl output inhibits all the non-selected aspects of the state
net.connect(env.getOrigin("state"), ctrl_output_relay)
net.connect(ctrl_agent.getOrigin("action_output"), ctrl_output_relay,
# transform=zip([0]*env.num_orientations + [-1]*(env.num_shapes+env.num_colours),
# [-1]*env.num_orientations + [0]*env.num_shapes + [-1]*env.num_colours))
#.........这里部分代码省略.........
示例8: run_flat_delivery
# 需要导入模块: from hrlproject.misc import HRLutils [as 别名]
# 或者: from hrlproject.misc.HRLutils import set_seed [as 别名]
def run_flat_delivery(args, seed=None):
"""Runs the model on the delivery task with only one hierarchical level."""
if seed is not None:
HRLutils.set_seed(seed)
seed = HRLutils.SEED
net = nef.Network("run_flat_delivery")
if args.has_key("load_weights") and args["load_weights"] is not None:
args["load_weights"] += "_%s" % seed
stateN = 1200
contextD = 2
context_scale = 1.0
max_state_input = 2
actions = [("up", [0, 1]), ("right", [1, 0]), ("down", [0, -1]), ("left", [-1, 0])]
###ENVIRONMENT
env = deliveryenvironment.DeliveryEnvironment(actions, HRLutils.datafile("contextmap.bmp"),
colormap={-16777216:"wall",
- 1:"floor",
- 256:"a",
- 2088896:"b"},
imgsize=(5, 5), dx=0.001, placedev=0.5)
net.add(env)
print "generated", len(env.placecells), "placecells"
###NAV AGENT
enc = env.gen_encoders(stateN, contextD, context_scale)
enc = MU.prod(enc, 1.0 / max_state_input)
with open(HRLutils.datafile("contextbmp_evalpoints_%s.txt" % seed)) as f:
evals = [[float(x) for x in l.split(" ")] for l in f.readlines()]
nav_agent = smdpagent.SMDPAgent(stateN, len(env.placecells) + contextD, actions, name="NavAgent",
state_encoders=enc, state_evals=evals, state_threshold=0.8,
**args)
net.add(nav_agent)
print "agent neurons:", nav_agent.countNeurons()
# Connect the agents actions to the environment so the agent can act upon the environment
net.connect(nav_agent.getOrigin("action_output"), env.getTermination("action"))
# Connect the environment state to the agent, so the agent knows the effect of it's action
net.connect(env.getOrigin("placewcontext"), nav_agent.getTermination("state_input"))
# net.connect(env.getOrigin("reward"), nav_agent.getTermination("reward"))
# net.connect(env.getOrigin("optimal_move"), nav_agent.getTermination("bg_input"))
# termination node for nav_agent (just a timer that goes off regularly)
nav_term_node = terminationnode.TerminationNode({terminationnode.Timer((0.6, 0.9)):None}, env,
name="NavTermNode", contextD=2)
net.add(nav_term_node)
net.connect(env.getOrigin("context"), nav_term_node.getTermination("context"))
net.connect(nav_term_node.getOrigin("reset"), nav_agent.getTermination("reset"))
net.connect(nav_term_node.getOrigin("learn"), nav_agent.getTermination("learn"))
net.connect(nav_term_node.getOrigin("reset"), nav_agent.getTermination("save_state"))
net.connect(nav_term_node.getOrigin("reset"), nav_agent.getTermination("save_action"))
# WTF why not connect directly? # Maybe this is the only way to make a direct connection between outputs in this version of Nengo?
reward_relay = net.make("reward_relay", 1, 1, mode="direct")
reward_relay.fixMode()
net.connect(env.getOrigin("reward"), reward_relay)
net.connect(nav_term_node.getOrigin("pseudoreward"), reward_relay)
net.connect(reward_relay, nav_agent.getTermination("reward"))
#save weights
weight_save = 600.0 #period to save weights (realtime, not simulation time)
HRLutils.WeightSaveThread(nav_agent.getNode("QNetwork").saveParams,
os.path.join("weights", "%s_%s" % (nav_agent.name, seed)), weight_save).start()
#data collection node
data = datanode.DataNode(period=5, show_plots=None, filename=HRLutils.datafile("dataoutput_%s.txt" % seed))
net.add(data)
#data.record_avg(env.getOrigin("reward"), filter=1e-5)
#data.record_avg(nav_agent.getNode("QNetwork").getNode("actionvals").getOrigin("X"), filter=1e-5)
#data.record_sparsity(nav_agent.getNode("QNetwork").getNode("state_pop").getOrigin("AXON"), filter=1e-5)
#data.record_avg(nav_agent.getNode("QNetwork").getNode("valdiff").getOrigin("X"), filter=1e-5)
# ErrorNetwork is apparently not the correct name and hell if I know what the correct one is
#data.record_avg(nav_agent.getNode("ErrorNetwork").getOrigin("error"), filter=1e-5)
# Try recording everything
net.add_to_nengo()
net.view()