Commit f61f13e1 authored by Marcus Gartner's avatar Marcus Gartner

schemaexpr: use sqlbase.TableColSet instead of maps

This commit replaces maps used as sets of integers with
sqlbase.TableColSet because it is a more efficient set implementation.

Release note: None
parent be12f0e8
......@@ -14,7 +14,6 @@ import (
"bytes"
"context"
"fmt"
"sort"
"github.com/cockroachdb/cockroach/pkg/sql/sem/tree"
"github.com/cockroachdb/cockroach/pkg/sql/sqlbase"
......@@ -102,7 +101,7 @@ func (b *CheckConstraintBuilder) Build(
// Replace the column variables with dummyColumns so that they can be
// type-checked.
replacedExpr, colIDsUsed, err := b.desc.ReplaceColumnVarsInExprWithDummies(expr)
replacedExpr, colIDs, err := replaceVars(&b.desc.TableDescriptor, expr)
if err != nil {
return nil, err
}
......@@ -121,17 +120,10 @@ func (b *CheckConstraintBuilder) Build(
return nil, err
}
// Collect and sort the column IDs referenced in the check expression.
colIDs := make(sqlbase.ColumnIDs, 0, colIDsUsed.Len())
colIDsUsed.ForEach(func(id int) {
colIDs = append(colIDs, sqlbase.ColumnID(id))
})
sort.Sort(colIDs)
return &sqlbase.TableDescriptor_CheckConstraint{
Expr: tree.Serialize(typedExpr),
Name: name,
ColumnIDs: colIDs,
ColumnIDs: colIDs.Ordered(),
Hidden: c.Hidden,
}, nil
}
......
......@@ -60,46 +60,6 @@ func DequalifyColumnRefs(
)
}
// iterColDescriptors iterates over the expression's variable columns and
// calls f on each.
//
// If the expression references a column that does not exist in the table
// descriptor, iterColDescriptors errs with pgcode.UndefinedColumn.
func iterColDescriptors(
desc *sqlbase.MutableTableDescriptor, rootExpr tree.Expr, f func(*sqlbase.ColumnDescriptor) error,
) error {
_, err := tree.SimpleVisit(rootExpr, func(expr tree.Expr) (recurse bool, newExpr tree.Expr, err error) {
vBase, ok := expr.(tree.VarName)
if !ok {
// Not a VarName, don't do anything to this node.
return true, expr, nil
}
v, err := vBase.NormalizeVarName()
if err != nil {
return false, nil, err
}
c, ok := v.(*tree.ColumnItem)
if !ok {
return true, expr, nil
}
col, dropped, err := desc.FindColumnByName(c.ColumnName)
if err != nil || dropped {
return false, nil, pgerror.Newf(pgcode.UndefinedColumn,
"column %q does not exist, referenced in %q", c.ColumnName, rootExpr.String())
}
if err := f(col); err != nil {
return false, nil, err
}
return false, expr, err
})
return err
}
// DeserializeTableDescExpr takes in a serialized expression and a table, and
// returns an expression that has all user defined types resolved for
// formatting. It is intended to be used when displaying a serialized
......@@ -116,7 +76,7 @@ func DeserializeTableDescExpr(
if err != nil {
return nil, err
}
expr, _, err = desc.ReplaceColumnVarsInExprWithDummies(expr)
expr, _, err = replaceVars(desc, expr)
if err != nil {
return nil, err
}
......@@ -163,3 +123,127 @@ func FormatColumnForDisplay(
}
return f.CloseAndGetString(), nil
}
// iterColDescriptors iterates over the expression's variable columns and
// calls f on each.
//
// If the expression references a column that does not exist in the table
// descriptor, iterColDescriptors errs with pgcode.UndefinedColumn.
func iterColDescriptors(
desc *sqlbase.MutableTableDescriptor, rootExpr tree.Expr, f func(*sqlbase.ColumnDescriptor) error,
) error {
_, err := tree.SimpleVisit(rootExpr, func(expr tree.Expr) (recurse bool, newExpr tree.Expr, err error) {
vBase, ok := expr.(tree.VarName)
if !ok {
// Not a VarName, don't do anything to this node.
return true, expr, nil
}
v, err := vBase.NormalizeVarName()
if err != nil {
return false, nil, err
}
c, ok := v.(*tree.ColumnItem)
if !ok {
return true, expr, nil
}
col, dropped, err := desc.FindColumnByName(c.ColumnName)
if err != nil || dropped {
return false, nil, pgerror.Newf(pgcode.UndefinedColumn,
"column %q does not exist, referenced in %q", c.ColumnName, rootExpr.String())
}
if err := f(col); err != nil {
return false, nil, err
}
return false, expr, err
})
return err
}
// dummyColumn represents a variable column that can type-checked. It is used
// in validating check constraint and partial index predicate expressions. This
// validation requires that the expression can be both both typed-checked and
// examined for variable expressions.
type dummyColumn struct {
typ *types.T
name tree.Name
}
// String implements the Stringer interface.
func (d *dummyColumn) String() string {
return tree.AsString(d)
}
// Format implements the NodeFormatter interface.
func (d *dummyColumn) Format(ctx *tree.FmtCtx) {
d.name.Format(ctx)
}
// Walk implements the Expr interface.
func (d *dummyColumn) Walk(_ tree.Visitor) tree.Expr {
return d
}
// TypeCheck implements the Expr interface.
func (d *dummyColumn) TypeCheck(
_ context.Context, _ *tree.SemaContext, desired *types.T,
) (tree.TypedExpr, error) {
return d, nil
}
// Eval implements the TypedExpr interface.
func (*dummyColumn) Eval(_ *tree.EvalContext) (tree.Datum, error) {
panic("dummyColumnItem.Eval() is undefined")
}
// ResolvedType implements the TypedExpr interface.
func (d *dummyColumn) ResolvedType() *types.T {
return d.typ
}
// replaceVars replaces the occurrences of column names in an expression with
// dummyColumns containing their type, so that they may be type-checked. It
// returns this new expression tree alongside a set containing the ColumnID of
// each column seen in the expression.
//
// If the expression references a column that does not exist in the table
// descriptor, replaceVars errs with pgcode.UndefinedColumn.
func replaceVars(
desc *sqlbase.TableDescriptor, rootExpr tree.Expr,
) (tree.Expr, sqlbase.TableColSet, error) {
var colIDs sqlbase.TableColSet
newExpr, err := tree.SimpleVisit(rootExpr, func(expr tree.Expr) (recurse bool, newExpr tree.Expr, err error) {
vBase, ok := expr.(tree.VarName)
if !ok {
// Not a VarName, don't do anything to this node.
return true, expr, nil
}
v, err := vBase.NormalizeVarName()
if err != nil {
return false, nil, err
}
c, ok := v.(*tree.ColumnItem)
if !ok {
return true, expr, nil
}
col, dropped, err := desc.FindColumnByName(c.ColumnName)
if err != nil || dropped {
return false, nil, pgerror.Newf(pgcode.UndefinedColumn,
"column %q does not exist, referenced in %q", c.ColumnName, rootExpr.String())
}
colIDs.Add(col.ID)
// Convert to a dummyColumn of the correct type.
return false, &dummyColumn{typ: col.Type, name: c.ColumnName}, nil
})
return newExpr, colIDs, err
}
......@@ -60,15 +60,14 @@ func (v *ComputedColumnValidator) Validate(d *tree.ColumnTableDef) error {
)
}
// TODO(mgartner): Use util.FastIntSet here instead.
dependencies := make(map[sqlbase.ColumnID]struct{})
var depColIDs sqlbase.TableColSet
// First, check that no column in the expression is a computed column.
err := iterColDescriptors(v.desc, d.Computed.Expr, func(c *sqlbase.ColumnDescriptor) error {
if c.IsComputed() {
return pgerror.New(pgcode.InvalidTableDefinition,
"computed columns cannot reference other computed columns")
}
dependencies[c.ID] = struct{}{}
depColIDs.Add(c.ID)
return nil
})
......@@ -82,7 +81,7 @@ func (v *ComputedColumnValidator) Validate(d *tree.ColumnTableDef) error {
for i := range v.desc.OutboundFKs {
fk := &v.desc.OutboundFKs[i]
for _, id := range fk.OriginColumnIDs {
if _, ok := dependencies[id]; !ok {
if !depColIDs.Contains(id) {
// We don't depend on this column.
continue
}
......@@ -103,7 +102,7 @@ func (v *ComputedColumnValidator) Validate(d *tree.ColumnTableDef) error {
// Replace the column variables with dummyColumns so that they can be
// type-checked.
replacedExpr, _, err := v.desc.ReplaceColumnVarsInExprWithDummies(d.Computed.Expr)
replacedExpr, _, err := replaceVars(&v.desc.TableDescriptor, d.Computed.Expr)
if err != nil {
return err
}
......
......@@ -11,25 +11,6 @@
/*
Package schemaexpr provides utilities for dealing with expressions with table
schemas, such as check constraints, computed columns, and partial index
predicates. It provides the following utilities.
CheckConstraintBuilder
Validates and builds sql.TableDescriptor_CheckConstraints from
tree.CheckConstraintTableDefs.
ComputedColumnValidator
Validates computed columns and can determine if a non-computed column has
dependent computed columns.
PartialIndexValidator
Validates partial index predicates and dequalifies the columns referenced.
DequalifyColumnRefs
Strips database and table names from qualified columns.
predicates.
*/
package schemaexpr
......@@ -58,7 +58,7 @@ func NewIndexPredicateValidator(
func (v *IndexPredicateValidator) Validate(expr tree.Expr) (tree.Expr, error) {
// Replace the column variables with dummyColumns so that they can be
// type-checked.
replacedExpr, _, err := v.desc.ReplaceColumnVarsInExprWithDummies(expr)
replacedExpr, _, err := replaceVars(&v.desc.TableDescriptor, expr)
if err != nil {
return nil, err
}
......
......@@ -28,7 +28,6 @@ import (
"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror"
"github.com/cockroachdb/cockroach/pkg/sql/sem/tree"
"github.com/cockroachdb/cockroach/pkg/sql/types"
"github.com/cockroachdb/cockroach/pkg/util"
"github.com/cockroachdb/cockroach/pkg/util/encoding"
"github.com/cockroachdb/cockroach/pkg/util/errorutil/unimplemented"
"github.com/cockroachdb/cockroach/pkg/util/hlc"
......@@ -836,85 +835,6 @@ func (desc *TableDescriptor) AllActiveAndInactiveChecks() []*TableDescriptor_Che
return checks
}
// ReplaceColumnVarsInExprWithDummies replaces the occurrences of column names in an expression with
// dummies containing their type, so that they may be typechecked. It returns
// this new expression tree alongside a set containing the ColumnID of each
// column seen in the expression.
func (desc *TableDescriptor) ReplaceColumnVarsInExprWithDummies(
rootExpr tree.Expr,
) (tree.Expr, util.FastIntSet, error) {
var colIDs util.FastIntSet
newExpr, err := tree.SimpleVisit(rootExpr, func(expr tree.Expr) (recurse bool, newExpr tree.Expr, err error) {
vBase, ok := expr.(tree.VarName)
if !ok {
// Not a VarName, don't do anything to this node.
return true, expr, nil
}
v, err := vBase.NormalizeVarName()
if err != nil {
return false, nil, err
}
c, ok := v.(*tree.ColumnItem)
if !ok {
return true, expr, nil
}
col, dropped, err := desc.FindColumnByName(c.ColumnName)
if err != nil || dropped {
return false, nil, pgerror.Newf(pgcode.UndefinedColumn,
"column %q does not exist, referenced in %q", c.ColumnName, rootExpr.String())
}
colIDs.Add(int(col.ID))
// Convert to a dummy node of the correct type.
return false, &dummyColumnItem{typ: col.Type, name: c.ColumnName}, nil
})
return newExpr, colIDs, err
}
// dummyColumnItem is used in makeCheckConstraint and validateIndexPredicate to
// construct an expression that can be both type-checked and examined for
// variable expressions. It can also be used to format typed expressions
// containing column references.
type dummyColumnItem struct {
typ *types.T
name tree.Name
}
// String implements the Stringer interface.
func (d *dummyColumnItem) String() string {
return tree.AsString(d)
}
// Format implements the NodeFormatter interface.
// It should be kept in line with ColumnItem.Format.
func (d *dummyColumnItem) Format(ctx *tree.FmtCtx) {
ctx.FormatNode(&d.name)
}
// Walk implements the Expr interface.
func (d *dummyColumnItem) Walk(_ tree.Visitor) tree.Expr {
return d
}
// TypeCheck implements the Expr interface.
func (d *dummyColumnItem) TypeCheck(
_ context.Context, _ *tree.SemaContext, desired *types.T,
) (tree.TypedExpr, error) {
return d, nil
}
// Eval implements the TypedExpr interface.
func (*dummyColumnItem) Eval(_ *tree.EvalContext) (tree.Datum, error) {
panic("dummyColumnItem.Eval() is undefined")
}
// ResolvedType implements the TypedExpr interface.
func (d *dummyColumnItem) ResolvedType() *types.T {
return d.typ
}
// GetColumnFamilyForShard returns the column family that a newly added shard column
// should be assigned to, given the set of columns it's computed from.
//
......
......@@ -29,9 +29,6 @@ func MakeTableColSet(vals ...ColumnID) TableColSet {
// Add adds a column to the set. No-op if the column is already in the set.
func (s *TableColSet) Add(col ColumnID) { s.set.Add(int(col)) }
// Remove removes a column from the set. No-op if the column is not in the set.
func (s *TableColSet) Remove(col ColumnID) { s.set.Remove(int(col)) }
// Contains returns true if the set contains the column.
func (s TableColSet) Contains(col ColumnID) bool { return s.set.Contains(int(col)) }
......@@ -41,13 +38,6 @@ func (s TableColSet) Empty() bool { return s.set.Empty() }
// Len returns the number of the columns in the set.
func (s TableColSet) Len() int { return s.set.Len() }
// Next returns the first value in the set which is >= startVal. If there is no
// such column, the second return value is false.
func (s TableColSet) Next(startVal ColumnID) (ColumnID, bool) {
c, ok := s.set.Next(int(startVal))
return ColumnID(c), ok
}
// ForEach calls a function for each column in the set (in increasing order).
func (s TableColSet) ForEach(f func(col ColumnID)) { s.set.ForEach(func(i int) { f(ColumnID(i)) }) }
......@@ -64,40 +54,6 @@ func (s TableColSet) Ordered() []ColumnID {
return result
}
// Copy returns a copy of s which can be modified independently.
func (s TableColSet) Copy() TableColSet { return TableColSet{set: s.set.Copy()} }
// UnionWith adds all the columns from rhs to this set.
func (s *TableColSet) UnionWith(rhs TableColSet) { s.set.UnionWith(rhs.set) }
// Union returns the union of s and rhs as a new set.
func (s TableColSet) Union(rhs TableColSet) TableColSet { return TableColSet{set: s.set.Union(rhs.set)} }
// IntersectionWith removes any columns not in rhs from this set.
func (s *TableColSet) IntersectionWith(rhs TableColSet) { s.set.IntersectionWith(rhs.set) }
// Intersection returns the intersection of s and rhs as a new set.
func (s TableColSet) Intersection(rhs TableColSet) TableColSet {
return TableColSet{set: s.set.Intersection(rhs.set)}
}
// DifferenceWith removes any elements in rhs from this set.
func (s *TableColSet) DifferenceWith(rhs TableColSet) { s.set.DifferenceWith(rhs.set) }
// Difference returns the elements of s that are not in rhs as a new set.
func (s TableColSet) Difference(rhs TableColSet) TableColSet {
return TableColSet{set: s.set.Difference(rhs.set)}
}
// Intersects returns true if s has any elements in common with rhs.
func (s TableColSet) Intersects(rhs TableColSet) bool { return s.set.Intersects(rhs.set) }
// Equals returns true if the two sets are identical.
func (s TableColSet) Equals(rhs TableColSet) bool { return s.set.Equals(rhs.set) }
// SubsetOf returns true if rhs contains all the elements in s.
func (s TableColSet) SubsetOf(rhs TableColSet) bool { return s.set.SubsetOf(rhs.set) }
// String returns a list representation of elements. Sequential runs of positive
// numbers are shown as ranges. For example, for the set {1, 2, 3 5, 6, 10},
// the output is "(1-3,5,6,10)".
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment