Commit 6a5479be authored by Jordan Lewis's avatar Jordan Lewis

execgen: extract template reading code

Previously, all template generators had to read their template file
themselves. Now, this is done by execgen main, opening the door to
global transforms that affect all templates in the same way.

Release note: None
parent 7bf34017
......@@ -12,7 +12,6 @@ package main
import (
"io"
"io/ioutil"
"strings"
"text/template"
)
......@@ -26,11 +25,7 @@ type logicalOperation struct {
const andOrProjTmpl = "pkg/sql/colexec/and_or_projection_tmpl.go"
func genAndOrProjectionOps(wr io.Writer) error {
t, err := ioutil.ReadFile(andOrProjTmpl)
if err != nil {
return err
}
func genAndOrProjectionOps(inputFileContents string, wr io.Writer) error {
r := strings.NewReplacer(
"_OP_LOWER", "{{.Lower}}",
......@@ -39,7 +34,7 @@ func genAndOrProjectionOps(wr io.Writer) error {
"_L_HAS_NULLS", "$.lHasNulls",
"_R_HAS_NULLS", "$.rHasNulls",
)
s := r.Replace(string(t))
s := r.Replace(inputFileContents)
addTupleForRight := makeFunctionRegex("_ADD_TUPLE_FOR_RIGHT", 1)
s = addTupleForRight.ReplaceAllString(s, `{{template "addTupleForRight" buildDict "Global" $ "lHasNulls" $1}}`)
......
......@@ -12,7 +12,6 @@ package main
import (
"io"
"io/ioutil"
"strings"
"text/template"
......@@ -21,12 +20,7 @@ import (
const anyNotNullAggTmpl = "pkg/sql/colexec/any_not_null_agg_tmpl.go"
func genAnyNotNullAgg(wr io.Writer) error {
t, err := ioutil.ReadFile(anyNotNullAggTmpl)
if err != nil {
return err
}
func genAnyNotNullAgg(inputFileContents string, wr io.Writer) error {
r := strings.NewReplacer(
"_CANONICAL_TYPE_FAMILY", "{{.CanonicalTypeFamilyStr}}",
"_TYPE_WIDTH", typeWidthReplacement,
......@@ -35,7 +29,7 @@ func genAnyNotNullAgg(wr io.Writer) error {
"_TYPE", "{{.VecMethod}}",
"TemplateType", "{{.VecMethod}}",
)
s := r.Replace(string(t))
s := r.Replace(inputFileContents)
findAnyNotNull := makeFunctionRegex("_FIND_ANY_NOT_NULL", 4)
s = findAnyNotNull.ReplaceAllString(s, `{{template "findAnyNotNull" buildDict "Global" . "HasNulls" $4}}`)
......
......@@ -13,7 +13,6 @@ package main
import (
"fmt"
"io"
"io/ioutil"
"strings"
"text/template"
......@@ -55,12 +54,7 @@ var (
const avgAggTmpl = "pkg/sql/colexec/avg_agg_tmpl.go"
func genAvgAgg(wr io.Writer) error {
t, err := ioutil.ReadFile(avgAggTmpl)
if err != nil {
return err
}
func genAvgAgg(inputFileContents string, wr io.Writer) error {
r := strings.NewReplacer(
"_CANONICAL_TYPE_FAMILY", "{{.CanonicalTypeFamilyStr}}",
"_TYPE_WIDTH", typeWidthReplacement,
......@@ -69,7 +63,7 @@ func genAvgAgg(wr io.Writer) error {
"_TYPE", "{{.VecMethod}}",
"TemplateType", "{{.VecMethod}}",
)
s := r.Replace(string(t))
s := r.Replace(inputFileContents)
assignDivRe := makeFunctionRegex("_ASSIGN_DIV_INT64", 6)
s = assignDivRe.ReplaceAllString(s, makeTemplateFunctionCall("AssignDivInt64", 6))
......
......@@ -13,7 +13,6 @@ package main
import (
"fmt"
"io"
"io/ioutil"
"strings"
"text/template"
......@@ -60,17 +59,12 @@ var (
const boolAggTmpl = "pkg/sql/colexec/bool_and_or_agg_tmpl.go"
func genBooleanAgg(wr io.Writer) error {
t, err := ioutil.ReadFile(boolAggTmpl)
if err != nil {
return err
}
func genBooleanAgg(inputFileContents string, wr io.Writer) error {
r := strings.NewReplacer(
"_OP_TYPE", "{{.OpType}}",
"_DEFAULT_VAL", "{{.DefaultVal}}",
)
s := r.Replace(string(t))
s := r.Replace(inputFileContents)
accumulateBoolean := makeFunctionRegex("_ACCUMULATE_BOOLEAN", 3)
s = accumulateBoolean.ReplaceAllString(s, `{{template "accumulateBoolean" buildDict "Global" .}}`)
......
......@@ -12,19 +12,13 @@ package main
import (
"io"
"io/ioutil"
"strings"
"text/template"
)
const castTmpl = "pkg/sql/colexec/cast_tmpl.go"
func genCastOperators(wr io.Writer) error {
t, err := ioutil.ReadFile(castTmpl)
if err != nil {
return err
}
func genCastOperators(inputFileContents string, wr io.Writer) error {
r := strings.NewReplacer(
"_LEFT_CANONICAL_TYPE_FAMILY", "{{.LeftCanonicalFamilyStr}}",
"_LEFT_TYPE_WIDTH", typeWidthReplacement,
......@@ -34,7 +28,7 @@ func genCastOperators(wr io.Writer) error {
"_L_TYP", "{{.Left.VecMethod}}",
"_R_TYP", "{{.Right.VecMethod}}",
)
s := r.Replace(string(t))
s := r.Replace(inputFileContents)
castRe := makeFunctionRegex("_CAST", 2)
s = castRe.ReplaceAllString(s, makeTemplateFunctionCall("Right.Cast", 2))
......
......@@ -12,7 +12,6 @@ package main
import (
"io"
"io/ioutil"
"strings"
"text/template"
......@@ -21,11 +20,7 @@ import (
const constTmpl = "pkg/sql/colexec/const_tmpl.go"
func genConstOps(wr io.Writer) error {
d, err := ioutil.ReadFile(constTmpl)
if err != nil {
return err
}
func genConstOps(inputFileContents string, wr io.Writer) error {
r := strings.NewReplacer(
"_CANONICAL_TYPE_FAMILY", "{{.CanonicalTypeFamilyStr}}",
......@@ -34,7 +29,7 @@ func genConstOps(wr io.Writer) error {
"_TYPE", "{{.VecMethod}}",
"TemplateType", "{{.VecMethod}}",
)
s := r.Replace(string(d))
s := r.Replace(inputFileContents)
s = replaceManipulationFuncs(s)
......
......@@ -12,22 +12,14 @@ package main
import (
"io"
"io/ioutil"
"strings"
"text/template"
)
const countAggTmpl = "pkg/sql/colexec/count_agg_tmpl.go"
func genCountAgg(wr io.Writer) error {
t, err := ioutil.ReadFile(countAggTmpl)
if err != nil {
return err
}
s := string(t)
s = strings.ReplaceAll(s, "_KIND", "{{.Kind}}")
func genCountAgg(inputFileContents string, wr io.Writer) error {
s := strings.ReplaceAll(inputFileContents, "_KIND", "{{.Kind}}")
accumulateSum := makeFunctionRegex("_ACCUMULATE_COUNT", 4)
s = accumulateSum.ReplaceAllString(s, `{{template "accumulateCount" buildDict "Global" . "ColWithNulls" $4}}`)
......
......@@ -12,7 +12,6 @@ package main
import (
"io"
"io/ioutil"
"strings"
"text/template"
......@@ -21,12 +20,7 @@ import (
const distinctOpsTmpl = "pkg/sql/colexec/distinct_tmpl.go"
func genDistinctOps(wr io.Writer) error {
d, err := ioutil.ReadFile(distinctOpsTmpl)
if err != nil {
return err
}
func genDistinctOps(inputFileContents string, wr io.Writer) error {
r := strings.NewReplacer(
"_CANONICAL_TYPE_FAMILY", "{{.CanonicalTypeFamilyStr}}",
"_TYPE_WIDTH", typeWidthReplacement,
......@@ -34,7 +28,7 @@ func genDistinctOps(wr io.Writer) error {
"_GOTYPE", "{{.GoType}}",
"_TYPE", "{{.VecMethod}}",
"TemplateType", "{{.VecMethod}}")
s := r.Replace(string(d))
s := r.Replace(inputFileContents)
assignNeRe := makeFunctionRegex("_ASSIGN_NE", 6)
s = assignNeRe.ReplaceAllString(s, makeTemplateFunctionCall("Assign", 6))
......
......@@ -12,7 +12,6 @@ package main
import (
"io"
"io/ioutil"
"strings"
"text/template"
......@@ -21,18 +20,13 @@ import (
const hashAggTmpl = "pkg/sql/colexec/hash_aggregator_tmpl.go"
func genHashAggregator(wr io.Writer) error {
t, err := ioutil.ReadFile(hashAggTmpl)
if err != nil {
return err
}
func genHashAggregator(inputFileContents string, wr io.Writer) error {
r := strings.NewReplacer(
"_CANONICAL_TYPE_FAMILY", "{{.CanonicalTypeFamilyStr}}",
"_TYPE_WIDTH", typeWidthReplacement,
"TemplateType", "{{.VecMethod}}",
)
s := r.Replace(string(t))
s := r.Replace(inputFileContents)
s = replaceManipulationFuncsAmbiguous(".Global", s)
......
......@@ -12,18 +12,13 @@ package main
import (
"io"
"io/ioutil"
"strings"
"text/template"
)
const hashUtilsTmpl = "pkg/sql/colexec/hash_utils_tmpl.go"
func genHashUtils(wr io.Writer) error {
t, err := ioutil.ReadFile(hashUtilsTmpl)
if err != nil {
return err
}
func genHashUtils(inputFileContents string, wr io.Writer) error {
r := strings.NewReplacer(
"_CANONICAL_TYPE_FAMILY", "{{.CanonicalTypeFamilyStr}}",
......@@ -31,7 +26,7 @@ func genHashUtils(wr io.Writer) error {
"_TYPE", "{{.VecMethod}}",
"TemplateType", "{{.VecMethod}}",
)
s := r.Replace(string(t))
s := r.Replace(inputFileContents)
assignHash := makeFunctionRegex("_ASSIGN_HASH", 4)
s = assignHash.ReplaceAllString(s, makeTemplateFunctionCall("Global.UnaryAssign", 4))
......
......@@ -12,19 +12,13 @@ package main
import (
"io"
"io/ioutil"
"text/template"
)
const hashJoinerTmpl = "pkg/sql/colexec/hashjoiner_tmpl.go"
func genHashJoiner(wr io.Writer) error {
t, err := ioutil.ReadFile(hashJoinerTmpl)
if err != nil {
return err
}
s := string(t)
func genHashJoiner(inputFileContents string, wr io.Writer) error {
s := inputFileContents
distinctCollectRightOuter := makeFunctionRegex("_DISTINCT_COLLECT_PROBE_OUTER", 3)
s = distinctCollectRightOuter.ReplaceAllString(s, `{{template "distinctCollectProbeOuter" buildDict "Global" . "UseSel" $3}}`)
......
......@@ -13,7 +13,6 @@ package main
import (
"fmt"
"io"
"io/ioutil"
"strings"
"text/template"
......@@ -73,12 +72,7 @@ var _ = hashTableMode.IsDeletingProbe
const hashTableTmpl = "pkg/sql/colexec/hashtable_tmpl.go"
func genHashTable(wr io.Writer, htm hashTableMode) error {
t, err := ioutil.ReadFile(hashTableTmpl)
if err != nil {
return err
}
func genHashTable(inputFileContents string, wr io.Writer, htm hashTableMode) error {
r := strings.NewReplacer(
"_LEFT_CANONICAL_TYPE_FAMILY", "{{.LeftCanonicalFamilyStr}}",
"_LEFT_TYPE_WIDTH", typeWidthReplacement,
......@@ -91,7 +85,7 @@ func genHashTable(wr io.Writer, htm hashTableMode) error {
"_DELETING_PROBE_MODE", "$deletingProbeMode",
"_OVERLOADS", ".Overloads",
)
s := r.Replace(string(t))
s := r.Replace(inputFileContents)
s = strings.ReplaceAll(s, "_L_UNSAFEGET", "execgen.UNSAFEGET")
s = replaceManipulationFuncsAmbiguous(".Global.Left", s)
......@@ -157,8 +151,8 @@ func genHashTable(wr io.Writer, htm hashTableMode) error {
func init() {
hashTableGenerator := func(htm hashTableMode) generator {
return func(wr io.Writer) error {
return genHashTable(wr, htm)
return func(inputFileContents string, wr io.Writer) error {
return genHashTable(inputFileContents, wr, htm)
}
}
......
......@@ -13,6 +13,7 @@ package main
import (
"fmt"
"io"
"io/ioutil"
"text/template"
"github.com/cockroachdb/cockroach/pkg/sql/types"
......@@ -39,15 +40,16 @@ import (
{{end}}
`
func genLikeOps(wr io.Writer) error {
tmpl, err := getSelectionOpsTmpl()
func genLikeOps(inputFileContents string, wr io.Writer) error {
tmpl, err := getSelectionOpsTmpl(inputFileContents)
if err != nil {
return err
}
projTemplate, err := getProjConstOpTmplString(false /* isConstLeft */)
projConstFile, err := ioutil.ReadFile(projConstOpsTmpl)
if err != nil {
return err
}
projTemplate := replaceProjConstTmplVariables(string(projConstFile), false /* isConstLeft */)
tmpl, err = tmpl.Funcs(template.FuncMap{"buildDict": buildDict}).Parse(projTemplate)
if err != nil {
return err
......
......@@ -15,6 +15,7 @@ import (
"flag"
"fmt"
"io"
"io/ioutil"
"os"
"path/filepath"
"regexp"
......@@ -47,20 +48,22 @@ type execgenTool struct {
cmdLine *flag.FlagSet
}
type generator func(io.Writer) error
// generator is a func that, given an input file's contents as a string,
// outputs the result of execgen to the outputFile.
type generator func(inputFileContents string, outputFile io.Writer) error
var generators = make(map[string]entry)
type entry struct {
fn generator
dep string
fn generator
inputFile string
}
func registerGenerator(g generator, filename, dep string) {
if _, ok := generators[filename]; ok {
colexecerror.InternalError(fmt.Sprintf("%s generator already registered", filename))
}
generators[filename] = entry{fn: g, dep: dep}
generators[filename] = entry{fn: g, inputFile: dep}
}
func (g *execgenTool) run(args ...string) bool {
......@@ -86,20 +89,20 @@ func (g *execgenTool) run(args ...string) bool {
for _, out := range args {
_, file := filepath.Split(out)
entry := generators[file]
if entry.fn == nil {
e := generators[file]
if e.fn == nil {
g.reportError(errors.Errorf("unrecognized filename: %s", file))
return false
}
if printDeps {
if entry.dep == "" {
if e.inputFile == "" {
// This output file has no template dependency (its template
// is embedded entirely in execgenTool). Skip it.
continue
}
fmt.Printf("%s: %s\n", out, entry.dep)
fmt.Printf("%s: %s\n", out, e.inputFile)
} else {
if err := g.generate(out, entry.fn); err != nil {
if err := g.generate(out, e); err != nil {
g.reportError(err)
return false
}
......@@ -112,11 +115,20 @@ func (g *execgenTool) run(args ...string) bool {
var emptyCommentRegex = regexp.MustCompile(`[ \t]*//[ \t]*\n`)
var emptyBlockCommentRegex = regexp.MustCompile(`[ \t]*/\*[ \t]*\*/[ \t]*\n`)
func (g *execgenTool) generate(path string, genFunc generator) error {
func (g *execgenTool) generate(path string, entry entry) error {
var buf bytes.Buffer
buf.WriteString("// Code generated by execgen; DO NOT EDIT.\n")
err := genFunc(&buf)
var inputFileContents []byte
var err error
if entry.inputFile != "" {
inputFileContents, err = ioutil.ReadFile(entry.inputFile)
if err != nil {
return err
}
}
err = entry.fn(string(inputFileContents), &buf)
if err != nil {
return err
}
......
......@@ -12,7 +12,6 @@ package main
import (
"io"
"io/ioutil"
"strings"
"text/template"
......@@ -21,19 +20,14 @@ import (
const mergeJoinBaseTmpl = "pkg/sql/colexec/mergejoinbase_tmpl.go"
func genMergeJoinBase(wr io.Writer) error {
d, err := ioutil.ReadFile(mergeJoinBaseTmpl)
if err != nil {
return err
}
func genMergeJoinBase(inputFileContents string, wr io.Writer) error {
r := strings.NewReplacer(
"_CANONICAL_TYPE_FAMILY", "{{.CanonicalTypeFamilyStr}}",
"_TYPE_WIDTH", typeWidthReplacement,
"_TYPE", "{{.VecMethod}}",
"TemplateType", "{{.VecMethod}}",
)
s := r.Replace(string(d))
s := r.Replace(inputFileContents)
assignEqRe := makeFunctionRegex("_ASSIGN_EQ", 6)
s = assignEqRe.ReplaceAllString(s, makeTemplateFunctionCall("Assign", 6))
......
......@@ -13,7 +13,6 @@ package main
import (
"fmt"
"io"
"io/ioutil"
"strings"
"text/template"
......@@ -43,12 +42,7 @@ type joinTypeInfo struct {
const mergeJoinerTmpl = "pkg/sql/colexec/mergejoiner_tmpl.go"
func genMergeJoinOps(wr io.Writer, jti joinTypeInfo) error {
d, err := ioutil.ReadFile(mergeJoinerTmpl)
if err != nil {
return err
}
func genMergeJoinOps(inputFileContents string, wr io.Writer, jti joinTypeInfo) error {
r := strings.NewReplacer(
"_CANONICAL_TYPE_FAMILY", "{{.CanonicalTypeFamilyStr}}",
"_TYPE_WIDTH", typeWidthReplacement,
......@@ -68,7 +62,7 @@ func genMergeJoinOps(wr io.Writer, jti joinTypeInfo) error {
"_HAS_SELECTION", "$.HasSelection",
"_SEL_PERMUTATION", "$.SelPermutation",
)
s := r.Replace(string(d))
s := r.Replace(inputFileContents)
leftUnmatchedGroupSwitch := makeFunctionRegex("_LEFT_UNMATCHED_GROUP_SWITCH", 1)
s = leftUnmatchedGroupSwitch.ReplaceAllString(s, `{{template "leftUnmatchedGroupSwitch" buildDict "Global" $ "JoinType" $1}}`)
......@@ -196,8 +190,8 @@ func init() {
}
mergeJoinGenerator := func(jti joinTypeInfo) generator {
return func(wr io.Writer) error {
return genMergeJoinOps(wr, jti)
return func(inputFileContents string, wr io.Writer) error {
return genMergeJoinOps(inputFileContents, wr, jti)
}
}
......
......@@ -12,7 +12,6 @@ package main
import (
"io"
"io/ioutil"
"strings"
"text/template"
......@@ -40,12 +39,7 @@ var _ = aggOverloads{}.AggNameTitle()
const minMaxAggTmpl = "pkg/sql/colexec/min_max_agg_tmpl.go"
func genMinMaxAgg(wr io.Writer) error {
t, err := ioutil.ReadFile(minMaxAggTmpl)
if err != nil {
return err
}
func genMinMaxAgg(inputFileContents string, wr io.Writer) error {
r := strings.NewReplacer(
"_AGG_TITLE", "{{.AggNameTitle}}",
......@@ -57,7 +51,7 @@ func genMinMaxAgg(wr io.Writer) error {
"_TYPE", "{{.VecMethod}}",
"TemplateType", "{{.VecMethod}}",
)
s := r.Replace(string(t))
s := r.Replace(inputFileContents)
assignCmpRe := makeFunctionRegex("_ASSIGN_CMP", 6)
s = assignCmpRe.ReplaceAllString(s, makeTemplateFunctionCall("Assign", 6))
......
......@@ -12,7 +12,6 @@ package main
import (
"io"
"io/ioutil"
"strings"
"text/template"
......@@ -21,12 +20,7 @@ import (
const ordSyncTmpl = "pkg/sql/colexec/ordered_synchronizer_tmpl.go"
func genOrderedSynchronizer(wr io.Writer) error {
d, err := ioutil.ReadFile(ordSyncTmpl)
if err != nil {
return err
}
func genOrderedSynchronizer(inputFileContents string, wr io.Writer) error {
r := strings.NewReplacer(
"_CANONICAL_TYPE_FAMILY", "{{.CanonicalTypeFamilyStr}}",
......@@ -34,7 +28,7 @@ func genOrderedSynchronizer(wr io.Writer) error {
"_GOTYPESLICE", "{{.GoTypeSliceName}}",
"_TYPE", "{{.VecMethod}}",
)
s := r.Replace(string(d))
s := r.Replace(inputFileContents)
s = replaceManipulationFuncs(s)
......
......@@ -59,7 +59,7 @@ func {{template "opName" .}}(a {{.Left.GoType}}, b {{.Right.GoType}}) {{.Right.R
// genOverloadsTestUtils creates a file that has a function for each binary and
// comparison overload supported by the vectorized engine. This is so that we
// can more easily test each overload.
func genOverloadsTestUtils(wr io.Writer) error {
func genOverloadsTestUtils(_ string, wr io.Writer) error {
tmpl, err := template.New("overloads_test_utils").Parse(overloadsTestUtilsTemplate)
if err != nil {
return err
......
......@@ -12,26 +12,12 @@ package main
import (
"io"
"io/ioutil"
"strings"
"text/template"
)
const projConstOpsTmpl = "pkg/sql/colexec/proj_const_ops_tmpl.go"
// getProjConstOpTmplString returns a "projConstOp" template with isConstLeft
// determining whether the constant is on the left or on the right.
func getProjConstOpTmplString(isConstLeft bool) (string, error) {
t, err := ioutil.ReadFile(projConstOpsTmpl)
if err != nil {
return "", err
}
s := string(t)
s = replaceProjConstTmplVariables(s, isConstLeft)
return s, nil
}
// replaceProjTmplVariables replaces template variables used in the templates
// for projection operators. It should only be used within this file.
// Note that not all template variables can be present in the template, and it
......@@ -95,14 +81,8 @@ func replaceProjConstTmplVariables(tmpl string, isConstLeft bool) string {
const projNonConstOpsTmpl = "pkg/sql/colexec/proj_non_const_ops_tmpl.go"
// genProjNonConstOps is the generator for projection operators on two vectors.
func genProjNonConstOps(wr io.Writer) error {
t, err := ioutil.ReadFile(projNonConstOpsTmpl)
if err != nil {
return err
}
s := string(t)
s = replaceProjTmplVariables(s)
func genProjNonConstOps(inputFileContents string, wr io.Writer) error {
s := replaceProjTmplVariables(inputFileContents)
tmpl, err := template.New("proj_non_const_ops").Funcs(template.FuncMap{"buildDict": buildDict}).Parse(s)
if err != nil {
......@@ -114,11 +94,8 @@ func genProjNonConstOps(wr io.Writer) error {
func init() {
projConstOpsGenerator := func(isConstLeft bool) generator {
return func(wr io.Writer) error {
tmplString, err := getProjConstOpTmplString(isConstLeft)
if err != nil {
return err
}
return func(inputFileContents string, wr io.Writer) error {
tmplString := replaceProjConstTmplVariables(inputFileContents, isConstLeft)
tmpl, err := template.New("proj_const_ops").Funcs(template.FuncMap{"buildDict": buildDict}).Parse(tmplString)
if err != nil {
return err
......
......@@ -13,7 +13,6 @@ package main
import (
"fmt"
"io"
"io/ioutil"
"strings"
"text/template"
)
......@@ -59,15 +58,8 @@ var (
const rankTmpl = "pkg/sql/colexec/rank_tmpl.go"
func genRankOps(wr io.Writer) error {
d, err := ioutil.ReadFile(rankTmpl)
if err != nil {
return err
}
s := string(d)
s = strings.ReplaceAll(s, "_RANK_STRING", "{{.String}}")
func genRankOps(inputFileContents string, wr io.Writer) error {
s := strings.ReplaceAll(inputFileContents, "_RANK_STRING", "{{.String}}")