En este ejemplo vamos a clasificar imágenes de 5 x 5 pixeles mediante una red neuronal que diferencie entre las que tienen una cruz:
00100
00100
11111
00100
00100
Y las que no tienen una cruz:
10010
01000
10100
00010
01001
Las imágenes se aplanarán en un CSV con 25 parámetros de datos y una columna etiqueta con valores {0,1} para indicar que la imagen no tiene o sí tiene una cruz:
0,0,1,0,0,0,0,1,0,0,1,1,1,1,1,0,0,1,0,0,0,0,1,0,0,1
1,0,0,1,0,0,1,0,0,0,1,0,1,0,0,0,0,0,1,0,0,1,0,0,1,0
Para el ejemplo tenemos 500 imágenes, 450 imágenes aleatorias sin cruces y 50 imágenes con cruces. El código es el siguiente:
package org.dune;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.util.Random;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.SplitTestAndTrain;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.lossfunctions.LossFunctions;
/**
*
* @author egdepedro
*/
public class EntrenaRedNeuronal {
// Ruta del fichero de datos
public static final String DATA_SET_PATH = "./src/main/resources/datos.csv";
// Número de columnas de artibutos (son 25 datos con ceros y unos para el valor de los pixeles de la imagen)
public static final int FEATURES_COUNT = 25;
// Columna en donde está la etiqueta (la etiqueta está en la columna 25 porque la matriz java comienza en la posición cero)
public static final int LABEL_INDEX = 25;
// Número de grupos en los que clasificar los datos, dos grupos, con cruz o sin cruz
public static final int NUM_POSSIBLE_LABLES = 2;
// Fijamos el % de los datos usados para entrenar
public static final int TRAIN_TO_TEST_RATIO = 70;
// Fijamos el número de líneas del fichero a cargar (
public static final int BATCH_SIZE = 500;
public static void main(String[] args) throws IOException, InterruptedException {
//System.out.println("Generamos el fichero CSV con las imágenes etiquetadas");
//GeneradorCSV(DATA_SET_PATH);
System.out.println("Cargamos los datos etiquetados y los barajamos para mejorar el rendimiento del modelo");
final DataSet allData = loadData(DATA_SET_PATH);
allData.shuffle();
System.out.println("Preparamos los datos etiquetados de entrenamiento y de test");
SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(TRAIN_TO_TEST_RATIO);
DataSet trainingData = testAndTrain.getTrain();
DataSet testData = testAndTrain.getTest();
System.out.println("Configuramos la red neuronal");
MultiLayerConfiguration configuration = new NeuralNetConfiguration.Builder()
// Ajustamos el número de iteraciones con los datos de entrenamiento
.iterations(2000)
// Ajustamos la función de activación de cada nodo
.activation(Activation.RELU)
// Ajustamos los pesos iniciales
.weightInit(WeightInit.RELU_UNIFORM).learningRate(0.05).regularization(true).l2(0.0001).list()
.layer(0, new DenseLayer.Builder().nIn(FEATURES_COUNT).nOut(25).build())
.layer(1, new DenseLayer.Builder().nIn(25).nOut(25).build())
.layer(2, new DenseLayer.Builder().nIn(25).nOut(25).build())
.layer(3, new DenseLayer.Builder().nIn(25).nOut(25).build())
.layer(4,
new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.activation(Activation.SOFTMAX).nIn(25).nOut(NUM_POSSIBLE_LABLES).build())
.backprop(true).pretrain(false).build();
System.out.println("Creamos la red neuronal y lanzamos el entrenamiento");
MultiLayerNetwork model = new MultiLayerNetwork(configuration);
model.init();
model.fit(trainingData);
System.out.println("Probamos la red entrenada");
INDArray output = model.output(testData.getFeatures());
Evaluation eval = new Evaluation(NUM_POSSIBLE_LABLES);
eval.eval(testData.getLabels(), output);
System.out.println(eval.stats());
}
private static DataSet loadData(String path) throws IOException, InterruptedException {
DataSet allData;
try (RecordReader recordReader = new CSVRecordReader(0, ',')) {
recordReader.initialize(new FileSplit(new File(path)));
DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader, BATCH_SIZE, LABEL_INDEX, NUM_POSSIBLE_LABLES);
allData = iterator.next();
}
return allData;
}
private static void GeneradorCSV(String path) throws IOException {
//Tenemos la matriz de las 50 imágenes etiquetadas como con cruces
final String[] matrizCruces =
{"1,1,1,1,1,1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1",
"0,1,0,0,0,1,1,1,1,1,0,1,0,0,0,0,1,0,0,0,0,1,0,0,0,1",
"0,0,1,0,0,0,0,1,0,0,1,1,1,1,1,0,0,1,0,0,0,0,1,0,0,1",
"0,0,1,0,0,0,0,1,0,0,1,1,1,1,1,0,0,1,0,0,0,0,1,0,0,1",
"1,1,1,1,1,1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1",
"0,1,0,0,0,1,1,1,1,1,0,1,0,0,0,0,1,0,0,0,0,1,0,0,0,1",
"0,0,1,0,0,0,0,1,0,0,1,1,1,1,1,0,0,1,0,0,0,0,1,0,0,1",
"0,0,1,0,0,0,0,1,0,0,1,1,1,1,1,0,0,1,0,0,0,0,1,0,0,1",
"0,0,0,1,0,0,0,0,1,0,0,0,0,1,0,1,1,1,1,1,0,0,0,1,0,1",
"0,0,0,0,1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1,1,1,1,1,1,1",
"1,1,1,1,1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1,1",
"0,0,0,1,0,1,1,1,1,1,0,0,0,1,0,0,0,0,1,0,0,0,0,1,0,1",
"0,0,1,0,0,0,0,1,0,0,1,1,1,1,1,0,0,1,0,0,0,0,1,0,0,1",
"0,0,1,0,0,0,0,1,0,0,1,1,1,1,1,0,0,1,0,0,0,0,1,0,0,1",
"0,1,0,0,0,0,1,0,0,0,0,1,0,0,0,1,1,1,1,1,0,1,0,0,0,1",
"1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1,1,1,1,1,1",
"0,0,0,1,0,0,0,0,1,0,0,0,0,1,0,1,1,1,1,1,0,0,0,1,0,1",
"1,1,1,1,1,1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1",
"0,1,0,0,0,1,1,1,1,1,0,1,0,0,0,0,1,0,0,0,0,1,0,0,0,1",
"0,0,1,0,0,0,0,1,0,0,1,1,1,1,1,0,0,1,0,0,0,0,1,0,0,1",
"0,0,0,1,0,0,0,0,1,0,0,0,0,1,0,1,1,1,1,1,0,0,0,1,0,1",
"0,0,0,0,1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1,1,1,1,1,1,1",
"1,1,1,1,1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1,1",
"0,0,0,1,0,1,1,1,1,1,0,0,0,1,0,0,0,0,1,0,0,0,0,1,0,1",
"0,0,0,1,0,0,0,0,1,0,0,0,0,1,0,1,1,1,1,1,0,0,0,1,0,1",
"0,0,0,0,1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1,1,1,1,1,1,1",
"1,1,1,1,1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1,1",
"0,0,0,1,0,1,1,1,1,1,0,0,0,1,0,0,0,0,1,0,0,0,0,1,0,1",
"0,0,0,1,0,1,1,1,1,1,0,0,0,1,0,0,0,0,1,0,0,0,0,1,0,1",
"0,0,1,0,0,0,0,1,0,0,1,1,1,1,1,0,0,1,0,0,0,0,1,0,0,1",
"0,1,0,0,0,0,1,0,0,0,0,1,0,0,0,1,1,1,1,1,0,1,0,0,0,1",
"1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1,1,1,1,1,1",
"0,0,0,1,0,1,1,1,1,1,0,0,0,1,0,0,0,0,1,0,0,0,0,1,0,1",
"0,0,0,0,1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1,1,1,1,1,1,1",
"1,1,1,1,1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1,1",
"1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1,1,1,1,1,1",
"0,0,0,1,0,1,1,1,1,1,0,0,0,1,0,0,0,0,1,0,0,0,0,1,0,1",
"0,0,1,0,0,0,0,1,0,0,1,1,1,1,1,0,0,1,0,0,0,0,1,0,0,1",
"0,0,1,0,0,0,0,1,0,0,1,1,1,1,1,0,0,1,0,0,0,0,1,0,0,1",
"1,1,1,1,1,1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1",
"0,0,0,1,0,1,1,1,1,1,0,0,0,1,0,0,0,0,1,0,0,0,0,1,0,1",
"0,1,0,0,0,1,1,1,1,1,0,1,0,0,0,0,1,0,0,0,0,1,0,0,0,1",
"0,0,1,0,0,0,0,1,0,0,1,1,1,1,1,0,0,1,0,0,0,0,1,0,0,1",
"0,0,1,0,0,0,0,1,0,0,1,1,1,1,1,0,0,1,0,0,0,0,1,0,0,1",
"0,1,0,0,0,0,1,0,0,0,0,1,0,0,0,1,1,1,1,1,0,1,0,0,0,1",
"0,0,0,1,0,0,0,0,1,0,0,0,0,1,0,1,1,1,1,1,0,0,0,1,0,1",
"0,0,0,0,1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1,1,1,1,1,1,1",
"0,0,1,0,0,0,0,1,0,0,1,1,1,1,1,0,0,1,0,0,0,0,1,0,0,1",
"1,1,1,1,1,1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1",
"1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1,1,1,1,1,1"};
//Preparamos el generador aleatorio
final Random r = new Random(System.currentTimeMillis());
//Preparamos la matriz de las 450 imágenes con ruido (en las que nos la estamos jugando porque entendemos que no deberían aparecer cruces pero no lo controlamos)
final String [] matrizRuido = new String [450];
String imagen_ruido;
for (int l=0;l<450;l++) {
imagen_ruido = "";
for (int c=0;c<25;c++) {
imagen_ruido +=+r.nextInt(0,2)+",";
}
//Metemos la etiqueta en la última columna
imagen_ruido += "0";
matrizRuido [l] = imagen_ruido;
}
//Generamos una matriz con los datos mezclados
final String [] matrizDatos = new String [BATCH_SIZE];
int contadorCruces=0;
int contadorRuido=0;
for (int l=0;l<BATCH_SIZE;l++) {
if (l%10==0) {
matrizDatos[l]=matrizCruces[contadorCruces++];
} else {
matrizDatos[l]=matrizRuido[contadorRuido++];
}
}
System.out.println("contadorCruces["+contadorCruces+"] contadorRuido ["+contadorRuido+"]" );
//Volcamos los datos al fichero CSV recibido como parámetro
File f = new File(path);
FileWriter fw = new FileWriter(f);
BufferedWriter bfw = new BufferedWriter(fw);
for (int l=0;l<BATCH_SIZE;l++) {
bfw.write(matrizDatos[l]+"\n");
}
bfw.flush();
bfw.close();
}
}
Probamos la red entrenada
Examples labeled as 0 classified by model as 0: 384 times
Examples labeled as 0 classified by model as 1: 5 times
Examples labeled as 1 classified by model as 0: 11 times
Examples labeled as 1 classified by model as 1: 30 times
==========================Scores========================================
# of classes: 2
Accuracy: 0,9628
Precision: 0,9146
Recall: 0,8594
F1 Score: 0,7895
========================================================================