当前位置: 首页>>代码示例>>Python>>正文


Python HRLutils.datafile方法代码示例

本文整理汇总了Python中hrlproject.misc.HRLutils.datafile方法的典型用法代码示例。如果您正苦于以下问题:Python HRLutils.datafile方法的具体用法?Python HRLutils.datafile怎么用?Python HRLutils.datafile使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在hrlproject.misc.HRLutils的用法示例。


在下文中一共展示了HRLutils.datafile方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: test_terminationnode

# 需要导入模块: from hrlproject.misc import HRLutils [as 别名]
# 或者: from hrlproject.misc.HRLutils import datafile [as 别名]
def test_terminationnode():
    net = nef.Network("testTerminationNode")

    actions = [("up", [0, 1]), ("right", [1, 0]), ("down", [0, -1]), ("left", [-1, 0])]
    env = deliveryenvironment.DeliveryEnvironment(
        actions,
        HRLutils.datafile("contextmap.bmp"),
        colormap={-16777216: "wall", -1: "floor", -256: "a", -2088896: "b"},
        imgsize=(5, 5),
        dx=0.001,
        placedev=0.5,
    )
    net.add(env)

    term_node = terminationnode.TerminationNode(
        {"a": [0, 1], "b": [1, 0], terminationnode.Timer((30, 30)): None}, env, contextD=2, rewardval=1
    )
    net.add(term_node)

    print term_node.conds

    context_input = net.make_input("contextinput", {0.0: [0, 0.1], 0.5: [1, 0], 1.0: [0, 1]})
    net.connect(context_input, term_node.getTermination("context"))

    net.add_to_nengo()
    net.view()
开发者ID:drasmuss,项目名称:nhrlmodel,代码行数:28,代码来源:test.py

示例2: saveParams

# 需要导入模块: from hrlproject.misc import HRLutils [as 别名]
# 或者: from hrlproject.misc.HRLutils import datafile [as 别名]
    def saveParams(self, prefix):
        #save connection weights
        if self.neuron_learning:
            self.getNode("actionvals").saveWeights(prefix)
            self.getNode("old_actionvals").saveWeights(prefix)
        else:
            dec = self.getNode("state_pop").getOrigin("vals").getDecoders()
            with open(HRLutils.datafile(prefix + "_state_decoders.txt"), "w") as f:
                f.write("\n".join([" ".join([str(x) for x in d]) for d in dec]))

            dec = self.getNode("old_state_pop").getOrigin("vals").getDecoders()
            with open(HRLutils.datafile(prefix + "_old_state_decoders.txt"), "w") as f:
                f.write("\n".join([" ".join([str(x) for x in d]) for d in dec]))

        #save state encoders
        enc = self.getNode("state_pop").getEncoders()
        with open(HRLutils.datafile(prefix + "_state_encoders.txt"), "w") as f:
            f.write("\n".join([" ".join([str(x) for x in e]) for e in enc]))
开发者ID:Seanny123,项目名称:HRL_1.0,代码行数:20,代码来源:Qnetwork.py

示例3: loadParams

# 需要导入模块: from hrlproject.misc import HRLutils [as 别名]
# 或者: from hrlproject.misc.HRLutils import datafile [as 别名]
    def loadParams(self, prefix):
        print "loading params: %s" % prefix

        #load connection weights
        if self.neuron_learning:
            self.getNode("actionvals").loadWeights(prefix)
            self.getNode("old_actionvals").loadWeights(prefix)
        else:
            with open(HRLutils.datafile(prefix + "_state_decoders.txt")) as f:
                self.getNode("state_pop").getOrigin("vals").setDecoders(
                            [[float(x) for x in d.split(" ")] for d in f.readlines()])

            with open(HRLutils.datafile(prefix + "_old_state_decoders.txt")) as f:
                self.getNode("old_state_pop").getOrigin("vals").setDecoders(
                            [[float(x) for x in d.split(" ")] for d in f.readlines()])

        #load state encoders
        with open(HRLutils.datafile(prefix + "_state_encoders.txt")) as f:
            enc = [[float(x) for x in e.split(" ")] for e in f.readlines()]
        self.getNode("state_pop").setEncoders(enc)
        self.getNode("old_state_pop").setEncoders(enc) #note we assume that state_pop and old_state_pop use the same encoders
开发者ID:Seanny123,项目名称:HRL_1.0,代码行数:23,代码来源:Qnetwork.py

示例4: test_bmp

# 需要导入模块: from hrlproject.misc import HRLutils [as 别名]
# 或者: from hrlproject.misc.HRLutils import datafile [as 别名]
def test_bmp():
    from javax.imageio import ImageIO
    from java.io import File

    img = ImageIO.read(File(HRLutils.datafile("contextmap.bmp")))

    colours = [int(val) for val in img.getRGB(0, 0, img.getWidth(), img.getHeight(), None, 0, img.getWidth())]
    unique_colours = []
    for c in colours:
        if c not in unique_colours:
            unique_colours += [c]

    print unique_colours
开发者ID:drasmuss,项目名称:nhrlmodel,代码行数:15,代码来源:test.py

示例5: __init__

# 需要导入模块: from hrlproject.misc import HRLutils [as 别名]
# 或者: from hrlproject.misc.HRLutils import datafile [as 别名]
    def __init__(self, actions, mapname, colormap, name="PlaceCellEnvironment",
                 imgsize=(1.0, 1.0), dx=0.01, placedev=0.1, num_places=None):
        """Initialize environment variables.

        :param actions: actions available to the system
            :type actions: list of tuples (action_name,action_vector)
        :param mapname: name of file describing environment map
        :param colormap: dict mapping pixel colours to labels
        :param name: name for environment
        :param imgsize: width of space represented by the map image
        :param dx: distance agent moves each timestep
        :param placedev: standard deviation of gaussian place cell activations
        :param num_places: number of placecells to use (if None it will attempt
            to fill the space)
        """

        EnvironmentTemplate.__init__(self, name, 2, actions)

        # parameters
        self.colormap = colormap
        self.rewardamount = 0  # number of timesteps spent in reward

        # number of timesteps to spend in reward before agent is reset
        # note: convenient to express this as time_in_reward / dt
        self.rewardresetamount = 0.6 / 0.001

        self.num_actions = len(actions)
        self.imgsize = [float(x) for x in imgsize]
        self.dx = dx
        self.placedev = placedev
        self.num_places = num_places
        self.optimal_move = None
        self.defaultreward = -0.075

        # load environment
        self.map = ImageIO.read(File(HRLutils.datafile(mapname)))

        # generate place cells
        self.gen_placecells(min_spread=1.0 * placedev)

        # initial conditions
        self.state = self.random_location(avoid=["wall", "target"])
        self.place_activations = [0 for _ in self.placecells]

        self.create_origin("place", lambda: self.place_activations)

        # note: making the value small, so that the noise node will give us
        # some random exploration as well
        self.create_origin("optimal_move",
                           lambda: [0.1 if self.optimal_move == a[0] else 0.0
                                    for a in self.actions])
开发者ID:drasmuss,项目名称:nhrlmodel,代码行数:53,代码来源:placecell_bmp.py

示例6: saveWeights

# 需要导入模块: from hrlproject.misc import HRLutils [as 别名]
# 或者: from hrlproject.misc.HRLutils import datafile [as 别名]
    def saveWeights(self, prefix):
        """Save the connection weights to file."""

        prefix = prefix + "_" + self.name
        for n in self.getNodes():
            if n.getName().startswith("action"):
                term = n.getTermination("learning")
                weights = [t.getWeights() for t in term.getNodeTerminations()]

                f = open(HRLutils.datafile(prefix + "_" + n.getName() + ".txt"), "w")
                f.write(str(HRLutils.SEED) + "\n")
                for row in weights:
                    f.write(" ".join([str(x) for x in row]) + "\n")
                f.close()
开发者ID:Seanny123,项目名称:HRL_1.0,代码行数:16,代码来源:actionvalues.py

示例7: loadWeights

# 需要导入模块: from hrlproject.misc import HRLutils [as 别名]
# 或者: from hrlproject.misc.HRLutils import datafile [as 别名]
    def loadWeights(self, prefix):
        """Load the connection weights from file."""

        prefix = prefix + "_" + self.name
        for n in self.getNodes():
            if n.getName().startswith("action"):
                f = open(HRLutils.datafile(prefix + "_" + n.getName() + ".txt"), "r")
                seed = int(f.readline())
                if seed != HRLutils.SEED:
                    print "Warning, loading weights with a seed (" + seed + ") that doesn't match current (" + HRLutils.SEED + ")"
                weights = []
                for line in f:
                    weights += [[float(x) for x in line.split()]]
                f.close()

                term = n.getTermination("learning")
                for i, t in enumerate(term.getNodeTerminations()):
                    t.setWeights(weights[i], True)
开发者ID:Seanny123,项目名称:HRL_1.0,代码行数:20,代码来源:actionvalues.py

示例8: test_placecell_bmp

# 需要导入模块: from hrlproject.misc import HRLutils [as 别名]
# 或者: from hrlproject.misc.HRLutils import datafile [as 别名]
def test_placecell_bmp():
    net = nef.Network("TestPlacecellBmp")

    actions = [("up", [0, 1]), ("right", [1, 0]), ("down", [0, -1]), ("left", [-1, 0])]

    env = placecell_bmp.PlaceCellEnvironment(
        actions,
        HRLutils.datafile("contextmap.bmp"),
        colormap={-16777216: "wall", -1: "floor", -256: "target", -2088896: "b"},
        imgsize=(5, 5),
        dx=0.001,
        placedev=0.5,
    )
    net.add(env)

    print "generated", len(env.placecells), "placecells"

    net.add_to_nengo()
    net.view()
开发者ID:drasmuss,项目名称:nhrlmodel,代码行数:21,代码来源:test.py

示例9: run_gridworld

# 需要导入模块: from hrlproject.misc import HRLutils [as 别名]
# 或者: from hrlproject.misc.HRLutils import datafile [as 别名]
def run_gridworld(args, seed=None):

    if seed is not None:
        HRLutils.set_seed(seed)
    seed = HRLutils.SEED

    net = nef.Network("run_gridworld")

    stateN = 400
    stateD = 2
    actions = [("up", [0, 1]), ("right", [1, 0]),
               ("down", [0, -1]), ("left", [-1, 0])]

    agent = smdpagent.SMDPAgent(stateN, stateD, actions, stateradius=3,
                                **args)
    net.add(agent)

    env = gridworldenvironment.GridWorldEnvironment(
        stateD, actions, HRLutils.datafile("smallgrid.txt"), cartesian=True,
        delay=(0.6, 0.9), datacollection=False)
    net.add(env)

    net.connect(env.getOrigin("state"), agent.getTermination("state_input"))
    net.connect(env.getOrigin("reward"), agent.getTermination("reward"))
    net.connect(env.getOrigin("reset"), agent.getTermination("reset"))
    net.connect(env.getOrigin("learn"), agent.getTermination("learn"))
    net.connect(env.getOrigin("reset"), agent.getTermination("save_state"))
    net.connect(env.getOrigin("reset"), agent.getTermination("save_action"))

    net.connect(agent.getOrigin("action_output"), env.getTermination("action"))
    net.connect(agent.getOrigin("Qs"), env.getTermination("Qs"))

    net.add_to_nengo()
    view = timeview.View(net.network, update_frequency=5)
    view.add_watch(gridworldwatch.GridWorldWatch())
    view.restore()
开发者ID:drasmuss,项目名称:nhrlmodel,代码行数:38,代码来源:run.py

示例10: gen_evalpoints

# 需要导入模块: from hrlproject.misc import HRLutils [as 别名]
# 或者: from hrlproject.misc.HRLutils import datafile [as 别名]
def gen_evalpoints(filename, seed=None):
    """Runs an environment for some length of time and records state values,
    to be used as eval points for agent initialization.

    :param filename: name of file in which to save eval points
    :param seed: random seed
    """

    if seed is not None:
        HRLutils.set_seed(seed)
    seed = HRLutils.SEED

    net = nef.Network("gen_evalpoints")

    contextD = 2
    actions = [("up", [0, 1]), ("right", [1, 0]),
               ("down", [0, -1]), ("left", [-1, 0])]

    rewards = {"a": 1, "b": 1}

    env = contextenvironment.ContextEnvironment(
        actions, HRLutils.datafile("contextmap.bmp"), contextD, rewards,
        imgsize=(5, 5), dx=0.001, placedev=0.5,
        colormap={-16777216: "wall", -1: "floor", -256: "a", -2088896: "b"})

    net.add(env)

    stateD = len(env.placecells) + contextD
    actions = env.actions
    actionD = len(actions)

    class EvalRecorder(nef.SimpleNode):
        def __init__(self, evalfile):
            self.action = actions[0]
            self.evalpoints = []
            self.evalfile = evalfile

            nef.SimpleNode.__init__(self, "EvalRecorder")

        def tick(self):
            if self.t % 0.1 < 0.001:
                self.evalpoints += [self.state]

            if self.t % 10.0 < 0.001:
                if len(self.evalpoints) > 10000:
                    self.evalpoints = self.evalpoints[len(self.evalpoints) -
                                                      10000:]

                with open(self.evalfile, "w") as f:
                    f.write("\n".join([" ".join([str(x) for x in e])
                                       for e in self.evalpoints]))

        def termination_state(self, x, dimensions=stateD):
            self.state = x

        def termination_action_in(self, x, dimensions=actionD):
            self.action = actions[x.index(max(x))]

        def origin_action_out(self):
            return self.action[1]

    em = EvalRecorder(HRLutils.datafile("%s_%s.txt" % (filename, seed)))
    net.add(em)

    net.connect(em.getOrigin("action_out"), env.getTermination("action"))
    net.connect(env.getOrigin("optimal_move"), em.getTermination("action_in"))
    net.connect(env.getOrigin("placewcontext"), em.getTermination("state"))

#     net.add_to_nengo()
    net.run(10)
开发者ID:drasmuss,项目名称:nhrlmodel,代码行数:72,代码来源:run.py

示例11: run_badreenvironment

# 需要导入模块: from hrlproject.misc import HRLutils [as 别名]
# 或者: from hrlproject.misc.HRLutils import datafile [as 别名]

#.........这里部分代码省略.........
#             if [round(a) for a in env.state[-2:]] == [round(b)
#                                                       for b in x[1:]]:
#                 return 1.5
#             else:
#                 return -1.5
#     net.connect(reward_relay, ctrl_agent.getTermination("reward"),
#                 func=ctrl_reward_func)

    # nav rewarded for picking ctrl target
    def nav_reward_func(x):
        if abs(x[0]) < 0.5 or env.action is None:
            return 0.0

        if x[1] + x[2] < 0.5:
            return x[0]

        if x[1] > x[2]:
            return (1.5 if env.action[1] == env.state[:env.num_orientations]
                    else -1.5)
        else:
            return (1.5 if env.action[1] == env.state[env.num_orientations:
                                                      - env.num_colours]
                    else -1.5)
    net.connect(reward_relay, nav_agent.getTermination("reward"),
                func=nav_reward_func)

    # state for navagent controlled by ctrlagent
    ctrl_state_inhib = net.make_array("ctrl_state_inhib", 50, env.stateD,
                                      radius=2, mode=HRLutils.SIMULATION_MODE)
    ctrl_state_inhib.fixMode([SimulationMode.DEFAULT, SimulationMode.RATE])

    inhib_matrix = [[0, -5]] * 50 * env.num_orientations + \
                   [[-5, 0]] * 50 * env.num_shapes + \
                   [[-5, -5]] * 50 * env.num_colours

    # ctrl output inhibits all the non-selected aspects of the state
    net.connect(env.getOrigin("state"), ctrl_state_inhib)
    net.connect(ctrl_agent.getOrigin("action_output"), ctrl_state_inhib,
                transform=inhib_matrix)

    # also give a boost to the selected aspects (so that neurons are roughly
    # equally activated).
    def boost_func(x):
        if x[0] > 0.5:
            return [3 * v for v in x[1:]]
        else:
            return x[1:]
    boost = net.make("boost", 1, 1 + env.stateD, mode="direct")
    boost.fixMode()
    net.connect(ctrl_state_inhib, boost,
                transform=([[0 for _ in range(env.stateD)]] +
                           list(MU.I(env.stateD))))
    net.connect(ctrl_agent.getOrigin("action_output"), boost,
                transform=[[1, 1]] + [[0, 0] for _ in range(env.stateD)])

    net.connect(boost, nav_agent.getTermination("state_input"),
                func=boost_func)

    # save weights
    weight_save = 1.0  # period to save weights (realtime, not simulation time)
    threads = [
        HRLutils.WeightSaveThread(nav_agent.getNode("QNetwork").saveParams,
                                  os.path.join("weights", "%s_%s" %
                                               (nav_agent.name, seed)),
                                  weight_save),
        HRLutils.WeightSaveThread(ctrl_agent.getNode("QNetwork").saveParams,
                                  os.path.join("weights", "%s_%s" %
                                               (ctrl_agent.name, seed)),
                                  weight_save)]
    for t in threads:
        t.start()

    # data collection node
    data = datanode.DataNode(period=1,
                             filename=HRLutils.datafile("dataoutput_%s.txt" %
                                                        label),
                             header="%s %s %s %s %s" % (nav_args, ctrl_args,
                                                        bias, seed, flat))
    print "saving data to", data.filename
    print "header", data.header
    net.add(data)
    nav_q = nav_agent.getNode("QNetwork")
    ctrl_q = ctrl_agent.getNode("QNetwork")
    ctrl_bg = ctrl_agent.getNode("BGNetwork").getNode("weight_actions")
    data.record_avg(env.getOrigin("reward"))
    data.record_avg(ctrl_q.getNode("actionvals").getOrigin("X"))
    data.record_sparsity(ctrl_q.getNode("state_pop").getOrigin("AXON"))
    data.record_sparsity(nav_q.getNode("state_pop").getOrigin("AXON"))
    data.record_avg(ctrl_q.getNode("valdiff").getOrigin("X"))
    data.record_avg(ctrl_agent.getNode("ErrorNetwork").getOrigin("error"))
    data.record_avg(ctrl_bg.getNode("0").getOrigin("AXON"))
    data.record_avg(ctrl_bg.getNode("1").getOrigin("AXON"))
    data.record(env.getOrigin("score"))

#     net.add_to_nengo()
#     net.network.simulator.run(0, 300, 0.001)
    net.view()

    for t in threads:
        t.stop()
开发者ID:drasmuss,项目名称:nhrlmodel,代码行数:104,代码来源:run.py

示例12: run_flat_delivery

# 需要导入模块: from hrlproject.misc import HRLutils [as 别名]
# 或者: from hrlproject.misc.HRLutils import datafile [as 别名]
def run_flat_delivery(args, seed=None):
    """Runs the model on the delivery task with only one hierarchical level."""

    if seed is not None:
        HRLutils.set_seed(seed)
    seed = HRLutils.SEED

    net = nef.Network("run_flat_delivery")

    if "load_weights" in args and args["load_weights"] is not None:
        args["load_weights"] += "_%s" % seed

    stateN = 1200
    contextD = 2
    context_scale = 1.0
    max_state_input = 2
    actions = [("up", [0, 1]), ("right", [1, 0]),
               ("down", [0, -1]), ("left", [-1, 0])]

    # ##ENVIRONMENT

    env = deliveryenvironment.DeliveryEnvironment(
        actions, HRLutils.datafile("contextmap.bmp"),
        colormap={-16777216: "wall", -1: "floor", -256: "a", -2088896: "b"},
        imgsize=(5, 5), dx=0.001, placedev=0.5)
    net.add(env)

    print "generated", len(env.placecells), "placecells"

    # ##NAV AGENT

    enc = env.gen_encoders(stateN, contextD, context_scale)
    enc = MU.prod(enc, 1.0 / max_state_input)

    with open(HRLutils.datafile("contextbmp_evalpoints_%s.txt" % seed)) as f:
        evals = [[float(x) for x in l.split(" ")] for l in f.readlines()]

    nav_agent = smdpagent.SMDPAgent(stateN, len(env.placecells) + contextD,
                                    actions, name="NavAgent",
                                    state_encoders=enc, state_evals=evals,
                                    state_threshold=0.8, **args)
    net.add(nav_agent)

    print "agent neurons:", nav_agent.countNeurons()

    net.connect(nav_agent.getOrigin("action_output"),
                env.getTermination("action"))
    net.connect(env.getOrigin("placewcontext"),
                nav_agent.getTermination("state_input"))

    nav_term_node = terminationnode.TerminationNode(
        {terminationnode.Timer((0.6, 0.9)): None}, env, name="NavTermNode",
        contextD=2)
    net.add(nav_term_node)
    net.connect(env.getOrigin("context"),
                nav_term_node.getTermination("context"))
    net.connect(nav_term_node.getOrigin("reset"),
                nav_agent.getTermination("reset"))
    net.connect(nav_term_node.getOrigin("learn"),
                nav_agent.getTermination("learn"))
    net.connect(nav_term_node.getOrigin("reset"),
                nav_agent.getTermination("save_state"))
    net.connect(nav_term_node.getOrigin("reset"),
                nav_agent.getTermination("save_action"))

    reward_relay = net.make("reward_relay", 1, 1, mode="direct")
    reward_relay.fixMode()
    net.connect(env.getOrigin("reward"), reward_relay)
    net.connect(nav_term_node.getOrigin("pseudoreward"), reward_relay)
    net.connect(reward_relay, nav_agent.getTermination("reward"))

    # period to save weights (realtime, not simulation time)
    weight_save = 600.0
    HRLutils.WeightSaveThread(nav_agent.getNode("QNetwork").saveParams,
                              os.path.join("weights", "%s_%s" %
                                           (nav_agent.name, seed)),
                              weight_save).start()

    # data collection node
    data = datanode.DataNode(period=5,
                             filename=HRLutils.datafile("dataoutput_%s.txt" %
                                                        seed))
    net.add(data)
    q_net = nav_agent.getNode("QNetwork")
    data.record_avg(env.getOrigin("reward"))
    data.record_avg(q_net.getNode("actionvals").getOrigin("X"))
    data.record_sparsity(q_net.getNode("state_pop").getOrigin("AXON"))
    data.record_avg(q_net.getNode("valdiff").getOrigin("X"))
    data.record_avg(nav_agent.getNode("ErrorNetwork").getOrigin("error"))

#    net.add_to_nengo()
#    net.run(10000)
    net.view()
开发者ID:drasmuss,项目名称:nhrlmodel,代码行数:95,代码来源:run.py

示例13: run_deliveryenvironment

# 需要导入模块: from hrlproject.misc import HRLutils [as 别名]
# 或者: from hrlproject.misc.HRLutils import datafile [as 别名]
def run_deliveryenvironment(navargs, ctrlargs, tag=None, seed=None):
    """Runs the model on the delivery task.

    :param navargs: kwargs for the nav_agent (see SMDPAgent.__init__)
    :param ctrlargs: kwargs for the ctrl_agent (see SMDPAgent.__init__)
    :param tag: string appended to datafiles associated with this run
    :param seed: random seed used for this run
    """

    if seed is not None:
        HRLutils.set_seed(seed)
    seed = HRLutils.SEED

    if tag is None:
        tag = str(seed)

    net = nef.Network("runDeliveryEnvironment", seed=seed)

    stateN = 1200  # number of neurons to use in state population
    contextD = 2  # dimension of context vector
    context_scale = 1.0  # relative scale of context vector vs state vector
    max_state_input = 2  # maximum length of input vector to state population

    # labels and vectors corresponding to basic actions available to the system
    actions = [("up", [0, 1]), ("right", [1, 0]),
               ("down", [0, -1]), ("left", [-1, 0])]

    if "load_weights" in navargs and navargs["load_weights"] is not None:
        navargs["load_weights"] += "_%s" % tag
    if "load_weights" in ctrlargs and ctrlargs["load_weights"] is not None:
        ctrlargs["load_weights"] += "_%s" % tag

    # ##ENVIRONMENT

    env = deliveryenvironment.DeliveryEnvironment(
        actions, HRLutils.datafile("contextmap.bmp"),
        colormap={-16777216: "wall", -1: "floor", -256: "a", -2088896: "b"},
        imgsize=(5, 5), dx=0.001, placedev=0.5)
    net.add(env)

    print "generated", len(env.placecells), "placecells"

    # ##NAV AGENT

    # generate encoders and divide them by max_state_input (so that inputs
    # will be scaled down to radius 1)
    enc = env.gen_encoders(stateN, contextD, context_scale)
    enc = MU.prod(enc, 1.0 / max_state_input)

    # read in eval points from file
    with open(HRLutils.datafile("contextbmp_evalpoints_%s.txt" % tag)) as f:
        evals = [[float(x) for x in l.split(" ")] for l in f.readlines()]

    nav_agent = smdpagent.SMDPAgent(stateN, len(env.placecells) + contextD,
                                    actions, name="NavAgent",
                                    state_encoders=enc, state_evals=evals,
                                    state_threshold=0.8,
                                    **navargs)
    net.add(nav_agent)

    print "agent neurons:", nav_agent.countNeurons()

    # output of nav_agent is what goes to the environment
    net.connect(nav_agent.getOrigin("action_output"),
                env.getTermination("action"))

    # termination node for nav_agent (just a timer that goes off regularly)
    nav_term_node = terminationnode.TerminationNode(
        {terminationnode.Timer((0.6, 0.9)): None}, env, contextD=2,
        name="NavTermNode")
    net.add(nav_term_node)

    net.connect(nav_term_node.getOrigin("reset"),
                nav_agent.getTermination("reset"))
    net.connect(nav_term_node.getOrigin("learn"),
                nav_agent.getTermination("learn"))
    net.connect(nav_term_node.getOrigin("reset"),
                nav_agent.getTermination("save_state"))
    net.connect(nav_term_node.getOrigin("reset"),
                nav_agent.getTermination("save_action"))

    # ##CTRL AGENT

    # actions corresponding to "go to A" or "go to B"
    actions = [("a", [0, 1]), ("b", [1, 0])]
    ctrl_agent = smdpagent.SMDPAgent(stateN, len(env.placecells) + contextD,
                                     actions, name="CtrlAgent",
                                     state_encoders=enc, state_evals=evals,
                                     state_threshold=0.8, **ctrlargs)
    net.add(ctrl_agent)

    print "agent neurons:", ctrl_agent.countNeurons()

    # ctrl_agent gets environmental state and reward
    net.connect(env.getOrigin("placewcontext"),
                ctrl_agent.getTermination("state_input"))
    net.connect(env.getOrigin("reward"),
                ctrl_agent.getTermination("reward"))

    # termination node for ctrl_agent (terminates whenever the agent is in the
#.........这里部分代码省略.........
开发者ID:drasmuss,项目名称:nhrlmodel,代码行数:103,代码来源:run.py

示例14: run_contextenvironment

# 需要导入模块: from hrlproject.misc import HRLutils [as 别名]
# 或者: from hrlproject.misc.HRLutils import datafile [as 别名]
def run_contextenvironment(args, seed=None):
    """Runs the model on the context task.

    :param args: kwargs for the agent
    :param seed: random seed
    """

    if seed is not None:
        HRLutils.set_seed(seed)
    seed = HRLutils.SEED

    net = nef.Network("runContextEnvironment")

    if "load_weights" in args and args["load_weights"] is not None:
        args["load_weights"] += "_%s" % seed

    stateN = 1200  # number of neurons to use in state population
    contextD = 2  # dimension of context vector
    context_scale = 1.0  # scale of context representation
    max_state_input = 2  # max length of input vector for state population
    # actions (label and vector) available to the system
    actions = [("up", [0, 1]), ("right", [1, 0]),
               ("down", [0, -1]), ("left", [-1, 0])]

    # context labels and rewards for achieving those context goals
    rewards = {"a": 1.5, "b": 1.5}

    env = contextenvironment.ContextEnvironment(
        actions, HRLutils.datafile("contextmap.bmp"), contextD, rewards,
        colormap={-16777216: "wall", -1: "floor", -256: "a", -2088896: "b"},
        imgsize=(5, 5), dx=0.001, placedev=0.5)
    net.add(env)

    print "generated", len(env.placecells), "placecells"

    # termination node for agent (just goes off on some regular interval)
    term_node = terminationnode.TerminationNode(
        {terminationnode.Timer((0.6, 0.9)): 0.0}, env)
    net.add(term_node)

    # generate encoders and divide by max_state_input (so that all inputs
    # will end up being radius 1)
    enc = env.gen_encoders(stateN, contextD, context_scale)
    enc = MU.prod(enc, 1.0 / max_state_input)

    # load eval points from file
    with open(HRLutils.datafile("contextbmp_evalpoints_%s.txt" % seed)) as f:
        print "loading contextbmp_evalpoints_%s.txt" % seed
        evals = [[float(x) for x in l.split(" ")] for l in f.readlines()]

    agent = smdpagent.SMDPAgent(stateN, len(env.placecells) + contextD,
                                actions, state_encoders=enc, state_evals=evals,
                                state_threshold=0.8, **args)
    net.add(agent)

    print "agent neurons:", agent.countNeurons()

    # period to save weights (realtime, not simulation time)
    weight_save = 600.0
    t = HRLutils.WeightSaveThread(agent.getNode("QNetwork").saveParams,
                                  os.path.join("weights", "%s_%s" %
                                               (agent.name, seed)),
                                  weight_save)
    t.start()

    # data collection node
    data = datanode.DataNode(period=5,
                             filename=HRLutils.datafile("dataoutput_%s.txt" %
                                                        seed))
    net.add(data)
    q_net = agent.getNode("QNetwork")
    data.record(env.getOrigin("reward"))
    data.record(q_net.getNode("actionvals").getOrigin("X"), func=max)
    data.record(q_net.getNode("actionvals").getOrigin("X"), func=min)
    data.record_sparsity(q_net.getNode("state_pop").getOrigin("AXON"))
    data.record_avg(q_net.getNode("valdiff").getOrigin("X"))
    data.record_avg(env.getOrigin("state"))

    net.connect(env.getOrigin("placewcontext"),
                agent.getTermination("state_input"))
    net.connect(env.getOrigin("reward"), agent.getTermination("reward"))
    net.connect(term_node.getOrigin("reset"), agent.getTermination("reset"))
    net.connect(term_node.getOrigin("learn"), agent.getTermination("learn"))
    net.connect(term_node.getOrigin("reset"),
                agent.getTermination("save_state"))
    net.connect(term_node.getOrigin("reset"),
                agent.getTermination("save_action"))

    net.connect(agent.getOrigin("action_output"), env.getTermination("action"))

#    net.add_to_nengo()
#    net.run(2000)
    net.view()

    t.stop()
开发者ID:drasmuss,项目名称:nhrlmodel,代码行数:97,代码来源:run.py

示例15: run_badreenvironment

# 需要导入模块: from hrlproject.misc import HRLutils [as 别名]
# 或者: from hrlproject.misc.HRLutils import datafile [as 别名]

#.........这里部分代码省略.........

    nav_term_node = terminationnode.TerminationNode({terminationnode.Timer((0.6, 0.6)):None}, env,
                                                    name="NavTermNode", state_delay=0.1)
    net.add(nav_term_node)

    net.connect(nav_term_node.getOrigin("reset"), nav_agent.getTermination("reset"))
    net.connect(nav_term_node.getOrigin("learn"), nav_agent.getTermination("learn"))
    net.connect(nav_term_node.getOrigin("reset"), nav_agent.getTermination("save_state"))
    net.connect(nav_term_node.getOrigin("reset"), nav_agent.getTermination("save_action"))

    net.connect(nav_agent.getOrigin("action_output"), env.getTermination("action"))

    ###CTRL AGENT
    enc = env.gen_encoders(stateN, 0, 0)
    enc = MU.prod(enc, 1.0 / max_state_input)
    actions = [("shape", [0, 1]), ("orientation", [1, 0]), ("null", [0, 0])]
    ctrl_agent = smdpagent.SMDPAgent(stateN, env.stateD, actions, name="CtrlAgent",
                                     load_weights=None, state_encoders=enc,
                                     state_evals=evals, discount=0.4, **ctrl_args)
    net.add(ctrl_agent)

    print "agent neurons:", ctrl_agent.countNeurons()

    net.connect(env.getOrigin("state"), ctrl_agent.getTermination("state_input"))

    ctrl_term_node = terminationnode.TerminationNode({terminationnode.Timer((0.6, 0.6)):None},
                                                     env, name="CtrlTermNode",
                                                     state_delay=0.1)
    net.add(ctrl_term_node)

    net.connect(ctrl_term_node.getOrigin("reset"), ctrl_agent.getTermination("reset"))
    net.connect(ctrl_term_node.getOrigin("learn"), ctrl_agent.getTermination("learn"))
    net.connect(ctrl_term_node.getOrigin("reset"), ctrl_agent.getTermination("save_state"))
    net.connect(ctrl_term_node.getOrigin("reset"), ctrl_agent.getTermination("save_action"))
    
    
    ## reward for nav/ctrl
    reward_relay = net.make("reward_relay", 1, 2, mode="direct")
    reward_relay.fixMode()
    net.connect(env.getOrigin("reward"), reward_relay, transform=[[1], [0]])
    net.connect(ctrl_agent.getOrigin("action_output"), reward_relay, transform=[[0, 0], [1, 1]])
    
    # nav reward is just environment
    net.connect(reward_relay, nav_agent.getTermination("reward"), 
                func=lambda x: x[0], origin_name="nav_reward")
    
    # ctrl gets a slight bonus if it selects a rule (as opposed to null), to encourage it not
    # to just pick null all the time
    net.connect(reward_relay, ctrl_agent.getTermination("reward"), 
                func=lambda x: x[0]+0.25*abs(x[0]) if x[1] > 0.5 else x[0], origin_name="ctrl_reward")

    ## state for navagent controlled by ctrlagent
#    ctrl_output_relay = net.make("ctrl_output_relay", 1, env.stateD+2, mode="direct")
#    ctrl_output_relay.fixMode()
    ctrl_output_relay = net.make_array("ctrl_output_relay", 50, env.stateD,
                                       radius=2, mode=HRLutils.SIMULATION_MODE)
    ctrl_output_relay.fixMode([SimulationMode.DEFAULT, SimulationMode.RATE])
    
    inhib_matrix = [[0,-5]]*50*env.num_orientations + \
                   [[-5,0]]*50*env.num_shapes + \
                   [[-5,-5]]*50*env.num_colours

    # ctrl output inhibits all the non-selected aspects of the state
    net.connect(env.getOrigin("state"), ctrl_output_relay)
    net.connect(ctrl_agent.getOrigin("action_output"), ctrl_output_relay,
#                transform=zip([0]*env.num_orientations + [-1]*(env.num_shapes+env.num_colours),
#                              [-1]*env.num_orientations + [0]*env.num_shapes + [-1]*env.num_colours))
                transform=inhib_matrix)
    
    # also give a boost to the selected aspects (so that neurons are roughly equally activated).
    # adding 2/3 to each element (base vector has length 3, inhibited vector has length 1, so add 2/3*3 --> 3)
    net.connect(ctrl_agent.getOrigin("action_output"), ctrl_output_relay,
                transform=zip([0.66]*env.num_orientations + [0]*(env.num_shapes+env.num_colours),
                              [0]*env.num_orientations + [0.66]*env.num_shapes + [2]*env.num_colours))

    net.connect(ctrl_output_relay, nav_agent.getTermination("state_input"))

    # save weights
    weight_save = 600.0 # period to save weights (realtime, not simulation time)
    HRLutils.WeightSaveThread(nav_agent.getNode("QNetwork").saveParams,
                     os.path.join("weights", "%s_%s" % (nav_agent.name, seed)), weight_save).start()
    HRLutils.WeightSaveThread(ctrl_agent.getNode("QNetwork").saveParams,
                     os.path.join("weights", "%s_%s" % (ctrl_agent.name, seed)), weight_save).start()

    # data collection node
    data = datanode.DataNode(period=5, show_plots=None, filename=HRLutils.datafile("dataoutput_%s.txt" % seed))
    filter = 1e-5
    net.add(data)
    data.record_avg(env.getOrigin("reward"), filter=filter)
    data.record_avg(ctrl_agent.getNode("QNetwork").getNode("actionvals").getOrigin("X"), filter=filter)
    data.record_sparsity(ctrl_agent.getNode("QNetwork").getNode("state_pop").getOrigin("AXON"), filter=filter)
    data.record_sparsity(nav_agent.getNode("QNetwork").getNode("state_pop").getOrigin("AXON"), filter=filter)
    data.record_avg(ctrl_agent.getNode("QNetwork").getNode("valdiff").getOrigin("X"), filter=filter)
    data.record_avg(ctrl_agent.getNode("ErrorNetwork").getOrigin("error"), filter=filter)
    data.record_avg(ctrl_agent.getNode("BGNetwork").getNode("weight_actions").getNode("0").getOrigin("AXON"), filter=filter)
    data.record_avg(ctrl_agent.getNode("BGNetwork").getNode("weight_actions").getNode("1").getOrigin("AXON"), filter=filter)

    net.add_to_nengo()
#    net.view()
    net.run(2000)
开发者ID:Seanny123,项目名称:HRL_1.0,代码行数:104,代码来源:run.py


注:本文中的hrlproject.misc.HRLutils.datafile方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。