加权随机采样WeightedRandomSampler

链接

其实在WeightedRandomSampler中,采样的权重针对的是每一个样本,所以我们可以确定好每个类对应的权重,再一一对应到样本上。并且,权重其实就是比值,num_samples就是一次采样的数目,里面的比值其实就是权重的比值。

class WeightedRandomSampler(Sampler):
    r"""Samples elements from ``[0,..,len(weights)-1]`` with given probabilities (weights)."""
    def __init__(self, weights, num_samples, replacement=True):
         # ...省略类型检查
         # weights用于确定生成索引的权重
        self.weights = torch.as_tensor(weights, dtype=torch.double)
        self.num_samples = num_samples
        # 用于控制是否对数据进行有放回采样
        self.replacement = replacement
    def __iter__(self):
        # 按照加权返回随机索引值
        return iter(torch.multinomial(self.weights, self.num_samples, self.replacement).tolist())

对于Weighted Random Sampler类的__init__()来说,replacement参数依旧用于控制采样是否是有放回的;num_sampler用于控制生成的个数

weights参数对应的是“样本”的权重而不是“类别的权重”

其中__iter__()方法返回的数值为随机数序列,只不过生成的随机数序列是按照weights指定的权重确定的,测试代码如下:

# 位置[0]的权重为0,位置[1]的权重为10,其余位置权重均为1.1
weights = torch.Tensor([0, 10, 1.1, 1.1, 1.1, 1.1, 1.1])
wei_sampler = sampler.WeightedRandomSampler(weights, 6, True)
# 下面是输出:
index: 1
index: 2
index: 3
index: 4
index: 1
index: 1

从输出可以看出,位置[1]由于权重较大,被采样的次数较多,位置[0]由于权重为0所以没有被采样到,其余位置权重低所以都仅仅被采样一次。

class WeightedRandomSampler(Sampler[int]):
    r"""Samples elements from ``[0,..,len(weights)-1]`` with given probabilities (weights).
    Args:
        weights (sequence)   : a sequence of weights, not necessary summing up to one
        num_samples (int): number of samples to draw
        replacement (bool): if ``True``, samples are drawn with replacement.
            If not, they are drawn without replacement, which means that when a
            sample index is drawn for a row, it cannot be drawn again for that row.
        generator (Generator): Generator used in sampling.
    Example:
        >>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True))
        [4, 4, 1, 4, 5]
        >>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False))
        [0, 1, 4, 3, 2]
    """
    weights: Tensor
    num_samples: int
    replacement: bool

    def __init__(self, weights: Sequence[float], num_samples: int,
                 replacement: bool = True, generator=None) -> None:
        if not isinstance(num_samples, _int_classes) or isinstance(num_samples, bool) or \
                num_samples <= 0:
            raise ValueError("num_samples should be a positive integer "
                             "value, but got num_samples={}".format(num_samples))
        if not isinstance(replacement, bool):
            raise ValueError("replacement should be a boolean value, but got "
                             "replacement={}".format(replacement))
        self.weights = torch.as_tensor(weights, dtype=torch.double)
        self.num_samples = num_samples
        self.replacement = replacement
        self.generator = generator

    def __iter__(self):
        rand_tensor = torch.multinomial(self.weights, self.num_samples, self.replacement, generator=self.generator)
        return iter(rand_tensor.tolist())

    def __len__(self):
        return self.num_samples

猜你喜欢

转载自blog.csdn.net/caihuanqia/article/details/113258690