当前位置:编程学习 > JAVA >>

JAVA反向传播神经网络

From Quantums blog

好久没写东西了,随便记一下~~如果不记录忘记了真是很浪费啊~

下面贴的这个是JAVA写的反向传播神经网络,面向对象,把神经元和连结都抽象成对象。

其实这个已经不是什么新奇的玩意了,对于分类 回归而言 神经网络是一个选择但并不一定是最好的选择。

类似很多进化计算都存在同样的问题,如GA遗传算法。

NN迭代学习过后,每个神经元权重的含义很难被理解,所以才有人提出神经网络的规则抽取,如何对网络的神经元进行裁剪

接下来就是NN 的学习过程中,最具有价值的

1. 计算 delta规则 ,一般采用梯度下降法等等(有本书说得不错的《最优化理论与方法》讲了很多牛顿法 最速下降法等等)

2. 激活函数,其实这个最终目的就是让数据2值化,神经网络设计一书上标准的方式,其实你也可以通过这个思想自己搞一个

3. 当然是NN的整个数据结构模型,模拟了大脑神经的信息传递过程。

以下代码 main函数是使用DEMO

package cn.isto.ai.algorithm.nn;
import java.util.HashMap;
import java.util.Random;
import java.util.Set;
import java.util.Map.Entry;

/**
 * JAVA 反向传输神经网络
 * @author kj021320  , codeby 2008.12.10
 *
 */
public class JavaBackPropagationNeuralNetwork {
    /**
     * 神经元
     */
    public class Neuron {
        HashMap<Integer, Link> target = new HashMap<Integer, Link>();// 连接其他神经元的
        HashMap<Integer, Link> source = new HashMap<Integer, Link>();// 被其他神经元连接的
        double data = 0.0;
        public Link sourceGet(int index) {
            return source.get(index);
        }
        public Link targetGet(int index) {
            return target.get(index);
        }
        public boolean targetContains(Link l) {
            return target.containsValue(l);
        }
        public boolean sourceContains(Link l) {
            return source.containsValue(l);
        }
        public Link sourceLink(int index, Link l) {
            if (l.linker != this) {
                l.setLinker(this);
            }
            return source.put(index, l);
        }
        public Link targetLink(int index, Link l) {
            if (l.owner != this) {
                l.setOwner(this);
            }
            return target.put(index, l);
        }
    }

    /**
     * 神经链
     */
    public class Link {
        Neuron owner;
        public void setOwner(Neuron o) {
            owner = o;
            if (!o.targetContains(this)) {
                o.targetLink(o.target.size(), this);
            }
        }
        public Link() {
            weight = rand(-1, 1);
        }
        public void setLinker(Neuron o) {
            linker = o;
            if (!o.sourceContains(this)) {
                o.sourceLink(o.source.size(), this);
            }
        }
        @Override
        public String toString(){
            return super.toString()+" weight:"+weight;
        }
        Neuron linker;
        double weight;
    }

    Random random = new Random();
    {
        random.setSeed(System.nanoTime());
    }
    Neuron[] inputnode;    //输入层神经元
    Neuron[] hiddennode;    //隐含层神经元
    Neuron[] outputnode;    //输出层神经元
    double learnrate;// 学习速度
    double threshold;// 阀值,误差允许度

    private final int inputCount;
    private final int hiddenCount;
    private final int outputCount;
    /**
     *
     * @param input        输入层的神经元个数
     * @param hidden    隐含层的神经元个数
     * @param output    输出层的神经元的个数
     */
    public JavaBackPropagationNeuralNetwork(int input, int hidden, int output) {
        inputCount = input;
        hiddenCount = hidden;
        outputCount = output;
        build();
    }
    public void reBuildNeuralNetwork(){
        build();
    }
    private void build(){
        inputnode = new Neuron[inputCount+1];
        hiddennode = new Neuron[hiddenCount];
        outputnode = new Neuron[outputCount];
        initNeurons(inputnode);
        initNeurons(hiddennode);
        initNeurons(outputnode);
        makeLink(inputnode, hiddennode);
        makeLink(hiddennode, outputnode);
    }
    /**
     * 思考方法
     * @param inputs    前馈层神经个数相符的浮点数    -1~1之间
     * @return      

补充:软件开发 , Java ,
CopyRight © 2022 站长资源库 编程知识问答 zzzyk.com All Rights Reserved
部分文章来自网络,