当前位置: 首页>>代码示例>>Python>>正文


Python NeuralNetwork.test方法代码示例

本文整理汇总了Python中NeuralNetwork.NeuralNetwork.test方法的典型用法代码示例。如果您正苦于以下问题:Python NeuralNetwork.test方法的具体用法?Python NeuralNetwork.test怎么用?Python NeuralNetwork.test使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在NeuralNetwork.NeuralNetwork的用法示例。


在下文中一共展示了NeuralNetwork.test方法的1个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: DigitClassifier

# 需要导入模块: from NeuralNetwork import NeuralNetwork [as 别名]
# 或者: from NeuralNetwork.NeuralNetwork import test [as 别名]
class DigitClassifier(tkinter.Tk):
    def __init__(self):
        tkinter.Tk.__init__(self)
        self.nn = NeuralNetwork(784, 300, 10)

        self.background = tkinter.Canvas(self, width = 308, height = 308)
        self.background.config(background="black")
        self.input_canvas = InputCanvas(self, width = 300, height = 300)
        self.result_label = tkinter.Label(self, text='')
        self.recog_button = tkinter.Button(self, text='Recognize', command=self.recognize)
        self.clear_button = tkinter.Button(self, text='Clear', command=self.input_canvas.clear)

        self.background.pack()
        self.input_canvas.place(x=4, y=4)
        self.result_label.pack()
        self.recog_button.pack()
        self.clear_button.pack()

    def train_nn(self, epochs=100000, edit_image=False):
        """ニューラルネットワークを訓練する"""
        import Mnist
        labels = Mnist.trainLabels
        images = Mnist.trainImages
        inputs, targets = [], []
        for _ in range(epochs):
            i = int(random.random() * len(labels))
            target = np.zeros(10)
            if edit_image:
                # 訓練データを加工する
                img = Image.fromarray(images[i])
                new_img = Image.new('L', (28, 28))
                new_img.paste(img.rotate(random.uniform(-45.0, 45.0)),
                              (random.randint(-5.0, 5.0), random.randint(-5.0, 5.0)))
                image = np.asarray(new_img).ravel()
            else:
                # 加工なし
                image = images[i].ravel()
            inputs.append(image/255.0)
            target[labels[i]] = 1.0
            targets.append(target)
        print("start training...")
        self.nn.train(np.array(inputs), np.array(targets), n=0.01)

        labels = Mnist.testLabels
        images = Mnist.testImages
        inputs, targets = [], []
        for i in range(len(labels)):
            target = np.zeros(10)
            inputs.append(images[i].ravel() / 255.0)
            target[labels[i]] = 1.0
            targets.append(target)
        print("start testing...")
        results = self.nn.test(np.array(inputs), np.array(targets))
        #print(results)

        overall = np.zeros((10, 10), dtype=int)
        correct = 0
        for result, target in zip(results, targets):
            ri = max(enumerate(result), key=lambda x: x[1])[0]
            ti = max(enumerate(target), key=lambda x: x[1])[0]
            overall[ti, ri] += 1
            if ti == ri:
                correct += 1
        print(overall)
        print(float(correct)/len(labels))

        # 訓練後のパラメータを保存する
        np.save('parameters/w1_2.npy', self.nn.w1_2)
        np.save('parameters/w2_3.npy', self.nn.w2_3)

    def load_nn_parameters(self):
        # パラメータを読み込む
        self.nn.w1_2 = np.load('parameters/w1_2.npy')
        self.nn.w2_3 = np.load('parameters/w2_3.npy')

    def recognize(self):
        # キャンバスに書き込まれた数字を認識する
        img = self.input_canvas.getImage().filter(ImageFilter.BLUR).convert('L')
        img.thumbnail((28, 28), getattr(Image, 'ANTIALIAS'))
        img = img.point(lambda x: 255 - x)
        input = np.asarray(img).ravel()
        result = self.nn.test([input / 255.0], np.zeros(10))[0]
        num = max(enumerate(result), key=lambda x: x[1])[0]
        self.result_label.configure(text = str(num))
        print(num, result)
开发者ID:ommadawn46,项目名称:DigitClassifier,代码行数:87,代码来源:DigitClassifier.py


注:本文中的NeuralNetwork.NeuralNetwork.test方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。