該庫允許在Python中有效地讀取和編寫Tfrecord文件。該庫還為Pytorch提供了TfreCord文件的iTerabledataSet讀取器。目前支持未壓縮和壓縮的Gzip Tfrecords。
pip3 install 'tfrecord[torch]'
建議為每個Tfrecord文件創建一個索引文件。使用多個工人時必須提供索引文件,否則加載程序可以返回重複記錄。您可以使用此實用程序程序為單個tfrecord文件創建索引文件:
python3 -m tfrecord.tools.tfrecord2idx <tfrecord path> <index path>
在目錄運行中創建所有“ .tfidnex”文件的“ .tfidnex”文件:
tfrecord2idx <data dir>
使用tfrecorddataset讀取pytorch中的tfrecord文件。
import torch
from tfrecord . torch . dataset import TFRecordDataset
tfrecord_path = "/tmp/data.tfrecord"
index_path = None
description = { "image" : "byte" , "label" : "float" }
dataset = TFRecordDataset ( tfrecord_path , index_path , description )
loader = torch . utils . data . DataLoader ( dataset , batch_size = 32 )
data = next ( iter ( loader ))
print ( data )使用MultitFrecordDataSet讀取多個Tfrecord文件。此類從給定的tfrecord文件中的示例帶有給定的概率。
import torch
from tfrecord . torch . dataset import MultiTFRecordDataset
tfrecord_pattern = "/tmp/{}.tfrecord"
index_pattern = "/tmp/{}.index"
splits = {
"dataset1" : 0.8 ,
"dataset2" : 0.2 ,
}
description = { "image" : "byte" , "label" : "int" }
dataset = MultiTFRecordDataset ( tfrecord_pattern , index_pattern , splits , description )
loader = torch . utils . data . DataLoader ( dataset , batch_size = 32 )
data = next ( iter ( loader ))
print ( data )默認情況下, MultiTFRecordDataset是無限的,這意味著它將永遠採樣數據。您可以通過提供適當的標誌來使其有限
dataset = MultiTFRecordDataset(..., infinite=False)
當您提供隊列尺寸時,TfrecordDataSet和MultitFrecordDataSet都會自動將數據供電。
dataset = TFRecordDataset(..., shuffle_queue_size=1024)
您可以選擇將函數作為transform參數傳遞,以在返回之前執行功能的後處理。例如,這可以用於解碼圖像或將顏色標準化為特定範圍或墊變量長度序列。
import tfrecord
import cv2
def decode_image ( features ):
# get BGR image from bytes
features [ "image" ] = cv2 . imdecode ( features [ "image" ], - 1 )
return features
description = {
"image" : "bytes" ,
}
dataset = tfrecord . torch . TFRecordDataset ( "/tmp/data.tfrecord" ,
index_path = None ,
description = description ,
transform = decode_image )
data = next ( iter ( dataset ))
print ( data ) import tfrecord
writer = tfrecord . TFRecordWriter ( "/tmp/data.tfrecord" )
writer . write ({
"image" : ( image_bytes , "byte" ),
"label" : ( label , "float" ),
"index" : ( index , "int" )
})
writer . close () import tfrecord
loader = tfrecord . tfrecord_loader ( "/tmp/data.tfrecord" , None , {
"image" : "byte" ,
"label" : "float" ,
"index" : "int"
})
for record in loader :
print ( record [ "label" ])可以使用上面顯示的相同方法讀取和編寫sequenceExamples,並帶有額外的參數(用於讀取的sequence_description和sequence_datum用於寫作),從而導致相應的讀/寫功能將數據視為sequenceExample。
import tfrecord
writer = tfrecord . TFRecordWriter ( "/tmp/data.tfrecord" )
writer . write ({ 'length' : ( 3 , 'int' ), 'label' : ( 1 , 'int' )},
{ 'tokens' : ([[ 0 , 0 , 1 ], [ 0 , 1 , 0 ], [ 1 , 0 , 0 ]], 'int' ), 'seq_labels' : ([ 0 , 1 , 1 ], 'int' )})
writer . write ({ 'length' : ( 3 , 'int' ), 'label' : ( 1 , 'int' )},
{ 'tokens' : ([[ 0 , 0 , 1 ], [ 1 , 0 , 0 ]], 'int' ), 'seq_labels' : ([ 0 , 1 ], 'int' )})
writer . close ()從序列樣本yeilds讀取一個包含兩個元素的元組。
import tfrecord
context_description = { "length" : "int" , "label" : "int" }
sequence_description = { "tokens" : "int" , "seq_labels" : "int" }
loader = tfrecord . tfrecord_loader ( "/tmp/data.tfrecord" , None ,
context_description ,
sequence_description = sequence_description )
for context , sequence_feats in loader :
print ( context [ "label" ])
print ( sequence_feats [ "seq_labels" ])如有關Transforming Input的部分所述,可以將函數作為transform參數傳遞,以執行特徵的後處理。對於序列特徵,應該使用它,因為這些是可變的長度序列,並且需要在批處理之前填充。
import torch
import numpy as np
from tfrecord . torch . dataset import TFRecordDataset
PAD_WIDTH = 5
def pad_sequence_feats ( data ):
context , features = data
for k , v in features . items ():
features [ k ] = np . pad ( v , (( 0 , PAD_WIDTH - len ( v )), ( 0 , 0 )), 'constant' )
return ( context , features )
context_description = { "length" : "int" , "label" : "int" }
sequence_description = { "tokens" : "int " , "seq_labels" : "int" }
dataset = TFRecordDataset ( "/tmp/data.tfrecord" ,
index_path = None ,
description = context_description ,
transform = pad_sequence_feats ,
sequence_description = sequence_description )
loader = torch . utils . data . DataLoader ( dataset , batch_size = 32 )
data = next ( iter ( loader ))
print ( data )另外,您可以選擇實現自定義collate_fn ,以組裝批處理,例如執行動態填充。
import torch
import numpy as np
from tfrecord . torch . dataset import TFRecordDataset
def collate_fn ( batch ):
from torch . utils . data . _utils import collate
from torch . nn . utils import rnn
context , feats = zip ( * batch )
feats_ = { k : [ torch . Tensor ( d [ k ]) for d in feats ] for k in feats [ 0 ]}
return ( collate . default_collate ( context ),
{ k : rnn . pad_sequence ( f , True ) for ( k , f ) in feats_ . items ()})
context_description = { "length" : "int" , "label" : "int" }
sequence_description = { "tokens" : "int " , "seq_labels" : "int" }
dataset = TFRecordDataset ( "/tmp/data.tfrecord" ,
index_path = None ,
description = context_description ,
transform = pad_sequence_feats ,
sequence_description = sequence_description )
loader = torch . utils . data . DataLoader ( dataset , batch_size = 32 , collate_fn = collate_fn )
data = next ( iter ( loader ))
print ( data )