Tensorflow事前処理されたモデルファイルをロードし、これらの前処理されたモデルファイルに基づいて予測する方法に関するJavaサンプルコード
以下に、CNN_CIFAR10.PB TensorFlowモデルファイルをロードし、画像分類を実行するために使用するCIFAR10IMAGECLASSIFIERのデモコードを示します。
package com . github . chen0040 . tensorflow . classifiers . demo ;
import com . github . chen0040 . tensorflow . classifiers . utils . ResourceUtils ;
import com . github . chen0040 . tensorflow . classifiers . cifar10 . Cifar10ImageClassifier ;
import org . slf4j . Logger ;
import org . slf4j . LoggerFactory ;
import java . awt . image . BufferedImage ;
import java . io . IOException ;
import java . io . InputStream ;
public class Cifar10ImageClassifierDemo {
private static final Logger logger = LoggerFactory . getLogger ( Cifar10ImageClassifierDemo . class );
public static void main ( String [] args ) throws IOException {
InputStream inputStream = ResourceUtils . getInputStream ( "tf_models/cnn_cifar10.pb" );
Cifar10ImageClassifier classifier = new Cifar10ImageClassifier ();
classifier . load_model ( inputStream );
String [] image_names = new String [] {
"airplane1" ,
"airplane2" ,
"airplane3" ,
"automobile1" ,
"automobile2" ,
"automobile3" ,
"bird1" ,
"bird2" ,
"bird3" ,
"cat1" ,
"cat2" ,
"cat3"
};
for ( String image_name : image_names ) {
String image_path = "images/cifar10/" + image_name + ".png" ;
BufferedImage img = ResourceUtils . getImage ( image_path );
String predicted_label = classifier . predict_image ( img );
logger . info ( "predicted class for {}: {}" , image_name , predicted_label );
}
}
}以下に、tensorflow_inception_graph.pb tensorflowモデルファイルをロードし、画像分類を実行するために使用するInceptionimageClassifierのデモコードを示します。
package com . github . chen0040 . tensorflow . classifiers . demo ;
import com . github . chen0040 . tensorflow . classifiers . inception . InceptionImageClassifier ;
import com . github . chen0040 . tensorflow . classifiers . utils . ResourceUtils ;
import org . slf4j . Logger ;
import org . slf4j . LoggerFactory ;
import java . awt . image . BufferedImage ;
import java . io . IOException ;
public class InceptionImageClassifierDemo {
private static final Logger logger = LoggerFactory . getLogger ( InceptionImageClassifierDemo . class );
public static void main ( String [] args ) throws IOException {
InceptionImageClassifier classifier = new InceptionImageClassifier ();
classifier . load_model ( ResourceUtils . getInputStream ( "tf_models/tensorflow_inception_graph.pb" ));
classifier . load_labels ( ResourceUtils . getInputStream ( "tf_models/imagenet_comp_graph_label_strings.txt" ));
String [] image_names = new String [] {
"tiger" ,
"lion"
};
for ( String image_name : image_names ) {
String image_path = "images/inception/" + image_name + ".jpg" ;
BufferedImage img = ResourceUtils . getImage ( image_path );
String predicted_label = classifier . predict_image ( img );
logger . info ( "predicted class for {}: {}" , image_name , predicted_label );
}
}
}以下に、WordVec_cnn.pb Tensorflowモデルファイルをロードし、センチメント分析を行うために使用しているCnnsentimentClassifierのデモコードを示します。
import com . github . chen0040 . tensorflow . classifiers . sentiment . models . CnnSentimentClassifier ;
import com . github . chen0040 . tensorflow . classifiers . sentiment . utils . ResourceUtils ;
import java . io . IOException ;
import java . util . List ;
public class CnnSentimentClassifierDemo {
public static void main ( String [] args ) throws IOException {
CnnSentimentClassifier classifier = new CnnSentimentClassifier ();
classifier . load_model ( ResourceUtils . getInputStream ( "tf_models/wordvec_cnn.pb" ));
classifier . load_vocab ( ResourceUtils . getInputStream ( "tf_models/wordvec_cnn.csv" ));
List < String > lines = ResourceUtils . getLines ( "data/umich-sentiment-train.txt" );
for ( String line : lines ){
String label = line . split ( " t " )[ 0 ];
String text = line . split ( " t " )[ 1 ];
float [] predicted = classifier . predict ( text );
String predicted_label = classifier . predict_label ( text );
System . out . println ( text );
System . out . println ( "Outcome: " + predicted [ 0 ] + ", " + predicted [ 1 ]);
System . out . println ( "Predicted: " + predicted_label + " Actual: " + label );
}
}
}
以下に、WordVec_Bidirectional_lstm.pb Tensorflowモデルファイルをロードし、センチメント分析を行うために使用するBidirectionAllStMESTMENTIMENTCLASSIFIERのデモコードを示します。
import com . github . chen0040 . tensorflow . classifiers . sentiment . models . BidirectionalLstmSentimentClassifier ;
import com . github . chen0040 . tensorflow . classifiers . sentiment . utils . ResourceUtils ;
import java . io . IOException ;
import java . util . List ;
public class BidirectionalLstmSentimentClassifierDemo {
public static void main ( String [] args ) throws IOException {
BidirectionalLstmSentimentClassifier classifier = new BidirectionalLstmSentimentClassifier ();
classifier . load_model ( ResourceUtils . getInputStream ( "tf_models/bidirectional_lstm_softmax.pb" ));
classifier . load_vocab ( ResourceUtils . getInputStream ( "tf_models/bidirectional_lstm_softmax.csv" ));
List < String > lines = ResourceUtils . getLines ( "data/umich-sentiment-train.txt" );
for ( String line : lines ){
String label = line . split ( " t " )[ 0 ];
String text = line . split ( " t " )[ 1 ];
float [] predicted = classifier . predict ( text );
String predicted_label = classifier . predict_label ( text );
System . out . println ( text );
System . out . println ( "Outcome: " + predicted [ 0 ] + ", " + predicted [ 1 ]);
System . out . println ( "Predicted: " + predicted_label + " Actual: " + label );
}
}
}以下に、CIFAR10.PB TensorFlowモデルファイルをロードし、音楽ジャンルの予測を行うために使用するCIFAR10AUDIOCLASSIFIERのデモコードを示します。
import com . github . chen0040 . tensorflow . classifiers . audio . models . cifar10 . Cifar10AudioClassifier ;
import com . github . chen0040 . tensorflow . classifiers . audio . utils . ResourceUtils ;
import org . slf4j . Logger ;
import org . slf4j . LoggerFactory ;
import java . io . File ;
import java . io . IOException ;
import java . io . InputStream ;
import java . util . ArrayList ;
import java . util . Collections ;
import java . util . List ;
public class Cifar10AudioClassifierDemo {
private static final Logger logger = LoggerFactory . getLogger ( Cifar10AudioClassifierDemo . class );
private static List < String > getAudioFiles () {
List < String > result = new ArrayList <>();
File file = new File ( "gtzan/genres" );
System . out . println ( file . getAbsolutePath ());
if ( file . isDirectory ()) {
for ( File class_folder : file . listFiles ()) {
if ( class_folder . isDirectory ()) {
for ( File f : class_folder . listFiles ()) {
String file_path = f . getAbsolutePath ();
if ( file_path . endsWith ( "au" )) {
result . add ( file_path );
}
}
}
}
}
return result ;
}
public static void main ( String [] args ) throws IOException {
InputStream inputStream = ResourceUtils . getInputStream ( "tf_models/cifar10.pb" );
Cifar10AudioClassifier classifier = new Cifar10AudioClassifier ();
classifier . load_model ( inputStream );
List < String > paths = getAudioFiles ();
Collections . shuffle ( paths );
for ( String path : paths ) {
System . out . println ( "Predicting " + path + " ..." );
File f = new File ( path );
String label = classifier . predict_audio ( f );
System . out . println ( "Predicted: " + label );
}
}
}以下に、ResNet-V2.PB TensorFlowモデルファイルをロードし、音楽ジャンルの予測を実行するために使用するResnetv2AudioClassifierのデモコードを示します。
import com . github . chen0040 . tensorflow . classifiers . audio . models . resnet . ResNetV2AudioClassifier ;
import com . github . chen0040 . tensorflow . classifiers . audio . utils . ResourceUtils ;
import org . slf4j . Logger ;
import org . slf4j . LoggerFactory ;
import java . io . File ;
import java . io . IOException ;
import java . io . InputStream ;
import java . util . ArrayList ;
import java . util . Collections ;
import java . util . List ;
public class ResNetV2AudioClassifierDemo {
private static final Logger logger = LoggerFactory . getLogger ( ResNetV2AudioClassifierDemo . class );
private static List < String > getAudioFiles () {
List < String > result = new ArrayList <>();
File dir = new File ( "music_samples" );
System . out . println ( dir . getAbsolutePath ());
if ( dir . isDirectory ()) {
for ( File f : dir . listFiles ()) {
String file_path = f . getAbsolutePath ();
if ( file_path . endsWith ( "au" )) {
result . add ( file_path );
}
}
}
return result ;
}
public static void main ( String [] args ) throws IOException {
InputStream inputStream = ResourceUtils . getInputStream ( "tf_models/resnet-v2.pb" );
ResNetV2AudioClassifier classifier = new ResNetV2AudioClassifier ();
classifier . load_model ( inputStream );
List < String > paths = getAudioFiles ();
Collections . shuffle ( paths );
for ( String path : paths ) {
System . out . println ( "Predicting " + path + " ..." );
File f = new File ( path );
String label = classifier . predict_audio ( f );
System . out . println ( "Predicted: " + label );
}
}
}以下のサンプルコードは、AudioSearchEngineクラスを使用してオーディオファイルをインデックス付けおよび検索する方法を示しています。
AudioSearchEngine searchEngine = new AudioSearchEngine ();
if (! searchEngine . loadIndexDbIfExists ()) {
searchEngine . indexAll ( new File ( "music_samples" ). listFiles ());
searchEngine . saveIndexDb ();
}
int pageIndex = 0 ;
int pageSize = 20 ;
boolean skipPerfectMatch = true ;
for ( File f : new File ( "music_samples" ). listFiles ()) {
System . out . println ( "querying similar music to " + f . getName ());
List < AudioSearchEntry > result = searchEngine . query ( f , pageIndex , pageSize , skipPerfectMatch );
for ( int i = 0 ; i < result . size (); ++ i ){
System . out . println ( "# " + i + ": " + result . get ( i ). getPath () + " (distSq: " + result . get ( i ). getDistanceSq () + ")" );
}
}以下のサンプルコードは、KnnaudioreCommenderクラスを使用して、ユーザーの音楽履歴に基づいてMusicsを推奨する方法を示しています。
AudioUserHistory userHistory = new AudioUserHistory ();
List < String > audioFiles = FileUtils . getAudioFiles ();
Collections . shuffle ( audioFiles );
for ( int i = 0 ; i < 40 ; ++ i ){
String filePath = audioFiles . get ( i );
userHistory . logAudio ( filePath );
try {
Thread . sleep ( 100L );
} catch ( InterruptedException e ) {
e . printStackTrace ();
}
}
KnnAudioRecommender recommender = new KnnAudioRecommender ();
if (! recommender . loadIndexDbIfExists ()) {
recommender . indexAll ( new File ( "music_samples" ). listFiles ( a -> a . getAbsolutePath (). toLowerCase (). endsWith ( ".au" )));
recommender . saveIndexDb ();
}
System . out . println ( userHistory . head ( 10 ));
int k = 10 ;
List < AudioSearchEntry > result = recommender . recommends ( userHistory . getHistory (), k );
for ( int i = 0 ; i < result . size (); ++ i ){
AudioSearchEntry entry = result . get ( i );
System . out . println ( "Search Result #" + ( i + 1 ) + ": " + entry . getPath ());
}