Remodeled GroupBy2

This commit is contained in:
Samppa Saarela 2011-09-16 12:16:21 +03:00
parent af3a2a8385
commit 8a58c22f0d
2 changed files with 229 additions and 271 deletions

View File

@ -6,18 +6,17 @@
package com.mysema.query.support;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import com.mysema.commons.lang.CloseableIterator;
import com.mysema.query.Projectable;
import com.mysema.query.ResultTransformer;
import com.mysema.query.Tuple;
import com.mysema.query.support.GroupBy2.Group2;
import com.mysema.query.types.Expression;
import com.mysema.query.types.Visitor;
/**
* Groups results by the first expression.
@ -29,228 +28,243 @@ import com.mysema.query.types.Visitor;
*
* @author sasa
*/
//@SuppressWarnings("unchecked")
public class GroupBy2 implements ResultTransformer<Collection<Tuple>> {
private final Expression<?>[] expressions;
@SuppressWarnings("unchecked")
public class GroupBy2<S> implements ResultTransformer<Map<S, Group2>> {
private static interface GroupFactory<T, R> {
public static interface GroupColumnDefinition<T, R> {
public void add(Object o);
public R get();
Expression<T> getExpression();
GroupColumn<R> createGroupColumn();
}
/**
* NOTE: This expression only applies to GroupBy.transform
*
* @param <T>
* @param <R>
*/
public static abstract class GroupAsExpression<T, R> implements Expression<R> {
private static final long serialVersionUID = -8164758792405567077L;
public static interface GroupColumn<R> {
void add(Object o);
R get();
}
public static interface Group2 {
Object[] toArray();
<T> T first(Expression<T> expr);
<T> Set<T> set(Expression<T> expr);
<T> List<T> list(Expression<T> expr);
}
public static abstract class AbstractGroupColumnDefinition<T, R> implements GroupColumnDefinition<T, R> {
private final Expression<T> expr;
public GroupAsExpression(Expression<T> expr) {
public AbstractGroupColumnDefinition(Expression<T> expr) {
this.expr = expr;
}
@Override
public <S, C> S accept(Visitor<S, C> v, C context) {
throw new UnsupportedOperationException();
public Expression<T> getExpression() {
return expr;
}
}
public static class GSet<T> extends AbstractGroupColumnDefinition<T, Set<T>>{
public GSet(Expression<T> expr) {
super(expr);
}
@Override
public Class<? extends R> getType() {
throw new UnsupportedOperationException();
public GroupColumn<Set<T>> createGroupColumn() {
return new GroupColumn<Set<T>>() {
private final Set<T> set = new LinkedHashSet<T>();
@Override
public void add(Object o) {
set.add((T) o);
}
@Override
public Set<T> get() {
return set;
}
};
}
public int hashCode() {
return expr.hashCode();
}
public boolean equals(Object o) {
if (o == this) {
return true;
} else if (o instanceof GroupAsExpression) {
GroupAsExpression<?, ?> other = (GroupAsExpression<?, ?>) o;
return this.expr.equals(other.expr);
} else {
return false;
}
}
public abstract GroupFactory<T, R> createGroupFactory();
}
public static <T> GroupAsExpression<T, Set<T>> set(Expression<T> expr) {
return new GroupAsExpression<T, Set<T>>(expr) {
private static final long serialVersionUID = -2507144565843468159L;
@Override
public GroupFactory<T, Set<T>> createGroupFactory() {
return new GroupFactory<T, Set<T>>() {
private final Set<T> set = new HashSet<T>();
@Override
public void add(Object o) {
set.add((T) o);
}
@Override
public Set<T> get() {
return set;
}
};
}
};
}
public static <T> GroupAsExpression<T, List<T>> list(Expression<T> expr) {
return new GroupAsExpression<T, List<T>>(expr) {
private static final long serialVersionUID = -6941324182786049824L;
@Override
public GroupFactory<T, List<T>> createGroupFactory() {
return new GroupFactory<T, List<T>>() {
private final List<T> list = new ArrayList<T>();
@Override
public void add(Object o) {
list.add((T) o);
}
@Override
public List<T> get() {
return list;
}
};
}
};
}
private static class ValueGroupFactory<T> implements GroupFactory<T, T> {
private T val;
public static class GList<T> extends AbstractGroupColumnDefinition<T, List<T>>{
private boolean first = true;
public GList(Expression<T> expr) {
super(expr);
}
@Override
public GroupColumn<List<T>> createGroupColumn() {
return new GroupColumn<List<T>>() {
private final List<T> list = new ArrayList<T>();
@Override
public void add(Object o) {
list.add((T) o);
}
@Override
public List<T> get() {
return list;
}
};
}
}
public static class GFirst<T> extends AbstractGroupColumnDefinition<T, T>{
public GFirst(Expression<T> expr) {
super(expr);
}
@Override
public GroupColumn<T> createGroupColumn() {
return new GroupColumn<T>() {
private T val;
private boolean first = true;
@Override
public void add(Object o) {
if (first) {
val = (T) o;
first = false;
}
}
@Override
public T get() {
return val;
}
};
}
}
private final List<GroupColumnDefinition<?, ?>> columns = new ArrayList<GroupBy2.GroupColumnDefinition<?,?>>();
public static <T> GroupBy2<T> groupBy(Expression<T> expr) {
return new GroupBy2<T>(expr);
}
public GroupBy2(Expression<S> groupBy) {
columns.add(new GFirst<S>(groupBy));
}
public <T> GroupBy2(Expression<S> groupBy, GroupColumnDefinition<?, ?> group, GroupColumnDefinition<?, ?>... groups) {
this(groupBy);
columns.add(group);
for (GroupColumnDefinition<?, ?> g : groups) {
columns.add(g);
}
}
public GroupBy2<S> group(GroupColumnDefinition<?, ?> g) {
columns.add(g);
return this;
}
public <T> GroupBy2<S> set(Expression<T> expr) {
columns.add(new GSet<T>(expr));
return this;
}
public <T> GroupBy2<S> list(Expression<T> expr) {
columns.add(new GList<T>(expr));
return this;
}
public <T> GroupBy2<S> first(Expression<T> expr) {
columns.add(new GFirst<T>(expr));
return this;
}
private class GroupImpl implements Group2 {
private final Map<Expression<?>, GroupColumn<?>> groupColumns;
public GroupImpl() {
groupColumns = new LinkedHashMap<Expression<?>, GroupColumn<?>>();
for (int i=0; i < columns.size(); i++) {
GroupColumnDefinition<?, ?> coldef = columns.get(i);
groupColumns.put(coldef.getExpression(), coldef.createGroupColumn());
}
}
@Override
public void add(Object o) {
if (first) {
val = (T) o;
first = false;
public <T> T first(Expression<T> expr) {
return (T) groupColumns.get(expr).get();
}
@Override
public <T> Set<T> set(Expression<T> expr) {
return (Set<T>) groupColumns.get(expr).get();
}
@Override
public <T> List<T> list(Expression<T> expr) {
return (List<T>) groupColumns.get(expr).get();
}
public void add(Object[] row) {
int i=0;
for (GroupColumn<?> groupColumn : groupColumns.values()) {
groupColumn.add(row[i]);
i++;
}
}
@Override
public T get() {
return val;
public Object[] toArray() {
List<Object> arr = new ArrayList<Object>(groupColumns.size());
for (GroupColumn<?> col : groupColumns.values()) {
arr.add(col.get());
}
return arr.toArray();
}
}
public GroupBy2(Expression<?> groupBy, Expression<?>... args) {
expressions = new Expression<?>[args.length + 1];
expressions[0] = groupBy;
System.arraycopy(args, 0, expressions, 1, args.length);
}
@Override
public Collection<Tuple> transform(Projectable projectable) {
final LinkedHashMap<Object, Tuple> groups = new LinkedHashMap<Object, Tuple>();
public Map<S, Group2> transform(Projectable projectable) {
final Map<S, Group2> groups = new LinkedHashMap<S, Group2>();
CloseableIterator<Object[]> iter = projectable.iterate(unwrap(expressions));
CloseableIterator<Object[]> iter = projectable.iterate(unwrapExpressions());
try {
while (iter.hasNext()) {
Object[] row = iter.next();
Object groupBy = row[0];
// groups.values() should return Collection<GTuple> instead of Collection<? extends GTuple>
GroupTuple group = (GroupTuple) groups.get(groupBy);
S groupId = (S) row[0];
GroupImpl group = (GroupImpl) groups.get(groupId);
if (group == null) {
group = new GroupTuple();
groups.put(groupBy, group);
group = new GroupImpl();
groups.put(groupId, group);
}
group.add(row);
}
} finally {
iter.close();
}
return groups.values();
return groups;
}
private static Expression<?>[] unwrap(Expression<?>... expressions) {
Expression<?>[] unwrapped = new Expression<?>[expressions.length];
for (int i=0; i < expressions.length; i++) {
if (expressions[i] instanceof GroupAsExpression) {
unwrapped[i] = ((GroupAsExpression<?, ?>) expressions[i]).expr;
} else {
unwrapped[i] = expressions[i];
}
private Expression<?>[] unwrapExpressions() {
Expression<?>[] unwrapped = new Expression<?>[columns.size()];
for (int i=0; i < columns.size(); i++) {
unwrapped[i] = columns.get(i).getExpression();
}
return unwrapped;
}
private int indexOf(Expression<?> expr) {
for (int i=0; i < expressions.length; i++) {
if (expressions[i].equals(expr)) {
return i;
}
}
return -1;
}
private class GroupTuple implements Tuple {
private final GroupFactory<?, ?>[] groupFactories;
public GroupTuple() {
groupFactories = new GroupFactory<?, ?>[expressions.length];
for (int i=0; i < expressions.length; i++) {
if (expressions[i] instanceof GroupAsExpression<?, ?>) {
groupFactories[i] = ((GroupAsExpression<?, ?>) expressions[i]).createGroupFactory();
} else {
groupFactories[i] = new ValueGroupFactory<Object>();
}
}
}
@Override
public <T> T get(int index, Class<T> type) {
return (T) groupFactories[index].get();
}
@Override
public <T> T get(Expression<T> expr) {
int index = indexOf(expr);
return (T) groupFactories[index].get();
}
@Override
public Object[] toArray() {
Object[] row = new Object[groupFactories.length];
for (int i=0; i < groupFactories.length; i++) {
row[i] = groupFactories[i].get();
}
return row;
}
public void add(Object[] row) {
for (int i=0; i < groupFactories.length; i++) {
groupFactories[i].add(row[i]);
}
}
}
}

View File

@ -1,24 +1,19 @@
package com.mysema.query.support;
import static com.mysema.query.support.GroupBy2.list;
import static com.mysema.query.support.GroupBy2.set;
import static junit.framework.Assert.assertEquals;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.junit.Test;
import com.mysema.commons.lang.CloseableIterator;
import com.mysema.commons.lang.EmptyCloseableIterator;
import com.mysema.commons.lang.IteratorAdapter;
import com.mysema.query.Projectable;
import com.mysema.query.Tuple;
import com.mysema.query.support.GroupBy2.Group2;
import com.mysema.query.types.Expression;
import com.mysema.query.types.expr.NumberExpression;
import com.mysema.query.types.expr.StringExpression;
@ -33,6 +28,8 @@ public class GroupBy2Test {
private final NumberExpression<Integer> commentId = new NumberPath<Integer>(Integer.class, "commentId");
private final StringExpression commentText = new StringPath("commentText");
static class PostWithComments {
public Integer id;
public String name;
@ -44,95 +41,38 @@ public class GroupBy2Test {
}
}
@Test
public void Expression_Order_And_Type() {
new GroupBy2(postId, postName, set(commentId)).transform(new AbstractProjectable(){
public CloseableIterator<Object[]> iterate(Expression<?>[] args) {
assertEquals(postId, args[0]);
assertEquals(postName, args[1]);
assertEquals(commentId, args[2]);
return new EmptyCloseableIterator<Object[]>();
}
});
}
/**
* <ol>
* <li>Order of groups by first row of a group
* <li>Rows belonging to a group may appear in any order
* <li>Group of null is handled correctly
* </ol>
*/
@Test
public void Multiple_Groups() {
Collection<Tuple> results = new GroupBy2(postId, postName, list(commentId)).transform(
projectable(
row(1, "post 1", 1),
row(2, "post 2", 4),
row(1, "post 1", 2),
row(2, "post 2", 5),
row(3, "post 3", 6),
row(null, "null post", 7),
row(null, "null post", 8),
row(1, "post 1", 3)
)
private static final Projectable BASIC_RESULTS = projectable(
row(1, "post 1", 1, "comment 1"),
row(2, "post 2", 4, "comment 4"),
row(1, "post 1", 2, "comment 2"),
row(2, "post 2", 5, "comment 5"),
row(3, "post 3", 6, "comment 6"),
row(null, "null post", 7, "comment 7"),
row(null, "null post", 8, "comment 8"),
row(1, "post 1", 3, "comment 3")
);
assertEquals(4, results.size());
Iterator<Tuple> iter = results.iterator();
Tuple g = iter.next();
assertEquals(toInt(1), g.get(postId));
assertEquals("post 1", g.get(postName));
List<Integer> comments = g.get(list(commentId));
assertEquals(toInt(1), comments.get(0));
assertEquals(toInt(2), comments.get(1));
assertEquals(toInt(3), comments.get(2));
g = iter.next();
assertEquals(toInt(2), g.get(postId));
assertEquals("post 2", g.get(postName));
comments = g.get(list(commentId));
assertEquals(toInt(4), comments.get(0));
assertEquals(toInt(5), comments.get(1));
g = iter.next();
assertEquals(toInt(3), g.get(postId));
assertEquals("post 3", g.get(postName));
comments = g.get(list(commentId));
assertEquals(toInt(6), comments.get(0));
// Group by null value
g = iter.next();
assertEquals(null, g.get(postId));
assertEquals("null post", g.get(postName));
comments = g.get(list(commentId));
assertEquals(toInt(7), comments.get(0));
assertEquals(toInt(8), comments.get(1));
}
private static <T> Set<T> toSet(T... o) {
return new HashSet<T>(Arrays.asList(o));
@Test
public void Group_Order() {
Map<Integer, Group2> results =
GroupBy2.groupBy(postId).first(postName).set(commentId).transform(BASIC_RESULTS);
assertEquals(4, results.size());
}
@Test
public void Group_As_Set() {
Collection<Tuple> results = new GroupBy2(postId, postName, set(commentId)).transform(projectable(
row(1, "post 1", 1),
row(null, "null post", 2)
));
assertEquals(2, results.size());
Iterator<Tuple> iter = results.iterator();
Tuple group = iter.next();
assertEquals(toInt(1), group.get(postId));
assertEquals("post 1", group.get(postName));
assertEquals(toSet(1), group.get(set(commentId)));
group = iter.next();
assertEquals(null, group.get(postId));
public void First_Set_And_List() {
Map<Integer, Group2> results =
GroupBy2.groupBy(postId).first(postName).set(commentId).list(commentText).transform(BASIC_RESULTS);
Group2 group = results.get(1);
assertEquals(toInt(1), group.first(postId));
assertEquals("post 1", group.first(postName));
assertEquals(toSet(1, 2, 3), group.set(commentId));
assertEquals(Arrays.asList("comment 1", "comment 2", "comment 3"), group.list(commentText));
}
private Projectable projectable(final Object[]... rows) {
private static Projectable projectable(final Object[]... rows) {
return new AbstractProjectable(){
public CloseableIterator<Object[]> iterate(Expression<?>[] args) {
return iterator(rows);
@ -144,6 +84,10 @@ public class GroupBy2Test {
return Integer.valueOf(i);
}
private <T >Set<T> toSet(T... s) {
return new HashSet<T>(Arrays.asList(s));
}
private static Object[] row(Object... row) {
return row;
}