View Javadoc

1   // BSD License (http://www.galagosearch.org/license)
2   
3   package org.galagosearch.core.eval;
4   
5   import java.util.Map;
6   import java.util.Random;
7   import java.util.Set;
8   import java.util.TreeSet;
9   import org.galagosearch.core.eval.stat.Stat;
10  
11  /***
12   *
13   * @author trevor
14   */
15  public class SetRetrievalComparator {
16      double[] baseline;
17      double[] treatment;
18  
19      /*** Creates a new instance of SetRetrievalComparator */
20      public SetRetrievalComparator(Map<String, Double> baseline, Map<String, Double> treatment) {
21          Set<String> commonQueries = new TreeSet<String>(baseline.keySet());
22          commonQueries.retainAll(treatment.keySet());
23  
24          this.baseline = new double[commonQueries.size()];
25          this.treatment = new double[commonQueries.size()];
26          int i = 0;
27  
28          for (String key : commonQueries) {
29              this.baseline[i] = baseline.get(key);
30              this.treatment[i] = treatment.get(key);
31              i++;
32          }
33      }
34  
35      private double[] multiply(double[] numbers, double boost) {
36          double[] result = new double[numbers.length];
37  
38          for (int i = 0; i < result.length; i++) {
39              result[i] = numbers[i] * boost;
40          }
41  
42          return result;
43      }
44  
45      private double mean(double[] numbers) {
46          double sum = 0;
47          for (int i = 0; i < numbers.length; i++) {
48              sum += numbers[i];
49          }
50  
51          return sum / (double) numbers.length;
52      }
53  
54      public double meanBaselineMetric() {
55          return mean(baseline);
56      }
57  
58      public double meanTreatmentMetric() {
59          return mean(treatment);
60      }
61  
62      public int countTreatmentBetter() {
63          int better = 0;
64  
65          for (int i = 0; i < baseline.length; i++) {
66              if (baseline[i] < treatment[i]) {
67                  better++;
68              }
69          }
70  
71          return better;
72      }
73  
74      public int countBaselineBetter() {
75          int better = 0;
76  
77          for (int i = 0; i < baseline.length; i++) {
78              if (baseline[i] > treatment[i]) {
79                  better++;
80              }
81          }
82  
83          return better;
84      }
85  
86      public int countEqual() {
87          int same = 0;
88  
89          for (int i = 0; i < baseline.length; i++) {
90              if (baseline[i] == treatment[i]) {
91                  same++;
92              }
93          }
94  
95          return same;
96      }
97  
98      public double supportedHypothesis(String testName, double pvalue) {
99          double currentBoost = 1.0;
100         double currentPvalue = test(testName, currentBoost);
101         double lastBoost = 1.0;
102         double lastPvalue = currentPvalue;
103         int iterations = 0;
104 
105         // search until we find an interval
106         while ((lastPvalue < pvalue) == (currentPvalue < pvalue)) {
107             double nextBoost = currentBoost;
108 
109             if (currentPvalue < pvalue) {
110                 nextBoost *= 1.05;
111             } else if (currentPvalue > pvalue) {
112                 nextBoost *= 0.95;
113             }
114 
115             double nextPvalue = test(testName, nextBoost);
116 
117             lastBoost = currentBoost;
118             lastPvalue = currentPvalue;
119             currentBoost = nextBoost;
120             currentPvalue = nextPvalue;
121 
122             iterations++;
123 
124             if (iterations > 50) {
125                 return 0;
126             }
127         }
128 
129         // now we have an interval to search in
130         double lowBoost = Math.min(lastBoost, currentBoost);
131         double highBoost = Math.max(lastBoost, currentBoost);
132 
133         while (highBoost - lowBoost > 0.00005) {
134             double middleBoost = (highBoost + lowBoost) / 2;
135             currentPvalue = test(testName, middleBoost);
136 
137             if (currentPvalue > pvalue) {
138                 highBoost = middleBoost;
139             } else {
140                 lowBoost = middleBoost;
141             }
142 
143             iterations++;
144 
145             if (iterations > 100) {
146                 return 0;
147             }
148         }
149 
150         return lowBoost;
151     }
152 
153     public double test(String testName, double boost) {
154         if (testName.compareToIgnoreCase("ttest") == 0 || testName.compareToIgnoreCase("pairedTTest") == 0) {
155             return pairedTTest(boost);
156         } else if (testName.compareToIgnoreCase("sign") == 0) {
157             return signTest(boost);
158         } else if (testName.compareToIgnoreCase("randomized") == 0) {
159             return randomizedTest(boost);
160         } else {
161             throw new RuntimeException("'" + testName + "' is not a recognized test.");
162         }
163     }
164 
165     public double pairedTTest() {
166         return pairedTTest(1.0);
167     }
168 
169     public double pairedTTest(double boost) {
170         double[] boostedBaseline = multiply(baseline, boost);
171         double sampleSum = 0;
172         double sampleSumSquares = 0;
173         int n = boostedBaseline.length;
174 
175         for (int i = 0; i < baseline.length; i++) {
176             double delta = treatment[i] - boostedBaseline[i];
177             sampleSum += delta;
178             sampleSumSquares += delta * delta;
179         }
180 
181         double sampleVariance = sampleSumSquares / (n - 1);
182         double sampleMean = sampleSum / baseline.length;
183 
184         double sampleDeviation = Math.sqrt(sampleVariance);
185         double meanDeviation = sampleDeviation / Math.sqrt(n);
186         double t = sampleMean / meanDeviation;
187 
188         return 1.0 - Stat.studentTProb(t, n - 1);
189     }
190 
191     public double signTest() {
192         return signTest(1.0);
193     }
194 
195     public double signTest(double boost) {
196         int treatmentIsBetter = 0;
197         int different = 0;
198 
199         for (int i = 0; i < treatment.length; i++) {
200             double boostedBaseline = baseline[i] * boost;
201             if (treatment[i] > boostedBaseline) {
202                 treatmentIsBetter++;
203             }
204             if (treatment[i] != boostedBaseline) {
205                 different++;
206             }
207         }
208 
209         double pvalue = Stat.binomialProb(0.5, different, treatmentIsBetter);
210         return pvalue;
211     }
212 
213     public double randomizedTest() {
214         return randomizedTest(1.0);
215     }
216 
217     public double randomizedTest(double boost) {
218         double[] boostedBaseline = multiply(baseline, boost);
219         double baseMean = mean(boostedBaseline);
220         double treatmentMean = mean(treatment);
221         double difference = treatmentMean - baseMean;
222         int batch = 10000;
223 
224         final int maxIterationsWithoutMatch = 1000000;
225         long iterations = 0;
226         long matches = 0;
227 
228         double[] leftSample = new double[boostedBaseline.length];
229         double[] rightSample = new double[boostedBaseline.length];
230         Random random = new Random();
231         double pValue = 0.0;
232 
233         while (true) {
234             for (int i = 0; i < batch; i++) {
235                 // create a sample from both distributions
236                 for (int j = 0; j < boostedBaseline.length; j++) {
237                     if (random.nextBoolean()) {
238                         leftSample[j] = boostedBaseline[j];
239                         rightSample[j] = treatment[j];
240                     } else {
241                         leftSample[j] = treatment[j];
242                         rightSample[j] = boostedBaseline[j];
243                     }
244                 }
245 
246                 double sampleDifference = mean(leftSample) - mean(rightSample);
247 
248                 if (difference <= sampleDifference) {
249                     matches++;
250                 }
251             }
252 
253             iterations += batch;
254 
255             // this is the current p-value estimate
256             pValue = (double) matches / (double) iterations;
257 
258             // if we still haven't found a match, keep looking
259             if (matches == 0) {
260                 if (iterations < maxIterationsWithoutMatch) {
261                     continue;
262                 } else {
263                     break;
264                 }
265             }
266 
267             // this is our accepted level of deviation in the p-value; we require:
268             //      - accuracy at the fourth decimal place, and
269             //      - less than 5% error in the p-value, or
270             //      - accuracy at the sixth decimal place.
271 
272             double maxDeviation = Math.max(0.0000005 / pValue, Math.min(0.00005 / pValue, 0.05));
273 
274             // this estimate is derived in Efron and Tibshirani, p.209.
275             // this is the estimated number of iterations necessary for convergence, given
276             // our current p-value estimate.
277             double estimatedIterations = Math.sqrt(pValue * (1.0 - pValue)) / maxDeviation;
278 
279             if (estimatedIterations > iterations) {
280                 break;
281             }
282         }
283 
284         return pValue;
285     }
286 }