Skip to content

Commit

Permalink
HHH-16867 - support index and join hints in the CockroachDB dialect
Browse files Browse the repository at this point in the history
  • Loading branch information
maesenka authored and beikov committed Sep 20, 2023
1 parent b7bdcd1 commit 8df6d39
Show file tree
Hide file tree
Showing 4 changed files with 423 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import java.time.temporal.TemporalAccessor;
import java.util.Calendar;
import java.util.Date;
import java.util.List;
import java.util.Map;
import java.util.TimeZone;
import java.util.regex.Matcher;
Expand Down Expand Up @@ -1079,6 +1080,33 @@ public SQLExceptionConversionDelegate buildSQLExceptionConversionDelegate() {
};
}

/**
* Applies the hints to the query string.
*
* The hints can be <a href="https://www.cockroachlabs.com/docs/v23.1/table-expressions#force-index-selection">index selection hints</a>
* or <a href="https://www.cockroachlabs.com/docs/stable/sql-grammar#opt_join_hint">join hints</a>.
* <p>
* For index selection hints, use the format {@code <tablename>@{FORCE_INDEX=<index>[,<DIRECTION>]}}
* where the optional DIRECTION is either ASC (ascending) or DESC (descending). Multiple index hints can be provided.
* The effect is that in the final SQL statement the hint is added to the table name mentioned in the hint.
*<p>
* For join hints, use the format {@code "<MERGE|HASH|LOOKUP|INVERTED> JOIN"}. Only one join hint will be added. It is
* applied to all join statements in the SQL statement.
* <p>
* Hints are only added to select statements.
*
* @param query The query to which to apply the hint.
* @param hintList The hints to apply
*
* @return the query with hints added
*/
@Override
public String getQueryHintString(String query, List<String> hintList) {
return new CockroachDialectQueryHints(query, hintList).getQueryHintString();
}



// CockroachDB doesn't support this by default. See sql.multiple_modifications_of_table.enabled
//
// @Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
/*
* Hibernate, Relational Persistence for Idiomatic Java
*
* License: GNU Lesser General Public License (LGPL), version 2.1 or later.
* See the lgpl.txt file in the root directory or <http://www.gnu.org/licenses/lgpl-2.1.html>.
*/

package org.hibernate.dialect;

import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

class CockroachDialectQueryHints {

final private Pattern TABLE_QUERY_PATTERN = Pattern.compile(
"(?i)^\\s*(select\\b.+?\\bfrom\\b)(.+?)(\\bwhere\\b.+?)$" );
final private Pattern JOIN_HINT_PATTERN = Pattern.compile( "(?i)(MERGE|HASH|LOOKUP|INVERTED)\\s+JOIN" );

//If matched, group 1 contains everything before the join keyword.
final private Pattern JOIN_PATTERN = Pattern.compile(
"(?i)\\b(cross|natural\\s+(.*)\\b|(full|left|right)(\\s+outer)?)?\\s+join" );

final private String query;
final private List<String> hints;

public CockroachDialectQueryHints(String query, List<String> hintList) {
this.query = query;
this.hints = hintList;
}

public String getQueryHintString() {
List<IndexHint> indexHints = new ArrayList<>();
JoinHint joinHint = null;
for ( var h : hints ) {
IndexHint indexHint = parseIndexHints( h );
if ( indexHint != null ) {
indexHints.add( indexHint );
continue;
}
joinHint = parseJoinHints( h );
}

String result = addIndexHints( query, indexHints );
return joinHint == null ? result : addJoinHint( query, joinHint );
}

private IndexHint parseIndexHints(String hint) {
var parts = hint.split( "@" );
if ( parts.length == 2 ) {
return new IndexHint( parts[0], hint );
}
return null;
}

private JoinHint parseJoinHints(String hint) {
var matcher = JOIN_HINT_PATTERN.matcher( hint );
if ( matcher.find() ) {
return new JoinHint( matcher.group( 1 ) );
}
return null;
}

String addIndexHints(String query, List<IndexHint> hints) {

Matcher statementMatcher = TABLE_QUERY_PATTERN.matcher( query );

if ( statementMatcher.matches() && statementMatcher.groupCount() > 2 ) {
String prefix = statementMatcher.group( 1 );
String fromList = statementMatcher.group( 2 );
String suffix = statementMatcher.group( 3 );
fromList = addIndexHintsToFromList( fromList, hints );
return prefix + fromList + suffix;
}
else {
return query;
}
}

String addJoinHint(String query, JoinHint hint) {
var m = JOIN_PATTERN.matcher( query );
StringBuilder buffer = new StringBuilder();
int start = 0;
while ( m.find() ) {
buffer.append( query.substring( start, m.start() ) );

if ( m.group( 1 ) == null ) {
buffer.append( " inner" );
}
else {
buffer.append( m.group( 1 ) );
}
buffer.append( " " )
.append( hint.joinType )
.append( " join" );
start = m.end();
}
buffer.append( query.substring( start ) );
return buffer.toString();
}

String addIndexHintsToFromList(String fromList, List<IndexHint> hints) {
String result = fromList;
for ( var hint : hints ) {
result = result.replaceAll( "(?i)\\b" + hint.table + "\\b", hint.text );
}
return result;
}


static class IndexHint {
final String table;
final String text;

IndexHint(String table, String text) {
this.table = table;
this.text = text;
}

}

static class JoinHint {
final String joinType;

JoinHint(String type) {
this.joinType = type.toLowerCase( Locale.ROOT );
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
/*
* Hibernate, Relational Persistence for Idiomatic Java
*
* License: GNU Lesser General Public License (LGPL), version 2.1 or later.
* See the lgpl.txt file in the root directory or <http://www.gnu.org/licenses/lgpl-2.1.html>.
*/

package org.hibernate.orm.test.dialect.functional;

import java.util.HashSet;
import java.util.Set;

import org.hibernate.dialect.CockroachDialect;

import org.hibernate.testing.jdbc.SQLStatementInspector;
import org.hibernate.testing.orm.junit.DomainModel;
import org.hibernate.testing.orm.junit.JiraKey;
import org.hibernate.testing.orm.junit.RequiresDialect;
import org.hibernate.testing.orm.junit.SessionFactory;
import org.hibernate.testing.orm.junit.SessionFactoryScope;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;

import jakarta.persistence.Entity;
import jakarta.persistence.Id;
import jakarta.persistence.Index;
import jakarta.persistence.JoinColumn;
import jakarta.persistence.ManyToOne;
import jakarta.persistence.OneToMany;
import jakarta.persistence.Table;
import jakarta.persistence.TypedQuery;

import static org.assertj.core.api.AssertionsForClassTypes.assertThat;

@RequiresDialect(CockroachDialect.class)
@SessionFactory(useCollectingStatementInspector = true)
@DomainModel(annotatedClasses = {
SimpleEntity.class, ChildEntity.class
})
@JiraKey("HHH-16867")
public class CockroachDBQueryHintsTest {

@BeforeAll
public void setUp(SessionFactoryScope scope) {
scope.inTransaction( session -> {
var se1 = new SimpleEntity( 1, "se1" );
se1.addChild( new ChildEntity( "se1child1" ) );
session.persist( se1 );
var se2 = new SimpleEntity( 2, "se2" );
session.persist( se2 );
var se3 = new SimpleEntity( 3, "se3" );
session.persist( se3 );
} );
}

@Test
public void testIndexHint(SessionFactoryScope scope) {
final SQLStatementInspector statementInspector = scope.getCollectingStatementInspector();
statementInspector.clear();
scope.inTransaction( session -> {
TypedQuery<Integer> query = session.createQuery( "select id from SimpleEntity where id < 3", Integer.class )
.addQueryHint( "parents@{FORCE_INDEX=idx,DESC}" );
var ignored = query.getResultList();
} );
assertThat( statementInspector.getSqlQueries().get( 0 ) ).contains(
" parents@{FORCE_INDEX=idx,DESC} " );
}

@Test
public void testJoinHint(SessionFactoryScope scope) {
final SQLStatementInspector statementInspector = scope.getCollectingStatementInspector();
statementInspector.clear();
scope.inTransaction( session -> {
TypedQuery<ChildEntity> query = session.createQuery(
"select c from SimpleEntity s join s.children c where s.id < 3",
ChildEntity.class
)
.addQueryHint( "haSh join" );
var ignored = query.getResultList();
} );
assertThat( statementInspector.getSqlQueries().get( 0 ) ).contains(
" hash join " );
}

@Test
public void testOuterJoinHint(SessionFactoryScope scope) {
final SQLStatementInspector statementInspector = scope.getCollectingStatementInspector();
statementInspector.clear();
scope.inTransaction( session -> {
TypedQuery<ChildEntity> query = session.createQuery(
"select c from SimpleEntity s left join s.children c where s.id < 3",
ChildEntity.class
)
.addQueryHint( "haSh join" );
var ignored = query.getResultList();
} );
assertThat( statementInspector.getSqlQueries().get( 0 ) ).contains(
" hash join " );
}
}

@Entity
@Table(name = "children")
class ChildEntity {
@Id
private Integer id;

private String childName;

@ManyToOne
@JoinColumn(name = "parent_id", nullable = false)
private SimpleEntity parent;

public ChildEntity() {
}

public ChildEntity(String childName) {
this.childName = childName;
}

public Integer getId() {
return id;
}

public void setId(Integer id) {
this.id = id;
}

public SimpleEntity getParent() {
return parent;
}
}

@Entity(name = "SimpleEntity")
@Table(name = "parents", indexes = { @Index(name = "idx", columnList = "id") })
class SimpleEntity {
@Id
private Integer id;

private String name;

@OneToMany(mappedBy = "parent")
private Set<ChildEntity> children;

public SimpleEntity() {
}

public SimpleEntity(Integer id, String name) {
this.id = id;
this.name = name;
this.children = new HashSet<>();
}

public Integer getId() {
return id;
}

public void setId(Integer id) {
this.id = id;
}

public String getName() {
return name;
}

public void setName(String name) {
this.name = name;
}


public Set<ChildEntity> getChildren() {
return children;
}

public void setChildren(Set<ChildEntity> children) {
this.children = children;
}

public void addChild(ChildEntity child) {
this.children.add( child );
}
}
Loading

0 comments on commit 8df6d39

Please sign in to comment.