NeuronDotNet: почему моя функция возвращает разные выходы во встроенную?

Я использую NeuronDotNet для нейронных сетей в С#. Чтобы проверить сеть (а также подготовить ее), я написал свою собственную функцию, чтобы получить ошибку квадрата. Однако, когда я протестировал эту функцию, запустив ее на обучающих данных и сравнив ее со средой MeanSquaredError сети Backpropagation, результаты были разными.

Я обнаружил, что причиной разной ошибки является то, что сеть возвращает разные выходные данные, когда я запускаю ее, когда она запускается на этапе обучения. Я запускаю его для каждого TrainingSample, используя:

double[] output = xorNetwork.Run(sample.InputVector);

На этапе обучения его использование:

xorNetwork.Learn(trainingSet, cycles);

... с делегатом, чтобы захватить событие конечной выборки:

xorNetwork.EndSampleEvent +=
    delegate(object network, TrainingSampleEventArgs args)
    {
        double[] test = xorNetwork.OutputLayer.GetOutput();
        debug.addSampleOutput(test);
    };

Я попытался сделать это, используя проблему XOR, чтобы это было просто, а выходы по-прежнему разные. Например, в конце первой эпохи выходы делегата EndSampleEvent и те, что указаны в моей функции:

  • Вход: 01, Ожидаемый: 1, my_функция: 0.703332, EndSampleEvent 0.734385
  • Вход: 00, Ожидаемое: 0, my_function: 0.632568, EndSampleEvent 0.649198
  • Вход: 10, Ожидаемое: 1, my_function: 0.650141, EndSampleEvent 0.710484
  • Вход: 11, Ожидаемый: 0, my_function: 0.715175, EndSampleEvent 0.647102
  • Ошибка: my_function: 0.280508, EndSampleEvent 0.291236

Его не так просто, как его захватывают на другой фазе в эпоху, выходы не идентичны выходам в следующую/предыдущую эпоху.

Я пробовал отлаживать, но я не эксперт в Visual Studio, и я немного борюсь с этим. Мой проект ссылается на DLL NeuronDotNet. Когда я помещаю точки останова в свой код, он не будет входить в код из DLL. Я искал в другом месте советы по этому поводу и попробовал несколько решений и нигде не попадал.

Я не думаю, что это связано с "эффектом наблюдателя", т.е. методом Run в моей функции, вызывающим изменение сети. Я изучил код (в проекте, который делает DLL), и я не думаю, что Run изменяет любой вес. Ошибки от моей функции, как правило, ниже, чем у EndSampleEvent, в несколько раз превышающем уменьшение ошибки от типичной эпохи, т.е. Как будто сеть во время моего кода временно опережает себя (с точки зрения обучения).

Нейронные сети являются стохастическими в том смысле, что они корректируют свои функции во время обучения. Однако выход должен быть детерминированным. Почему два набора выходов различны?

EDIT: Вот код, который я использую.

/***********************************************************************************************
COPYRIGHT 2008 Vijeth D

This file is part of NeuronDotNet XOR Sample.
(Project Website : http://neurondotnet.freehostia.com)

NeuronDotNet is a free software. You can redistribute it and/or modify it under the terms of
the GNU General Public License as published by the Free Software Foundation, either version 3
of the License, or (at your option) any later version.

NeuronDotNet is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
See the GNU General Public License for more details.

You should have received a copy of the GNU General Public License along with NeuronDotNet.
If not, see <http://www.gnu.org/licenses/>.

***********************************************************************************************/

using System;
using System.Collections.Generic;
using System.Drawing;
using System.IO;
using System.Text;
using System.Windows.Forms;
using NeuronDotNet.Core;
using NeuronDotNet.Core.Backpropagation;
using ZedGraph;

namespace NeuronDotNet.Samples.XorSample
{
    public partial class MainForm : Form
    {
        private BackpropagationNetwork xorNetwork;
        private double[] errorList;
        private int cycles = 5000;
        private int neuronCount = 3;
        private double learningRate = 0.25d;

        public MainForm()
        {
            InitializeComponent();
        }

        private void Train(object sender, EventArgs e)
        {
            EnableControls(false);
            if (!int.TryParse(txtCycles.Text.Trim(), out cycles)) { cycles = 5000; }
            if (!double.TryParse(txtLearningRate.Text.Trim(), out learningRate)) { learningRate = 0.25d; }
            if (!int.TryParse(txtNeuronCount.Text.Trim(), out neuronCount)) { neuronCount = 3; }

            if (cycles < 1) { cycles = 1; }
            if (learningRate < 0.01) { learningRate = 0.01; }
            if (neuronCount < 1) { neuronCount = 1; }

            txtNeuronCount.Text = neuronCount.ToString();
            txtCycles.Text = cycles.ToString();
            txtLearningRate.Text = learningRate.ToString();

            errorList = new double[cycles];
            InitGraph();

            LinearLayer inputLayer = new LinearLayer(2);
            SigmoidLayer hiddenLayer = new SigmoidLayer(neuronCount);
            SigmoidLayer outputLayer = new SigmoidLayer(1);
            new BackpropagationConnector(inputLayer, hiddenLayer);
            new BackpropagationConnector(hiddenLayer, outputLayer);
            xorNetwork = new BackpropagationNetwork(inputLayer, outputLayer);
            xorNetwork.SetLearningRate(learningRate);

            TrainingSet trainingSet = new TrainingSet(2, 1);
            trainingSet.Add(new TrainingSample(new double[2] { 0d, 0d }, new double[1] { 0d }));
            trainingSet.Add(new TrainingSample(new double[2] { 0d, 1d }, new double[1] { 1d }));
            trainingSet.Add(new TrainingSample(new double[2] { 1d, 0d }, new double[1] { 1d }));
            trainingSet.Add(new TrainingSample(new double[2] { 1d, 1d }, new double[1] { 0d }));
           Console.WriteLine("mse_begin,mse_end,output,outputs,myerror");
            double max = 0d;
         Console.WriteLine(NNDebug.Header);
           List < NNDebug > debugList = new List<NNDebug>();
           NNDebug debug = null;
         xorNetwork.BeginEpochEvent +=
              delegate(object network, TrainingEpochEventArgs args)
                 {
                  debug = new NNDebug(trainingSet);
                 };

           xorNetwork.EndSampleEvent +=
            delegate(object network, TrainingSampleEventArgs args)
                 {                                                  
                  double[] test = xorNetwork.OutputLayer.GetOutput();

                  debug.addSampleOutput(args.TrainingSample, test);
                 };

         xorNetwork.EndEpochEvent +=
            delegate(object network, TrainingEpochEventArgs args)
            {    
               errorList[args.TrainingIteration] = xorNetwork.MeanSquaredError;
               debug.setMSE(xorNetwork.MeanSquaredError);
               double[] test = xorNetwork.OutputLayer.GetOutput();
               GetError(trainingSet, debug);
               max = Math.Max(max, xorNetwork.MeanSquaredError);
               progressBar.Value = (int)(args.TrainingIteration * 100d / cycles);
               //Console.WriteLine(debug);
               debugList.Add(debug);
            };

            xorNetwork.Learn(trainingSet, cycles);
            double[] indices = new double[cycles];
            for (int i = 0; i < cycles; i++) { indices[i] = i; }

            lblTrainErrorVal.Text = xorNetwork.MeanSquaredError.ToString("0.000000");

            LineItem errorCurve = new LineItem("Error Dynamics", indices, errorList, Color.Tomato, SymbolType.None, 1.5f);
            errorGraph.GraphPane.YAxis.Scale.Max = max;
            errorGraph.GraphPane.CurveList.Add(errorCurve);
            errorGraph.Invalidate();
         writeOut(debugList);
            EnableControls(true);
        }

       private const String pathFileName = "C:\\Temp\\NDN_Debug_Output.txt";

      private void writeOut(IEnumerable<NNDebug> data)
      {
         using (StreamWriter streamWriter = new StreamWriter(pathFileName))
         {
            streamWriter.WriteLine(NNDebug.Header);

            //write results to a file for each load combination
            foreach (NNDebug debug in data)
            {
               streamWriter.WriteLine(debug);
            }
         } 
      }

      private void GetError(TrainingSet trainingSet, NNDebug debug)
      {
         double total = 0;
         foreach (TrainingSample sample in trainingSet.TrainingSamples)
         {
            double[] output = xorNetwork.Run(sample.InputVector);

            double[] expected = sample.OutputVector;
            debug.addOutput(sample, output);
            int len = output.Length;
            for (int i = 0; i < len; i++)
            {
               double error = output[i] - expected[i];
               total += (error * error);
            }
         }
         total = total / trainingSet.TrainingSampleCount;
         debug.setMyError(total);
      }

      private class NNDebug
      {
         public const String Header = "output(00->0),output(01->1),output(10->1),output(11->0),mse,my_output(00->0),my_output(01->1),my_output(10->1),my_output(11->0),my_error";

         public double MyErrorAtEndOfEpoch;
         public double MeanSquaredError;
         public double[][] OutputAtEndOfEpoch;
         public double[][] SampleOutput;
         private readonly List<TrainingSample> samples;

         public NNDebug(TrainingSet trainingSet)
         {
            samples =new List<TrainingSample>(trainingSet.TrainingSamples);
            SampleOutput = new double[samples.Count][];
            OutputAtEndOfEpoch = new double[samples.Count][];
         } 

         public void addSampleOutput(TrainingSample mySample, double[] output)
         {
            int index = samples.IndexOf(mySample);
            SampleOutput[index] = output;
         }

         public void addOutput(TrainingSample mySample, double[] output)
         {
            int index = samples.IndexOf(mySample);
            OutputAtEndOfEpoch[index] = output;
         }

         public void setMyError(double error)
         {
            MyErrorAtEndOfEpoch = error;
         }

         public void setMSE(double mse)
         {
            this.MeanSquaredError = mse;
         }

         public override string ToString()
         {
            StringBuilder sb = new StringBuilder();
            foreach (double[] arr in SampleOutput)
            {
               writeOut(arr, sb);
               sb.Append(',');
            }
            sb.Append(Math.Round(MeanSquaredError,6));
            sb.Append(',');
            foreach (double[] arr in OutputAtEndOfEpoch)
            {
               writeOut(arr, sb);
               sb.Append(',');
            }
            sb.Append(Math.Round(MyErrorAtEndOfEpoch,6));
            return sb.ToString();
         }
      }

      private static void writeOut(double[] arr, StringBuilder sb)
      {
         bool first = true;
         foreach (double d in arr)
         {
            if (first)
            {
               first = false;
            }
            else
            {
               sb.Append(',');
            }
            sb.Append(Math.Round(d, 6));
         }  
      }   

        private void EnableControls(bool enabled)
        {
            btnTrain.Enabled = enabled;
            txtCycles.Enabled = enabled;
            txtNeuronCount.Enabled = enabled;
            txtLearningRate.Enabled = enabled;
            progressBar.Value = 0;
            btnTest.Enabled = enabled;
            txtTestInput.Enabled = enabled;
        }

        private void LoadForm(object sender, EventArgs e)
        {
            InitGraph();
            txtCycles.Text = cycles.ToString();
            txtLearningRate.Text = learningRate.ToString();
            txtNeuronCount.Text = neuronCount.ToString();
        }

        private void InitGraph()
        {
            GraphPane pane = errorGraph.GraphPane;
            pane.Chart.Fill = new Fill(Color.AntiqueWhite, Color.Honeydew, -45F);
            pane.Title.Text = "Back Propagation Training - Error Graph";
            pane.XAxis.Title.Text = "Training Iteration";
            pane.YAxis.Title.Text = "Sum Squared Error";
            pane.XAxis.MajorGrid.IsVisible = true;
            pane.YAxis.MajorGrid.IsVisible = true;
            pane.YAxis.MajorGrid.Color = Color.LightGray;
            pane.XAxis.MajorGrid.Color = Color.LightGray;
            pane.XAxis.Scale.Max = cycles;
            pane.XAxis.Scale.Min = 0;
            pane.YAxis.Scale.Min = 0;
            pane.CurveList.Clear();
            pane.Legend.IsVisible = false;
            pane.AxisChange();
            errorGraph.Invalidate();
        }

        private void Test(object sender, EventArgs e)
        {
            if (xorNetwork != null)
            {
                lblTestOutput.Text = xorNetwork.Run(
                new double[] {double.Parse(txtTestInput.Text.Substring(2,4)),
                    double.Parse(txtTestInput.Text.Substring(8,4))})[0].ToString("0.000000");
            }
        }
    }
}

Это не связано с нормализацией, поскольку отображение между двумя наборами выходов не является монотонным. Например, результат в {0,1} выше в EndSampleEvent, но в {1,1} он ниже. Нормализация была бы простой линейной функцией.

Это не связано с дрожанием, поскольку я попытался отключить это, и результаты все еще отличаются.

Ответ 1

Я получил ответ от своего профессора. Проблема заключается в методе LearnSample из класса BackpropagationNetwork, который вызывается для каждой тестовой выборки на каждой итерации.

Порядок соответствующих событий в этом методе - это.... 1) Добавить в значение MeanSquaredError, которое рассчитывается с использованием только выходного уровня и желаемого вывода 2) Backpropagate ошибки для всех более ранних слоев; это не влияет на сеть. 3) Наконец, пересчитайте смещения для каждого слоя; это влияет на сеть.

(3) - последнее, что происходит в методе LearnSample и происходит после вычисления ошибки вывода для каждого учебного экземпляра. Для примера XOR это означает, что сеть изменяется 4 раза с момента, когда было выполнено вычисление MSE.

В теории, если вы хотите сравнить учебные и тестовые ошибки, вам следует выполнить ручной расчет (например, функцию GetError) и запустить его дважды: один раз для каждого набора данных. Однако на самом деле, возможно, нет необходимости идти на все эти проблемы, поскольку значения не отличаются друг от друга.