當前位置: 首頁>>代碼示例>>Python>>正文


Python Net.return_stats方法代碼示例

本文整理匯總了Python中net.Net.return_stats方法的典型用法代碼示例。如果您正苦於以下問題:Python Net.return_stats方法的具體用法?Python Net.return_stats怎麽用?Python Net.return_stats使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在net.Net的用法示例。


在下文中一共展示了Net.return_stats方法的1個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: NSupervise

# 需要導入模塊: from net import Net [as 別名]
# 或者: from net.Net import return_stats [as 別名]
class NSupervise():

    def __init__(self, grid, mdp, moves=40,net = 'Net'):
        self.grid = grid
        self.mdp = mdp
        self.net_name = net
        self.svm = LinearSVM(grid, mdp)
        self.net = Net(grid,mdp,net,T=moves)
        self.moves = moves
        #self.reward = np.zeros(40)
        self.super_pi = mdp.pi
        self.mdp.pi_noise = False
        self.reward = np.zeros(self.moves)
        self.animate = False
        self.train_loss = 0
        self.test_loss = 0
        self.record = True
        
    def rollout(self):
        self.grid.reset_mdp()
        self.reward = np.zeros(self.moves)
        for t in range(self.moves):
            a = self.super_pi.get_next(self.mdp.state)
            #print "action ",a
            
            #Get current state and action
            x_t = self.mdp.state
            a_t = self.mdp.pi.get_next(x_t)



            #Take next step 
            a_taken = self.grid.step(self.mdp)

            print "action taken ", a_taken
            print "timestep ", t
            if(self.record):
                if(self.net_name == 'UB'):
                    self.net.add_datum(x_t, a,a_taken)
                else:
                    self.net.add_datum(x_t,a)

            x_t_1 = self.mdp.state

            #Evaualte reward recieved 
            self.reward[t] = self.grid.reward(x_t,a_t,x_t_1)


        if(self.animate):
            self.grid.show_recording()
        
        #print self.svm.data
    def sample_policy(self):
        self.record = True
        self.net.clear_data()
    def get_states(self):
        return self.net.get_states()
    def get_weights(self):
        return self.net.get_weights()
    def get_reward(self):
        return np.sum(self.reward)
    def set_supervisor_pi(self, pi):
        self.super_pi = pi

    def train(self):
        self.net.fit()
        stats = self.net.return_stats()
        self.train_loss = stats[0]
        self.test_loss = stats[1]
        self.mdp.pi_noise = False
        self.mdp.pi = NetPolicy(self.net)
        self.record = False

    def get_train_loss(self):
        return self.train_loss

    def get_test_loss(self):
        return self.test_loss
開發者ID:jon--lee,項目名稱:daggermdp,代碼行數:80,代碼來源:nsupervise.py


注:本文中的net.Net.return_stats方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。