Tensorflow 후드 프리릿 모델 파일을로드하는 방법에 대한 Java 샘플 코드 및 이러한 사기꾼 모델 파일을 기반으로 예측합니다.
아래는 CNN_CIFAR10.pb TensorFlow 모델 파일을로드하고 이미지 분류를 수행하는 CIFAR10ImageClassifier의 데모 코드를 보여줍니다.
package com . github . chen0040 . tensorflow . classifiers . demo ;
import com . github . chen0040 . tensorflow . classifiers . utils . ResourceUtils ;
import com . github . chen0040 . tensorflow . classifiers . cifar10 . Cifar10ImageClassifier ;
import org . slf4j . Logger ;
import org . slf4j . LoggerFactory ;
import java . awt . image . BufferedImage ;
import java . io . IOException ;
import java . io . InputStream ;
public class Cifar10ImageClassifierDemo {
private static final Logger logger = LoggerFactory . getLogger ( Cifar10ImageClassifierDemo . class );
public static void main ( String [] args ) throws IOException {
InputStream inputStream = ResourceUtils . getInputStream ( "tf_models/cnn_cifar10.pb" );
Cifar10ImageClassifier classifier = new Cifar10ImageClassifier ();
classifier . load_model ( inputStream );
String [] image_names = new String [] {
"airplane1" ,
"airplane2" ,
"airplane3" ,
"automobile1" ,
"automobile2" ,
"automobile3" ,
"bird1" ,
"bird2" ,
"bird3" ,
"cat1" ,
"cat2" ,
"cat3"
};
for ( String image_name : image_names ) {
String image_path = "images/cifar10/" + image_name + ".png" ;
BufferedImage img = ResourceUtils . getImage ( image_path );
String predicted_label = classifier . predict_image ( img );
logger . info ( "predicted class for {}: {}" , image_name , predicted_label );
}
}
}아래는 tensorflow_inception_graph.pb tensorflow model 파일을로드하고 이미지 분류를 수행하는 데 사용하는 InceptionImageClassifier의 데모 코드를 보여줍니다.
package com . github . chen0040 . tensorflow . classifiers . demo ;
import com . github . chen0040 . tensorflow . classifiers . inception . InceptionImageClassifier ;
import com . github . chen0040 . tensorflow . classifiers . utils . ResourceUtils ;
import org . slf4j . Logger ;
import org . slf4j . LoggerFactory ;
import java . awt . image . BufferedImage ;
import java . io . IOException ;
public class InceptionImageClassifierDemo {
private static final Logger logger = LoggerFactory . getLogger ( InceptionImageClassifierDemo . class );
public static void main ( String [] args ) throws IOException {
InceptionImageClassifier classifier = new InceptionImageClassifier ();
classifier . load_model ( ResourceUtils . getInputStream ( "tf_models/tensorflow_inception_graph.pb" ));
classifier . load_labels ( ResourceUtils . getInputStream ( "tf_models/imagenet_comp_graph_label_strings.txt" ));
String [] image_names = new String [] {
"tiger" ,
"lion"
};
for ( String image_name : image_names ) {
String image_path = "images/inception/" + image_name + ".jpg" ;
BufferedImage img = ResourceUtils . getImage ( image_path );
String predicted_label = classifier . predict_image ( img );
logger . info ( "predicted class for {}: {}" , image_name , predicted_label );
}
}
}아래는 WordVec_cnn.pb TensorFlow 모델 파일을로드하고 감정 분석을 수행하는 데 사용하는 CNNSentimentClassifier의 데모 코드를 표시합니다.
import com . github . chen0040 . tensorflow . classifiers . sentiment . models . CnnSentimentClassifier ;
import com . github . chen0040 . tensorflow . classifiers . sentiment . utils . ResourceUtils ;
import java . io . IOException ;
import java . util . List ;
public class CnnSentimentClassifierDemo {
public static void main ( String [] args ) throws IOException {
CnnSentimentClassifier classifier = new CnnSentimentClassifier ();
classifier . load_model ( ResourceUtils . getInputStream ( "tf_models/wordvec_cnn.pb" ));
classifier . load_vocab ( ResourceUtils . getInputStream ( "tf_models/wordvec_cnn.csv" ));
List < String > lines = ResourceUtils . getLines ( "data/umich-sentiment-train.txt" );
for ( String line : lines ){
String label = line . split ( " t " )[ 0 ];
String text = line . split ( " t " )[ 1 ];
float [] predicted = classifier . predict ( text );
String predicted_label = classifier . predict_label ( text );
System . out . println ( text );
System . out . println ( "Outcome: " + predicted [ 0 ] + ", " + predicted [ 1 ]);
System . out . println ( "Predicted: " + predicted_label + " Actual: " + label );
}
}
}
아래는 WordVec_bidirectional_lstm.pb Tensorflow 모델 파일을로드하고 감정 분석을 수행하는 데 사용하는 BidilectionAllStmestentimentClassifier의 데모 코드를 표시합니다.
import com . github . chen0040 . tensorflow . classifiers . sentiment . models . BidirectionalLstmSentimentClassifier ;
import com . github . chen0040 . tensorflow . classifiers . sentiment . utils . ResourceUtils ;
import java . io . IOException ;
import java . util . List ;
public class BidirectionalLstmSentimentClassifierDemo {
public static void main ( String [] args ) throws IOException {
BidirectionalLstmSentimentClassifier classifier = new BidirectionalLstmSentimentClassifier ();
classifier . load_model ( ResourceUtils . getInputStream ( "tf_models/bidirectional_lstm_softmax.pb" ));
classifier . load_vocab ( ResourceUtils . getInputStream ( "tf_models/bidirectional_lstm_softmax.csv" ));
List < String > lines = ResourceUtils . getLines ( "data/umich-sentiment-train.txt" );
for ( String line : lines ){
String label = line . split ( " t " )[ 0 ];
String text = line . split ( " t " )[ 1 ];
float [] predicted = classifier . predict ( text );
String predicted_label = classifier . predict_label ( text );
System . out . println ( text );
System . out . println ( "Outcome: " + predicted [ 0 ] + ", " + predicted [ 1 ]);
System . out . println ( "Predicted: " + predicted_label + " Actual: " + label );
}
}
}아래는 cifar10.pb tensorflow 모델 파일을로드하고 음악 장르 예측을 수행하는 데 사용하는 Cifar10audioclassifier의 데모 코드를 보여줍니다.
import com . github . chen0040 . tensorflow . classifiers . audio . models . cifar10 . Cifar10AudioClassifier ;
import com . github . chen0040 . tensorflow . classifiers . audio . utils . ResourceUtils ;
import org . slf4j . Logger ;
import org . slf4j . LoggerFactory ;
import java . io . File ;
import java . io . IOException ;
import java . io . InputStream ;
import java . util . ArrayList ;
import java . util . Collections ;
import java . util . List ;
public class Cifar10AudioClassifierDemo {
private static final Logger logger = LoggerFactory . getLogger ( Cifar10AudioClassifierDemo . class );
private static List < String > getAudioFiles () {
List < String > result = new ArrayList <>();
File file = new File ( "gtzan/genres" );
System . out . println ( file . getAbsolutePath ());
if ( file . isDirectory ()) {
for ( File class_folder : file . listFiles ()) {
if ( class_folder . isDirectory ()) {
for ( File f : class_folder . listFiles ()) {
String file_path = f . getAbsolutePath ();
if ( file_path . endsWith ( "au" )) {
result . add ( file_path );
}
}
}
}
}
return result ;
}
public static void main ( String [] args ) throws IOException {
InputStream inputStream = ResourceUtils . getInputStream ( "tf_models/cifar10.pb" );
Cifar10AudioClassifier classifier = new Cifar10AudioClassifier ();
classifier . load_model ( inputStream );
List < String > paths = getAudioFiles ();
Collections . shuffle ( paths );
for ( String path : paths ) {
System . out . println ( "Predicting " + path + " ..." );
File f = new File ( path );
String label = classifier . predict_audio ( f );
System . out . println ( "Predicted: " + label );
}
}
}아래는 RESNET-v2.pb 텐서 플로우 모델 파일을로드하고 음악 장르 예측을 수행하는 데 사용하는 RESNETV2AudioClassifier의 데모 코드를 보여줍니다.
import com . github . chen0040 . tensorflow . classifiers . audio . models . resnet . ResNetV2AudioClassifier ;
import com . github . chen0040 . tensorflow . classifiers . audio . utils . ResourceUtils ;
import org . slf4j . Logger ;
import org . slf4j . LoggerFactory ;
import java . io . File ;
import java . io . IOException ;
import java . io . InputStream ;
import java . util . ArrayList ;
import java . util . Collections ;
import java . util . List ;
public class ResNetV2AudioClassifierDemo {
private static final Logger logger = LoggerFactory . getLogger ( ResNetV2AudioClassifierDemo . class );
private static List < String > getAudioFiles () {
List < String > result = new ArrayList <>();
File dir = new File ( "music_samples" );
System . out . println ( dir . getAbsolutePath ());
if ( dir . isDirectory ()) {
for ( File f : dir . listFiles ()) {
String file_path = f . getAbsolutePath ();
if ( file_path . endsWith ( "au" )) {
result . add ( file_path );
}
}
}
return result ;
}
public static void main ( String [] args ) throws IOException {
InputStream inputStream = ResourceUtils . getInputStream ( "tf_models/resnet-v2.pb" );
ResNetV2AudioClassifier classifier = new ResNetV2AudioClassifier ();
classifier . load_model ( inputStream );
List < String > paths = getAudioFiles ();
Collections . shuffle ( paths );
for ( String path : paths ) {
System . out . println ( "Predicting " + path + " ..." );
File f = new File ( path );
String label = classifier . predict_audio ( f );
System . out . println ( "Predicted: " + label );
}
}
}아래 샘플 코드는 AudiosearchEngine 클래스를 사용하여 오디오 파일을 색인 및 검색하는 방법을 보여줍니다.
AudioSearchEngine searchEngine = new AudioSearchEngine ();
if (! searchEngine . loadIndexDbIfExists ()) {
searchEngine . indexAll ( new File ( "music_samples" ). listFiles ());
searchEngine . saveIndexDb ();
}
int pageIndex = 0 ;
int pageSize = 20 ;
boolean skipPerfectMatch = true ;
for ( File f : new File ( "music_samples" ). listFiles ()) {
System . out . println ( "querying similar music to " + f . getName ());
List < AudioSearchEntry > result = searchEngine . query ( f , pageIndex , pageSize , skipPerfectMatch );
for ( int i = 0 ; i < result . size (); ++ i ){
System . out . println ( "# " + i + ": " + result . get ( i ). getPath () + " (distSq: " + result . get ( i ). getDistanceSq () + ")" );
}
}아래 샘플 코드는 Knnaudiorecommender 클래스를 사용하여 사용자의 음악 기록을 기반으로 Musics를 추천하는 방법을 보여줍니다.
AudioUserHistory userHistory = new AudioUserHistory ();
List < String > audioFiles = FileUtils . getAudioFiles ();
Collections . shuffle ( audioFiles );
for ( int i = 0 ; i < 40 ; ++ i ){
String filePath = audioFiles . get ( i );
userHistory . logAudio ( filePath );
try {
Thread . sleep ( 100L );
} catch ( InterruptedException e ) {
e . printStackTrace ();
}
}
KnnAudioRecommender recommender = new KnnAudioRecommender ();
if (! recommender . loadIndexDbIfExists ()) {
recommender . indexAll ( new File ( "music_samples" ). listFiles ( a -> a . getAbsolutePath (). toLowerCase (). endsWith ( ".au" )));
recommender . saveIndexDb ();
}
System . out . println ( userHistory . head ( 10 ));
int k = 10 ;
List < AudioSearchEntry > result = recommender . recommends ( userHistory . getHistory (), k );
for ( int i = 0 ; i < result . size (); ++ i ){
AudioSearchEntry entry = result . get ( i );
System . out . println ( "Search Result #" + ( i + 1 ) + ": " + entry . getPath ());
}