【损失函数】Keras Loss Function

ps: 前半部分代码均为pytorch版本。最后附上uPIT-SiSNR的tensorflow版本。

损失函数,也是模型训练中非常重要的一块。

常见损失函数:

  • 语音分离: uPIT-SiSNR
  • 语音增强:l1, mse

损失函数示例
语音分离
【SepFormer】:uPIT-SiSNR(https://github.com/speechbrain/speechbrain
【DuralPath RNN】:SiSNR
【TransMask】:SiSNR
【Conv-Tasnet】:SiSNR(https://github.com/kaituoxu/Conv-TasNet

音乐分离
【Demucs】:l1(https://github.com/facebookresearch/demucs

语音降噪
【Denoiser】:l1 ,stft_loss(https://github.com/facebookresearch/denoiser
【Phasen】:SiSNR或mag_spec (https://github.com/huyanxin/phasen/blob/master/model/phasen.py)
【Transformer】l1
【Conformer】l1
【DCCRN】SiSNR(https://github.com/huyanxin/DeepComplexCRNhttps://huyanxin.github.io/DeepComplexCRN/
【DCUNet】:wSNR
【DF-Conformer】:SNR

Referance

关于uPIT Si-SNR https://blog.csdn.net/zjuPeco/article/details/106300674

speechbrain 的 losses 代码:【speechbrain/speechbrain/nnet/losses.py】:https://github.com/speechbrain/speechbrain/blob/develop/speechbrain/nnet/losses.py

SNR (Signal-to-Noise Ratio)

ref: https://blog.csdn.net/zjuPeco/article/details/106300674

扫描二维码关注公众号,回复: 13280556 查看本文章

在这里插入图片描述

Si-SNR (Scale invariant Signal-to-Noise Ratio)

在这里插入图片描述
也可参见论文中的表述:【Optimal scale-invariant signal-to-noise ratio and curriculum learning for monaural multi-speaker speech separation in noisy environmenthttp://www.apsipa.org/proceedings/2020/pdfs/0000711.pdf

可看出,SISNR的定义其实不止一种

这里以SpeechBrain中代码为例。可参见SpeechBrain的github主页:https://github.com/speechbrain/speechbrain
speechbrain/speechbrain/nnet/losses.py】:https://github.com/speechbrain/speechbrain/blob/develop/speechbrain/nnet/losses.py

def cal_si_snr(source, estimate_source):
    """Calculate SI-SNR.

    Arguments:
    ---------
    source: [T, B, C],
        Where B is batch size, T is the length of the sources, C is the number of sources
        the ordering is made so that this loss is compatible with the class PitWrapper.

    estimate_source: [T, B, C]
        The estimated source.

    Example:
    ---------
    >>> import numpy as np
    >>> x = torch.Tensor([[1, 0], [123, 45], [34, 5], [2312, 421]])
    >>> xhat = x[:, (1, 0)]
    >>> x = x.unsqueeze(-1).repeat(1, 1, 2)
    >>> xhat = xhat.unsqueeze(1).repeat(1, 2, 1)
    >>> si_snr = -cal_si_snr(x, xhat)
    >>> print(si_snr)
    tensor([[[ 25.2142, 144.1789],
             [130.9283,  25.2142]]])
    """
    EPS = 1e-8
    assert source.size() == estimate_source.size()
    device = estimate_source.device.type

    source_lengths = torch.tensor(
        [estimate_source.shape[0]] * estimate_source.shape[1], device=device
    )
    mask = get_mask(source, source_lengths)
    estimate_source *= mask

    num_samples = (
        source_lengths.contiguous().reshape(1, -1, 1).float()
    )  # [1, B, 1]
    mean_target = torch.sum(source, dim=0, keepdim=True) / num_samples
    mean_estimate = (
        torch.sum(estimate_source, dim=0, keepdim=True) / num_samples
    )
    zero_mean_target = source - mean_target
    zero_mean_estimate = estimate_source - mean_estimate
    # mask padding position along T
    zero_mean_target *= mask
    zero_mean_estimate *= mask

    # Step 2. SI-SNR with PIT
    # reshape to use broadcast
    s_target = zero_mean_target  # [T, B, C]
    s_estimate = zero_mean_estimate  # [T, B, C]
    # s_target = <s', s>s / ||s||^2
    dot = torch.sum(s_estimate * s_target, dim=0, keepdim=True)  # [1, B, C]
    s_target_energy = (
        torch.sum(s_target ** 2, dim=0, keepdim=True) + EPS
    )  # [1, B, C]
    proj = dot * s_target / s_target_energy  # [T, B, C]
    # e_noise = s' - s_target
    e_noise = s_estimate - proj  # [T, B, C]
    # SI-SNR = 10 * log_10(||s_target||^2 / ||e_noise||^2)
    si_snr_beforelog = torch.sum(proj ** 2, dim=0) / (
        torch.sum(e_noise ** 2, dim=0) + EPS
    )
    si_snr = 10 * torch.log10(si_snr_beforelog + EPS)  # [B, C]

    return -si_snr.unsqueeze(0)
   


def get_mask(source, source_lengths):
    """
    Arguments
    ---------
    source : [T, B, C]
    source_lengths : [B]

    Returns
    -------
    mask : [T, B, 1]

    Example:
    ---------
    >>> source = torch.randn(4, 3, 2)
    >>> source_lengths = torch.Tensor([2, 1, 4]).int()
    >>> mask = get_mask(source, source_lengths)
    >>> print(mask)
    tensor([[[1.],
             [1.],
             [1.]],
    <BLANKLINE>
            [[1.],
             [0.],
             [1.]],
    <BLANKLINE>
            [[0.],
             [0.],
             [1.]],
    <BLANKLINE>
            [[0.],
             [0.],
             [1.]]])
    """
    T, B, _ = source.size()
    mask = source.new_ones((T, B, 1))
    for i in range(B):
        mask[source_lengths[i] :, i, :] = 0
    return mask

值得注意的是,这里的s_target 和s_estimate均减去了平均值。同时,为了防止出现除法分母为0的错误,加上了EPS

PIT (Permutation Invariant Training)

PIT是一种训练的方法,全称为Permutation Invariant Training。这种训练方式就可以end-to-end去训练,总体思想很直觉,就是我先随便假设一个speakers对应于输出的的顺序,稍微train几下,得到一个model。然后,下一次train的时候,我会算两次SI-SDR之类的评价指标,分别是红1,蓝2和蓝1,红2,然后把Loss小的那个作为排序,然后按这个顺序train下去

uPIT (utterance-level PIT)

uPIT 相当于在上述所有permutation组合的情况中,找一种最优的输出。

实现代码:

class PitWrapper(nn.Module):
    """
    Permutation Invariant Wrapper to allow Permutation Invariant Training
    (PIT) with existing losses.

    Permutation invariance is calculated over the sources/classes axis which is
    assumed to be the rightmost dimension: predictions and targets tensors are
    assumed to have shape [batch, ..., channels, sources].

    Arguments
    ---------
    base_loss : function
        Base loss function, e.g. torch.nn.MSELoss. It is assumed that it takes
        two arguments:
        predictions and targets and no reduction is performed.
        (if a pytorch loss is used, the user must specify reduction="none").

    Returns
    ---------
    pit_loss : torch.nn.Module
        Torch module supporting forward method for PIT.

    Example
    -------
    >>> pit_mse = PitWrapper(nn.MSELoss(reduction="none"))
    >>> targets = torch.rand((2, 32, 4))
    >>> p = (3, 0, 2, 1)
    >>> predictions = targets[..., p]
    >>> loss, opt_p = pit_mse(predictions, targets)
    >>> loss
    tensor([0., 0.])
    """

    def __init__(self, base_loss):
        super(PitWrapper, self).__init__()
        self.base_loss = base_loss

    def _fast_pit(self, loss_mat):
        """
        Arguments
        ----------
        loss_mat : torch.Tensor
            Tensor of shape [sources, source] containing loss values for each
            possible permutation of predictions.

        Returns
        -------
        loss : torch.Tensor
            Permutation invariant loss for the current batch, tensor of shape [1]

        assigned_perm : tuple
            Indexes for optimal permutation of the input over sources which
            minimizes the loss.
        """

        loss = None
        assigned_perm = None
        for p in permutations(range(loss_mat.shape[0])):
            c_loss = loss_mat[range(loss_mat.shape[0]), p].mean()
            # return loss_mat[range(loss_mat.shape[0]), p][0], p
            #########################################################
            ### IMPORTANT ###########################################
            if loss is None or loss > c_loss:
                loss = c_loss
                assigned_perm = p
            #########################################################
        return loss, assigned_perm

    def _opt_perm_loss(self, pred, target):
        """
        Arguments
        ---------
        pred : torch.Tensor
            Network prediction for the current example, tensor of
            shape [..., sources].
        target : torch.Tensor
            Target for the current example, tensor of shape [..., sources].

        Returns
        -------
        loss : torch.Tensor
            Permutation invariant loss for the current example, tensor of shape [1]

        assigned_perm : tuple
            Indexes for optimal permutation of the input over sources which
            minimizes the loss.

        """

        n_sources = pred.size(-1)

        pred = pred.unsqueeze(-2).repeat(
            *[1 for x in range(len(pred.shape) - 1)], n_sources, 1
        )
        target = target.unsqueeze(-1).repeat(
            1, *[1 for x in range(len(target.shape) - 1)], n_sources
        )

        loss_mat = self.base_loss(pred, target)
        assert (
            len(loss_mat.shape) >= 2
        ), "Base loss should not perform any reduction operation"
        mean_over = [x for x in range(len(loss_mat.shape))]
        loss_mat = loss_mat.mean(dim=mean_over[:-2])

        return self._fast_pit(loss_mat)

    def reorder_tensor(self, tensor, p):
        """
        Arguments
        ---------
        tensor : torch.Tensor
            Tensor to reorder given the optimal permutation, of shape
            [batch, ..., sources].
        p : list of tuples
            List of optimal permutations, e.g. for batch=2 and n_sources=3
            [(0, 1, 2), (0, 2, 1].

        Returns
        -------
        reordered : torch.Tensor
            Reordered tensor given permutation p.
        """

        reordered = torch.zeros_like(tensor, device=tensor.device)
        for b in range(tensor.shape[0]):
            reordered[b] = tensor[b][..., p[b]].clone()
        return reordered

    def forward(self, preds, targets):
        """
            Arguments
            ---------
            preds : torch.Tensor
                Network predictions tensor, of shape
                [batch, channels, ..., sources].
            targets : torch.Tensor
                Target tensor, of shape [batch, channels, ..., sources].

            Returns
            -------
            loss : torch.Tensor
                Permutation invariant loss for current examples, tensor of
                shape [batch]

            perms : list
                List of indexes for optimal permutation of the inputs over
                sources.
                e.g., [(0, 1, 2), (2, 1, 0)] for three sources and 2 examples
                per batch.
        """
        losses = []
        perms = []
        for pred, label in zip(preds, targets):
            loss, p = self._opt_perm_loss(pred, label)
            perms.append(p)
            losses.append(loss)
        loss = torch.stack(losses)
        return loss, perms

其中,permulations类是python中一个枚举所有permulations的类:

如我们在做语音分离时,分离2路语音时,输出的两路可能对应的是分别是(第1个讲话人,第2个讲话人)或(第2个讲话人,第1个讲话人)。

同理,分离3路语音时,可能就存在6种permutations的组合。

将这个问题泛化即为,输入m路语音,输出n路语音 (n<=m), 那么可用class permutations来做: perms = permutations(range(m), n)

在上述uPIT-SiSNR中的具体用法为:

for p in permutations(range(loss_mat.shape[0])):
     c_loss = loss_mat[range(loss_mat.shape[0]), p].mean()
     # return loss_mat[range(loss_mat.shape[0]), p][0], p
     #########################################################
     ### IMPORTANT ###########################################
     if loss is None or loss > c_loss:
         loss = c_loss
         assigned_perm = p
     #########################################################
 return loss, assigned_perm
class permutations(object):
    """
    permutations(iterable[, r]) --> permutations object
    
    Return successive r-length permutations of elements in the iterable.
    
    permutations(range(3), 2) --> (0,1), (0,2), (1,0), (1,2), (2,0), (2,1)
    """
    def __getattribute__(self, *args, **kwargs): # real signature unknown
        """ Return getattr(self, name). """
        pass

    def __init__(self, iterable, r=None): # real signature unknown; restored from __doc__
        pass

    def __iter__(self, *args, **kwargs): # real signature unknown
        """ Implement iter(self). """
        pass

    @staticmethod # known case of __new__
    def __new__(*args, **kwargs): # real signature unknown
        """ Create and return a new object.  See help(type) for accurate signature. """
        pass

    def __next__(self, *args, **kwargs): # real signature unknown
        """ Implement next(self). """
        pass

    def __reduce__(self, *args, **kwargs): # real signature unknown
        """ Return state information for pickling. """
        pass

    def __setstate__(self, *args, **kwargs): # real signature unknown
        """ Set state information for unpickling. """
        pass

    def __sizeof__(self, *args, **kwargs): # real signature unknown
        """ Returns size in memory, in bytes. """
        pass

这里的base_function可为指定的loss function,比如我们这里的SiSNR。具体实现方式如下。

uPIT-SiSNR

Pytorch实现代码

def get_si_snr_with_pitwrapper(source, estimate_source):
    """This function wraps si_snr calculation with the speechbrain pit-wrapper.

    Arguments:
    ---------
    source: [B, T, C],
        Where B is the batch size, T is the length of the sources, C is
        the number of sources the ordering is made so that this loss is
        compatible with the class PitWrapper.

    estimate_source: [B, T, C]
        The estimated source.

    Example:
    ---------
    >>> x = torch.arange(600).reshape(3, 100, 2)
    >>> xhat = x[:, :, (1, 0)]
    >>> si_snr = -get_si_snr_with_pitwrapper(x, xhat)
    >>> print(si_snr)
    tensor([135.2284, 135.2284, 135.2284])
    """

    pit_si_snr = PitWrapper(cal_si_snr)
    loss, perms = pit_si_snr(source, estimate_source)

    return loss

Tensorflow实现代码

在此实现上述的uPIT-SiSNR的tensorflow版本(v2.40)

注:我们基于keras的Loss基类函数进行继承。其中部分注释为原pytorch版本的代码。

pytorch与tensorflow的部分对比如下:

# pytorch
c_loss = loss_mat[range(loss_mat.shape[0]), p].mean()
# tensorflow
c_loss = tf.reduce_mean([loss_mat[i][p[i]] for i in range(loss_mat.shape[0])])


# pytorch
pred = pred.unsqueeze(-2).repeat(
            *[1 for x in range(len(pred.shape) - 1)], n_sources, 1
        )
# tensorflow
pred = tf.tile(tf.expand_dims(pred, axis=-2), [len([1 for x in range(len(pred.shape) - 1)]), n_sources, 1])


# pytorch
loss_mat = loss_mat.mean(dim=mean_over[:-2])
# tensorflow
loss_mat = tf.reduce_mean(loss_mat, axis=mean_over[:-2])


# pytorch
def forward(self, preds, targets):
# tensorflow
def call(self, preds, targets):


# pytorch
-si_snr.unsqueeze(0)
# tensorflow
-tf.expand_dims(si_snr, 0)


# pytorch
x = torch.Tensor([[1, 0], [123, 45], [34, 5], [2312, 421]])
xhat = x[:, (1, 0)]
# tensorflow
x = tf.constant([[1, 0], [123, 45], [34, 5], [2312, 421]], dtype=float)
xhat = tf.slice(x, [0, 1], [x.shape[0], 1])
xhat = tf.concat([xhat, tf.slice(x, [0, 0], [x.shape[0], 1])], axis=1)

# pytorch
xhat = x[:, :, (1, 0)]
# tensorflow
xhat = tf.slice(x, [0, 0, 1], [x.shape[0], x.shape[1], 1])
xhat = tf.concat([xhat, tf.slice(x, [0, 0, 0], [x.shape[0], x.shape[1], 1])], axis=2)

代码如下:

import tensorflow as tf
from tensorflow.python.keras.losses import Loss, mse
from itertools import permutations
from tensorflow.python.keras.utils import losses_utils

class PitWrapper(Loss):
    """
    Permutation Invariant Wrapper to allow Permutation Invariant Training
    (PIT) with existing losses.

    Permutation invariance is calculated over the sources/classes axis which is
    assumed to be the rightmost dimension: predictions and targets tensors are
    assumed to have shape [batch, ..., channels, sources].

    Arguments
    ---------
    base_loss : function
        Base loss function, e.g. torch.nn.MSELoss. It is assumed that it takes
        two arguments:
        predictions and targets and no reduction is performed.
        (if a pytorch loss is used, the user must specify reduction="none").

    Returns
    ---------
    pit_loss : torch.nn.Module
        Torch module supporting forward method for PIT.

    Example
    -------
    >>> pit_mse = PitWrapper(nn.MSELoss(reduction="none"))
    >>> targets = torch.rand((2, 32, 4))
    >>> p = (3, 0, 2, 1)
    >>> predictions = targets[..., p]
    >>> loss, opt_p = pit_mse(predictions, targets)
    >>> loss
    tensor([0., 0.])
    """

    def __init__(self, base_loss):
        super().__init__()
        self.reduction = losses_utils.ReductionV2.NONE ## IMPORTANT ##
        self.base_loss = base_loss

    def _fast_pit(self, loss_mat):
        """
        Arguments
        ----------
        loss_mat : torch.Tensor
            Tensor of shape [sources, source] containing loss values for each
            possible permutation of predictions.

        Returns
        -------
        loss : torch.Tensor
            Permutation invariant loss for the current batch, tensor of shape [1]

        assigned_perm : tuple
            Indexes for optimal permutation of the input over sources which
            minimizes the loss.
        """

        loss = None
        assigned_perm = None
        for p in permutations(range(loss_mat.shape[0])):
            c_loss = tf.reduce_mean([loss_mat[i][p[i]] for i in range(loss_mat.shape[0])]) # loss_mat[range(loss_mat.shape[0]), p].mean()
            if loss is None or loss > c_loss:
                loss = c_loss
                assigned_perm = p
        return loss, assigned_perm

    def _opt_perm_loss(self, pred, target):
        """
        Arguments
        ---------
        pred : torch.Tensor
            Network prediction for the current example, tensor of
            shape [..., sources].
        target : torch.Tensor
            Target for the current example, tensor of shape [..., sources].

        Returns
        -------
        loss : torch.Tensor
            Permutation invariant loss for the current example, tensor of shape [1]

        assigned_perm : tuple
            Indexes for optimal permutation of the input over sources which
            minimizes the loss.

        """

        n_sources = pred.shape[-1] #pred.size(-1)

        # pred = pred.unsqueeze(-2).repeat(
        #     *[1 for x in range(len(pred.shape) - 1)], n_sources, 1
        # )
        pred = tf.tile(tf.expand_dims(pred, axis=-2), [len([1 for x in range(len(pred.shape) - 1)]), n_sources, 1])

        # target = target.unsqueeze(-1).repeat(
        #     1, *[1 for x in range(len(target.shape) - 1)], n_sources
        # )
        target = tf.tile(tf.expand_dims(target, axis=-1), [1, len([1 for x in range(len(target.shape) - 1)]), n_sources])

        loss_mat = self.base_loss(pred, target)
        assert (
            len(loss_mat.shape) >= 2
        ), "Base loss should not perform any reduction operation"
        mean_over = [x for x in range(len(loss_mat.shape))]
        # loss_mat = loss_mat.mean(dim=mean_over[:-2])
        loss_mat = tf.reduce_mean(loss_mat, axis=mean_over[:-2])

        return self._fast_pit(loss_mat)

    def reorder_tensor(self, tensor, p):
        """
        Arguments
        ---------
        tensor : torch.Tensor
            Tensor to reorder given the optimal permutation, of shape
            [batch, ..., sources].
        p : list of tuples
            List of optimal permutations, e.g. for batch=2 and n_sources=3
            [(0, 1, 2), (0, 2, 1].

        Returns
        -------
        reordered : torch.Tensor
            Reordered tensor given permutation p.
        """

        reordered = tf.zeros_like(tensor, device=tensor.device)
        for b in range(tensor.shape[0]):
            reordered[b] = tensor[b][..., p[b]].clone()
        return reordered

    def call(self, preds, targets): #forward(self, preds, targets):
        """
            Arguments
            ---------
            preds : torch.Tensor
                Network predictions tensor, of shape
                [batch, channels, ..., sources].
            targets : torch.Tensor
                Target tensor, of shape [batch, channels, ..., sources].

            Returns
            -------
            loss : torch.Tensor
                Permutation invariant loss for current examples, tensor of
                shape [batch]

            perms : list
                List of indexes for optimal permutation of the inputs over
                sources.
                e.g., [(0, 1, 2), (2, 1, 0)] for three sources and 2 examples
                per batch.
        """
        losses = []
        perms = []
        for pred, label in zip(preds, targets):
            loss, p = self._opt_perm_loss(pred, label)
            perms.append(p)
            losses.append(loss)
        loss = tf.stack(losses)
        return loss #, perms # todo?

def get_si_snr_with_pitwrapper(source, estimate_source):
    """This function wraps si_snr calculation with the speechbrain pit-wrapper.

    Arguments:
    ---------
    source: [B, T, C],
        Where B is the batch size, T is the length of the sources, C is
        the number of sources the ordering is made so that this loss is
        compatible with the class PitWrapper.

    estimate_source: [B, T, C]
        The estimated source.

    Example:
    ---------
    >>> x = torch.arange(600).reshape(3, 100, 2)
    >>> xhat = x[:, :, (1, 0)]
    >>> si_snr = -get_si_snr_with_pitwrapper(x, xhat)
    >>> print(si_snr)
    tensor([135.2284, 135.2284, 135.2284])
    """

    pit_si_snr = PitWrapper(cal_si_snr)
    loss = pit_si_snr(estimate_source, source) # , perms

    return loss

def get_mse_with_pitwrapper(source, estimate_source):

    pit_si_snr = PitWrapper(mse)
    loss = pit_si_snr(estimate_source, source) # , perms

    return loss

def cal_si_snr(estimate_source, source):
    """Calculate SI-SNR.

    Arguments:
    ---------
    source: [T, B, C],
        Where B is batch size, T is the length of the sources, C is the number of sources
        the ordering is made so that this loss is compatible with the class PitWrapper.

    estimate_source: [T, B, C]
        The estimated source.

    Example:
    ---------
    >>> import numpy as np
    >>> x = torch.Tensor([[1, 0], [123, 45], [34, 5], [2312, 421]])
    >>> xhat = x[:, (1, 0)]
    >>> x = x.unsqueeze(-1).repeat(1, 1, 2)
    >>> xhat = xhat.unsqueeze(1).repeat(1, 2, 1)
    >>> si_snr = -cal_si_snr(x, xhat)
    >>> print(si_snr)
    tensor([[[ 25.2142, 144.1789],
             [130.9283,  25.2142]]])
    """
    EPS = 1e-8
    source_lengths = tf.constant([estimate_source.shape[0]] * estimate_source.shape[1])

    mask = get_mask(source, source_lengths)
    estimate_source = tf.multiply(estimate_source, mask)

    # num_samples = tf.reshape(source_lengths, [1, -1, 1])

    mean_target = tf.math.reduce_mean(source, axis=0, keepdims=True) #/ num_samples
    mean_estimate = (
        tf.math.reduce_mean(estimate_source, axis=0, keepdims=True) #/ num_samples
    )
    zero_mean_target = source - mean_target
    zero_mean_estimate = estimate_source - mean_estimate
    # mask padding position along T
    zero_mean_target *= mask
    zero_mean_estimate *= mask

    # Step 2. SI-SNR with PIT
    # reshape to use broadcast
    s_target = zero_mean_target  # [T, B, C]
    s_estimate = zero_mean_estimate  # [T, B, C]
    # s_target = <s', s>s / ||s||^2
    dot = tf.math.reduce_sum(s_estimate * s_target, axis=0, keepdims=True)  # [1, B, C]
    s_target_energy = (
        tf.math.reduce_sum(s_target ** 2, axis=0, keepdims=True) + EPS
    )  # [1, B, C]
    proj = dot * s_target / s_target_energy  # [T, B, C]
    # e_noise = s' - s_target
    e_noise = s_estimate - proj  # [T, B, C]
    # SI-SNR = 10 * log_10(||s_target||^2 / ||e_noise||^2)
    si_snr_beforelog = tf.math.reduce_sum(proj ** 2, axis=0) / (
        tf.reduce_sum(e_noise ** 2, axis=0) + EPS
    )
    si_snr = 10 * tf.math.log(si_snr_beforelog + EPS) / tf.math.log(10.0) # [B, C]

    return -tf.expand_dims(si_snr, 0) #-si_snr.unsqueeze(0)

def get_mask(source, source_lengths):
    """
    Arguments
    ---------
    source : [T, B, C]
    source_lengths : [B]

    Returns
    -------
    mask : [T, B, 1]

    Example:
    ---------
    >>> source = torch.randn(4, 3, 2)
    >>> source_lengths = torch.Tensor([2, 1, 4]).int()
    >>> mask = get_mask(source, source_lengths)
    >>> print(mask)
    tensor([[[1.],
             [1.],
             [1.]],
    <BLANKLINE>
            [[1.],
             [0.],
             [1.]],
    <BLANKLINE>
            [[0.],
             [0.],
             [1.]],
    <BLANKLINE>
            [[0.],
             [0.],
             [1.]]])
    """
    T, B, _ = source.shape
    mask = None
    for i in range(B):
        if mask is None:
            mask = tf.concat([tf.ones(shape=[source_lengths[i], 1, 1], dtype=source.dtype),
                              tf.zeros(shape=[T - source_lengths[i], 1, 1], dtype=source.dtype)], axis=0)
        else:
            mask_i = tf.concat([tf.ones(shape=[source_lengths[i], 1, 1], dtype=source.dtype),
                               tf.zeros(shape=[T - source_lengths[i], 1, 1], dtype=source.dtype)], axis=0)
            mask = tf.concat([mask, mask_i], axis=1)
    return mask

def test_get_mask():
    '''
    Example:
    ---------
    >>> source = torch.randn(4, 3, 2)
    >>> source_lengths = torch.Tensor([2, 1, 4]).int()
    >>> mask = get_mask(source, source_lengths)
    >>> print(mask)
    tensor([[[1.],
             [1.],
             [1.]],
    <BLANKLINE>
            [[1.],
             [0.],
             [1.]],
    <BLANKLINE>
            [[0.],
             [0.],
             [1.]],
    <BLANKLINE>
            [[0.],
             [0.],
             [1.]]])
    '''
    source = tf.random.uniform([4, 3, 2])
    source_length = tf.constant([2, 1, 4])
    mask = get_mask(source, source_length)
    print(mask)

def test_cal_si_snr():
    '''
    Example:
    ---------
    >>> import numpy as np
    >>> x = torch.Tensor([[1, 0], [123, 45], [34, 5], [2312, 421]])
    >>> xhat = x[:, (1, 0)]
    >>> x = x.unsqueeze(-1).repeat(1, 1, 2)
    >>> xhat = xhat.unsqueeze(1).repeat(1, 2, 1)
    >>> si_snr = -cal_si_snr(x, xhat)
    >>> print(si_snr)
    tensor([[[ 25.2142, 144.1789],
             [130.9283,  25.2142]]])
    '''
    x = tf.constant([[1, 0], [123, 45], [34, 5], [2312, 421]], dtype=float)
    xhat = tf.slice(x, [0, 1], [x.shape[0], 1])
    xhat = tf.concat([xhat, tf.slice(x, [0, 0], [x.shape[0], 1])], axis=1)
    x = tf.expand_dims(x, axis=-1)
    x = tf.concat([x, x], axis=-1)
    xhat = tf.expand_dims(xhat, axis=1)
    xhat = tf.concat([xhat, xhat], axis=1)
    si_snr = -cal_si_snr(xhat, x)
    print(si_snr)

def test_upit_sisnr():
    '''
    Example:
    ---------
    >>> x = torch.arange(600).reshape(3, 100, 2)
    >>> xhat = x[:, :, (1, 0)]
    >>> si_snr = -get_si_snr_with_pitwrapper(x, xhat)
    >>> print(si_snr)
    tensor([135.2284, 135.2284, 135.2284])
    '''
    x = tf.random.uniform([800], 0, 800) #tf.range(800, dtype=float)
    x = tf.reshape(x, [4, 100, 2])
    xhat = tf.slice(x, [0, 0, 1], [x.shape[0], x.shape[1], 1])
    xhat = tf.concat([xhat, tf.slice(x, [0, 0, 0], [x.shape[0], x.shape[1], 1])], axis=2)
    si_snr = -get_si_snr_with_pitwrapper(xhat, x)
    print(si_snr)
    si_snr = -get_si_snr_with_pitwrapper(x, xhat)
    print(si_snr)
    # tf.Tensor(135.22835, shape=(), dtype=float32), which is different from the pytorch version by a mean manipulation
    # tf.Tensor([135.22835 135.22835 135.22835], shape=(3,), dtype=float32), where reduction is none

if __name__ == '__main__':

    # test_get_mask() # ok

    # test_cal_si_snr() # ok

    # test_upit_sisnr()  # ok, somehow: according to reduction

    print('done')

猜你喜欢

转载自blog.csdn.net/u010637291/article/details/118381364