/* * Copyright (c) 2010 Mysema Ltd. * All rights reserved. * */ package com.mysema.query.sql; import java.lang.reflect.Field; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Modifier; import java.sql.Connection; import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.SQLException; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import javax.annotation.Nullable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import com.mysema.commons.lang.Assert; import com.mysema.commons.lang.CloseableIterator; import com.mysema.commons.lang.IteratorAdapter; import com.mysema.query.DefaultQueryMetadata; import com.mysema.query.JoinExpression; import com.mysema.query.JoinFlag; import com.mysema.query.QueryException; import com.mysema.query.QueryFlag; import com.mysema.query.QueryFlag.Position; import com.mysema.query.QueryMetadata; import com.mysema.query.QueryModifiers; import com.mysema.query.SearchResults; import com.mysema.query.support.ProjectableQuery; import com.mysema.query.types.Expression; import com.mysema.query.types.ExpressionUtils; import com.mysema.query.types.FactoryExpression; import com.mysema.query.types.OrderSpecifier; import com.mysema.query.types.ParamExpression; import com.mysema.query.types.ParamNotSetException; import com.mysema.query.types.Path; import com.mysema.query.types.Predicate; import com.mysema.query.types.QBean; import com.mysema.query.types.SubQueryExpression; import com.mysema.query.types.query.ListSubQuery; import com.mysema.query.types.template.NumberTemplate; import com.mysema.query.types.template.SimpleTemplate; import com.mysema.util.ResultSetAdapter; /** * AbstractSQLQuery is the base type for SQL query implementations * * @author tiwe */ public abstract class AbstractSQLQuery> extends ProjectableQuery { public class UnionBuilder implements Union { @Override @SuppressWarnings("unchecked") public List list() { List projection = union[0].getMetadata().getProjection(); if (projection.size() == 1) { return IteratorAdapter.asList(iterateSingle(union[0].getMetadata(), projection.get(0))); } else { return (List) IteratorAdapter.asList(iterateMultiple(union[0].getMetadata())); } } @SuppressWarnings("unchecked") @Override public CloseableIterator iterate() { List projection = union[0].getMetadata().getProjection(); if (projection.size() == 1) { return iterateSingle(union[0].getMetadata(), projection.get(0)); } else { return (CloseableIterator) iterateMultiple(union[0].getMetadata()); } } @Override public UnionBuilder orderBy(OrderSpecifier... o) { AbstractSQLQuery.this.orderBy(o); return this; } @Override public String toString() { return AbstractSQLQuery.this.toString(); } } private static final Logger logger = LoggerFactory.getLogger(AbstractSQLQuery.class); @Nullable private final Connection conn; @Nullable private List constants; @Nullable private List> constantPaths; @Nullable protected SubQueryExpression[] union; private final Configuration configuration; protected final SQLQueryMixin queryMixin; protected boolean unionAll; public AbstractSQLQuery(@Nullable Connection conn, Configuration configuration) { this(conn, configuration, new DefaultQueryMetadata()); } @SuppressWarnings("unchecked") public AbstractSQLQuery(@Nullable Connection conn, Configuration configuration, QueryMetadata metadata) { super(new SQLQueryMixin(metadata)); this.queryMixin = (SQLQueryMixin)super.queryMixin; this.queryMixin.setSelf((Q) this); this.conn = conn; this.configuration = configuration; } /** * Add the given String literal as a join flag to the last added join with the position BEFORE_TARGET * * @param flag * @return */ public Q addJoinFlag(String flag) { return addJoinFlag(flag, JoinFlag.Position.BEFORE_TARGET); } /** * Add the given String literal as a join flag to the last added join * * @param flag * @param position * @return */ @SuppressWarnings("unchecked") public Q addJoinFlag(String flag, JoinFlag.Position position) { List joins = queryMixin.getMetadata().getJoins(); joins.get(joins.size()-1).addFlag(new JoinFlag(flag, position)); return (Q)this; } /** * Add the given prefix and expression as a general query flag * * @param position position of the flag * @param prefix prefix for the flag * @param expr expression of the flag * @return */ public Q addFlag(Position position, String prefix, Expression expr) { Expression flag = SimpleTemplate.create(expr.getType(), prefix + "{0}", expr); return queryMixin.addFlag(new QueryFlag(position, flag)); } /** * Add the given String literal as query flag * * @param position * @param flag * @return */ public Q addFlag(Position position, String flag) { return queryMixin.addFlag(new QueryFlag(position, flag)); } /** * Add the given Expression as a query flag * * @param position * @param flag * @return */ public Q addFlag(Position position, Expression flag) { return queryMixin.addFlag(new QueryFlag(position, flag)); } protected String buildQueryString(boolean forCountRow) { SQLSerializer serializer = createSerializer(); if (union != null) { serializer.serializeUnion(union, queryMixin.getMetadata().getOrderBy(), unionAll); } else { serializer.serialize(queryMixin.getMetadata(), forCountRow); } constants = serializer.getConstants(); constantPaths = serializer.getConstantPaths(); return serializer.toString(); } @Override public long count() { try { return unsafeCount(); } catch (SQLException e) { String error = "Caught " + e.getClass().getName(); logger.error(error, e); throw new QueryException(e.getMessage(), e); } } @Override public boolean exists() { return limit(1).uniqueResult(NumberTemplate.ONE) != null; } protected SQLSerializer createSerializer() { return new SQLSerializer(configuration.getTemplates()); } public Q from(Expression... args) { return queryMixin.from(args); } @SuppressWarnings("unchecked") public Q from(SubQueryExpression subQuery, Path alias) { return queryMixin.from(ExpressionUtils.as((Expression)subQuery, alias)); } public Q fullJoin(RelationalPath target) { return queryMixin.fullJoin(target); } public Q fullJoin(RelationalFunctionCall target, Path alias) { return queryMixin.fullJoin(target, alias); } public Q fullJoin(SubQueryExpression target, Path alias) { return queryMixin.fullJoin(target, alias); } public Q fullJoin(ForeignKey key, RelationalPath entity) { return queryMixin.fullJoin(entity).on(key.on(entity)); } public Q innerJoin(RelationalPath target) { return queryMixin.innerJoin(target); } public Q innerJoin(RelationalFunctionCall target, Path alias) { return queryMixin.innerJoin(target, alias); } public Q innerJoin(SubQueryExpression target, Path alias) { return queryMixin.innerJoin(target, alias); } public Q innerJoin(ForeignKey key, RelationalPath entity) { return queryMixin.innerJoin(entity).on(key.on(entity)); } public Q join(RelationalPath target) { return queryMixin.join(target); } public Q join(RelationalFunctionCall target, Path alias) { return queryMixin.join(target, alias); } public Q join(SubQueryExpression target, Path alias) { return queryMixin.join(target, alias); } public Q join(ForeignKey key, RelationalPath entity) { return queryMixin.join(entity).on(key.on(entity)); } public Q leftJoin(RelationalPath target) { return queryMixin.leftJoin(target); } public Q leftJoin(RelationalFunctionCall target, Path alias) { return queryMixin.leftJoin(target, alias); } public Q leftJoin(SubQueryExpression target, Path alias) { return queryMixin.leftJoin(target, alias); } public Q leftJoin(ForeignKey key, RelationalPath entity) { return queryMixin.leftJoin(entity).on(key.on(entity)); } public Q rightJoin(RelationalPath target) { return queryMixin.rightJoin(target); } public Q rightJoin(RelationalFunctionCall target, Path alias) { return queryMixin.rightJoin(target, alias); } public Q rightJoin(SubQueryExpression target, Path alias) { return queryMixin.rightJoin(target, alias); } public Q rightJoin(ForeignKey key, RelationalPath entity) { return queryMixin.rightJoin(entity).on(key.on(entity)); } @SuppressWarnings("unchecked") @Nullable private T get(ResultSet rs, Expression expr, int i, Class type) throws SQLException { return configuration.get(rs, expr instanceof Path ? (Path)expr : null, i, type); } private int set(PreparedStatement stmt, Path path, int i, Object value) throws SQLException{ return configuration.set(stmt, path, i, value); } public QueryMetadata getMetadata() { return queryMixin.getMetadata(); } public ResultSet getResults(Expression... exprs) { queryMixin.addToProjection(exprs); String queryString = buildQueryString(false); logger.debug("query : {}", queryString); try { final PreparedStatement stmt = conn.prepareStatement(queryString); setParameters(stmt, constants, constantPaths, getMetadata().getParams()); ResultSet rs = stmt.executeQuery(); return new ResultSetAdapter(rs) { @Override public void close() throws SQLException { try { super.close(); } finally { stmt.close(); } } }; } catch (SQLException e) { throw new QueryException(e); } finally { reset(); } } private UnionBuilder innerUnion(SubQueryExpression... sq) { queryMixin.getMetadata().setValidate(false); if (!queryMixin.getMetadata().getJoins().isEmpty()) { throw new IllegalArgumentException("Don't mix union and from"); } this.union = sq; return new UnionBuilder(); } protected Configuration getConfiguration() { return configuration; } @SuppressWarnings("unchecked") @Override public CloseableIterator iterate(Expression[] args) { for (int i = 0; i < args.length; i++) { if (args[i] instanceof RelationalPath) { args[i] = wrap((RelationalPath)args[i]); } } queryMixin.addToProjection(args); return iterateMultiple(queryMixin.getMetadata()); } @Override public CloseableIterator iterate(Expression expr) { if (expr instanceof RelationalPath) { return iterate(wrap((RelationalPath)expr)); } else { expr = queryMixin.convert(expr); queryMixin.addToProjection(expr); return iterateSingle(queryMixin.getMetadata(), expr); } } @SuppressWarnings("unchecked") protected QBean wrap(RelationalPath expr) { try{ if (expr.getType().equals(expr.getClass())) { throw new IllegalArgumentException("RelationalPath based projection can only be used with generated Bean types"); } Map> bindings = new HashMap>(); for (Field field : expr.getClass().getDeclaredFields()) { if (Expression.class.isAssignableFrom(field.getType()) && !Modifier.isStatic(field.getModifiers())) { field.setAccessible(true); Expression column = (Expression) field.get(expr); bindings.put(field.getName(), column); } } if (bindings.isEmpty()) { throw new IllegalArgumentException("No bindings could be derived from " + expr); } return new QBean((Class)expr.getType(), bindings); }catch(IllegalAccessException e) { throw new QueryException(e); } } private CloseableIterator iterateMultiple(QueryMetadata metadata) { String queryString = buildQueryString(false); logger.debug("query : {}", queryString); try { PreparedStatement stmt = Assert.notNull(conn, "connection").prepareStatement(queryString); final List> projection = metadata.getProjection(); setParameters(stmt, constants, constantPaths, metadata.getParams()); ResultSet rs = stmt.executeQuery(); return new SQLResultIterator(stmt, rs) { @SuppressWarnings("unchecked") @Override protected Object[] produceNext(ResultSet rs) { try { List objects = new ArrayList(projection.size()); int index = 0; for (int i = 0; i < projection.size(); i++) { Expression expr = projection.get(i); if (expr instanceof FactoryExpression) { objects.add(newInstance((FactoryExpression)expr, rs, index)); index += ((FactoryExpression)expr).getArgs().size(); }else if (expr.getType().isArray()) { for (int j = index; j < rs.getMetaData().getColumnCount(); j++) { objects.add(get(rs, expr, index++ + 1, Object.class)); } i = objects.size(); } else { objects.add(get(rs, expr, index++ + 1, expr.getType())); } } return objects.toArray(); } catch (InstantiationException e) { close(); throw new QueryException(e); } catch (IllegalAccessException e) { close(); throw new QueryException(e); } catch (InvocationTargetException e) { close(); throw new QueryException(e); } catch (SQLException e) { close(); throw new QueryException(e); } } }; } catch (SQLException e) { throw new QueryException(e); } finally { reset(); } } @SuppressWarnings("unchecked") private CloseableIterator iterateSingle(QueryMetadata metadata, @Nullable final Expression expr) { String queryString = buildQueryString(false); logger.debug("query : {}", queryString); try { PreparedStatement stmt = Assert.notNull(conn, "connection").prepareStatement(queryString); setParameters(stmt, constants, constantPaths, metadata.getParams()); ResultSet rs = stmt.executeQuery(); return new SQLResultIterator(stmt, rs) { @Override public RT produceNext(ResultSet rs) { try { if (expr == null) { return (RT) rs.getObject(1); }else if (expr instanceof FactoryExpression) { return newInstance((FactoryExpression) expr, rs, 0); }else if (expr.getType().isArray()) { Object[] rv = new Object[rs.getMetaData().getColumnCount()]; for (int i = 0; i < rv.length; i++) { rv[i] = rs.getObject(i+1); } return (RT) rv; } else{ return get(rs, expr, 1, expr.getType()); } } catch (IllegalAccessException e) { close(); throw new QueryException(e); } catch (InvocationTargetException e) { close(); throw new QueryException(e); } catch (InstantiationException e) { close(); throw new QueryException(e); } catch (SQLException e) { close(); throw new QueryException(e); } } }; } catch (SQLException e) { throw new QueryException("Caught " + e.getClass().getSimpleName() + " for " + queryString, e); } finally { reset(); } } @Override public List list(Expression[] args) { return IteratorAdapter.asList(iterate(args)); } @Override public List list(Expression expr) { return IteratorAdapter.asList(iterate(expr)); } @Override public SearchResults listResults(Expression expr) { queryMixin.addToProjection(expr); long total = count(); try { if (total > 0) { QueryModifiers modifiers = queryMixin.getMetadata().getModifiers(); return new SearchResults(list(expr), modifiers, total); } else { return SearchResults.emptyResults(); } } finally { reset(); } } private RT newInstance(FactoryExpression c, ResultSet rs, int offset) throws InstantiationException, IllegalAccessException, InvocationTargetException, SQLException{ Object[] args = new Object[c.getArgs().size()]; for (int i = 0; i < args.length; i++) { args[i] = get(rs, c.getArgs().get(i), offset + i + 1, c.getArgs().get(i).getType()); } return c.newInstance(args); } public Q on(Predicate... conditions) { return queryMixin.on(conditions); } private void reset() { queryMixin.getMetadata().reset(); constants = null; } protected void setParameters(PreparedStatement stmt, List objects, List> constantPaths, Map, ?> params) { if (objects.size() != constantPaths.size()) { throw new IllegalArgumentException("Expected " + objects.size() + " paths, but got " + constantPaths.size()); } int counter = 1; for (int i = 0; i < objects.size(); i++) { Object o = objects.get(i); try { if (ParamExpression.class.isInstance(o)) { if (!params.containsKey(o)) { throw new ParamNotSetException((ParamExpression) o); } o = params.get(o); } counter += set(stmt, constantPaths.get(i), counter, o); } catch (SQLException e) { throw new IllegalArgumentException(e); } } } @Override public String toString() { return buildQueryString(false).trim(); } public UnionBuilder union(ListSubQuery... sq) { return innerUnion(sq); } public UnionBuilder union(SubQueryExpression... sq) { return innerUnion(sq); } public UnionBuilder unionAll(ListSubQuery... sq) { unionAll = true; return innerUnion(sq); } public UnionBuilder unionAll(SubQueryExpression... sq) { unionAll = true; return innerUnion(sq); } @Override public RT uniqueResult(Expression expr) { if (getMetadata().getModifiers().getLimit() == null && !expr.toString().contains("count(")) { limit(2); } CloseableIterator iterator = iterate(expr); return uniqueResult(iterator); } @Override public Object[] uniqueResult(Expression[] expr) { if (getMetadata().getModifiers().getLimit() == null) { limit(2); } CloseableIterator iterator = iterate(expr); return uniqueResult(iterator); } private long unsafeCount() throws SQLException { String queryString = buildQueryString(true); logger.debug("query : {}", queryString); PreparedStatement stmt = null; ResultSet rs = null; try { stmt = Assert.notNull(conn, "connection").prepareStatement(queryString); setParameters(stmt, constants, constantPaths, getMetadata().getParams()); rs = stmt.executeQuery(); rs.next(); return rs.getLong(1); } catch (SQLException e) { throw new QueryException(e.getMessage(), e); } finally { try { if (rs != null) { rs.close(); } } finally { if (stmt != null) { stmt.close(); } } } } }