对一批数据分batch的代码

def get_batches(x, y, n_batches=10):
    """ 这是一个生成器函数,按照n_batches的大小将数据划分了小块 """
    batch_size = len(x)//n_batches
    
    for ii in range(0, n_batches*batch_size, batch_size):
        # 如果不是最后一个batch,那么这个batch中应该有batch_size个数据
        if ii != (n_batches-1)*batch_size:
            X, Y = x[ii: ii+batch_size], y[ii: ii+batch_size] 
        # 否则的话,那剩余的不够batch_size的数据都凑入到一个batch中
        else:
            X, Y = x[ii:], y[ii:]
        # 生成器语法,返回X和Y
        yield X, Y

猜你喜欢

转载自www.cnblogs.com/datascience1/p/11430017.html
今日推荐