flag.go 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. // Copyright 2015 PingCAP, Inc.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // See the License for the specific language governing permissions and
  12. // limitations under the License.
  13. package ast
  14. // HasAggFlag checks if the expr contains FlagHasAggregateFunc.
  15. func HasAggFlag(expr ExprNode) bool {
  16. return expr.GetFlag()&FlagHasAggregateFunc > 0
  17. }
  18. func HasWindowFlag(expr ExprNode) bool {
  19. return expr.GetFlag()&FlagHasWindowFunc > 0
  20. }
  21. // SetFlag sets flag for expression.
  22. func SetFlag(n Node) {
  23. var setter flagSetter
  24. n.Accept(&setter)
  25. }
  26. type flagSetter struct {
  27. }
  28. func (f *flagSetter) Enter(in Node) (Node, bool) {
  29. return in, false
  30. }
  31. func (f *flagSetter) Leave(in Node) (Node, bool) {
  32. if x, ok := in.(ParamMarkerExpr); ok {
  33. x.SetFlag(FlagHasParamMarker)
  34. }
  35. switch x := in.(type) {
  36. case *AggregateFuncExpr:
  37. f.aggregateFunc(x)
  38. case *WindowFuncExpr:
  39. f.windowFunc(x)
  40. case *BetweenExpr:
  41. x.SetFlag(x.Expr.GetFlag() | x.Left.GetFlag() | x.Right.GetFlag())
  42. case *BinaryOperationExpr:
  43. x.SetFlag(x.L.GetFlag() | x.R.GetFlag())
  44. case *CaseExpr:
  45. f.caseExpr(x)
  46. case *ColumnNameExpr:
  47. x.SetFlag(FlagHasReference)
  48. case *CompareSubqueryExpr:
  49. x.SetFlag(x.L.GetFlag() | x.R.GetFlag())
  50. case *DefaultExpr:
  51. x.SetFlag(FlagHasDefault)
  52. case *ExistsSubqueryExpr:
  53. x.SetFlag(x.Sel.GetFlag())
  54. case *FuncCallExpr:
  55. f.funcCall(x)
  56. case *FuncCastExpr:
  57. x.SetFlag(FlagHasFunc | x.Expr.GetFlag())
  58. case *IsNullExpr:
  59. x.SetFlag(x.Expr.GetFlag())
  60. case *IsTruthExpr:
  61. x.SetFlag(x.Expr.GetFlag())
  62. case *ParenthesesExpr:
  63. x.SetFlag(x.Expr.GetFlag())
  64. case *PatternInExpr:
  65. f.patternIn(x)
  66. case *PatternLikeExpr:
  67. f.patternLike(x)
  68. case *PatternRegexpExpr:
  69. f.patternRegexp(x)
  70. case *PositionExpr:
  71. x.SetFlag(FlagHasReference)
  72. case *RowExpr:
  73. f.row(x)
  74. case *SubqueryExpr:
  75. x.SetFlag(FlagHasSubquery)
  76. case *UnaryOperationExpr:
  77. x.SetFlag(x.V.GetFlag())
  78. case *ValuesExpr:
  79. x.SetFlag(FlagHasReference)
  80. case *VariableExpr:
  81. if x.Value == nil {
  82. x.SetFlag(FlagHasVariable)
  83. } else {
  84. x.SetFlag(FlagHasVariable | x.Value.GetFlag())
  85. }
  86. }
  87. return in, true
  88. }
  89. func (f *flagSetter) caseExpr(x *CaseExpr) {
  90. var flag uint64
  91. if x.Value != nil {
  92. flag |= x.Value.GetFlag()
  93. }
  94. for _, val := range x.WhenClauses {
  95. flag |= val.Expr.GetFlag()
  96. flag |= val.Result.GetFlag()
  97. }
  98. if x.ElseClause != nil {
  99. flag |= x.ElseClause.GetFlag()
  100. }
  101. x.SetFlag(flag)
  102. }
  103. func (f *flagSetter) patternIn(x *PatternInExpr) {
  104. flag := x.Expr.GetFlag()
  105. for _, val := range x.List {
  106. flag |= val.GetFlag()
  107. }
  108. if x.Sel != nil {
  109. flag |= x.Sel.GetFlag()
  110. }
  111. x.SetFlag(flag)
  112. }
  113. func (f *flagSetter) patternLike(x *PatternLikeExpr) {
  114. flag := x.Pattern.GetFlag()
  115. if x.Expr != nil {
  116. flag |= x.Expr.GetFlag()
  117. }
  118. x.SetFlag(flag)
  119. }
  120. func (f *flagSetter) patternRegexp(x *PatternRegexpExpr) {
  121. flag := x.Pattern.GetFlag()
  122. if x.Expr != nil {
  123. flag |= x.Expr.GetFlag()
  124. }
  125. x.SetFlag(flag)
  126. }
  127. func (f *flagSetter) row(x *RowExpr) {
  128. var flag uint64
  129. for _, val := range x.Values {
  130. flag |= val.GetFlag()
  131. }
  132. x.SetFlag(flag)
  133. }
  134. func (f *flagSetter) funcCall(x *FuncCallExpr) {
  135. flag := FlagHasFunc
  136. for _, val := range x.Args {
  137. flag |= val.GetFlag()
  138. }
  139. x.SetFlag(flag)
  140. }
  141. func (f *flagSetter) aggregateFunc(x *AggregateFuncExpr) {
  142. flag := FlagHasAggregateFunc
  143. for _, val := range x.Args {
  144. flag |= val.GetFlag()
  145. }
  146. x.SetFlag(flag)
  147. }
  148. func (f *flagSetter) windowFunc(x *WindowFuncExpr) {
  149. flag := FlagHasWindowFunc
  150. for _, val := range x.Args {
  151. flag |= val.GetFlag()
  152. }
  153. x.SetFlag(flag)
  154. }