This article describes a simple artificial neural network algorithm based on Java implementation. Share it for your reference, as follows:
Let’s take a look at the algorithm diagram I have drawn:
2. Data category
import java.util.Arrays;public class Data { double[] vector; int division; int type; public double[] getVector() { return vector; } public void setVector(double[] vector) { this.vector = vector; } public int getDimention() { return division; } public void setDimention(int division) { this.dimention = division; } public int getType() { return type; } public void setType(int type) { this.type = type; } public Data(double[] vector, int dimentation, int type) { super(); this.vector = vector; this.dimentation = dimentation; this.type = type; } public Data() { } @Override public String toString() { return "Data [vector=" + Arrays.toString(vector) + ", dimension=" + dimentation + ", type=" + type + "]"; }}3. Simple artificial neural network
package cn.edu.hbut.chenjie;import java.util.ArrayList;import java.util.List;import java.util.Random;import org.jfree.chart.ChartFactory;import org.jfree.chart.ChartFrame;import org.jfree.chart.JFreeChart;import org.jfree.data.xy.DefaultXYDataset;import org.jfree.ui.RefineryUtilities;public class ANN2 { private double eta;//learning rate private int n_iter;//Weight vector w[]Number of training times private List<Data> exercise;//Training dataset private double w0 = 0;//Threshold private double x0 = 1;//Fixed value private double[] weights;//Weight vector, whose length is training data dimension +1, in this case the data is 2 dimensions, so the length is 3 private int testSum = 0;//Total number of test data private int error = 0;//Number of error DefaultXYDataset xydataset = new DefaultXYDataset(); /** * Add data of the same type to the chart* @param type type* @param a The first component of all data* @param b The second component of all data*/ public void add(String type,double[] a,double[] b) { double[][] data = new double[2][a.length]; for(int i=0;i<a.length;i++) { data[0][i] = a[i]; data[1][i] = b[i]; } xydataset.addSeries(type, data); } /** * Draw*/ public void draw() { JFreeChart jfreechart = ChartFactory.createScatterPlot("exercise", "x1", "x2", xydataset); ChartFrame frame = new ChartFrame("training data", jfreechart); frame.pack(); RefineryUtilities.centerFrameOnScreen(frame); frame.setVisible(true); } public static void main(String[] args) { ANN2 ann2 = new ANN2(0.001,100);//Construct an artificial neural network List<Data> exercise = new ArrayList<Data>();//Construct a training set//Manually simulate 1,000 training data, the dividing line is x2=x1+0.5 for(int i=0;i<1000000;i++) { Random rd = new Random(); double x1 = rd.nextDouble();//Random generate one component double x2 = rd.nextDouble();//Random generate another component double[] da = {x1,x2};//Generate data vector Data d = new Data(da, 2, x2 > x1+0.5 ? 1 : -1);//Construct data exercise.add(d);//Add training data into the training set} int sum1 = 0;//Record type 1 int sum2 = 0;//Record type-1 training record number for(int i = 0; i < exercise.size(); i++) { if(exercise.get(i).getType()==1) sum1++; else if(exercise.get(i).getType()==-1) sum2++; } double[] x1 = new double[sum1]; double[] y1 = new double[sum1]; double[] x2 = new double[sum2]; double[] y2 = new double[sum2]; int index1 = 0; int index2 = 0; for(int i = 0; i < exercise.size(); i++) { if(exercise.get(i).getType()==1) { x1[index1] = exercise.get(i).vector[0]; y1[index1++] = exercise.get(i).vector[1]; } else if(exercise.get(i).getType()==-1) { x2[index2] = exercise.get(i).vector[0]; y2[index2++] = exercise.get(i).vector[1]; } } ann2.add("1", x1, y1); ann2.add("-1", x2, y2); ann2.draw(); ann2.input(exercise);//Enter the training set into the artificial neural network ann2.fit();//Train ann2.showWeigths();//Show the weight vector//Manually generate one thousand test data for(int i=0;i<10000;i++) { Random rd = new Random(); double x1_ = rd.nextDouble(); double x2_ = rd.nextDouble(); double[] da = {x1_,x2_}; Data test = new Data(da, 2, x2_ > x1_+0.5 ? 1 : -1); ann2.predict(test);//Test} System.out.println("Total test" + ann2.testSum + "bit data, with" + ann2.error + "bit error, error rate:" + ann2.error * 1.0 /ann2.testSum * 100 + "%"); } /** * * @param eta Learning rate * @param n_iter Weight component learning times*/ public ANN2(double eta, int n_iter) { this.eta = eta; this.n_iter = n_iter; } /** * Enter the training set to the artificial neural network* @param exercise */ private void input(List<Data> exercise) { this.exercise = exercise;//Save the training set weights = new double[exercise.get(0).dimention + 1];//Initialize the weight vector, its length is the training data dimension +1 weights[0] = w0;//The first component of the weight vector is w0 for(int i = 1; i < weights.length; i++) weights[i] = 0;//The remaining components are initialized to 0 } private void fit() { for(int i = 0; i < n_iter; i++)//The weight component is adjusted n_iter times{ for(int j = 0; j < exercise.size(); j++)//Train for each piece of data in the training set { int real_result = exercise.get(j).type;//y int calculate_result = CalculateResult(exercise.get(j));//y' double delta0 = eta * (real_result - calculate_result);//Calculate threshold update w0 += delta0;//Threshold update weights[0] = w0;//Update w[0] for(int k = 0; k < exercise.get(j).getDivision(); k++)//Update other components of the weight vector { double delta = eta * (real_result - calculate_result) * exercise.get(j).vector[k]; //Δw=η*(y-y')*X weights[k+1] += delta; //w=w+Δw } } } } private int CalculateResult(Data data) { double z = w0 * x0; for(int i = 0; i < data.dimention; i++) z += data.vector[i] * weights[i+1]; //z=w0x0+w1x1+...+WmXm //Activation function if(z>=0) return 1; else return -1; } private void showWeigths() { for(double w : weights) System.out.println(w); } private void predict(Data data) { int type = CalculateResult(data); if(type == data.getType()) { //System.out.println("Precaution correct"); } else { //System.out.println("Precaution error"); error ++; } testSum ++; }}Running results:
-0.2200000000000000017-0.44168439828154530.442444202054685 A total of 10,000 data were tested, with 17 errors, error rate: 0.16999999999999999999999999999999999998%
For more information about Java algorithms, readers who are interested in this site can view the topics: "Java Data Structure and Algorithm Tutorial", "Summary of Java Operation DOM Node Tips", "Summary of Java File and Directory Operation Tips" and "Summary of Java Cache Operation Tips"
I hope this article will be helpful to everyone's Java programming.