组合/子集问题,从n个元素的集合中选择k个元素作为子集的所有情况(python实现)
问题背景
机器学习基础这门课的第一个实验是熟悉实验环境,任务是利用服务器上的jupyter环境可视化一个数据集,这个数据集的特征有3个,标签有一个,要求是以两个特征为一个组合画散点图,我觉得这个时候不能太粗暴的写几个for迭代就完事了,稍微写的巧妙一点,泛化能力强一点的工具函数能更好的完成之后出现的任务。
那么问题来了,如何获得所有k个特征组合的情况呢?
解决方法
参考了网上的一片博客,觉得他的方法不错,搞懂了之后决定记录一下这个方法。
参考原文: https://blog.csdn.net/zdy0_2004/article/details/17006957
先耐心搞懂POJ2453
这个问题是来自北京大学的Online Judge
问题是: 求一个数next,这个数大于给定的数C,满足2进制形式和C有相同1的个数的最小的数。
举个例子: 26的二进制为11010,他的下一个数next为11100,他们二进制含有相同的1,并且是满足大于26的最小的一个正整数。
再举个例子: 78 => 1001110,next => 1010011(83)
可以发现其原理就是要将低位的1串的第一个往左进一位,剩下的串右移到最低位:
- 1001110 -> 1010110
- 1010110 -> 1010011
算法流程(78为例子):
找到最右的一串1,并将最前的1进位:
只要在这一串的最后一个1(也就是这个数二进制的最后一个1)加1就可以将最后的一串1都进位,通过
C & (-C)
: 负数是原码取反码加一,-C: 1001110 => 0110001 + 1 => 0110010,仅有最后一个1是与原数相同的
C & (-C): 可以得到最后一个1所在的数x = 0000010
1001110 & 0110010 0000010
用x加C得到第一个1进位后的数t = 1010000
将最后一串1中剩下的1移动到最低位:
得到仅有最右一串1的二进制数
C ^ t
用进位后的t和原数异或,可以去掉高位的1,但是会多出两个1:C ^ t: 得到0011110第三个1是进位的1,并且进位的1算在1的总个数里面,原来位置也多余了,所以只要将粗体的两个1移动就好,所以多出了两个1
1001110 ^ 1010000 0011110
对齐最后一位1,右移到最低位:
(C ^ t) / x
,很巧妙就是最后一位1截断0011110 / 0000010 0001111
由于多出了2位,反正是右移到最低,直接再右移两位即可
>>2
将这个数加在t上,或者用或运算
t | ()(C ^ t) / x) >> 2
就得到了符合条件的数啦
感觉只要涉及到二进制的都会用位运算
python 实现
def next_n(n):
"""
find next n [POJ2453]
:param n: positive int
:return next_n
"""
x = n & -n
t = n + x
ans = t | (((n ^ t) // x) >> 2) # 这里要整除 不然就被转换成float了
return ans
进入正题 解决组合问题
其实组合问题可以看成是每一个情况都有一个one hot编码,就是说被选进组合的下标所在对应特征位置为1,否则为0,那也可以看成是二进制编码!
n个元素中选则k个作为一个组合,那么可以看成是一个n为二进制的数,其中只有k个位是1
算法流程:
- 构造一个n为的二进制数,其最低的k位为1,用位运算的方法其实考虑了很久,有一个方法是用第k+1位为1的数减去1就可以得到最低k位全是1了:
s = (1 << k) - 1
- 从s开始,
s = next_n(s)
,直到s成为最高k位都是1其余都是0的数。每一个s对应一种组合情况。 - 处理组合,简单来说就是s的所有位,见代码解释
python 实现
def gen_index_comb(pool):
"""
遍历所有pool位来判断被选入的下标
:param pool: int
:return comb: tuple 下标组合
"""
i = 0 # 记录下标
# x用来判断pool在第i位上是否有1
x = 1 << i # 从1开始
comb = []
while x <= pool:
if x & pool:
# 相与不为0 说明有1
comb.append(i)
i += 1
x <<= 1 # 看下一个bit
return tuple(comb)
def gen_comb(n, k):
"""
choose k elements from n-d array
:param n: int, k: int
:return comb_list: list of combination of index
"""
pool = (1 << k) - 1 # 注意<<优先级
ceiling = pool << (n - k) # 上限通过下限左移n-k位得到
comb_list = [tuple(range(k))]
while pool < ceiling:
pool = next_n(pool)
comb_list.append(gen_index_comb(pool))
return comb_list
p.s. C++实现之后补上