Java Exemples de codes sur la façon de charger le fichier de modèle TensorFlow et de prédire en fonction de ces fichiers de modèle pré-entraînés
Vous trouverez ci-dessous les codes de démonstration du CIFAR10ImageClassificateur qui charge le fichier de modèle CNN_CIFAR10.pb TensorFlow, et l'utilise pour faire la classification d'image:
package com . github . chen0040 . tensorflow . classifiers . demo ;
import com . github . chen0040 . tensorflow . classifiers . utils . ResourceUtils ;
import com . github . chen0040 . tensorflow . classifiers . cifar10 . Cifar10ImageClassifier ;
import org . slf4j . Logger ;
import org . slf4j . LoggerFactory ;
import java . awt . image . BufferedImage ;
import java . io . IOException ;
import java . io . InputStream ;
public class Cifar10ImageClassifierDemo {
private static final Logger logger = LoggerFactory . getLogger ( Cifar10ImageClassifierDemo . class );
public static void main ( String [] args ) throws IOException {
InputStream inputStream = ResourceUtils . getInputStream ( "tf_models/cnn_cifar10.pb" );
Cifar10ImageClassifier classifier = new Cifar10ImageClassifier ();
classifier . load_model ( inputStream );
String [] image_names = new String [] {
"airplane1" ,
"airplane2" ,
"airplane3" ,
"automobile1" ,
"automobile2" ,
"automobile3" ,
"bird1" ,
"bird2" ,
"bird3" ,
"cat1" ,
"cat2" ,
"cat3"
};
for ( String image_name : image_names ) {
String image_path = "images/cifar10/" + image_name + ".png" ;
BufferedImage img = ResourceUtils . getImage ( image_path );
String predicted_label = classifier . predict_image ( img );
logger . info ( "predicted class for {}: {}" , image_name , predicted_label );
}
}
}
Vous trouverez ci-dessous les codes de démonstration de la création de logagage-classificateur qui charge le fichier de modèle Tensorflow_inception_graph.pb TensorFlow, et l'utilise pour faire la classification d'images:
package com . github . chen0040 . tensorflow . classifiers . demo ;
import com . github . chen0040 . tensorflow . classifiers . inception . InceptionImageClassifier ;
import com . github . chen0040 . tensorflow . classifiers . utils . ResourceUtils ;
import org . slf4j . Logger ;
import org . slf4j . LoggerFactory ;
import java . awt . image . BufferedImage ;
import java . io . IOException ;
public class InceptionImageClassifierDemo {
private static final Logger logger = LoggerFactory . getLogger ( InceptionImageClassifierDemo . class );
public static void main ( String [] args ) throws IOException {
InceptionImageClassifier classifier = new InceptionImageClassifier ();
classifier . load_model ( ResourceUtils . getInputStream ( "tf_models/tensorflow_inception_graph.pb" ));
classifier . load_labels ( ResourceUtils . getInputStream ( "tf_models/imagenet_comp_graph_label_strings.txt" ));
String [] image_names = new String [] {
"tiger" ,
"lion"
};
for ( String image_name : image_names ) {
String image_path = "images/inception/" + image_name + ".jpg" ;
BufferedImage img = ResourceUtils . getImage ( image_path );
String predicted_label = classifier . predict_image ( img );
logger . info ( "predicted class for {}: {}" , image_name , predicted_label );
}
}
}
Vous trouverez ci-dessous les codes de démonstration du CNNSentimentClassifier qui charge le fichier de modèle TensorFlow WordVEC_CNN.PB, et l'utilise pour faire l'analyse du sentiment:
import com . github . chen0040 . tensorflow . classifiers . sentiment . models . CnnSentimentClassifier ;
import com . github . chen0040 . tensorflow . classifiers . sentiment . utils . ResourceUtils ;
import java . io . IOException ;
import java . util . List ;
public class CnnSentimentClassifierDemo {
public static void main ( String [] args ) throws IOException {
CnnSentimentClassifier classifier = new CnnSentimentClassifier ();
classifier . load_model ( ResourceUtils . getInputStream ( "tf_models/wordvec_cnn.pb" ));
classifier . load_vocab ( ResourceUtils . getInputStream ( "tf_models/wordvec_cnn.csv" ));
List < String > lines = ResourceUtils . getLines ( "data/umich-sentiment-train.txt" );
for ( String line : lines ){
String label = line . split ( " t " )[ 0 ];
String text = line . split ( " t " )[ 1 ];
float [] predicted = classifier . predict ( text );
String predicted_label = classifier . predict_label ( text );
System . out . println ( text );
System . out . println ( "Outcome: " + predicted [ 0 ] + ", " + predicted [ 1 ]);
System . out . println ( "Predicted: " + predicted_label + " Actual: " + label );
}
}
}
Vous trouverez ci-dessous les codes de démonstration de la bidirectionalstmSentimentClassifier qui charge le fichier de modèle TensorFlow WordVec_bidirectional_lstm.pb, et l'utilise pour faire l'analyse du sentiment:
import com . github . chen0040 . tensorflow . classifiers . sentiment . models . BidirectionalLstmSentimentClassifier ;
import com . github . chen0040 . tensorflow . classifiers . sentiment . utils . ResourceUtils ;
import java . io . IOException ;
import java . util . List ;
public class BidirectionalLstmSentimentClassifierDemo {
public static void main ( String [] args ) throws IOException {
BidirectionalLstmSentimentClassifier classifier = new BidirectionalLstmSentimentClassifier ();
classifier . load_model ( ResourceUtils . getInputStream ( "tf_models/bidirectional_lstm_softmax.pb" ));
classifier . load_vocab ( ResourceUtils . getInputStream ( "tf_models/bidirectional_lstm_softmax.csv" ));
List < String > lines = ResourceUtils . getLines ( "data/umich-sentiment-train.txt" );
for ( String line : lines ){
String label = line . split ( " t " )[ 0 ];
String text = line . split ( " t " )[ 1 ];
float [] predicted = classifier . predict ( text );
String predicted_label = classifier . predict_label ( text );
System . out . println ( text );
System . out . println ( "Outcome: " + predicted [ 0 ] + ", " + predicted [ 1 ]);
System . out . println ( "Predicted: " + predicted_label + " Actual: " + label );
}
}
}
Ci-dessous, montrez les codes de démonstration du CIFAR10Audioclassifiant qui charge le fichier de modèle CIFAR10.pb TensorFlow, et l'utilise pour faire la prédiction des genres de musique:
import com . github . chen0040 . tensorflow . classifiers . audio . models . cifar10 . Cifar10AudioClassifier ;
import com . github . chen0040 . tensorflow . classifiers . audio . utils . ResourceUtils ;
import org . slf4j . Logger ;
import org . slf4j . LoggerFactory ;
import java . io . File ;
import java . io . IOException ;
import java . io . InputStream ;
import java . util . ArrayList ;
import java . util . Collections ;
import java . util . List ;
public class Cifar10AudioClassifierDemo {
private static final Logger logger = LoggerFactory . getLogger ( Cifar10AudioClassifierDemo . class );
private static List < String > getAudioFiles () {
List < String > result = new ArrayList <>();
File file = new File ( "gtzan/genres" );
System . out . println ( file . getAbsolutePath ());
if ( file . isDirectory ()) {
for ( File class_folder : file . listFiles ()) {
if ( class_folder . isDirectory ()) {
for ( File f : class_folder . listFiles ()) {
String file_path = f . getAbsolutePath ();
if ( file_path . endsWith ( "au" )) {
result . add ( file_path );
}
}
}
}
}
return result ;
}
public static void main ( String [] args ) throws IOException {
InputStream inputStream = ResourceUtils . getInputStream ( "tf_models/cifar10.pb" );
Cifar10AudioClassifier classifier = new Cifar10AudioClassifier ();
classifier . load_model ( inputStream );
List < String > paths = getAudioFiles ();
Collections . shuffle ( paths );
for ( String path : paths ) {
System . out . println ( "Predicting " + path + " ..." );
File f = new File ( path );
String label = classifier . predict_audio ( f );
System . out . println ( "Predicted: " + label );
}
}
}
Ci-dessous, affichez les codes de démonstration du ResNetV2AudioClassifier qui charge le fichier de modèle Resnet-V2.pb TensorFlow, et l'utilise pour faire la prédiction des genres de musique:
import com . github . chen0040 . tensorflow . classifiers . audio . models . resnet . ResNetV2AudioClassifier ;
import com . github . chen0040 . tensorflow . classifiers . audio . utils . ResourceUtils ;
import org . slf4j . Logger ;
import org . slf4j . LoggerFactory ;
import java . io . File ;
import java . io . IOException ;
import java . io . InputStream ;
import java . util . ArrayList ;
import java . util . Collections ;
import java . util . List ;
public class ResNetV2AudioClassifierDemo {
private static final Logger logger = LoggerFactory . getLogger ( ResNetV2AudioClassifierDemo . class );
private static List < String > getAudioFiles () {
List < String > result = new ArrayList <>();
File dir = new File ( "music_samples" );
System . out . println ( dir . getAbsolutePath ());
if ( dir . isDirectory ()) {
for ( File f : dir . listFiles ()) {
String file_path = f . getAbsolutePath ();
if ( file_path . endsWith ( "au" )) {
result . add ( file_path );
}
}
}
return result ;
}
public static void main ( String [] args ) throws IOException {
InputStream inputStream = ResourceUtils . getInputStream ( "tf_models/resnet-v2.pb" );
ResNetV2AudioClassifier classifier = new ResNetV2AudioClassifier ();
classifier . load_model ( inputStream );
List < String > paths = getAudioFiles ();
Collections . shuffle ( paths );
for ( String path : paths ) {
System . out . println ( "Predicting " + path + " ..." );
File f = new File ( path );
String label = classifier . predict_audio ( f );
System . out . println ( "Predicted: " + label );
}
}
}
L'exemple de codes ci-dessous montre comment indexer et rechercher un fichier audio à l'aide de la classe AudiOSearchEngine:
AudioSearchEngine searchEngine = new AudioSearchEngine ();
if (! searchEngine . loadIndexDbIfExists ()) {
searchEngine . indexAll ( new File ( "music_samples" ). listFiles ());
searchEngine . saveIndexDb ();
}
int pageIndex = 0 ;
int pageSize = 20 ;
boolean skipPerfectMatch = true ;
for ( File f : new File ( "music_samples" ). listFiles ()) {
System . out . println ( "querying similar music to " + f . getName ());
List < AudioSearchEntry > result = searchEngine . query ( f , pageIndex , pageSize , skipPerfectMatch );
for ( int i = 0 ; i < result . size (); ++ i ){
System . out . println ( "# " + i + ": " + result . get ( i ). getPath () + " (distSq: " + result . get ( i ). getDistanceSq () + ")" );
}
}
L'exemple de codes ci-dessous montre comment recommander des musiques en fonction de l'historique de la musique de l'utilisateur à l'aide de la classe KnnaudioreCommern:
AudioUserHistory userHistory = new AudioUserHistory ();
List < String > audioFiles = FileUtils . getAudioFiles ();
Collections . shuffle ( audioFiles );
for ( int i = 0 ; i < 40 ; ++ i ){
String filePath = audioFiles . get ( i );
userHistory . logAudio ( filePath );
try {
Thread . sleep ( 100L );
} catch ( InterruptedException e ) {
e . printStackTrace ();
}
}
KnnAudioRecommender recommender = new KnnAudioRecommender ();
if (! recommender . loadIndexDbIfExists ()) {
recommender . indexAll ( new File ( "music_samples" ). listFiles ( a -> a . getAbsolutePath (). toLowerCase (). endsWith ( ".au" )));
recommender . saveIndexDb ();
}
System . out . println ( userHistory . head ( 10 ));
int k = 10 ;
List < AudioSearchEntry > result = recommender . recommends ( userHistory . getHistory (), k );
for ( int i = 0 ; i < result . size (); ++ i ){
AudioSearchEntry entry = result . get ( i );
System . out . println ( "Search Result #" + ( i + 1 ) + ": " + entry . getPath ());
}