2802 - Fixing ArithmeticException when using GroupBy.avg() - GAvg now uses default or user-provided MathContext for calculating the average

This commit is contained in:
Maciej Dobrowolski 2021-04-09 13:56:06 +02:00
parent 99e3ca936f
commit 82f240a889
5 changed files with 72 additions and 9 deletions

View File

@ -14,6 +14,7 @@
package com.querydsl.core.group;
import java.math.BigDecimal;
import java.math.MathContext;
import com.querydsl.core.types.Expression;
import com.querydsl.core.util.MathUtils;
@ -23,8 +24,15 @@ class GAvg<T extends Number> extends AbstractGroupExpression<T, T> {
private static final long serialVersionUID = 3518868612387641383L;
private final MathContext mathContext;
public GAvg(Expression<T> expr) {
this(expr, MathContext.DECIMAL128);
}
public GAvg(Expression<T> expr, MathContext mathContext) {
super((Class) expr.getType(), expr);
this.mathContext = mathContext;
}
@Override
@ -43,7 +51,7 @@ class GAvg<T extends Number> extends AbstractGroupExpression<T, T> {
@Override
public T get() {
BigDecimal avg = sum.divide(BigDecimal.valueOf(count));
BigDecimal avg = sum.divide(BigDecimal.valueOf(count), mathContext);
return MathUtils.cast(avg, getType());
}

View File

@ -13,6 +13,7 @@
*/
package com.querydsl.core.group;
import java.math.MathContext;
import java.util.*;
import com.mysema.commons.lang.Pair;
@ -79,7 +80,7 @@ public class GroupBy {
}
/**
* Create a new aggregating avg expression
* Create a new aggregating avg expression, uses default MathContext.DECIMAL128 for average calculation
*
* @param expression expression for which the accumulated average value will be used in the group by projection
* @return wrapper expression
@ -88,6 +89,17 @@ public class GroupBy {
return new GAvg<E>(expression);
}
/**
* Create a new aggregating avg expression with a user-provided MathContext
*
* @param expression expression for which the accumulated average value will be used in the group by projection
* @param mathContext mathContext for average calculation
* @return wrapper expression
*/
public static <E extends Number> AbstractGroupExpression<E, E> avg(Expression<E> expression, MathContext mathContext) {
return new GAvg<E>(expression, mathContext);
}
/**
* Create a new aggregating max expression
*

View File

@ -106,6 +106,15 @@ public abstract class AbstractGroupByTest {
row("John", "John", 1, "post 1", comment(3))
);
protected static final DummyFetchableQuery<Tuple> POSTS_W_COMMENTS_SCORE = projectable(
row(null, 1.5),
row(1, 1.5),
row(1, 2.0),
row(1, 0.5),
row(2, 1.0),
row(2, 2.0)
);
// protected static final Projectable USERS_W_LATEST_POST_AND_COMMENTS2 = projectable(
// row("John", 1, "post 1", comment(1)),
// row("Jane", 2, "post 2", comment(4)),
@ -130,7 +139,9 @@ public abstract class AbstractGroupByTest {
protected static final StringExpression commentText = Expressions.stringPath(comment, "text");
protected static final ConstructorExpression<Comment> qComment = Projections.constructor(Comment.class, commentId, commentText);
protected static final NumberExpression<Double> score = Expressions.numberPath(Double.class, comment, "score");
protected static final ConstructorExpression<Comment> qComment = Projections.constructor(Comment.class, commentId, commentText, score);
protected static <K, V> Pair<K, V> pair(K key, V value) {
return new Pair<K, V>(key, value);
@ -145,7 +156,7 @@ public abstract class AbstractGroupByTest {
}
protected static Comment comment(Integer id) {
return new Comment(id, "comment " + id);
return new Comment(id, "comment " + id, 0.0);
}
protected static DummyFetchableQuery<Tuple> projectable(final Object[]... rows) {

View File

@ -15,15 +15,18 @@ package com.querydsl.core.group;
public class Comment {
private Integer id;
private Integer id;
private String text;
private Double score;
public Comment() { }
public Comment(Integer id, String text) {
public Comment(Integer id, String text, Double score) {
this.id = id;
this.text = text;
this.score = score;
}
public Integer getId() {
@ -42,9 +45,17 @@ public class Comment {
this.text = text;
}
public Double getScore() {
return score;
}
public void setScore(Double score) {
this.score = score;
}
@Override
public int hashCode() {
return 31 * id.hashCode() + text.hashCode();
return 31 * id.hashCode() + text.hashCode() + score.hashCode();
}
@Override
@ -53,7 +64,7 @@ public class Comment {
return true;
} else if (o instanceof Comment) {
Comment other = (Comment) o;
return this.id.equals(other.id) && this.text.equals(other.text);
return this.id.equals(other.id) && this.text.equals(other.text) && this.score.equals(other.score);
} else {
return false;
}
@ -61,6 +72,6 @@ public class Comment {
@Override
public String toString() {
return id + ": " + text;
return id + ": " + text + "(score: " + score + ")";
}
}

View File

@ -30,6 +30,9 @@ import com.querydsl.core.types.dsl.NumberPath;
import com.querydsl.core.types.dsl.StringExpression;
import com.querydsl.core.types.dsl.StringPath;
import java.math.MathContext;
import java.math.RoundingMode;
public class GroupByMapTest extends AbstractGroupByTest {
@Test
@ -330,4 +333,22 @@ public class GroupByMapTest extends AbstractGroupByTest {
assertNotNull(resultTransformer);
}
@Test
public void average_with_default_math_context() {
Map<Integer, Double> results = POSTS_W_COMMENTS_SCORE
.transform(groupBy(postId).as(avg(score)));
assertEquals(1.5, results.get(null), 0.0);
assertEquals(((1.5 + 2.0 + 0.5) / 3), results.get(1), 0.0);
assertEquals(((1.0 + 2.0) / 2), results.get(2), 0.0);
}
@Test
public void average_with_user_provided_math_context() {
MathContext oneDigitMathContext = new MathContext(2, RoundingMode.HALF_EVEN);
Map<Integer, Double> results = POSTS_W_COMMENTS_SCORE
.transform(groupBy(postId).as(avg(score, oneDigitMathContext)));
assertEquals(1.5, results.get(null), 0.0);
assertEquals(1.3, results.get(1), 0.0);
assertEquals(1.5, results.get(2), 0.0);
}
}