How to create simple 3-layer neural network and teach it using supervised learning?

前端 未结 1 1722
青春惊慌失措
青春惊慌失措 2021-01-05 15:16

Based on PyBrain\'s tutorials I managed to knock together the following code:

#!/usr/bin/env python2
# coding: utf-8

from pybrain.structure import FeedForwar         


        
相关标签:
1条回答
  • 2021-01-05 16:22

    There are four problems with your approach, all easy to identify after reading Neural Network FAQ:

    • Why use a bias/threshold?: you should add a bias node. Lack of bias makes the learning very limited: the separating hyperplane represented by the network can only pass through the origin. With the bias node, it can move freely and fit the data better:

      bias = BiasUnit()
      n.addModule(bias)
      
      bias_to_hidden = FullConnection(bias, hiddenLayer)
      n.addConnection(bias_to_hidden)
      
    • Why not code binary inputs as 0 and 1?: all your samples lay in a single quadrant of the sample space. Move them to be scattered around the origin:

      ds = SupervisedDataSet(2, 1)
      ds.addSample((-1, -1), (0,))
      ds.addSample((-1, 1), (1,))
      ds.addSample((1, -1), (1,))
      ds.addSample((1, 1), (0,))
      

      (Fix the validation code at the end of your script accordingly.)

    • trainUntilConvergence method works using validation, and does something that resembles the early stopping method. This doesn't make sense for such a small dataset. Use trainEpochs instead. 1000 epochs is more than enough for this problem for the network to learn:

      trainer.trainEpochs(1000)
      
    • What learning rate should be used for backprop?: Tune the learning rate parameter. This is something you do every time you employ a neural network. In this case, the value 0.1 or even 0.2 dramatically increases the learning speed:

      trainer = BackpropTrainer(n, dataset=ds, learningrate=0.1, verbose=True)
      

      (Note the verbose=True parameter. Observing how the error behaves is essential when tuning parameters.)

    With these fixes I get consistent, and correct results for the given network with the given dataset, and error less than 1e-23.

    0 讨论(0)
提交回复
热议问题