當前位置: 首頁>>代碼示例>>Python>>正文


Python Net.add_datum方法代碼示例

本文整理匯總了Python中net.Net.add_datum方法的典型用法代碼示例。如果您正苦於以下問題:Python Net.add_datum方法的具體用法?Python Net.add_datum怎麽用?Python Net.add_datum使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在net.Net的用法示例。


在下文中一共展示了Net.add_datum方法的2個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: Dagger

# 需要導入模塊: from net import Net [as 別名]
# 或者: from net.Net import add_datum [as 別名]
class Dagger():

    def __init__(self, grid, mdp, moves=40):
        self.grid = grid
        self.mdp = mdp
        self.svm = LinearSVM(grid, mdp)
        self.net = Net(grid,mdp)
        self.moves = moves
        #self.reward = np.zeros(40)
        self.super_pi = mdp.pi
        self.reward = np.zeros(self.moves)
        self.animate = False
        self.record = True
        self.recent_rollout_states = None
        
    def rollout(self):
        self.grid.reset_mdp()
        self.reward = np.zeros(self.moves)
        self.recent_rollout_states = [self.mdp.state]
        self.mistakes = 0.0
        for t in range(self.moves):
            if self.record:
                assert self.super_pi.desc == ClassicPolicy.desc
                self.net.add_datum(self.mdp.state, self.super_pi.get_next(self.mdp.state))
            #Get current state and action
            x_t = self.mdp.state
            a_t = self.mdp.pi.get_next(x_t)

            self.compare_policies(x_t, a_t)

            #Take next step 
            self.grid.step(self.mdp)

            x_t_1 = self.mdp.state

            #Evaualte reward recieved 
            self.reward[t] = self.grid.reward(x_t,a_t,x_t_1)
            self.recent_rollout_states.append(self.mdp.state)

        if(self.animate):
            self.grid.show_recording()
        #print self.svm.data

    def compare_policies(self, x, a):
        if self.super_pi.get_next(x) != a:
            self.mistakes += 1

    def get_states(self):
        return self.net.get_states()
    def get_reward(self):
        return np.sum(self.reward)
    def set_supervisor_pi(self, pi):
        self.super_pi = pi

    def get_loss(self):
        return float(self.mistakes) / float(self.moves)

    def get_recent_rollout_states(self):
        N = len(self.recent_rollout_states)
        states = np.zeros([N,2])
        for i in range(N):
            x = self.recent_rollout_states[i].toArray()
            states[i,:] = x        
        return states
    
    def retrain(self):
        self.net.fit()
        self.mdp.pi = NetPolicy(self.net)
開發者ID:jon--lee,項目名稱:daggermdp,代碼行數:70,代碼來源:dagger.py

示例2: NSupervise

# 需要導入模塊: from net import Net [as 別名]
# 或者: from net.Net import add_datum [as 別名]
class NSupervise():

    def __init__(self, grid, mdp, moves=40,net = 'Net'):
        self.grid = grid
        self.mdp = mdp
        self.net_name = net
        self.svm = LinearSVM(grid, mdp)
        self.net = Net(grid,mdp,net,T=moves)
        self.moves = moves
        #self.reward = np.zeros(40)
        self.super_pi = mdp.pi
        self.mdp.pi_noise = False
        self.reward = np.zeros(self.moves)
        self.animate = False
        self.train_loss = 0
        self.test_loss = 0
        self.record = True
        
    def rollout(self):
        self.grid.reset_mdp()
        self.reward = np.zeros(self.moves)
        for t in range(self.moves):
            a = self.super_pi.get_next(self.mdp.state)
            #print "action ",a
            
            #Get current state and action
            x_t = self.mdp.state
            a_t = self.mdp.pi.get_next(x_t)



            #Take next step 
            a_taken = self.grid.step(self.mdp)

            print "action taken ", a_taken
            print "timestep ", t
            if(self.record):
                if(self.net_name == 'UB'):
                    self.net.add_datum(x_t, a,a_taken)
                else:
                    self.net.add_datum(x_t,a)

            x_t_1 = self.mdp.state

            #Evaualte reward recieved 
            self.reward[t] = self.grid.reward(x_t,a_t,x_t_1)


        if(self.animate):
            self.grid.show_recording()
        
        #print self.svm.data
    def sample_policy(self):
        self.record = True
        self.net.clear_data()
    def get_states(self):
        return self.net.get_states()
    def get_weights(self):
        return self.net.get_weights()
    def get_reward(self):
        return np.sum(self.reward)
    def set_supervisor_pi(self, pi):
        self.super_pi = pi

    def train(self):
        self.net.fit()
        stats = self.net.return_stats()
        self.train_loss = stats[0]
        self.test_loss = stats[1]
        self.mdp.pi_noise = False
        self.mdp.pi = NetPolicy(self.net)
        self.record = False

    def get_train_loss(self):
        return self.train_loss

    def get_test_loss(self):
        return self.test_loss
開發者ID:jon--lee,項目名稱:daggermdp,代碼行數:80,代碼來源:nsupervise.py


注:本文中的net.Net.add_datum方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。