L'utilisation de base du package d'algorithmes de machine vectorielle de support libsvm. Ce qui est démontré ici est la machine de régression vectorielle de support.
Copiez le code comme suit :
importer java.io.BufferedReader ;
importer java.io.File ;
importer java.io.FileReader ;
importer java.util.ArrayList ;
importer java.util.List ;
importer libsvm.svm ;
importer libsvm.svm_model ;
importer libsvm.svm_node ;
importer libsvm.svm_parameter ;
importer libsvm.svm_problem ;
SVM de classe publique {
public static void main (String[] arguments) {
// Définissez le point de consigne d'entraînement a{10.0, 10.0} et le point b{-10.0, -10.0}, et l'étiquette correspondante est {1.0, -1.0}
List<Double> label = new ArrayList<Double>();
List<svm_node[]> nodeSet = new ArrayList<svm_node[]>();
getData(nodeSet, étiquette, "file/train.txt");
int dataRange=nodeSet.get(0).length;
svm_node[][] datas = new svm_node[nodeSet.size()][dataRange]; // Table vectorielle de l'ensemble d'entraînement
pour (int i = 0; i < datas.length; i++) {
pour (int j = 0; j < dataRange; j++) {
datas[i][j] = nodeSet.get(i)[j];
}
}
double[] labels = new double[label.size()]; // étiquettes correspondant à a,b
pour (int i = 0; i < étiquettes.length; i++) {
étiquettes[i] = label.get(i);
}
//Définir l'objet svm_problem
problème svm_problem = new svm_problem();
problème.l = nodeSet.size(); // Nombre de vecteurs
problème.x = datas; //Table vectorielle de l'ensemble d'entraînement
problème.y = lables; // Tableau d'étiquettes correspondant
//Définir l'objet svm_parameter
svm_parameter param = new svm_parameter();
param.svm_type = svm_parameter.EPSILON_SVR ;
param.kernel_type = svm_parameter.LINEAR;
param.cache_size = 100 ;
param.eps = 0,00001 ;
param.C = 1,9 ;
//Former le modèle de classification SVM
System.out.println(svm.svm_check_parameter(problème, param));
// S'il n'y a pas de problème avec les paramètres, la fonction svm.svm_check_parameter() renvoie null, sinon elle renvoie une description d'erreur.
svm_model model = svm.svm_train(problème, param);
// svm.svm_train() entraîne le modèle de classification SVM
// Récupère les données de test
List<Double> testlabel = new ArrayList<Double>();
List<svm_node[]> testnodeSet = new ArrayList<svm_node[]>();
getData(testnodeSet, testlabel, "file/test.txt");
svm_node[][] testdatas = new svm_node[testnodeSet.size()][dataRange]; // Table vectorielle de l'ensemble d'entraînement
pour (int i = 0; i < testdatas.length; i++) {
pour (int j = 0; j < dataRange; j++) {
testdatas[i][j] = testnodeSet.get(i)[j];
}
}
double[] testlables = new double[testlabel.size()]; // étiquettes correspondant à a,b
pour (int i = 0; i < testlables.length; i++) {
testlabels[i] = testlabel.get(i);
}
// Étiquette pour prédire les données de test
double erreur = 0,0 ;
pour (int i = 0; i < testdatas.length; i++) {
double valeur vraie = testlables[i];
System.out.print(truevalue + " ");
double predictValue = svm.svm_predict(model, testdatas[i]);
System.out.println(predictValue);
err += Math.abs(predictValue - truevalue);
}
System.out.println("err=" + err / datas.length);
}
public static void getData(List<svm_node[]> nodeSet, List<Double> label,
Nom de fichier de chaîne) {
essayer {
FileReader fr = new FileReader(new File(filename));
BufferedReader br = new BufferedReader(fr);
Ligne de chaîne = null ;
while ((line = br.readLine()) != null) {
String[] datas = line.split(",");
svm_node[] vecteur = nouveau svm_node[datas.length - 1];
pour (int i = 0; i < datas.length - 1; i++) {
svm_node node = nouveau svm_node();
noeud.index = i + 1 ;
node.value = Double.parseDouble(datas[i]);
vecteur[i] = nœud ;
}
nodeSet.add(vecteur);
double lablevalue = Double.parseDouble(datas[datas.length - 1]);
label.add(lablevalue);
}
} attraper (Exception e) {
e.printStackTrace();
}
}
}
Données d'entraînement, la dernière colonne est la valeur cible
Copiez le code comme suit :
17.6,17.7,17.7,17.7,17.8
17,7,17,7,17,7,17,8,17,8
17.7,17.7,17.8,17.8,17.9
17.7,17.8,17.8,17.9,18
17.8,17.8,17.9,18,18.1
17.8,17.9,18,18.1,18.2
17.9,18,18.1,18.2,18.4
18,18.1,18.2,18.4,18.6
18.1,18.2,18.4,18.6,18.7
18.2,18.4,18.6,18.7,18.9
18.4,18.6,18.7,18.9,19.1
18.6,18.7,18.9,19.1,19.3
données de test
Copiez le code comme suit :
18.7,18.9,19.1,19.3,19.6
18.9,19.1,19.3,19.6,19.9
19.1,19.3,19.6,19.9,20.2
19.3,19.6,19.9,20.2,20.6
19.6,19.9,20.2,20.6,21
19.9,20.2,20.6,21,21.5
20.2,20.6,21,21.5,22