博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
ImageDataGenerator读取的数据集转Numpy array
阅读量:3960 次
发布时间:2019-05-24

本文共 2140 字,大约阅读时间需要 7 分钟。

ImageDataGenerator读取的数据集转Numpy array

常用的数据集类型可以分为两种:
(1)一种是网上的经典数据集,比如mnist,一般会给写好的读取方法,比如mnist.load_data(),读取出来的返回值是Numpy array;
(2)一种是自己本地的数据集,路径下每个文件夹代表一类图像,目录结构类似于

data--type1----img1-1----img1-2--type2----img2-2----img2-2--type3----img3-2----img3-2

这时就要用keras的ImageDataGenerator生成flow来读取数据。


使用shap做可解释机器学习时发现原作者给的例子里用的数据集是mnist(链接:),计算相关的代码如下

background = x_train[np.random.choice(x_train.shape[0], 100, replace=False)]# explain predictions of the model on four imagese = shap.DeepExplainer(model, background)# ...or pass tensors directly# e = shap.DeepExplainer((model.layers[0].input, model.layers[-1].output), background)shap_values = e.shap_values(x_test[1:5])# plot the feature attributionsshap.image_plot(shap_values, -x_test[1:5])

这里DeepExplainer、shap_values和image_plot的参数都要求是Numpy array,但是用ImageDataGenerator生成的数据集不是这种格式的,就需要进行转换。作者没有给相关的例子,就自己写一个,代码效率有点低,不知道有没有更好的写法,欢迎大家补充。

test_dir = './data' #数据集路径,路径下每个文件夹代表一类图像test_pic_gen = ImageDataGenerator(rescale = 1./255) #图像预处理,和训练好的模型保持一致,通常是缩放为1/255以归一化test_flow = test_pic_gen.flow_from_directory(test_dir, (224, 224), batch_size = 1, class_mode = 'categorical') 	#注意这里batch_size最好设为1,方便读取;设成其他值(比如比较常用的8)需要额外进行遍历def imgFlow2npArray(img_flow, img_sum, img_size):    x = np.zeros(shape = (img_sum, img_size[0], img_size[1], img_size[2]))    y = np.zeros(shape = (img_sum))    for image in test_flow:        img_sum = img_sum - 1        x[img_sum] = image[0][0] #image[0]以矩阵形式保存图像数据,需要去除多余的一个维度        y[img_sum] = image[1][0].tolist().index(1.) #image[1]保存图像对应类别,同样需要去除多余的一个维度        if img_sum <= 0:            break    return x, yx_test, y_test = imgFlow2npArray(test_flow, 5, (224, 224, 3)) #此时的x_test和y_test同(x_train, y_train), (x_test, y_test) = mnist.load_data()中的x_test和y_test

这里x_test, y_test就已经是Numpy array格式了,可以顺利跑通shap代码。


但是这里也出现了一个奇怪的问题:图表中最左边一列应该对应原始图像的位置显示为纯黑图像,暂时不知道是为啥导致的,以后找到解决方法了再补充。

跑原始shap代码的显示结果:

跑原始shap代码的显示结果
跑自己的代码的显示结果:

跑自己的代码的显示结果


2021/6/21更新:

问题解决了,自己用的数据集图像是24位RGB图,改成8位灰度图就行了,代码如下:

# shap.image_plot(shap_values, x_test[0:]) 改成下面两行代码x_test_gray = x_test[:, :, :, :1]shap.image_plot(shap_values, x_test_gray[0:])

结果图:

在这里插入图片描述
但是作者的教程里也有用到显示RGB图的例子,代码里没看到有什么特殊处理,不知道他是怎么做到的,以后找到办法再更新。

转载地址:http://kbqzi.baihongyu.com/

你可能感兴趣的文章
3-python之PyCharm如何新建项目
查看>>
15-python之while循环嵌套应用场景
查看>>
17-python之for循环
查看>>
18-python之while循环,for循环与else的配合
查看>>
19-python之字符串简单介绍
查看>>
20-python之切片详细介绍
查看>>
P24-c++类继承-01详细的例子演示继承的好处
查看>>
P8-c++对象和类-01默认构造函数详解
查看>>
P1-c++函数详解-01函数的默认参数
查看>>
P3-c++函数详解-03函数模板详细介绍
查看>>
P4-c++函数详解-04函数重载,函数模板和函数模板重载,编译器选择使用哪个函数版本?
查看>>
P5-c++内存模型和名称空间-01头文件相关
查看>>
P6-c++内存模型和名称空间-02存储连续性、作用域和链接性
查看>>
P9-c++对象和类-02构造函数和析构函数总结
查看>>
P10-c++对象和类-03this指针详细介绍,详细的例子演示
查看>>
bat备份数据库
查看>>
linux数据库导出结果集且比对 && grep -v ---无法过滤的问题
查看>>
shell函数与自带变量
查看>>
linux下shell获取不到PID
查看>>
sort详解
查看>>