diff --git a/querydsl-core/src/test/java/com/mysema/query/support/GroupBy2.java b/querydsl-core/src/test/java/com/mysema/query/support/GroupBy2.java index 0861f5891..0d99bc941 100644 --- a/querydsl-core/src/test/java/com/mysema/query/support/GroupBy2.java +++ b/querydsl-core/src/test/java/com/mysema/query/support/GroupBy2.java @@ -6,18 +6,17 @@ package com.mysema.query.support; import java.util.ArrayList; -import java.util.Collection; -import java.util.HashSet; import java.util.LinkedHashMap; +import java.util.LinkedHashSet; import java.util.List; +import java.util.Map; import java.util.Set; import com.mysema.commons.lang.CloseableIterator; import com.mysema.query.Projectable; import com.mysema.query.ResultTransformer; -import com.mysema.query.Tuple; +import com.mysema.query.support.GroupBy2.Group2; import com.mysema.query.types.Expression; -import com.mysema.query.types.Visitor; /** * Groups results by the first expression. @@ -29,228 +28,243 @@ import com.mysema.query.types.Visitor; * * @author sasa */ -//@SuppressWarnings("unchecked") -public class GroupBy2 implements ResultTransformer> { - - private final Expression[] expressions; +@SuppressWarnings("unchecked") +public class GroupBy2 implements ResultTransformer> { - private static interface GroupFactory { + public static interface GroupColumnDefinition { - public void add(Object o); - - public R get(); + Expression getExpression(); + + GroupColumn createGroupColumn(); } - /** - * NOTE: This expression only applies to GroupBy.transform - * - * @param - * @param - */ - public static abstract class GroupAsExpression implements Expression { - - private static final long serialVersionUID = -8164758792405567077L; + public static interface GroupColumn { + + void add(Object o); + R get(); + + } + + public static interface Group2 { + Object[] toArray(); + T first(Expression expr); + Set set(Expression expr); + List list(Expression expr); + } + + public static abstract class AbstractGroupColumnDefinition implements GroupColumnDefinition { + private final Expression expr; - public GroupAsExpression(Expression expr) { + public AbstractGroupColumnDefinition(Expression expr) { this.expr = expr; } @Override - public S accept(Visitor v, C context) { - throw new UnsupportedOperationException(); + public Expression getExpression() { + return expr; + } + } + + public static class GSet extends AbstractGroupColumnDefinition>{ + + public GSet(Expression expr) { + super(expr); } @Override - public Class getType() { - throw new UnsupportedOperationException(); + public GroupColumn> createGroupColumn() { + return new GroupColumn>() { + + private final Set set = new LinkedHashSet(); + + @Override + public void add(Object o) { + set.add((T) o); + } + + @Override + public Set get() { + return set; + } + + }; } - - public int hashCode() { - return expr.hashCode(); - } - - public boolean equals(Object o) { - if (o == this) { - return true; - } else if (o instanceof GroupAsExpression) { - GroupAsExpression other = (GroupAsExpression) o; - return this.expr.equals(other.expr); - } else { - return false; - } - } - - public abstract GroupFactory createGroupFactory(); - } - public static GroupAsExpression> set(Expression expr) { - return new GroupAsExpression>(expr) { - - private static final long serialVersionUID = -2507144565843468159L; - - @Override - public GroupFactory> createGroupFactory() { - return new GroupFactory>() { - - private final Set set = new HashSet(); - - @Override - public void add(Object o) { - set.add((T) o); - } - - @Override - public Set get() { - return set; - } - - }; - } - - }; - } - - public static GroupAsExpression> list(Expression expr) { - return new GroupAsExpression>(expr) { - - private static final long serialVersionUID = -6941324182786049824L; - - @Override - public GroupFactory> createGroupFactory() { - return new GroupFactory>() { - - private final List list = new ArrayList(); - - @Override - public void add(Object o) { - list.add((T) o); - } - - @Override - public List get() { - return list; - } - - }; - } - - }; - } - - private static class ValueGroupFactory implements GroupFactory { - private T val; + public static class GList extends AbstractGroupColumnDefinition>{ - private boolean first = true; + public GList(Expression expr) { + super(expr); + } + + @Override + public GroupColumn> createGroupColumn() { + return new GroupColumn>() { + + private final List list = new ArrayList(); + + @Override + public void add(Object o) { + list.add((T) o); + } + + @Override + public List get() { + return list; + } + + }; + } + } + + + public static class GFirst extends AbstractGroupColumnDefinition{ + + public GFirst(Expression expr) { + super(expr); + } + + @Override + public GroupColumn createGroupColumn() { + return new GroupColumn() { + private T val; + + private boolean first = true; + + @Override + public void add(Object o) { + if (first) { + val = (T) o; + first = false; + } + } + + @Override + public T get() { + return val; + } + }; + } + } + + private final List> columns = new ArrayList>(); + + public static GroupBy2 groupBy(Expression expr) { + return new GroupBy2(expr); + } + + public GroupBy2(Expression groupBy) { + columns.add(new GFirst(groupBy)); + } + + public GroupBy2(Expression groupBy, GroupColumnDefinition group, GroupColumnDefinition... groups) { + this(groupBy); + columns.add(group); + for (GroupColumnDefinition g : groups) { + columns.add(g); + } + } + + public GroupBy2 group(GroupColumnDefinition g) { + columns.add(g); + return this; + } + + public GroupBy2 set(Expression expr) { + columns.add(new GSet(expr)); + return this; + } + + public GroupBy2 list(Expression expr) { + columns.add(new GList(expr)); + return this; + } + + public GroupBy2 first(Expression expr) { + columns.add(new GFirst(expr)); + return this; + } + + + + private class GroupImpl implements Group2 { + + private final Map, GroupColumn> groupColumns; + + public GroupImpl() { + groupColumns = new LinkedHashMap, GroupColumn>(); + for (int i=0; i < columns.size(); i++) { + GroupColumnDefinition coldef = columns.get(i); + groupColumns.put(coldef.getExpression(), coldef.createGroupColumn()); + } + } @Override - public void add(Object o) { - if (first) { - val = (T) o; - first = false; + public T first(Expression expr) { + return (T) groupColumns.get(expr).get(); + } + + @Override + public Set set(Expression expr) { + return (Set) groupColumns.get(expr).get(); + } + + @Override + public List list(Expression expr) { + return (List) groupColumns.get(expr).get(); + } + + public void add(Object[] row) { + int i=0; + for (GroupColumn groupColumn : groupColumns.values()) { + groupColumn.add(row[i]); + i++; } } @Override - public T get() { - return val; + public Object[] toArray() { + List arr = new ArrayList(groupColumns.size()); + for (GroupColumn col : groupColumns.values()) { + arr.add(col.get()); + } + return arr.toArray(); } - - } - - public GroupBy2(Expression groupBy, Expression... args) { - expressions = new Expression[args.length + 1]; - expressions[0] = groupBy; - System.arraycopy(args, 0, expressions, 1, args.length); + + } @Override - public Collection transform(Projectable projectable) { - final LinkedHashMap groups = new LinkedHashMap(); + public Map transform(Projectable projectable) { + final Map groups = new LinkedHashMap(); - CloseableIterator iter = projectable.iterate(unwrap(expressions)); + CloseableIterator iter = projectable.iterate(unwrapExpressions()); try { while (iter.hasNext()) { Object[] row = iter.next(); - Object groupBy = row[0]; - // groups.values() should return Collection instead of Collection - GroupTuple group = (GroupTuple) groups.get(groupBy); + S groupId = (S) row[0]; + GroupImpl group = (GroupImpl) groups.get(groupId); if (group == null) { - group = new GroupTuple(); - groups.put(groupBy, group); + group = new GroupImpl(); + groups.put(groupId, group); } group.add(row); } } finally { iter.close(); } - return groups.values(); + return groups; } - private static Expression[] unwrap(Expression... expressions) { - Expression[] unwrapped = new Expression[expressions.length]; - for (int i=0; i < expressions.length; i++) { - if (expressions[i] instanceof GroupAsExpression) { - unwrapped[i] = ((GroupAsExpression) expressions[i]).expr; - } else { - unwrapped[i] = expressions[i]; - } + private Expression[] unwrapExpressions() { + Expression[] unwrapped = new Expression[columns.size()]; + for (int i=0; i < columns.size(); i++) { + unwrapped[i] = columns.get(i).getExpression(); } return unwrapped; } - private int indexOf(Expression expr) { - for (int i=0; i < expressions.length; i++) { - if (expressions[i].equals(expr)) { - return i; - } - } - return -1; - } - - private class GroupTuple implements Tuple { - - private final GroupFactory[] groupFactories; - - public GroupTuple() { - groupFactories = new GroupFactory[expressions.length]; - for (int i=0; i < expressions.length; i++) { - if (expressions[i] instanceof GroupAsExpression) { - groupFactories[i] = ((GroupAsExpression) expressions[i]).createGroupFactory(); - } else { - groupFactories[i] = new ValueGroupFactory(); - } - } - } - - @Override - public T get(int index, Class type) { - return (T) groupFactories[index].get(); - } - - @Override - public T get(Expression expr) { - int index = indexOf(expr); - return (T) groupFactories[index].get(); - } - - @Override - public Object[] toArray() { - Object[] row = new Object[groupFactories.length]; - for (int i=0; i < groupFactories.length; i++) { - row[i] = groupFactories[i].get(); - } - return row; - } - - public void add(Object[] row) { - for (int i=0; i < groupFactories.length; i++) { - groupFactories[i].add(row[i]); - } - } - } - } diff --git a/querydsl-core/src/test/java/com/mysema/query/support/GroupBy2Test.java b/querydsl-core/src/test/java/com/mysema/query/support/GroupBy2Test.java index b8271a4e7..72c150b87 100644 --- a/querydsl-core/src/test/java/com/mysema/query/support/GroupBy2Test.java +++ b/querydsl-core/src/test/java/com/mysema/query/support/GroupBy2Test.java @@ -1,24 +1,19 @@ package com.mysema.query.support; -import static com.mysema.query.support.GroupBy2.list; -import static com.mysema.query.support.GroupBy2.set; import static junit.framework.Assert.assertEquals; import java.util.Arrays; -import java.util.Collection; import java.util.HashSet; -import java.util.Iterator; -import java.util.List; +import java.util.Map; import java.util.Set; import org.junit.Test; import com.mysema.commons.lang.CloseableIterator; -import com.mysema.commons.lang.EmptyCloseableIterator; import com.mysema.commons.lang.IteratorAdapter; import com.mysema.query.Projectable; -import com.mysema.query.Tuple; +import com.mysema.query.support.GroupBy2.Group2; import com.mysema.query.types.Expression; import com.mysema.query.types.expr.NumberExpression; import com.mysema.query.types.expr.StringExpression; @@ -33,6 +28,8 @@ public class GroupBy2Test { private final NumberExpression commentId = new NumberPath(Integer.class, "commentId"); + private final StringExpression commentText = new StringPath("commentText"); + static class PostWithComments { public Integer id; public String name; @@ -44,95 +41,38 @@ public class GroupBy2Test { } } - @Test - public void Expression_Order_And_Type() { - new GroupBy2(postId, postName, set(commentId)).transform(new AbstractProjectable(){ - public CloseableIterator iterate(Expression[] args) { - assertEquals(postId, args[0]); - assertEquals(postName, args[1]); - assertEquals(commentId, args[2]); - return new EmptyCloseableIterator(); - } - }); - } - - /** - *
    - *
  1. Order of groups by first row of a group - *
  2. Rows belonging to a group may appear in any order - *
  3. Group of null is handled correctly - *
- */ - @Test - public void Multiple_Groups() { - Collection results = new GroupBy2(postId, postName, list(commentId)).transform( - projectable( - row(1, "post 1", 1), - row(2, "post 2", 4), - row(1, "post 1", 2), - row(2, "post 2", 5), - row(3, "post 3", 6), - row(null, "null post", 7), - row(null, "null post", 8), - row(1, "post 1", 3) - ) + private static final Projectable BASIC_RESULTS = projectable( + row(1, "post 1", 1, "comment 1"), + row(2, "post 2", 4, "comment 4"), + row(1, "post 1", 2, "comment 2"), + row(2, "post 2", 5, "comment 5"), + row(3, "post 3", 6, "comment 6"), + row(null, "null post", 7, "comment 7"), + row(null, "null post", 8, "comment 8"), + row(1, "post 1", 3, "comment 3") ); - assertEquals(4, results.size()); - Iterator iter = results.iterator(); - - Tuple g = iter.next(); - assertEquals(toInt(1), g.get(postId)); - assertEquals("post 1", g.get(postName)); - List comments = g.get(list(commentId)); - assertEquals(toInt(1), comments.get(0)); - assertEquals(toInt(2), comments.get(1)); - assertEquals(toInt(3), comments.get(2)); - - g = iter.next(); - assertEquals(toInt(2), g.get(postId)); - assertEquals("post 2", g.get(postName)); - comments = g.get(list(commentId)); - assertEquals(toInt(4), comments.get(0)); - assertEquals(toInt(5), comments.get(1)); - - g = iter.next(); - assertEquals(toInt(3), g.get(postId)); - assertEquals("post 3", g.get(postName)); - comments = g.get(list(commentId)); - assertEquals(toInt(6), comments.get(0)); - - // Group by null value - g = iter.next(); - assertEquals(null, g.get(postId)); - assertEquals("null post", g.get(postName)); - comments = g.get(list(commentId)); - assertEquals(toInt(7), comments.get(0)); - assertEquals(toInt(8), comments.get(1)); - } - private static Set toSet(T... o) { - return new HashSet(Arrays.asList(o)); + @Test + public void Group_Order() { + Map results = + GroupBy2.groupBy(postId).first(postName).set(commentId).transform(BASIC_RESULTS); + + assertEquals(4, results.size()); } @Test - public void Group_As_Set() { - Collection results = new GroupBy2(postId, postName, set(commentId)).transform(projectable( - row(1, "post 1", 1), - row(null, "null post", 2) - )); - assertEquals(2, results.size()); - Iterator iter = results.iterator(); - - Tuple group = iter.next(); - assertEquals(toInt(1), group.get(postId)); - assertEquals("post 1", group.get(postName)); - assertEquals(toSet(1), group.get(set(commentId))); - - group = iter.next(); - assertEquals(null, group.get(postId)); + public void First_Set_And_List() { + Map results = + GroupBy2.groupBy(postId).first(postName).set(commentId).list(commentText).transform(BASIC_RESULTS); + + Group2 group = results.get(1); + assertEquals(toInt(1), group.first(postId)); + assertEquals("post 1", group.first(postName)); + assertEquals(toSet(1, 2, 3), group.set(commentId)); + assertEquals(Arrays.asList("comment 1", "comment 2", "comment 3"), group.list(commentText)); } - private Projectable projectable(final Object[]... rows) { + private static Projectable projectable(final Object[]... rows) { return new AbstractProjectable(){ public CloseableIterator iterate(Expression[] args) { return iterator(rows); @@ -144,6 +84,10 @@ public class GroupBy2Test { return Integer.valueOf(i); } + private Set toSet(T... s) { + return new HashSet(Arrays.asList(s)); + } + private static Object[] row(Object... row) { return row; }