1
2
3 package org.galagosearch.core.eval;
4
5 import java.util.Map;
6 import java.util.Random;
7 import java.util.Set;
8 import java.util.TreeSet;
9 import org.galagosearch.core.eval.stat.Stat;
10
11 /***
12 *
13 * @author trevor
14 */
15 public class SetRetrievalComparator {
16 double[] baseline;
17 double[] treatment;
18
19 /*** Creates a new instance of SetRetrievalComparator */
20 public SetRetrievalComparator(Map<String, Double> baseline, Map<String, Double> treatment) {
21 Set<String> commonQueries = new TreeSet<String>(baseline.keySet());
22 commonQueries.retainAll(treatment.keySet());
23
24 this.baseline = new double[commonQueries.size()];
25 this.treatment = new double[commonQueries.size()];
26 int i = 0;
27
28 for (String key : commonQueries) {
29 this.baseline[i] = baseline.get(key);
30 this.treatment[i] = treatment.get(key);
31 i++;
32 }
33 }
34
35 private double[] multiply(double[] numbers, double boost) {
36 double[] result = new double[numbers.length];
37
38 for (int i = 0; i < result.length; i++) {
39 result[i] = numbers[i] * boost;
40 }
41
42 return result;
43 }
44
45 private double mean(double[] numbers) {
46 double sum = 0;
47 for (int i = 0; i < numbers.length; i++) {
48 sum += numbers[i];
49 }
50
51 return sum / (double) numbers.length;
52 }
53
54 public double meanBaselineMetric() {
55 return mean(baseline);
56 }
57
58 public double meanTreatmentMetric() {
59 return mean(treatment);
60 }
61
62 public int countTreatmentBetter() {
63 int better = 0;
64
65 for (int i = 0; i < baseline.length; i++) {
66 if (baseline[i] < treatment[i]) {
67 better++;
68 }
69 }
70
71 return better;
72 }
73
74 public int countBaselineBetter() {
75 int better = 0;
76
77 for (int i = 0; i < baseline.length; i++) {
78 if (baseline[i] > treatment[i]) {
79 better++;
80 }
81 }
82
83 return better;
84 }
85
86 public int countEqual() {
87 int same = 0;
88
89 for (int i = 0; i < baseline.length; i++) {
90 if (baseline[i] == treatment[i]) {
91 same++;
92 }
93 }
94
95 return same;
96 }
97
98 public double supportedHypothesis(String testName, double pvalue) {
99 double currentBoost = 1.0;
100 double currentPvalue = test(testName, currentBoost);
101 double lastBoost = 1.0;
102 double lastPvalue = currentPvalue;
103 int iterations = 0;
104
105
106 while ((lastPvalue < pvalue) == (currentPvalue < pvalue)) {
107 double nextBoost = currentBoost;
108
109 if (currentPvalue < pvalue) {
110 nextBoost *= 1.05;
111 } else if (currentPvalue > pvalue) {
112 nextBoost *= 0.95;
113 }
114
115 double nextPvalue = test(testName, nextBoost);
116
117 lastBoost = currentBoost;
118 lastPvalue = currentPvalue;
119 currentBoost = nextBoost;
120 currentPvalue = nextPvalue;
121
122 iterations++;
123
124 if (iterations > 50) {
125 return 0;
126 }
127 }
128
129
130 double lowBoost = Math.min(lastBoost, currentBoost);
131 double highBoost = Math.max(lastBoost, currentBoost);
132
133 while (highBoost - lowBoost > 0.00005) {
134 double middleBoost = (highBoost + lowBoost) / 2;
135 currentPvalue = test(testName, middleBoost);
136
137 if (currentPvalue > pvalue) {
138 highBoost = middleBoost;
139 } else {
140 lowBoost = middleBoost;
141 }
142
143 iterations++;
144
145 if (iterations > 100) {
146 return 0;
147 }
148 }
149
150 return lowBoost;
151 }
152
153 public double test(String testName, double boost) {
154 if (testName.compareToIgnoreCase("ttest") == 0 || testName.compareToIgnoreCase("pairedTTest") == 0) {
155 return pairedTTest(boost);
156 } else if (testName.compareToIgnoreCase("sign") == 0) {
157 return signTest(boost);
158 } else if (testName.compareToIgnoreCase("randomized") == 0) {
159 return randomizedTest(boost);
160 } else {
161 throw new RuntimeException("'" + testName + "' is not a recognized test.");
162 }
163 }
164
165 public double pairedTTest() {
166 return pairedTTest(1.0);
167 }
168
169 public double pairedTTest(double boost) {
170 double[] boostedBaseline = multiply(baseline, boost);
171 double sampleSum = 0;
172 double sampleSumSquares = 0;
173 int n = boostedBaseline.length;
174
175 for (int i = 0; i < baseline.length; i++) {
176 double delta = treatment[i] - boostedBaseline[i];
177 sampleSum += delta;
178 sampleSumSquares += delta * delta;
179 }
180
181 double sampleVariance = sampleSumSquares / (n - 1);
182 double sampleMean = sampleSum / baseline.length;
183
184 double sampleDeviation = Math.sqrt(sampleVariance);
185 double meanDeviation = sampleDeviation / Math.sqrt(n);
186 double t = sampleMean / meanDeviation;
187
188 return 1.0 - Stat.studentTProb(t, n - 1);
189 }
190
191 public double signTest() {
192 return signTest(1.0);
193 }
194
195 public double signTest(double boost) {
196 int treatmentIsBetter = 0;
197 int different = 0;
198
199 for (int i = 0; i < treatment.length; i++) {
200 double boostedBaseline = baseline[i] * boost;
201 if (treatment[i] > boostedBaseline) {
202 treatmentIsBetter++;
203 }
204 if (treatment[i] != boostedBaseline) {
205 different++;
206 }
207 }
208
209 double pvalue = Stat.binomialProb(0.5, different, treatmentIsBetter);
210 return pvalue;
211 }
212
213 public double randomizedTest() {
214 return randomizedTest(1.0);
215 }
216
217 public double randomizedTest(double boost) {
218 double[] boostedBaseline = multiply(baseline, boost);
219 double baseMean = mean(boostedBaseline);
220 double treatmentMean = mean(treatment);
221 double difference = treatmentMean - baseMean;
222 int batch = 10000;
223
224 final int maxIterationsWithoutMatch = 1000000;
225 long iterations = 0;
226 long matches = 0;
227
228 double[] leftSample = new double[boostedBaseline.length];
229 double[] rightSample = new double[boostedBaseline.length];
230 Random random = new Random();
231 double pValue = 0.0;
232
233 while (true) {
234 for (int i = 0; i < batch; i++) {
235
236 for (int j = 0; j < boostedBaseline.length; j++) {
237 if (random.nextBoolean()) {
238 leftSample[j] = boostedBaseline[j];
239 rightSample[j] = treatment[j];
240 } else {
241 leftSample[j] = treatment[j];
242 rightSample[j] = boostedBaseline[j];
243 }
244 }
245
246 double sampleDifference = mean(leftSample) - mean(rightSample);
247
248 if (difference <= sampleDifference) {
249 matches++;
250 }
251 }
252
253 iterations += batch;
254
255
256 pValue = (double) matches / (double) iterations;
257
258
259 if (matches == 0) {
260 if (iterations < maxIterationsWithoutMatch) {
261 continue;
262 } else {
263 break;
264 }
265 }
266
267
268
269
270
271
272 double maxDeviation = Math.max(0.0000005 / pValue, Math.min(0.00005 / pValue, 0.05));
273
274
275
276
277 double estimatedIterations = Math.sqrt(pValue * (1.0 - pValue)) / maxDeviation;
278
279 if (estimatedIterations > iterations) {
280 break;
281 }
282 }
283
284 return pValue;
285 }
286 }