diff --git a/querydsl-sql/src/main/java/com/querydsl/sql/SQLListeners.java b/querydsl-sql/src/main/java/com/querydsl/sql/SQLListeners.java index e21cf10a3..df4ca87a1 100644 --- a/querydsl-sql/src/main/java/com/querydsl/sql/SQLListeners.java +++ b/querydsl-sql/src/main/java/com/querydsl/sql/SQLListeners.java @@ -245,4 +245,9 @@ public class SQLListeners implements SQLDetailedListener { listener.exception(context); } } + + public Set getListeners() { + return listeners; + } + } diff --git a/querydsl-sql/src/main/java/com/querydsl/sql/SQLNoCloseListener.java b/querydsl-sql/src/main/java/com/querydsl/sql/SQLNoCloseListener.java new file mode 100644 index 000000000..6e989155a --- /dev/null +++ b/querydsl-sql/src/main/java/com/querydsl/sql/SQLNoCloseListener.java @@ -0,0 +1,36 @@ +/* + * Copyright 2015, The Querydsl Team (http://www.querydsl.com/team) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.querydsl.sql; + +/** + * SQLNoCloseListener can be used to block {@link SQLCloseListener} from closing the connection, useful for + * helper query executions + */ +public final class SQLNoCloseListener extends SQLBaseListener { + + public static final SQLNoCloseListener DEFAULT = new SQLNoCloseListener(); + + private SQLNoCloseListener() { } + + @Override + public void start(SQLListenerContext context) { + context.setData(AbstractSQLQuery.PARENT_CONTEXT, context); + } + + @Override + public void end(SQLListenerContext context) { + context.setData(AbstractSQLQuery.PARENT_CONTEXT, null); + } + +} \ No newline at end of file diff --git a/querydsl-sql/src/main/java/com/querydsl/sql/dml/SQLMergeClause.java b/querydsl-sql/src/main/java/com/querydsl/sql/dml/SQLMergeClause.java index adab3688f..29220f56a 100644 --- a/querydsl-sql/src/main/java/com/querydsl/sql/dml/SQLMergeClause.java +++ b/querydsl-sql/src/main/java/com/querydsl/sql/dml/SQLMergeClause.java @@ -320,6 +320,10 @@ public class SQLMergeClause extends AbstractSQLClause implements private boolean hasRow() { SQLQuery query = new SQLQuery(connection(), configuration).from(entity); + for (SQLListener listener : listeners.getListeners()) { + query.addListener(listener); + } + query.addListener(SQLNoCloseListener.DEFAULT); addKeyConditions(query); return query.select(Expressions.ONE).fetchFirst() != null; } @@ -344,17 +348,25 @@ public class SQLMergeClause extends AbstractSQLClause implements // update SQLUpdateClause update = new SQLUpdateClause(connection(), configuration, entity); populate(update); + addListeners(update); addKeyConditions(update); return update.execute(); } else { // insert SQLInsertClause insert = new SQLInsertClause(connection(), configuration, entity); + addListeners(insert); populate(insert); return insert.execute(); } } + private void addListeners(AbstractSQLClause clause) { + for (SQLListener listener : listeners.getListeners()) { + clause.addListener(listener); + } + } + @SuppressWarnings("unchecked") private void populate(StoreClause clause) { for (int i = 0; i < columns.size(); i++) { diff --git a/querydsl-sql/src/test/java/com/querydsl/sql/MergeBase.java b/querydsl-sql/src/test/java/com/querydsl/sql/MergeBase.java index d2947a728..22f11a6a7 100644 --- a/querydsl-sql/src/test/java/com/querydsl/sql/MergeBase.java +++ b/querydsl-sql/src/test/java/com/querydsl/sql/MergeBase.java @@ -21,6 +21,7 @@ import static org.junit.Assert.*; import java.sql.ResultSet; import java.sql.SQLException; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; import org.junit.After; import org.junit.Before; @@ -223,5 +224,24 @@ public class MergeBase extends AbstractBaseTest { assertEquals(1, merge.execute()); } + @Test + public void merge_listener() { + final AtomicInteger calls = new AtomicInteger(0); + SQLListener listener = new SQLBaseListener() { + @Override + public void end(SQLListenerContext context) { + if (context.getData(AbstractSQLQuery.PARENT_CONTEXT) == null) { + calls.incrementAndGet(); + } + } + }; + + SQLMergeClause clause = merge(survey).keys(survey.id) + .set(survey.id, 5) + .set(survey.name, "Hello World"); + clause.addListener(listener); + assertEquals(1, clause.execute()); + assertEquals(1, calls.intValue()); + } }