diff --git a/querydsl-core/src/main/java/com/mysema/query/types/Expr.java b/querydsl-core/src/main/java/com/mysema/query/types/Expr.java index 4e714eaf1..73a16bacc 100644 --- a/querydsl-core/src/main/java/com/mysema/query/types/Expr.java +++ b/querydsl-core/src/main/java/com/mysema/query/types/Expr.java @@ -159,7 +159,7 @@ public abstract class Expr implements Serializable{ * @param right rhs of the comparison * @return */ - public final EBoolean notIn(D... right) { + public EBoolean notIn(D... right) { if (right.length == 1){ return ne(right[0]); }else{ diff --git a/querydsl-core/src/main/java/com/mysema/query/types/expr/ENumber.java b/querydsl-core/src/main/java/com/mysema/query/types/expr/ENumber.java index 6b724418a..80b99f08b 100644 --- a/querydsl-core/src/main/java/com/mysema/query/types/expr/ENumber.java +++ b/querydsl-core/src/main/java/com/mysema/query/types/expr/ENumber.java @@ -7,6 +7,8 @@ package com.mysema.query.types.expr; import java.math.BigDecimal; import java.math.BigInteger; +import java.util.ArrayList; +import java.util.List; import javax.annotation.Nullable; @@ -14,6 +16,7 @@ import com.mysema.query.types.Expr; import com.mysema.query.types.Operator; import com.mysema.query.types.Ops; import com.mysema.query.types.Ops.MathOps; +import com.mysema.util.MathUtils; /** * ENumber represents a numeric expression @@ -197,7 +200,6 @@ public abstract class ENumber> extends ECompara return ONumber.create(Double.class, Ops.DIV, this, ENumberConst.create(right)); } - /** * Get the double expression of this numeric expression * @@ -517,4 +519,22 @@ public abstract class ENumber> extends ECompara return sum; } + @Override + public EBoolean in(Number... numbers){ + return super.in(convert(numbers)); + } + + @Override + public EBoolean notIn(Number... numbers){ + return super.notIn(convert(numbers)); + } + + private List convert(Number... numbers){ + List list = new ArrayList(numbers.length); + for (int i = 0; i < numbers.length; i++){ + list.add(MathUtils.cast(numbers[i], getType())); + } + return list; + } + } \ No newline at end of file diff --git a/querydsl-core/src/main/java/com/mysema/util/MathUtils.java b/querydsl-core/src/main/java/com/mysema/util/MathUtils.java index 83c2aacd2..300cc7d26 100644 --- a/querydsl-core/src/main/java/com/mysema/util/MathUtils.java +++ b/querydsl-core/src/main/java/com/mysema/util/MathUtils.java @@ -25,9 +25,9 @@ public final class MathUtils { } @SuppressWarnings("unchecked") - private static > D cast(BigDecimal num, Class type){ + public static > D cast(Number num, Class type){ Number rv; - if (type.equals(Double.class)){ + if (type.equals(Byte.class)){ rv = num.byteValue(); }else if (type.equals(Double.class)){ rv = num.doubleValue(); @@ -40,9 +40,20 @@ public final class MathUtils { }else if (type.equals(Short.class)){ rv = num.shortValue(); }else if (type.equals(BigDecimal.class)){ + if (num instanceof BigDecimal){ + rv = num; + }else{ + rv = new BigDecimal(num.toString()); + } rv = num; }else if (type.equals(BigInteger.class)){ - rv = num.toBigInteger(); + if (num instanceof BigInteger){ + rv = num; + }else if (num instanceof BigDecimal){ + rv = ((BigDecimal)num).toBigInteger(); + }else{ + rv = new BigInteger(num.toString()); + } }else{ throw new IllegalArgumentException(String.format("Illegal type : %s", type.getSimpleName())); } diff --git a/querydsl-core/src/test/java/com/mysema/query/Filters.java b/querydsl-core/src/test/java/com/mysema/query/Filters.java index 4ae02836a..c05cd6e2d 100644 --- a/querydsl-core/src/test/java/com/mysema/query/Filters.java +++ b/querydsl-core/src/test/java/com/mysema/query/Filters.java @@ -175,6 +175,9 @@ public class Filters { rv.add(expr.loe(knownValue)); rv.add(expr.lt(other)); rv.add(expr.lt(knownValue)); + + rv.add(expr.in(1,2,3)); + rv.add(expr.in(1l,2l,3l)); if (expr.getType().equals(Integer.class)){ ENumber eint = (ENumber)expr; diff --git a/querydsl-core/src/test/java/com/mysema/query/types/path/PNumberTest.java b/querydsl-core/src/test/java/com/mysema/query/types/path/PNumberTest.java new file mode 100644 index 000000000..e891d0a1b --- /dev/null +++ b/querydsl-core/src/test/java/com/mysema/query/types/path/PNumberTest.java @@ -0,0 +1,39 @@ +package com.mysema.query.types.path; + +import static org.junit.Assert.assertEquals; + +import java.util.List; + +import org.junit.Test; + +import com.mysema.query.types.Constant; +import com.mysema.query.types.Operation; + +public class PNumberTest { + + private PNumber bytePath = new PNumber(Byte.class, "bytePath"); + + @SuppressWarnings("unchecked") + @Test + public void bytePath_in(){ + Operation operation = (Operation) bytePath.in(1, 2, 3); + + List numbers = (List) ((Constant)operation.getArg(1)).getConstant(); + assertEquals(Byte.valueOf((byte)1), numbers.get(0)); + assertEquals(Byte.valueOf((byte)2), numbers.get(1)); + assertEquals(Byte.valueOf((byte)3), numbers.get(2)); + } + + @SuppressWarnings("unchecked") + @Test + public void bytePath_notIn(){ + Operation operation = (Operation) bytePath.notIn(1, 2, 3); + // unwrap negation + operation = (Operation) operation.getArg(0); + + List numbers = (List) ((Constant)operation.getArg(1)).getConstant(); + assertEquals(Byte.valueOf((byte)1), numbers.get(0)); + assertEquals(Byte.valueOf((byte)2), numbers.get(1)); + assertEquals(Byte.valueOf((byte)3), numbers.get(2)); + } +}