sklearn-LabelEncoder 遇到没在编码规则里的新值的解决办法

sklearn的LabelEncoder,当transform的时候,遇到没在fit后的编码规则里的新值,会报错,它不像spark的LabelEncoder碰到新值会给你编成len+1。提供两种解决办法:

重写LabelEncoder(不推荐)

相信你肯定看过这个文章:

https://blog.csdn.net/qq_19446965/article/details/120110169

from sklearn.preprocessing import LabelEncoder as LEncoder

'''   重写LabelEncoder   '''
# 重写LabelEncoder,将没有在编码规则里的填充Unknown
class LabelEncoder(LEncoder):

    def fit(self, y):
        return super(LabelEncoder, self).fit(list(y) + ['Unknown'])

    def fit_transform(self, y):
        return super(LabelEncoder, self).fit_transform(list(y) + ['Unknown'])

    def transform(self, y):
        new_y = ['Unknown' if x not in set(self.classes_) else x for x in y]
        return super(LabelEncoder, self).transform(new_y)

继承并重写这个类虽然方便,但是存在问题,通过例子说明:

country_list = ['A', 'a', 'b', 'c', 'd']

label_encoder = LabelEncoder()
label_encoder.fit(country_list)
print('country_list: ', label_encoder.classes_)
print('encode_country_list: ', label_encoder.transform(country_list))

new_country_list = ['a', 'b', 'c', 'g', 'h', 'i']
print('new_encode_country_list: ', label_encoder.transform(new_country_list))
country_list:  ['A' 'Unknown' 'a' 'b' 'c' 'd']
encode_country_list:  [0 2 3 4 5]
new_encode_country_list:  [2 3 4 1 1 1]
country_list = ['889', '778', '567', '1920', '999']

label_encoder = LabelEncoder()
label_encoder.fit(country_list)
print('country_list: ', label_encoder.classes_)
print('encode_country_list: ', label_encoder.transform(country_list))

new_country_list = ['889', '778', '100', '200', '300']
print('new_encode_country_list: ', label_encoder.transform(new_country_list))
country_list:  ['1920' '567' '778' '889' '999' 'Unknown']
encode_country_list:  [3 2 1 0 4]
new_encode_country_list:  [3 2 5 5 5]

我的需求是这些不在编码规则里的值是需要删除的,所以重写后,编码的值我并不知道那个是我要删除的,而且对于数字的字符串它编码也会有问题。所以如下方法:

基于编码规则的修改(推荐)

from sklearn.preprocessing import LabelEncoder

le = preprocessing.LabelEncoder()
le.fit(X)

# label编码其实就是映射的字典,将编码字典保存
le_dict = dict(zip(le.classes_, le.transform(le.classes_)))

检索单个新项目的标签,如果项目丢失,则将值设置为未知

le_dict.get(new_item, 'Unknown')

检索 Dataframe 列的标签:

df['col'] = df['col'].apply(lambda x: le_dict.get(x, 'Unknown'))

# 再将新值删除
df = df[df['col'] != 'Unknown']
df['col'] = df['col'].astype(dtype='int64')



还有其它的方法,可以参考:

https://stackoom.com/question/1QM33

猜你喜欢

转载自blog.csdn.net/qq_42363032/article/details/121514951