123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170 |
- // Copyright 2015 PingCAP, Inc.
- //
- // Licensed under the Apache License, Version 2.0 (the "License");
- // you may not use this file except in compliance with the License.
- // You may obtain a copy of the License at
- //
- // http://www.apache.org/licenses/LICENSE-2.0
- //
- // Unless required by applicable law or agreed to in writing, software
- // distributed under the License is distributed on an "AS IS" BASIS,
- // See the License for the specific language governing permissions and
- // limitations under the License.
- package ast
- // HasAggFlag checks if the expr contains FlagHasAggregateFunc.
- func HasAggFlag(expr ExprNode) bool {
- return expr.GetFlag()&FlagHasAggregateFunc > 0
- }
- func HasWindowFlag(expr ExprNode) bool {
- return expr.GetFlag()&FlagHasWindowFunc > 0
- }
- // SetFlag sets flag for expression.
- func SetFlag(n Node) {
- var setter flagSetter
- n.Accept(&setter)
- }
- type flagSetter struct {
- }
- func (f *flagSetter) Enter(in Node) (Node, bool) {
- return in, false
- }
- func (f *flagSetter) Leave(in Node) (Node, bool) {
- if x, ok := in.(ParamMarkerExpr); ok {
- x.SetFlag(FlagHasParamMarker)
- }
- switch x := in.(type) {
- case *AggregateFuncExpr:
- f.aggregateFunc(x)
- case *WindowFuncExpr:
- f.windowFunc(x)
- case *BetweenExpr:
- x.SetFlag(x.Expr.GetFlag() | x.Left.GetFlag() | x.Right.GetFlag())
- case *BinaryOperationExpr:
- x.SetFlag(x.L.GetFlag() | x.R.GetFlag())
- case *CaseExpr:
- f.caseExpr(x)
- case *ColumnNameExpr:
- x.SetFlag(FlagHasReference)
- case *CompareSubqueryExpr:
- x.SetFlag(x.L.GetFlag() | x.R.GetFlag())
- case *DefaultExpr:
- x.SetFlag(FlagHasDefault)
- case *ExistsSubqueryExpr:
- x.SetFlag(x.Sel.GetFlag())
- case *FuncCallExpr:
- f.funcCall(x)
- case *FuncCastExpr:
- x.SetFlag(FlagHasFunc | x.Expr.GetFlag())
- case *IsNullExpr:
- x.SetFlag(x.Expr.GetFlag())
- case *IsTruthExpr:
- x.SetFlag(x.Expr.GetFlag())
- case *ParenthesesExpr:
- x.SetFlag(x.Expr.GetFlag())
- case *PatternInExpr:
- f.patternIn(x)
- case *PatternLikeExpr:
- f.patternLike(x)
- case *PatternRegexpExpr:
- f.patternRegexp(x)
- case *PositionExpr:
- x.SetFlag(FlagHasReference)
- case *RowExpr:
- f.row(x)
- case *SubqueryExpr:
- x.SetFlag(FlagHasSubquery)
- case *UnaryOperationExpr:
- x.SetFlag(x.V.GetFlag())
- case *ValuesExpr:
- x.SetFlag(FlagHasReference)
- case *VariableExpr:
- if x.Value == nil {
- x.SetFlag(FlagHasVariable)
- } else {
- x.SetFlag(FlagHasVariable | x.Value.GetFlag())
- }
- }
- return in, true
- }
- func (f *flagSetter) caseExpr(x *CaseExpr) {
- var flag uint64
- if x.Value != nil {
- flag |= x.Value.GetFlag()
- }
- for _, val := range x.WhenClauses {
- flag |= val.Expr.GetFlag()
- flag |= val.Result.GetFlag()
- }
- if x.ElseClause != nil {
- flag |= x.ElseClause.GetFlag()
- }
- x.SetFlag(flag)
- }
- func (f *flagSetter) patternIn(x *PatternInExpr) {
- flag := x.Expr.GetFlag()
- for _, val := range x.List {
- flag |= val.GetFlag()
- }
- if x.Sel != nil {
- flag |= x.Sel.GetFlag()
- }
- x.SetFlag(flag)
- }
- func (f *flagSetter) patternLike(x *PatternLikeExpr) {
- flag := x.Pattern.GetFlag()
- if x.Expr != nil {
- flag |= x.Expr.GetFlag()
- }
- x.SetFlag(flag)
- }
- func (f *flagSetter) patternRegexp(x *PatternRegexpExpr) {
- flag := x.Pattern.GetFlag()
- if x.Expr != nil {
- flag |= x.Expr.GetFlag()
- }
- x.SetFlag(flag)
- }
- func (f *flagSetter) row(x *RowExpr) {
- var flag uint64
- for _, val := range x.Values {
- flag |= val.GetFlag()
- }
- x.SetFlag(flag)
- }
- func (f *flagSetter) funcCall(x *FuncCallExpr) {
- flag := FlagHasFunc
- for _, val := range x.Args {
- flag |= val.GetFlag()
- }
- x.SetFlag(flag)
- }
- func (f *flagSetter) aggregateFunc(x *AggregateFuncExpr) {
- flag := FlagHasAggregateFunc
- for _, val := range x.Args {
- flag |= val.GetFlag()
- }
- x.SetFlag(flag)
- }
- func (f *flagSetter) windowFunc(x *WindowFuncExpr) {
- flag := FlagHasWindowFunc
- for _, val := range x.Args {
- flag |= val.GetFlag()
- }
- x.SetFlag(flag)
- }
|