From 82f240a8893a287556e17bc94d91e9df3ca281b1 Mon Sep 17 00:00:00 2001 From: Maciej Dobrowolski Date: Fri, 9 Apr 2021 13:56:06 +0200 Subject: [PATCH] 2802 - Fixing ArithmeticException when using GroupBy.avg() - GAvg now uses default or user-provided MathContext for calculating the average --- .../java/com/querydsl/core/group/GAvg.java | 10 ++++++++- .../java/com/querydsl/core/group/GroupBy.java | 14 ++++++++++++- .../core/group/AbstractGroupByTest.java | 15 +++++++++++-- .../java/com/querydsl/core/group/Comment.java | 21 ++++++++++++++----- .../querydsl/core/group/GroupByMapTest.java | 21 +++++++++++++++++++ 5 files changed, 72 insertions(+), 9 deletions(-) diff --git a/querydsl-core/src/main/java/com/querydsl/core/group/GAvg.java b/querydsl-core/src/main/java/com/querydsl/core/group/GAvg.java index 8fb1ff2b6..a2d2465ed 100644 --- a/querydsl-core/src/main/java/com/querydsl/core/group/GAvg.java +++ b/querydsl-core/src/main/java/com/querydsl/core/group/GAvg.java @@ -14,6 +14,7 @@ package com.querydsl.core.group; import java.math.BigDecimal; +import java.math.MathContext; import com.querydsl.core.types.Expression; import com.querydsl.core.util.MathUtils; @@ -23,8 +24,15 @@ class GAvg extends AbstractGroupExpression { private static final long serialVersionUID = 3518868612387641383L; + private final MathContext mathContext; + public GAvg(Expression expr) { + this(expr, MathContext.DECIMAL128); + } + + public GAvg(Expression expr, MathContext mathContext) { super((Class) expr.getType(), expr); + this.mathContext = mathContext; } @Override @@ -43,7 +51,7 @@ class GAvg extends AbstractGroupExpression { @Override public T get() { - BigDecimal avg = sum.divide(BigDecimal.valueOf(count)); + BigDecimal avg = sum.divide(BigDecimal.valueOf(count), mathContext); return MathUtils.cast(avg, getType()); } diff --git a/querydsl-core/src/main/java/com/querydsl/core/group/GroupBy.java b/querydsl-core/src/main/java/com/querydsl/core/group/GroupBy.java index 00eb8f3ab..dfd21ea9b 100644 --- a/querydsl-core/src/main/java/com/querydsl/core/group/GroupBy.java +++ b/querydsl-core/src/main/java/com/querydsl/core/group/GroupBy.java @@ -13,6 +13,7 @@ */ package com.querydsl.core.group; +import java.math.MathContext; import java.util.*; import com.mysema.commons.lang.Pair; @@ -79,7 +80,7 @@ public class GroupBy { } /** - * Create a new aggregating avg expression + * Create a new aggregating avg expression, uses default MathContext.DECIMAL128 for average calculation * * @param expression expression for which the accumulated average value will be used in the group by projection * @return wrapper expression @@ -88,6 +89,17 @@ public class GroupBy { return new GAvg(expression); } + /** + * Create a new aggregating avg expression with a user-provided MathContext + * + * @param expression expression for which the accumulated average value will be used in the group by projection + * @param mathContext mathContext for average calculation + * @return wrapper expression + */ + public static AbstractGroupExpression avg(Expression expression, MathContext mathContext) { + return new GAvg(expression, mathContext); + } + /** * Create a new aggregating max expression * diff --git a/querydsl-core/src/test/java/com/querydsl/core/group/AbstractGroupByTest.java b/querydsl-core/src/test/java/com/querydsl/core/group/AbstractGroupByTest.java index 5d28212aa..627dace99 100644 --- a/querydsl-core/src/test/java/com/querydsl/core/group/AbstractGroupByTest.java +++ b/querydsl-core/src/test/java/com/querydsl/core/group/AbstractGroupByTest.java @@ -106,6 +106,15 @@ public abstract class AbstractGroupByTest { row("John", "John", 1, "post 1", comment(3)) ); + protected static final DummyFetchableQuery POSTS_W_COMMENTS_SCORE = projectable( + row(null, 1.5), + row(1, 1.5), + row(1, 2.0), + row(1, 0.5), + row(2, 1.0), + row(2, 2.0) + ); + // protected static final Projectable USERS_W_LATEST_POST_AND_COMMENTS2 = projectable( // row("John", 1, "post 1", comment(1)), // row("Jane", 2, "post 2", comment(4)), @@ -130,7 +139,9 @@ public abstract class AbstractGroupByTest { protected static final StringExpression commentText = Expressions.stringPath(comment, "text"); - protected static final ConstructorExpression qComment = Projections.constructor(Comment.class, commentId, commentText); + protected static final NumberExpression score = Expressions.numberPath(Double.class, comment, "score"); + + protected static final ConstructorExpression qComment = Projections.constructor(Comment.class, commentId, commentText, score); protected static Pair pair(K key, V value) { return new Pair(key, value); @@ -145,7 +156,7 @@ public abstract class AbstractGroupByTest { } protected static Comment comment(Integer id) { - return new Comment(id, "comment " + id); + return new Comment(id, "comment " + id, 0.0); } protected static DummyFetchableQuery projectable(final Object[]... rows) { diff --git a/querydsl-core/src/test/java/com/querydsl/core/group/Comment.java b/querydsl-core/src/test/java/com/querydsl/core/group/Comment.java index cbc0fd794..61d7a5510 100644 --- a/querydsl-core/src/test/java/com/querydsl/core/group/Comment.java +++ b/querydsl-core/src/test/java/com/querydsl/core/group/Comment.java @@ -15,15 +15,18 @@ package com.querydsl.core.group; public class Comment { - private Integer id; + private Integer id; private String text; + private Double score; + public Comment() { } - public Comment(Integer id, String text) { + public Comment(Integer id, String text, Double score) { this.id = id; this.text = text; + this.score = score; } public Integer getId() { @@ -42,9 +45,17 @@ public class Comment { this.text = text; } + public Double getScore() { + return score; + } + + public void setScore(Double score) { + this.score = score; + } + @Override public int hashCode() { - return 31 * id.hashCode() + text.hashCode(); + return 31 * id.hashCode() + text.hashCode() + score.hashCode(); } @Override @@ -53,7 +64,7 @@ public class Comment { return true; } else if (o instanceof Comment) { Comment other = (Comment) o; - return this.id.equals(other.id) && this.text.equals(other.text); + return this.id.equals(other.id) && this.text.equals(other.text) && this.score.equals(other.score); } else { return false; } @@ -61,6 +72,6 @@ public class Comment { @Override public String toString() { - return id + ": " + text; + return id + ": " + text + "(score: " + score + ")"; } } \ No newline at end of file diff --git a/querydsl-core/src/test/java/com/querydsl/core/group/GroupByMapTest.java b/querydsl-core/src/test/java/com/querydsl/core/group/GroupByMapTest.java index 8615ab242..512aeb362 100644 --- a/querydsl-core/src/test/java/com/querydsl/core/group/GroupByMapTest.java +++ b/querydsl-core/src/test/java/com/querydsl/core/group/GroupByMapTest.java @@ -30,6 +30,9 @@ import com.querydsl.core.types.dsl.NumberPath; import com.querydsl.core.types.dsl.StringExpression; import com.querydsl.core.types.dsl.StringPath; +import java.math.MathContext; +import java.math.RoundingMode; + public class GroupByMapTest extends AbstractGroupByTest { @Test @@ -330,4 +333,22 @@ public class GroupByMapTest extends AbstractGroupByTest { assertNotNull(resultTransformer); } + @Test + public void average_with_default_math_context() { + Map results = POSTS_W_COMMENTS_SCORE + .transform(groupBy(postId).as(avg(score))); + assertEquals(1.5, results.get(null), 0.0); + assertEquals(((1.5 + 2.0 + 0.5) / 3), results.get(1), 0.0); + assertEquals(((1.0 + 2.0) / 2), results.get(2), 0.0); + } + + @Test + public void average_with_user_provided_math_context() { + MathContext oneDigitMathContext = new MathContext(2, RoundingMode.HALF_EVEN); + Map results = POSTS_W_COMMENTS_SCORE + .transform(groupBy(postId).as(avg(score, oneDigitMathContext))); + assertEquals(1.5, results.get(null), 0.0); + assertEquals(1.3, results.get(1), 0.0); + assertEquals(1.5, results.get(2), 0.0); + } }