001package org.opengion.penguin.math.statistics; 002 003import java.util.Arrays; 004 005import org.apache.commons.math3.stat.regression.OLSMultipleLinearRegression; 006 007/** 008 * apache.commons.mathを利用したOLS重回帰計算のクラスです。 009 * y = c0 + x1c1 + x2c2 + x3c3 ...の係数を求めます。 010 * c0の切片を考慮するかどうかはnoInterceptで決めます。 011 * 012 */ 013public class HybsMultiRegression { 014 private double coe[]; // 各係数(xの種類+1になる?) 015 private double rsquare; // 決定係数 016 private boolean noIntercept; //切片を利用するかどうか 017 018 /** 019 * コンストラクタ。 020 * 与えた二次元データを元に重回帰を計算します。 021 * xデータとして二次元配列を与えます。 022 * noInterceptで切片有り無しを選択します。 023 * @param in_x 説明変数 024 * @param in_y 目的変数 025 * @param noIntercept 切片利用有無(trueで利用しない) 026 */ 027 public HybsMultiRegression(final double[][] in_x, final double[] in_y, final boolean noIntercept){ 028 this.noIntercept = noIntercept; 029 030 // ここで重回帰計算 031 OLSMultipleLinearRegression regression = new OLSMultipleLinearRegression(); 032 regression.setNoIntercept(noIntercept); 033 regression.newSampleData(in_y, in_x); 034 035 coe = regression.estimateRegressionParameters(); 036 rsquare = regression.calculateRSquared(); 037 } 038 039 /** 040 * コンストラクタ。 041 * 係数配列を与えられるようにしておきます。 042 * (以前に計算したものを利用) 043 * @param in_c 係数配列 044 * @param noIntercept 切片利用有無(trueで利用しない) 045 * 046 */ 047 public HybsMultiRegression( final double[] in_c, final boolean noIntercept){ 048 this.coe = in_c; 049 this.noIntercept = noIntercept; 050 } 051 052 053 /** 054 * 係数の取得。 055 * @return 係数配列 056 */ 057 public double[] getParam(){ 058 return coe; 059 } 060 061 /** 062 * 決定係数の取得。 063 * @return 決定係数 064 */ 065 public double getRSquare(){ 066 return rsquare; 067 } 068 069 /** 070 * 計算( c0 + c1x1...)を行う。 071 * noInterceptによってc0の利用を決める。 072 * xの大きさが足りない場合は0を返す。 073 * 074 * @param in_x 必要な大きさの変数配列 075 * @return 計算結果 076 */ 077 public double predict(final double[] in_x){ 078 double rtn = 0; 079 int itr = noIntercept ? 0 : 1; 080 if( in_x.length < coe.length-itr ){ 081 return 0; 082 } 083 084 for( int i=0; i < in_x.length; i++ ){ 085 rtn = rtn + in_x[i] * coe[i+itr]; 086 } 087 if( !noIntercept ){ rtn = rtn + coe[0]; } 088 089 return rtn; 090 } 091 092 /*** ここまでが本体 ***/ 093 /*** ここからテスト用mainメソッド ***/ 094 /** 095 * @param args *****************************************/ 096 public static void main(final String [] args) { 097 // データはhttp://mjin.doshisha.ac.jp/R/14.htmlより 098 double[] y = new double[] { 50, 60, 65, 65, 70, 75, 80, 85, 90, 95 }; 099 double[][] x = new double[10][]; 100 x[0] = new double[] { 165, 65 }; 101 x[1] = new double[] { 170, 68 }; 102 x[2] = new double[] { 172, 70 }; 103 x[3] = new double[] { 175, 65 }; 104 x[4] = new double[] { 170, 80 }; 105 x[5] = new double[] { 172, 85 }; 106 x[6] = new double[] { 183, 78 }; 107 x[7] = new double[] { 187, 79 }; 108 x[8] = new double[] { 180, 95 }; 109 x[9] = new double[] { 185, 97 }; 110 111 112 HybsMultiRegression mr = new HybsMultiRegression(x,y,true); 113 114 System.out.println( mr.getRSquare() ); 115 System.out.println( Arrays.toString( mr.getParam()) ); 116 117 System.out.println( mr.predict( new double[] { 169,85 } )); 118 } 119} 120