Páginas

martes, 28 de febrero de 2023

Entrenando una red neuronal para reconocimiento de imágenes con DeepLearning4J

En este ejemplo vamos a clasificar imágenes de 5 x 5 pixeles mediante una red neuronal que diferencie entre las que tienen una cruz:

 00100
 00100
 11111
 00100
 00100

Y las que no tienen una cruz:

10010
01000
10100
00010
01001

Las imágenes se aplanarán en un CSV con 25 parámetros de datos y una columna etiqueta con valores {0,1} para indicar que la imagen no tiene o sí tiene una cruz:

0,0,1,0,0,0,0,1,0,0,1,1,1,1,1,0,0,1,0,0,0,0,1,0,0,1

1,0,0,1,0,0,1,0,0,0,1,0,1,0,0,0,0,0,1,0,0,1,0,0,1,0 

Para el ejemplo tenemos 500 imágenes, 450 imágenes aleatorias sin cruces y 50 imágenes con cruces. El código es el siguiente:

package org.dune;

import java.io.BufferedWriter;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.util.Random;

import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.SplitTestAndTrain;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.lossfunctions.LossFunctions;

/**
 *
 * @author egdepedro
 */
public class EntrenaRedNeuronal {
    
    // Ruta del fichero de datos
    public static final String DATA_SET_PATH = "./src/main/resources/datos.csv";
    // Número de columnas de artibutos (son 25 datos con ceros y unos para el valor de los pixeles de la imagen)
    public static final int FEATURES_COUNT = 25;
    // Columna en donde está la etiqueta (la etiqueta está en la columna 25 porque la matriz java comienza en la posición cero)
    public static final int LABEL_INDEX = 25;
    // Número de grupos en los que clasificar los datos, dos grupos, con cruz o sin cruz
    public static final int NUM_POSSIBLE_LABLES = 2;
    // Fijamos el % de los datos usados para entrenar
    public static final int TRAIN_TO_TEST_RATIO = 70;
    // Fijamos el número de líneas del fichero a cargar (
    public static final int BATCH_SIZE = 500;

    
    public static void main(String[] args) throws IOException, InterruptedException {
        
        //System.out.println("Generamos el fichero CSV con las imágenes etiquetadas");
        //GeneradorCSV(DATA_SET_PATH);

        System.out.println("Cargamos los datos etiquetados y los barajamos para mejorar el rendimiento del modelo");
        final DataSet allData = loadData(DATA_SET_PATH);
        allData.shuffle();

        System.out.println("Preparamos los datos etiquetados de entrenamiento y de test");
        SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(TRAIN_TO_TEST_RATIO);
        DataSet trainingData = testAndTrain.getTrain();
        DataSet testData = testAndTrain.getTest();

        System.out.println("Configuramos la red neuronal");
        MultiLayerConfiguration configuration = new NeuralNetConfiguration.Builder()
                // Ajustamos el número de iteraciones con los datos de entrenamiento
                .iterations(2000)
                // Ajustamos la función de activación de cada nodo
                .activation(Activation.RELU)
                // Ajustamos los pesos iniciales
                .weightInit(WeightInit.RELU_UNIFORM).learningRate(0.05).regularization(true).l2(0.0001).list()
                .layer(0, new DenseLayer.Builder().nIn(FEATURES_COUNT).nOut(25).build())
                .layer(1, new DenseLayer.Builder().nIn(25).nOut(25).build())
                .layer(2, new DenseLayer.Builder().nIn(25).nOut(25).build())
                .layer(3, new DenseLayer.Builder().nIn(25).nOut(25).build())
                .layer(4,
                        new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
                                .activation(Activation.SOFTMAX).nIn(25).nOut(NUM_POSSIBLE_LABLES).build())
                .backprop(true).pretrain(false).build();

        System.out.println("Creamos la red neuronal y lanzamos el entrenamiento");
        MultiLayerNetwork model = new MultiLayerNetwork(configuration);
        model.init();
        model.fit(trainingData);

        System.out.println("Probamos la red entrenada");
        INDArray output = model.output(testData.getFeatures());
        Evaluation eval = new Evaluation(NUM_POSSIBLE_LABLES);
        eval.eval(testData.getLabels(), output);
        System.out.println(eval.stats());
    }

    private static DataSet loadData(String path) throws IOException, InterruptedException {
        DataSet allData;
        try (RecordReader recordReader = new CSVRecordReader(0, ',')) {
            recordReader.initialize(new FileSplit(new File(path)));
            DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader, BATCH_SIZE, LABEL_INDEX, NUM_POSSIBLE_LABLES);
            allData = iterator.next();
        }
        return allData;
    }
    
    private static void GeneradorCSV(String path) throws IOException {
        
        //Tenemos la matriz de las 50 imágenes etiquetadas como con cruces
        final String[] matrizCruces =
        {"1,1,1,1,1,1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1",
        "0,1,0,0,0,1,1,1,1,1,0,1,0,0,0,0,1,0,0,0,0,1,0,0,0,1",
        "0,0,1,0,0,0,0,1,0,0,1,1,1,1,1,0,0,1,0,0,0,0,1,0,0,1",
        "0,0,1,0,0,0,0,1,0,0,1,1,1,1,1,0,0,1,0,0,0,0,1,0,0,1",
        "1,1,1,1,1,1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1",
        "0,1,0,0,0,1,1,1,1,1,0,1,0,0,0,0,1,0,0,0,0,1,0,0,0,1",
        "0,0,1,0,0,0,0,1,0,0,1,1,1,1,1,0,0,1,0,0,0,0,1,0,0,1",
        "0,0,1,0,0,0,0,1,0,0,1,1,1,1,1,0,0,1,0,0,0,0,1,0,0,1",
        "0,0,0,1,0,0,0,0,1,0,0,0,0,1,0,1,1,1,1,1,0,0,0,1,0,1",
        "0,0,0,0,1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1,1,1,1,1,1,1",
        "1,1,1,1,1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1,1",
        "0,0,0,1,0,1,1,1,1,1,0,0,0,1,0,0,0,0,1,0,0,0,0,1,0,1",
        "0,0,1,0,0,0,0,1,0,0,1,1,1,1,1,0,0,1,0,0,0,0,1,0,0,1",
        "0,0,1,0,0,0,0,1,0,0,1,1,1,1,1,0,0,1,0,0,0,0,1,0,0,1",
        "0,1,0,0,0,0,1,0,0,0,0,1,0,0,0,1,1,1,1,1,0,1,0,0,0,1",
        "1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1,1,1,1,1,1",
        "0,0,0,1,0,0,0,0,1,0,0,0,0,1,0,1,1,1,1,1,0,0,0,1,0,1",
        "1,1,1,1,1,1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1",
        "0,1,0,0,0,1,1,1,1,1,0,1,0,0,0,0,1,0,0,0,0,1,0,0,0,1",
        "0,0,1,0,0,0,0,1,0,0,1,1,1,1,1,0,0,1,0,0,0,0,1,0,0,1",
        "0,0,0,1,0,0,0,0,1,0,0,0,0,1,0,1,1,1,1,1,0,0,0,1,0,1",
        "0,0,0,0,1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1,1,1,1,1,1,1",
        "1,1,1,1,1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1,1",
        "0,0,0,1,0,1,1,1,1,1,0,0,0,1,0,0,0,0,1,0,0,0,0,1,0,1",
        "0,0,0,1,0,0,0,0,1,0,0,0,0,1,0,1,1,1,1,1,0,0,0,1,0,1",
        "0,0,0,0,1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1,1,1,1,1,1,1",
        "1,1,1,1,1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1,1",
        "0,0,0,1,0,1,1,1,1,1,0,0,0,1,0,0,0,0,1,0,0,0,0,1,0,1",
        "0,0,0,1,0,1,1,1,1,1,0,0,0,1,0,0,0,0,1,0,0,0,0,1,0,1",
        "0,0,1,0,0,0,0,1,0,0,1,1,1,1,1,0,0,1,0,0,0,0,1,0,0,1",
        "0,1,0,0,0,0,1,0,0,0,0,1,0,0,0,1,1,1,1,1,0,1,0,0,0,1",
        "1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1,1,1,1,1,1",
        "0,0,0,1,0,1,1,1,1,1,0,0,0,1,0,0,0,0,1,0,0,0,0,1,0,1",
        "0,0,0,0,1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1,1,1,1,1,1,1",
        "1,1,1,1,1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1,1",
        "1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1,1,1,1,1,1",
        "0,0,0,1,0,1,1,1,1,1,0,0,0,1,0,0,0,0,1,0,0,0,0,1,0,1",
        "0,0,1,0,0,0,0,1,0,0,1,1,1,1,1,0,0,1,0,0,0,0,1,0,0,1",
        "0,0,1,0,0,0,0,1,0,0,1,1,1,1,1,0,0,1,0,0,0,0,1,0,0,1",
        "1,1,1,1,1,1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1",
        "0,0,0,1,0,1,1,1,1,1,0,0,0,1,0,0,0,0,1,0,0,0,0,1,0,1",
        "0,1,0,0,0,1,1,1,1,1,0,1,0,0,0,0,1,0,0,0,0,1,0,0,0,1",
        "0,0,1,0,0,0,0,1,0,0,1,1,1,1,1,0,0,1,0,0,0,0,1,0,0,1",
        "0,0,1,0,0,0,0,1,0,0,1,1,1,1,1,0,0,1,0,0,0,0,1,0,0,1",
        "0,1,0,0,0,0,1,0,0,0,0,1,0,0,0,1,1,1,1,1,0,1,0,0,0,1",
        "0,0,0,1,0,0,0,0,1,0,0,0,0,1,0,1,1,1,1,1,0,0,0,1,0,1",
        "0,0,0,0,1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1,1,1,1,1,1,1",
        "0,0,1,0,0,0,0,1,0,0,1,1,1,1,1,0,0,1,0,0,0,0,1,0,0,1",
        "1,1,1,1,1,1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1",
        "1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1,1,1,1,1,1"};
        
        //Preparamos el generador aleatorio
        final Random r = new Random(System.currentTimeMillis());
                
        //Preparamos la matriz de las 450 imágenes con ruido (en las que nos la estamos jugando porque entendemos que no deberían aparecer cruces pero no lo controlamos)
        final String [] matrizRuido = new String [450];
        String imagen_ruido;
        for (int l=0;l<450;l++) {
            imagen_ruido = "";
            for (int c=0;c<25;c++) {
                imagen_ruido +=+r.nextInt(0,2)+",";
            }
            //Metemos la etiqueta en la última columna
            imagen_ruido += "0";
            matrizRuido [l] = imagen_ruido;
        }
        
        //Generamos una matriz con los datos mezclados
        final String [] matrizDatos = new String [BATCH_SIZE];
        int contadorCruces=0;
        int contadorRuido=0;
        for (int l=0;l<BATCH_SIZE;l++) {
            if (l%10==0) {
                matrizDatos[l]=matrizCruces[contadorCruces++];
            } else {
                matrizDatos[l]=matrizRuido[contadorRuido++];
            }
        }
        System.out.println("contadorCruces["+contadorCruces+"] contadorRuido ["+contadorRuido+"]" );
        
        //Volcamos los datos al fichero CSV recibido como parámetro
        File f = new File(path);
        FileWriter fw = new FileWriter(f);
        BufferedWriter bfw = new BufferedWriter(fw);
        for (int l=0;l<BATCH_SIZE;l++) {
            bfw.write(matrizDatos[l]+"\n");
        }
        bfw.flush();
        bfw.close();
    }
}


Probamos la red entrenada

Examples labeled as 0 classified by model as 0: 384 times

Examples labeled as 0 classified by model as 1: 5 times

Examples labeled as 1 classified by model as 0: 11 times

Examples labeled as 1 classified by model as 1: 30 times

==========================Scores========================================

# of classes: 2

Accuracy: 0,9628

Precision: 0,9146

Recall: 0,8594

F1 Score: 0,7895

========================================================================

No hay comentarios:

Publicar un comentario