tensor内部结构

内部结构

1.tensor分为头信息区(Tensor)和存储区(Storage);
    信息区:tensor的形状(size)、步长(stride)、数据类型(type),信息区占用内存较少
    存储区:数据保存为连续数组,主要内存占用在存储区
2.每一个tensor有着一个对应的storage,storage是在data之上封装的接口;

具体操作

 
  1. a = t.arange(0,6)  
  2. a.storage()  
  3. b = a.view(2,3)  
  4. b.storage()  
  5. #a和b的storage的内存地址一样,即他们是用同一个storage  
  6. print( id(b.storage) == id(a.storage) )  
  7.   
  8. #a改变,b也随之改变,因为他们共享storage  
  9. a[1] = 100  
  10. print(b)  
  11.   
  12. c = a[2:]  
  13. c.storage()  
  14. print(c)  
  15.   
  16. #3198436924144    3198436924128,首地址差16,因为两个元素2*8,每个元素占8个字节  
  17. print(c.data_ptr())  
  18. print(a.data_ptr())  
  19.   
  20. c[0] = -100  
  21. print(a)  
  22.   
  23. #3个tensor共享storage  
  24. print(id( a.storage() ) == id( b.storage() ) == id( c.storage()) )  
  25.   
  26.   
  27. #以储存元素的个数的形式返回tensor在地城内存中的偏移量  
  28. print( a.storage_offset() )  
  29. print( b.storage_offset() )  
  30. print( c.storage_offset() )  
  31. '''''0  0  2'''  
  32.   
  33.   
  34. print('b',b)  
  35. e = b[::1,::2]  
  36. print('e',e)  
  37. '''''b tensor([[   0,  100, -100], 
  38.         [   3,    4,    5]]) 
  39. e tensor([[   0, -100], 
  40.         [   3,    5]])'''  
  41.   
  42. #tensor步长  
  43. print(b.stride(),e.stride())  
  44. '''''(3, 1) (3, 2)'''  
  45.   
  46. #判断tensor是否连续  
  47. print(e.is_contiguous())  
  48. '''''False'''  
  49. f = e.contiguous()  
  50. print(f.is_contiguous())  
  51. '''''True'''  

总结

大部分操作并不修改tensor的数据,只修改了tensor的头信息,这种做法更节省内存,提升了处理速度。
注意:有些操作会导致tensor不连续,可以用tensor.contiguous方法将它们变成连续的数据。

2018-11-22 20:36:00

猜你喜欢

转载自www.cnblogs.com/monkeyT/p/10003747.html