其实在WeightedRandomSampler中,采样的权重针对的是每一个样本,所以我们可以确定好每个类对应的权重,再一一对应到样本上。并且,权重其实就是比值,num_samples就是一次采样的数目,里面的比值其实就是权重的比值。
class WeightedRandomSampler(Sampler):
r"""Samples elements from ``[0,..,len(weights)-1]`` with given probabilities (weights)."""
def __init__(self, weights, num_samples, replacement=True):
# ...省略类型检查
# weights用于确定生成索引的权重
self.weights = torch.as_tensor(weights, dtype=torch.double)
self.num_samples = num_samples
# 用于控制是否对数据进行有放回采样
self.replacement = replacement
def __iter__(self):
# 按照加权返回随机索引值
return iter(torch.multinomial(self.weights, self.num_samples, self.replacement).tolist())
对于Weighted Random Sampler类的__init__()来说,replacement参数依旧用于控制采样是否是有放回的;num_sampler用于控制生成的个数;
weights参数对应的是“样本”的权重而不是“类别的权重”。
其中__iter__()方法返回的数值为随机数序列,只不过生成的随机数序列是按照weights指定的权重确定的,测试代码如下:
# 位置[0]的权重为0,位置[1]的权重为10,其余位置权重均为1.1
weights = torch.Tensor([0, 10, 1.1, 1.1, 1.1, 1.1, 1.1])
wei_sampler = sampler.WeightedRandomSampler(weights, 6, True)
# 下面是输出:
index: 1
index: 2
index: 3
index: 4
index: 1
index: 1
从输出可以看出,位置[1]由于权重较大,被采样的次数较多,位置[0]由于权重为0所以没有被采样到,其余位置权重低所以都仅仅被采样一次。
class WeightedRandomSampler(Sampler[int]):
r"""Samples elements from ``[0,..,len(weights)-1]`` with given probabilities (weights).
Args:
weights (sequence) : a sequence of weights, not necessary summing up to one
num_samples (int): number of samples to draw
replacement (bool): if ``True``, samples are drawn with replacement.
If not, they are drawn without replacement, which means that when a
sample index is drawn for a row, it cannot be drawn again for that row.
generator (Generator): Generator used in sampling.
Example:
>>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True))
[4, 4, 1, 4, 5]
>>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False))
[0, 1, 4, 3, 2]
"""
weights: Tensor
num_samples: int
replacement: bool
def __init__(self, weights: Sequence[float], num_samples: int,
replacement: bool = True, generator=None) -> None:
if not isinstance(num_samples, _int_classes) or isinstance(num_samples, bool) or \
num_samples <= 0:
raise ValueError("num_samples should be a positive integer "
"value, but got num_samples={}".format(num_samples))
if not isinstance(replacement, bool):
raise ValueError("replacement should be a boolean value, but got "
"replacement={}".format(replacement))
self.weights = torch.as_tensor(weights, dtype=torch.double)
self.num_samples = num_samples
self.replacement = replacement
self.generator = generator
def __iter__(self):
rand_tensor = torch.multinomial(self.weights, self.num_samples, self.replacement, generator=self.generator)
return iter(rand_tensor.tolist())
def __len__(self):
return self.num_samples