Mnist数据集 将其转化为图片

1、可以前往https://blog.csdn.net/qq_36511401/article/details/102788424进行简单的了解,并将.idx3-ubyte文件下载下来。将将要生成的训练图片放在train文件夹下面,测试图片放在test文件夹下面。

2、工具类。

public class FileUtils {
    /**
     * 删除指定文件夹下所有文件
     *
     * @param path 文件夹完整绝对路径
     */
    public static void delAllFile(String path) {
        File file = new File(path);
        if (!file.exists() || !file.isDirectory()) {
            return;
        }
        String[] fileStrList = file.list();
        File fileItem;
        for (int i = 0; i < fileStrList.length; i++) {
            if (path.endsWith(File.separator)) {
                fileItem = new File(path + fileStrList[i]);
            } else {
                fileItem = new File(path + File.separator + fileStrList[i]);
            }
            if (fileItem.isFile()) {
                fileItem.delete();
            }
            if (fileItem.isDirectory()) {
                delAllFile(path + "//" + fileStrList[i]);// 先删除文件夹里面的文件
                delFolder(path + "//" + fileStrList[i]);// 再删除空文件夹
            }
        }
        return;
    }

    /**
     * 删除文件夹
     *
     * @param folderPath folderPath文件夹完整绝对路径
     */
    public static void delFolder(String folderPath) {
        try {
            delAllFile(folderPath); // 删除完里面所有内容
            File folder = new File(folderPath);
            folder.delete(); // 删除空文件夹
        } catch (Exception e) {
            e.printStackTrace();
        }
    }
}
public class ByteUtils {
    public static int getTenHex(byte[] bytes) {
        int result = 0;
        for (int i = 0; i < bytes.length; i++) {
            int move = bytes.length - i - 1;
            int value = (bytes[i] & 255) << (8 * move);
            result += value;
        }
        return result;
    }
}

3、读取图片的方法,返回一个byte[][]数组,第一个[]代表的是第几个图片,第二个[]代表的是该图片的内容。记得用BufferedInputStream来读取,因为文件比较大。如果直接使用FileInputStream来读取的话,程序会卡住的。

//获取图片
private static byte[][] getMnistImg(String filePath) {
    InputStream inputStream = null;//输入流
    byte[][] imgArray = null;
    try {
        inputStream = new BufferedInputStream(new FileInputStream(filePath));
        //读取基本信息
        byte[] readBytes = new byte[4];
        inputStream.read(readBytes);
        System.out.println("读取到的幻数:" + ByteUtils.getTenHex(readBytes));
        inputStream.read(readBytes);
        int imgCount = ByteUtils.getTenHex(readBytes);
        inputStream.read(readBytes);
        imgWidth = ByteUtils.getTenHex(readBytes);
        inputStream.read(readBytes);
        imgHeight = ByteUtils.getTenHex(readBytes);
        int imgSize = imgWidth * imgHeight;
        System.out.println(String.format("一共有%d张图片。\n" + "每张图片每行%d个像素点,每列%d个像素点。" +
                "\n每张图片一共有%d个像素点", imgCount, imgWidth, imgHeight, imgSize));
        //读取每张图片
        imgArray = new byte[imgCount][imgSize];
        for (int i = 0; i < imgCount; i++) {
            byte[] imgBytes = new byte[imgSize];
            for (int j = 0; j < imgSize; j++) {
                imgBytes[j] = (byte) inputStream.read();
            }
            imgArray[i] = imgBytes;
        }
    } catch (Exception e) {
        e.printStackTrace();
    } finally {
        try {
            if (inputStream != null) {
                inputStream.close();
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
    }
    return imgArray;
}

4、读取图片所代表数字的方法,返回int[],其顺序是和步骤3中的图片顺序是一一对应的。

//获取图片所代表的数字
private static byte[] getMnistLable(String filePath) {
    InputStream inputStream = null;//输入流
    byte[] labelArray = null;
    try {
        inputStream = new BufferedInputStream(new FileInputStream(filePath));
        //读取基本信息
        byte[] readBytes = new byte[4];
        inputStream.read(readBytes);
        System.out.println("读取到的幻数:" + ByteUtils.getTenHex(readBytes));
        inputStream.read(readBytes);
        int labelCount = ByteUtils.getTenHex(readBytes);
        System.out.println(String.format("一共有%d个标签", labelCount));
        //读取每个标签
        labelArray = new byte[labelCount];
        for (int i = 0; i < labelCount; i++) {
            labelArray[i] = (byte) inputStream.read();
        }
    } catch (Exception e) {
        e.printStackTrace();
    } finally {
        try {
            if (inputStream != null) {
                inputStream.close();
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
    }
    return labelArray;
}

5、主程序和保存图片的方法。因为我们从Mnist数据集中获取到的图片是8位灰度图像,每个像素存放在一个byte空间(8位,0-255:0表示最暗色,255表示最亮色)。8位灰度图像可以看成是一系列1位“位平面”的叠加。所以初始化BufferedImage的时候要用TYPE_INT_GRAY,只有8bit大小的存储空间,图片的存储空间也比其他如TYPE_INT_RGB类型生成的小。这样子在使用bufferedImage.setRGB的时候,按28*28大小的形状,将灰度值一个个的放进去就可以了。

private static final String Pre_Path = "G:\\xiaojie-java-test\\mnist\\";
//下载的测试集(二进制文件)。
private static final String Train_Img_Path = Pre_Path + "download\\train-images.idx3-ubyte";//训练集图像
private static final String Train_Label_Path = Pre_Path + "download\\train-labels.idx1-ubyte";//训练集标签(标签指明图像代表的意思)
//下载的测试集(二进制文件)。测试集的前5000个示例来自原始的NIST训练集。最后的5000个来自原始的NIST测试集。前5000个比后5000个更干净,更容易识别。
private static final String Test_Images_Path = Pre_Path + "download\\t10k-images.idx3-ubyte";//测试集图像
private static final String Test_Lable_Path = Pre_Path + "download\\t10k-labels.idx1-ubyte";//测试集标签(标签指明图像代表的意思)
//图片文件的保存地址
private static final String Train_Save_Path = Pre_Path + "train\\";
private static final String Test_Save_Path = Pre_Path + "test\\";

private static final int Img_Page_Count = 1000;

private static int imgWidth;
private static int imgHeight;

public static void main(String[] args) {
    byte[][] mnistImg = getMnistImg(Test_Images_Path);
    byte[] label = getMnistLable(Test_Lable_Path);
    saveImg(mnistImg, label, Test_Save_Path);
}

//保存图片
private static void saveImg(byte[][] mnistImg, byte[] label, String savePath) {
    BufferedImage bufferedImage;//输出流
    try {
        for (int i = 0; i < mnistImg.length; i++) {
            byte[] imgArray = mnistImg[i];
            //生成BufferedImage对象
            bufferedImage = new BufferedImage(imgWidth, imgHeight, BufferedImage.TYPE_BYTE_GRAY);
            for (int j = 0; j < imgHeight; j++) {
                for (int k = 0; k < imgWidth; k++) {
                    //System.out.print(String.format("%4d", imgArray[j * imgWidth + k]));//可以在控制台打出图片的样子
                    bufferedImage.setRGB(k, j, imgArray[j * imgWidth + k]);
                }
                // System.out.println();//可以在控制台打出图片的样子
            }
            //生成文件夹 由于文件太多了,让其分成60个文件夹,每个文件夹里面有1000张图片
            int pageIndex = i / Img_Page_Count + 1;
            String filePath = savePath + (pageIndex < 10 ? "0" + pageIndex : pageIndex);
            if (i % Img_Page_Count == 0) {
                FileUtils.delAllFile(filePath);
                System.out.println(String.format("准备生成第%d个有%d个图像的文件夹,", pageIndex, Img_Page_Count));
            }
            File file = new File(filePath);
            if (!file.exists()) {
                file.mkdir();
            }
            //生成图片
            String fileName = filePath + "//" + label[i] + "_" + System.currentTimeMillis() + ".png";
            ImageIO.write(bufferedImage, "PNG", new File(fileName));
            Thread.sleep(1);//要休眠一下,不然有时图像会生成失败
        }
    } catch (Exception e) {
        e.printStackTrace();
    }
}

6、结果。演示的是生成测试集的图片,一共有10000个图片,生成10个文件夹,每个文件夹中有1000个图片。如果要生成训练集的图片,将main()方法中的,源文件和目标文件的路径改成相应的就可以了。

    Connected to the target VM, address: '127.0.0.1:53725', transport: 'socket'
    读取到的幻数:2051
    一共有10000张图片。
    每张图片每行28个像素点,每列28个像素点。
    每张图片一共有784个像素点
    读取到的幻数:2049
    一共有10000个标签
    准备生成第1个有1000个图像的文件夹,
    准备生成第2个有1000个图像的文件夹,
    准备生成第3个有1000个图像的文件夹,
    准备生成第4个有1000个图像的文件夹,
    准备生成第5个有1000个图像的文件夹,
    准备生成第6个有1000个图像的文件夹,
    准备生成第7个有1000个图像的文件夹,
    准备生成第8个有1000个图像的文件夹,
    准备生成第9个有1000个图像的文件夹,
    准备生成第10个有1000个图像的文件夹,
    Disconnected from the target VM, address: '127.0.0.1:53725', transport: 'socket'
    
    Process finished with exit code 0

发布了67 篇原创文章 · 获赞 401 · 访问量 41万+

猜你喜欢

转载自blog.csdn.net/qq_36511401/article/details/102788557