Diese Bibliothek ermöglicht das Lesen und Schreiben von TFRECORD -Dateien effizient in Python. Die Bibliothek bietet auch einen iterableDataset -Leser von TFRECORD -Dateien für PyTorch. Derzeit werden unkomprimierte und komprimierte GZIP -TFRECORDS unterstützt.
pip3 install 'tfrecord[torch]'
Es wird empfohlen, eine Indexdatei für jede TFRECORD -Datei zu erstellen. In der Indexdatei muss bei der Verwendung mehrerer Mitarbeiter bereitgestellt werden, da der Loader doppelte Datensätze zurückgeben kann. Mit diesem Dienstprogramm können Sie eine Indexdatei für eine einzelne TFRECORD -Datei erstellen:
python3 -m tfrecord.tools.tfrecord2idx <tfrecord path> <index path>
So erstellen Sie " .tfidNex" -Dateien für alle " .tfrecord" -Dateien in einem Verzeichnislauf:
tfrecord2idx <data dir>
Verwenden Sie TFRECORDDATASET, um TFRECORD -Dateien in Pytorch zu lesen.
import torch
from tfrecord . torch . dataset import TFRecordDataset
tfrecord_path = "/tmp/data.tfrecord"
index_path = None
description = { "image" : "byte" , "label" : "float" }
dataset = TFRecordDataset ( tfrecord_path , index_path , description )
loader = torch . utils . data . DataLoader ( dataset , batch_size = 32 )
data = next ( iter ( loader ))
print ( data )Verwenden Sie MultitfreCordDataset, um mehrere TFRECORD -Dateien zu lesen. Diese Klasse Beispiele aus angegebenen TFRECORD -Dateien mit gegebener Wahrscheinlichkeit.
import torch
from tfrecord . torch . dataset import MultiTFRecordDataset
tfrecord_pattern = "/tmp/{}.tfrecord"
index_pattern = "/tmp/{}.index"
splits = {
"dataset1" : 0.8 ,
"dataset2" : 0.2 ,
}
description = { "image" : "byte" , "label" : "int" }
dataset = MultiTFRecordDataset ( tfrecord_pattern , index_pattern , splits , description )
loader = torch . utils . data . DataLoader ( dataset , batch_size = 32 )
data = next ( iter ( loader ))
print ( data ) Standardmäßig ist MultiTFRecordDataset unendlich, was bedeutet, dass die Daten für immer abgetastet werden. Sie können es endlich machen, indem Sie die entsprechende Flagge bereitstellen
dataset = MultiTFRecordDataset(..., infinite=False)
Sowohl tFrecordDataset als auch MultitfreCordDataset mischen die Daten automatisch, wenn Sie eine Warteschlangengröße angeben.
dataset = TFRecordDataset(..., shuffle_queue_size=1024)
Sie können optional eine Funktion als transform übergeben, um vor der Rückkehr die Funktionen nach der Verarbeitung von Funktionen durchzuführen. Dies kann beispielsweise verwendet werden, um Bilder zu dekodieren oder Farben auf eine bestimmte Sequenz für Bereiche oder pad -variable Länge zu normalisieren.
import tfrecord
import cv2
def decode_image ( features ):
# get BGR image from bytes
features [ "image" ] = cv2 . imdecode ( features [ "image" ], - 1 )
return features
description = {
"image" : "bytes" ,
}
dataset = tfrecord . torch . TFRecordDataset ( "/tmp/data.tfrecord" ,
index_path = None ,
description = description ,
transform = decode_image )
data = next ( iter ( dataset ))
print ( data ) import tfrecord
writer = tfrecord . TFRecordWriter ( "/tmp/data.tfrecord" )
writer . write ({
"image" : ( image_bytes , "byte" ),
"label" : ( label , "float" ),
"index" : ( index , "int" )
})
writer . close () import tfrecord
loader = tfrecord . tfrecord_loader ( "/tmp/data.tfrecord" , None , {
"image" : "byte" ,
"label" : "float" ,
"index" : "int"
})
for record in loader :
print ( record [ "label" ]) SequenceExamples können mit denselben Methoden gelesen und geschrieben werden, die oben mit einem zusätzlichen Argument ( sequence_description für das Lesen und sequence_datum zum Schreiben) gezeigt werden, die die jeweiligen Lese-/Schreibfunktionen verursachen, um die Daten als Sequence -Beispiel zu behandeln.
import tfrecord
writer = tfrecord . TFRecordWriter ( "/tmp/data.tfrecord" )
writer . write ({ 'length' : ( 3 , 'int' ), 'label' : ( 1 , 'int' )},
{ 'tokens' : ([[ 0 , 0 , 1 ], [ 0 , 1 , 0 ], [ 1 , 0 , 0 ]], 'int' ), 'seq_labels' : ([ 0 , 1 , 1 ], 'int' )})
writer . write ({ 'length' : ( 3 , 'int' ), 'label' : ( 1 , 'int' )},
{ 'tokens' : ([[ 0 , 0 , 1 ], [ 1 , 0 , 0 ]], 'int' ), 'seq_labels' : ([ 0 , 1 ], 'int' )})
writer . close ()Lesen aus einer Sequenz, die ein Tupel mit zwei Elementen enthält.
import tfrecord
context_description = { "length" : "int" , "label" : "int" }
sequence_description = { "tokens" : "int" , "seq_labels" : "int" }
loader = tfrecord . tfrecord_loader ( "/tmp/data.tfrecord" , None ,
context_description ,
sequence_description = sequence_description )
for context , sequence_feats in loader :
print ( context [ "label" ])
print ( sequence_feats [ "seq_labels" ]) Wie im Abschnitt über die Transforming Input beschrieben, kann man eine Funktion als transform für die Durchführung der Nachverarbeitung von Merkmalen übergeben. Dies sollte insbesondere für die Sequenzfunktionen verwendet werden, da es sich um eine variable Längensequenz handelt und vor dem Batching ausgelöst werden muss.
import torch
import numpy as np
from tfrecord . torch . dataset import TFRecordDataset
PAD_WIDTH = 5
def pad_sequence_feats ( data ):
context , features = data
for k , v in features . items ():
features [ k ] = np . pad ( v , (( 0 , PAD_WIDTH - len ( v )), ( 0 , 0 )), 'constant' )
return ( context , features )
context_description = { "length" : "int" , "label" : "int" }
sequence_description = { "tokens" : "int " , "seq_labels" : "int" }
dataset = TFRecordDataset ( "/tmp/data.tfrecord" ,
index_path = None ,
description = context_description ,
transform = pad_sequence_feats ,
sequence_description = sequence_description )
loader = torch . utils . data . DataLoader ( dataset , batch_size = 32 )
data = next ( iter ( loader ))
print ( data ) Alternativ können Sie eine benutzerdefinierte collate_fn implementieren, um die Charge zusammenzustellen, um beispielsweise dynamische Polsterung durchzuführen.
import torch
import numpy as np
from tfrecord . torch . dataset import TFRecordDataset
def collate_fn ( batch ):
from torch . utils . data . _utils import collate
from torch . nn . utils import rnn
context , feats = zip ( * batch )
feats_ = { k : [ torch . Tensor ( d [ k ]) for d in feats ] for k in feats [ 0 ]}
return ( collate . default_collate ( context ),
{ k : rnn . pad_sequence ( f , True ) for ( k , f ) in feats_ . items ()})
context_description = { "length" : "int" , "label" : "int" }
sequence_description = { "tokens" : "int " , "seq_labels" : "int" }
dataset = TFRecordDataset ( "/tmp/data.tfrecord" ,
index_path = None ,
description = context_description ,
transform = pad_sequence_feats ,
sequence_description = sequence_description )
loader = torch . utils . data . DataLoader ( dataset , batch_size = 32 , collate_fn = collate_fn )
data = next ( iter ( loader ))
print ( data )