Cette bibliothèque permet de lire et d'écrire des fichiers tfrecord efficacement dans Python. La bibliothèque fournit également un lecteur IterableDataSet des fichiers tfrecord pour pytorch. Les tfrecords GZIP non compressés et compressés sont actuellement pris en charge.
pip3 install 'tfrecord[torch]'
Il est recommandé de créer un fichier d'index pour chaque fichier tfrecord. Le fichier d'index doit être fourni lors de l'utilisation de plusieurs travailleurs, sinon le chargeur peut renvoyer des enregistrements en double. Vous pouvez créer un fichier d'index pour un fichier tfrecord individuel avec ce programme d'utilité:
python3 -m tfrecord.tools.tfrecord2idx <tfrecord path> <index path>
Pour créer des fichiers ".tfidnex" pour tous " .tfrecord" Fichiers dans un répertoire exécuté:
tfrecord2idx <data dir>
Utilisez tfrecordDataset pour lire les fichiers tfrecord dans pytorch.
import torch
from tfrecord . torch . dataset import TFRecordDataset
tfrecord_path = "/tmp/data.tfrecord"
index_path = None
description = { "image" : "byte" , "label" : "float" }
dataset = TFRecordDataset ( tfrecord_path , index_path , description )
loader = torch . utils . data . DataLoader ( dataset , batch_size = 32 )
data = next ( iter ( loader ))
print ( data )Utilisez MultitFrecordDataset pour lire plusieurs fichiers TFrecord. Cette classe échantillonne à partir de fichiers tfrecord donnés avec une probabilité donnée.
import torch
from tfrecord . torch . dataset import MultiTFRecordDataset
tfrecord_pattern = "/tmp/{}.tfrecord"
index_pattern = "/tmp/{}.index"
splits = {
"dataset1" : 0.8 ,
"dataset2" : 0.2 ,
}
description = { "image" : "byte" , "label" : "int" }
dataset = MultiTFRecordDataset ( tfrecord_pattern , index_pattern , splits , description )
loader = torch . utils . data . DataLoader ( dataset , batch_size = 32 )
data = next ( iter ( loader ))
print ( data ) Par défaut, MultiTFRecordDataset est infini, ce qui signifie qu'il échantillonne pour toujours les données. Vous pouvez le rendre fini en fournissant le drapeau approprié
dataset = MultiTFRecordDataset(..., infinite=False)
TfrecordDataset et MultitfrecordDataset mélangent automatiquement les données lorsque vous fournissez une taille de file d'attente.
dataset = TFRecordDataset(..., shuffle_queue_size=1024)
Vous pouvez éventuellement transmettre une fonction en tant qu'argument transform pour effectuer le post-traitement des fonctionnalités avant de retourner. Cela peut par exemple être utilisé pour décoder des images ou normaliser les couleurs à une certaine séquence de longueur variable de plage ou de pad.
import tfrecord
import cv2
def decode_image ( features ):
# get BGR image from bytes
features [ "image" ] = cv2 . imdecode ( features [ "image" ], - 1 )
return features
description = {
"image" : "bytes" ,
}
dataset = tfrecord . torch . TFRecordDataset ( "/tmp/data.tfrecord" ,
index_path = None ,
description = description ,
transform = decode_image )
data = next ( iter ( dataset ))
print ( data ) import tfrecord
writer = tfrecord . TFRecordWriter ( "/tmp/data.tfrecord" )
writer . write ({
"image" : ( image_bytes , "byte" ),
"label" : ( label , "float" ),
"index" : ( index , "int" )
})
writer . close () import tfrecord
loader = tfrecord . tfrecord_loader ( "/tmp/data.tfrecord" , None , {
"image" : "byte" ,
"label" : "float" ,
"index" : "int"
})
for record in loader :
print ( record [ "label" ]) Sequence Examples peut être lue et écrite en utilisant les mêmes méthodes ci-dessus avec un argument supplémentaire ( sequence_description pour la lecture et sequence_datum pour l'écriture) qui provoquent les fonctions de lecture / écriture respectives pour traiter les données comme un échantillon de séquence.
import tfrecord
writer = tfrecord . TFRecordWriter ( "/tmp/data.tfrecord" )
writer . write ({ 'length' : ( 3 , 'int' ), 'label' : ( 1 , 'int' )},
{ 'tokens' : ([[ 0 , 0 , 1 ], [ 0 , 1 , 0 ], [ 1 , 0 , 0 ]], 'int' ), 'seq_labels' : ([ 0 , 1 , 1 ], 'int' )})
writer . write ({ 'length' : ( 3 , 'int' ), 'label' : ( 1 , 'int' )},
{ 'tokens' : ([[ 0 , 0 , 1 ], [ 1 , 0 , 0 ]], 'int' ), 'seq_labels' : ([ 0 , 1 ], 'int' )})
writer . close ()Lire à partir d'une séquence exemple Yeilds Un tuple contenant deux éléments.
import tfrecord
context_description = { "length" : "int" , "label" : "int" }
sequence_description = { "tokens" : "int" , "seq_labels" : "int" }
loader = tfrecord . tfrecord_loader ( "/tmp/data.tfrecord" , None ,
context_description ,
sequence_description = sequence_description )
for context , sequence_feats in loader :
print ( context [ "label" ])
print ( sequence_feats [ "seq_labels" ]) Comme décrit dans la section sur Transforming Input , on peut passer une fonction comme l'argument transform pour effectuer le post-traitement des fonctionnalités. Ceci doit être utilisé spécialement pour les caractéristiques de séquence car ce sont des séquences de longueur variable et doivent être rembourrées avant d'être ldée.
import torch
import numpy as np
from tfrecord . torch . dataset import TFRecordDataset
PAD_WIDTH = 5
def pad_sequence_feats ( data ):
context , features = data
for k , v in features . items ():
features [ k ] = np . pad ( v , (( 0 , PAD_WIDTH - len ( v )), ( 0 , 0 )), 'constant' )
return ( context , features )
context_description = { "length" : "int" , "label" : "int" }
sequence_description = { "tokens" : "int " , "seq_labels" : "int" }
dataset = TFRecordDataset ( "/tmp/data.tfrecord" ,
index_path = None ,
description = context_description ,
transform = pad_sequence_feats ,
sequence_description = sequence_description )
loader = torch . utils . data . DataLoader ( dataset , batch_size = 32 )
data = next ( iter ( loader ))
print ( data ) Alternativement, vous pouvez choisir d'implémenter un collate_fn personnalisé afin d'assembler le lot, par exemple, pour effectuer un rembourrage dynamique.
import torch
import numpy as np
from tfrecord . torch . dataset import TFRecordDataset
def collate_fn ( batch ):
from torch . utils . data . _utils import collate
from torch . nn . utils import rnn
context , feats = zip ( * batch )
feats_ = { k : [ torch . Tensor ( d [ k ]) for d in feats ] for k in feats [ 0 ]}
return ( collate . default_collate ( context ),
{ k : rnn . pad_sequence ( f , True ) for ( k , f ) in feats_ . items ()})
context_description = { "length" : "int" , "label" : "int" }
sequence_description = { "tokens" : "int " , "seq_labels" : "int" }
dataset = TFRecordDataset ( "/tmp/data.tfrecord" ,
index_path = None ,
description = context_description ,
transform = pad_sequence_feats ,
sequence_description = sequence_description )
loader = torch . utils . data . DataLoader ( dataset , batch_size = 32 , collate_fn = collate_fn )
data = next ( iter ( loader ))
print ( data )