载入中。。。 'S bLog
 
载入中。。。
 
载入中。。。
载入中。。。
载入中。。。
载入中。。。
载入中。。。
 
填写您的邮件地址,订阅我们的精彩内容:


 
数据挖掘:K最近邻(KNN)算法的java实现
[ 2011/5/26 12:31:00 | By: 梦翔儿 ]
 

本算法只适合学习使用,可以大致了解一下KNN算法的原理。

算法作了如下的假定与简化处理:

1.小规模数据集

2.假设所有数据及类别都是数值类型的

3.直接根据数据规模设定了k值

4.对原训练集进行测试

KNN实现代码如下:

view plaincopy to clipboardprint?
01.package KNN;  
02./** 
03. * KNN结点类,用来存储最近邻的k个元组相关的信息 
04. * @author Rowen 
05. * @qq 443773264 
06. * @mail luowen3405@163.com 
07. * @blog blog.csdn.net/luowen3405 
08. * @data 2011.03.25 
09. */ 
10.public class KNNNode {  
11.    private int index; // 元组标号  
12.    private double distance; // 与测试元组的距离  
13.    private String c; // 所属类别  
14.    public KNNNode(int index, double distance, String c) {  
15.        super();  
16.        this.index = index;  
17.        this.distance = distance;  
18.        this.c = c;  
19.    }  
20.      
21.      
22.    public int getIndex() {  
23.        return index;  
24.    }  
25.    public void setIndex(int index) {  
26.        this.index = index;  
27.    }  
28.    public double getDistance() {  
29.        return distance;  
30.    }  
31.    public void setDistance(double distance) {  
32.        this.distance = distance;  
33.    }  
34.    public String getC() {  
35.        return c;  
36.    }  
37.    public void setC(String c) {  
38.        this.c = c;  
39.    }  
40.} 
package KNN;
/**
 * KNN结点类,用来存储最近邻的k个元组相关的信息
 * @author Rowen
 * @qq 443773264
 * @mail luowen3405@163.com
 * @blog blog.csdn.net/luowen3405
 * @data 2011.03.25
 */
public class KNNNode {
 private int index; // 元组标号
 private double distance; // 与测试元组的距离
 private String c; // 所属类别
 public KNNNode(int index, double distance, String c) {
  super();
  this.index = index;
  this.distance = distance;
  this.c = c;
 }
 
 
 public int getIndex() {
  return index;
 }
 public void setIndex(int index) {
  this.index = index;
 }
 public double getDistance() {
  return distance;
 }
 public void setDistance(double distance) {
  this.distance = distance;
 }
 public String getC() {
  return c;
 }
 public void setC(String c) {
  this.c = c;
 }
}
 

view plaincopy to clipboardprint?
01.package KNN;  
02.import java.util.ArrayList;  
03.import java.util.Comparator;  
04.import java.util.HashMap;  
05.import java.util.List;  
06.import java.util.Map;  
07.import java.util.PriorityQueue;  
08. 
09./** 
10. * KNN算法主体类 
11. * @author Rowen 
12. * @qq 443773264 
13. * @mail luowen3405@163.com 
14. * @blog blog.csdn.net/luowen3405 
15. * @data 2011.03.25 
16. */ 
17.public class KNN {  
18.    /** 
19.     * 设置优先级队列的比较函数,距离越大,优先级越高 
20.     */ 
21.    private Comparator<KNNNode> comparator = new Comparator<KNNNode>() {  
22.        public int compare(KNNNode o1, KNNNode o2) {  
23.            if (o1.getDistance() >= o2.getDistance()) {  
24.                return 1;  
25.            } else {  
26.                return 0;  
27.            }  
28.        }  
29.    };  
30.    /** 
31.     * 获取K个不同的随机数 
32.     * @param k 随机数的个数 
33.     * @param max 随机数最大的范围 
34.     * @return 生成的随机数数组 
35.     */ 
36.    public List<Integer> getRandKNum(int k, int max) {  
37.        List<Integer> rand = new ArrayList<Integer>(k);  
38.        for (int i = 0; i < k; i++) {  
39.            int temp = (int) (Math.random() * max);  
40.            if (!rand.contains(temp)) {  
41.                rand.add(temp);  
42.            } else {  
43.                i--;  
44.            }  
45.        }  
46.        return rand;  
47.    }  
48.    /** 
49.     * 计算测试元组与训练元组之前的距离 
50.     * @param d1 测试元组 
51.     * @param d2 训练元组 
52.     * @return 距离值 
53.     */ 
54.    public double calDistance(List<Double> d1, List<Double> d2) {  
55.        double distance = 0.00;  
56.        for (int i = 0; i < d1.size(); i++) {  
57.            distance += (d1.get(i) - d2.get(i)) * (d1.get(i) - d2.get(i));  
58.        }  
59.        return distance;  
60.    }  
61.    /** 
62.     * 执行KNN算法,获取测试元组的类别 
63.     * @param datas 训练数据集 
64.     * @param testData 测试元组 
65.     * @param k 设定的K值 
66.     * @return 测试元组的类别 
67.     */ 
68.    public String knn(List<List<Double>> datas, List<Double> testData, int k) {  
69.        PriorityQueue<KNNNode> pq = new PriorityQueue<KNNNode>(k, comparator);  
70.        List<Integer> randNum = getRandKNum(k, datas.size());  
71.        for (int i = 0; i < k; i++) {  
72.            int index = randNum.get(i);  
73.            List<Double> currData = datas.get(index);  
74.            String c = currData.get(currData.size() - 1).toString();  
75.            KNNNode node = new KNNNode(index, calDistance(testData, currData), c);  
76.            pq.add(node);  
77.        }  
78.        for (int i = 0; i < datas.size(); i++) {  
79.            List<Double> t = datas.get(i);  
80.            double distance = calDistance(testData, t);  
81.            KNNNode top = pq.peek();  
82.            if (top.getDistance() > distance) {  
83.                pq.remove();  
84.                pq.add(new KNNNode(i, distance, t.get(t.size() - 1).toString()));  
85.            }  
86.        }  
87.          
88.        return getMostClass(pq);  
89.    }  
90.    /** 
91.     * 获取所得到的k个最近邻元组的多数类 
92.     * @param pq 存储k个最近近邻元组的优先级队列 
93.     * @return 多数类的名称 
94.     */ 
95.    private String getMostClass(PriorityQueue<KNNNode> pq) {  
96.        Map<String, Integer> classCount = new HashMap<String, Integer>();  
97.        for (int i = 0; i < pq.size(); i++) {  
98.            KNNNode node = pq.remove();  
99.            String c = node.getC();  
100.            if (classCount.containsKey(c)) {  
101.                classCount.put(c, classCount.get(c) + 1);  
102.            } else {  
103.                classCount.put(c, 1);  
104.            }  
105.        }  
106.        int maxIndex = -1;  
107.        int maxCount = 0;  
108.        Object[] classes = classCount.keySet().toArray();  
109.        for (int i = 0; i < classes.length; i++) {  
110.            if (classCount.get(classes[i]) > maxCount) {  
111.                maxIndex = i;  
112.                maxCount = classCount.get(classes[i]);  
113.            }  
114.        }  
115.        return classes[maxIndex].toString();  
116.    }  
117.} 
package KNN;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;

/**
 * KNN算法主体类
 * @author Rowen
 * @qq 443773264
 * @mail luowen3405@163.com
 * @blog blog.csdn.net/luowen3405
 * @data 2011.03.25
 */
public class KNN {
 /**
  * 设置优先级队列的比较函数,距离越大,优先级越高
  */
 private Comparator<KNNNode> comparator = new Comparator<KNNNode>() {
  public int compare(KNNNode o1, KNNNode o2) {
   if (o1.getDistance() >= o2.getDistance()) {
    return 1;
   } else {
    return 0;
   }
  }
 };
 /**
  * 获取K个不同的随机数
  * @param k 随机数的个数
  * @param max 随机数最大的范围
  * @return 生成的随机数数组
  */
 public List<Integer> getRandKNum(int k, int max) {
  List<Integer> rand = new ArrayList<Integer>(k);
  for (int i = 0; i < k; i++) {
   int temp = (int) (Math.random() * max);
   if (!rand.contains(temp)) {
    rand.add(temp);
   } else {
    i--;
   }
  }
  return rand;
 }
 /**
  * 计算测试元组与训练元组之前的距离
  * @param d1 测试元组
  * @param d2 训练元组
  * @return 距离值
  */
 public double calDistance(List<Double> d1, List<Double> d2) {
  double distance = 0.00;
  for (int i = 0; i < d1.size(); i++) {
   distance += (d1.get(i) - d2.get(i)) * (d1.get(i) - d2.get(i));
  }
  return distance;
 }
 /**
  * 执行KNN算法,获取测试元组的类别
  * @param datas 训练数据集
  * @param testData 测试元组
  * @param k 设定的K值
  * @return 测试元组的类别
  */
 public String knn(List<List<Double>> datas, List<Double> testData, int k) {
  PriorityQueue<KNNNode> pq = new PriorityQueue<KNNNode>(k, comparator);
  List<Integer> randNum = getRandKNum(k, datas.size());
  for (int i = 0; i < k; i++) {
   int index = randNum.get(i);
   List<Double> currData = datas.get(index);
   String c = currData.get(currData.size() - 1).toString();
   KNNNode node = new KNNNode(index, calDistance(testData, currData), c);
   pq.add(node);
  }
  for (int i = 0; i < datas.size(); i++) {
   List<Double> t = datas.get(i);
   double distance = calDistance(testData, t);
   KNNNode top = pq.peek();
   if (top.getDistance() > distance) {
    pq.remove();
    pq.add(new KNNNode(i, distance, t.get(t.size() - 1).toString()));
   }
  }
  
  return getMostClass(pq);
 }
 /**
  * 获取所得到的k个最近邻元组的多数类
  * @param pq 存储k个最近近邻元组的优先级队列
  * @return 多数类的名称
  */
 private String getMostClass(PriorityQueue<KNNNode> pq) {
  Map<String, Integer> classCount = new HashMap<String, Integer>();
  for (int i = 0; i < pq.size(); i++) {
   KNNNode node = pq.remove();
   String c = node.getC();
   if (classCount.containsKey(c)) {
    classCount.put(c, classCount.get(c) + 1);
   } else {
    classCount.put(c, 1);
   }
  }
  int maxIndex = -1;
  int maxCount = 0;
  Object[] classes = classCount.keySet().toArray();
  for (int i = 0; i < classes.length; i++) {
   if (classCount.get(classes[i]) > maxCount) {
    maxIndex = i;
    maxCount = classCount.get(classes[i]);
   }
  }
  return classes[maxIndex].toString();
 }
}
 

view plaincopy to clipboardprint?
01.package KNN;  
02.import java.io.BufferedReader;  
03.import java.io.File;  
04.import java.io.FileReader;  
05.import java.util.ArrayList;  
06.import java.util.List;  
07./** 
08. * KNN算法测试类 
09. * @author Rowen 
10. * @qq 443773264 
11. * @mail luowen3405@163.com 
12. * @blog blog.csdn.net/luowen3405 
13. * @data 2011.03.25 
14. */ 
15.public class TestKNN {  
16.      
17.    /** 
18.     * 从数据文件中读取数据 
19.     * @param datas 存储数据的集合对象 
20.     * @param path 数据文件的路径 
21.     */ 
22.    public void read(List<List<Double>> datas, String path){  
23.        try {  
24.            BufferedReader br = new BufferedReader(new FileReader(new File(path)));  
25.            String data = br.readLine();  
26.            List<Double> l = null;  
27.            while (data != null) {  
28.                String t[] = data.split(" ");  
29.                l = new ArrayList<Double>();  
30.                for (int i = 0; i < t.length; i++) {  
31.                    l.add(Double.parseDouble(t[i]));  
32.                }  
33.                datas.add(l);  
34.                data = br.readLine();  
35.            }  
36.        } catch (Exception e) {  
37.            e.printStackTrace();  
38.        }  
39.    }  
40.      
41.    /** 
42.     * 程序执行入口 
43.     * @param args 
44.     */ 
45.    public static void main(String[] args) {  
46.        TestKNN t = new TestKNN();  
47.        String datafile = new File("").getAbsolutePath() + File.separator + "datafile";  
48.        String testfile = new File("").getAbsolutePath() + File.separator + "testfile";  
49.        try {  
50.            List<List<Double>> datas = new ArrayList<List<Double>>();  
51.            List<List<Double>> testDatas = new ArrayList<List<Double>>();  
52.            t.read(datas, datafile);  
53.            t.read(testDatas, testfile);  
54.            KNN knn = new KNN();  
55.            for (int i = 0; i < testDatas.size(); i++) {  
56.                List<Double> test = testDatas.get(i);  
57.                System.out.print("测试元组: ");  
58.                for (int j = 0; j < test.size(); j++) {  
59.                    System.out.print(test.get(j) + " ");  
60.                }  
61.                System.out.print("类别为: ");  
62.                System.out.println(Math.round(Float.parseFloat((knn.knn(datas, test, 3)))));  
63.            }  
64.        } catch (Exception e) {  
65.            e.printStackTrace();  
66.        }  
67.    }  
68.} 
package KNN;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.util.ArrayList;
import java.util.List;
/**
 * KNN算法测试类
 * @author Rowen
 * @qq 443773264
 * @mail luowen3405@163.com
 * @blog blog.csdn.net/luowen3405
 * @data 2011.03.25
 */
public class TestKNN {
 
 /**
  * 从数据文件中读取数据
  * @param datas 存储数据的集合对象
  * @param path 数据文件的路径
  */
 public void read(List<List<Double>> datas, String path){
  try {
   BufferedReader br = new BufferedReader(new FileReader(new File(path)));
   String data = br.readLine();
   List<Double> l = null;
   while (data != null) {
    String t[] = data.split(" ");
    l = new ArrayList<Double>();
    for (int i = 0; i < t.length; i++) {
     l.add(Double.parseDouble(t[i]));
    }
    datas.add(l);
    data = br.readLine();
   }
  } catch (Exception e) {
   e.printStackTrace();
  }
 }
 
 /**
  * 程序执行入口
  * @param args
  */
 public static void main(String[] args) {
  TestKNN t = new TestKNN();
  String datafile = new File("").getAbsolutePath() + File.separator + "datafile";
  String testfile = new File("").getAbsolutePath() + File.separator + "testfile";
  try {
   List<List<Double>> datas = new ArrayList<List<Double>>();
   List<List<Double>> testDatas = new ArrayList<List<Double>>();
   t.read(datas, datafile);
   t.read(testDatas, testfile);
   KNN knn = new KNN();
   for (int i = 0; i < testDatas.size(); i++) {
    List<Double> test = testDatas.get(i);
    System.out.print("测试元组: ");
    for (int j = 0; j < test.size(); j++) {
     System.out.print(test.get(j) + " ");
    }
    System.out.print("类别为: ");
    System.out.println(Math.round(Float.parseFloat((knn.knn(datas, test, 3)))));
   }
  } catch (Exception e) {
   e.printStackTrace();
  }
 }
}
 

训练数据文件:

view plaincopy to clipboardprint?
01.1.0 1.1 1.2 2.1 0.3 2.3 1.4 0.5 1 
02.1.7 1.2 1.4 2.0 0.2 2.5 1.2 0.8 1 
03.1.2 1.8 1.6 2.5 0.1 2.2 1.8 0.2 1 
04.1.9 2.1 6.2 1.1 0.9 3.3 2.4 5.5 0 
05.1.0 0.8 1.6 2.1 0.2 2.3 1.6 0.5 1 
06.1.6 2.1 5.2 1.1 0.8 3.6 2.4 4.5 0 
1.0 1.1 1.2 2.1 0.3 2.3 1.4 0.5 1
1.7 1.2 1.4 2.0 0.2 2.5 1.2 0.8 1
1.2 1.8 1.6 2.5 0.1 2.2 1.8 0.2 1
1.9 2.1 6.2 1.1 0.9 3.3 2.4 5.5 0
1.0 0.8 1.6 2.1 0.2 2.3 1.6 0.5 1
1.6 2.1 5.2 1.1 0.8 3.6 2.4 4.5 0

view plaincopy to clipboardprint?
01.1.0 1.1 1.2 2.1 0.3 2.3 1.4 0.5 
02.1.7 1.2 1.4 2.0 0.2 2.5 1.2 0.8 
03.1.2 1.8 1.6 2.5 0.1 2.2 1.8 0.2 
04.1.9 2.1 6.2 1.1 0.9 3.3 2.4 5.5 
05.1.0 0.8 1.6 2.1 0.2 2.3 1.6 0.5 
06.1.6 2.1 5.2 1.1 0.8 3.6 2.4 4.5 
1.0 1.1 1.2 2.1 0.3 2.3 1.4 0.5
1.7 1.2 1.4 2.0 0.2 2.5 1.2 0.8
1.2 1.8 1.6 2.5 0.1 2.2 1.8 0.2
1.9 2.1 6.2 1.1 0.9 3.3 2.4 5.5
1.0 0.8 1.6 2.1 0.2 2.3 1.6 0.5
1.6 2.1 5.2 1.1 0.8 3.6 2.4 4.5

程序运行结果:

view plaincopy to clipboardprint?
01.测试元组: 1.0 1.1 1.2 2.1 0.3 2.3 1.4 0.5 类别为: 1 
02.测试元组: 1.7 1.2 1.4 2.0 0.2 2.5 1.2 0.8 类别为: 1 
03.测试元组: 1.2 1.8 1.6 2.5 0.1 2.2 1.8 0.2 类别为: 1 
04.测试元组: 1.9 2.1 6.2 1.1 0.9 3.3 2.4 5.5 类别为: 0 
05.测试元组: 1.0 0.8 1.6 2.1 0.2 2.3 1.6 0.5 类别为: 1 
06.测试元组: 1.6 2.1 5.2 1.1 0.8 3.6 2.4 4.5 类别为: 0 
测试元组: 1.0 1.1 1.2 2.1 0.3 2.3 1.4 0.5 类别为: 1
测试元组: 1.7 1.2 1.4 2.0 0.2 2.5 1.2 0.8 类别为: 1
测试元组: 1.2 1.8 1.6 2.5 0.1 2.2 1.8 0.2 类别为: 1
测试元组: 1.9 2.1 6.2 1.1 0.9 3.3 2.4 5.5 类别为: 0
测试元组: 1.0 0.8 1.6 2.1 0.2 2.3 1.6 0.5 类别为: 1
测试元组: 1.6 2.1 5.2 1.1 0.8 3.6 2.4 4.5 类别为: 0

由结果可以看出,分类的测试结果是比较准确的!

 

http://blog.csdn.net/luowen3405/archive/2011/03/25/6278764.aspx

 
 
发表评论:
载入中。。。

 
 
 

梦翔儿网站 梦飞翔的地方 http://www.dreamflier.net
中华人民共和国信息产业部TCP/IP系统 备案序号:辽ICP备09000550号

Powered by Oblog.