schema.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456
  1. // Copyright 2012, Google Inc. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package schema
  5. import (
  6. "database/sql"
  7. "fmt"
  8. "strconv"
  9. "strings"
  10. "github.com/pingcap/errors"
  11. "github.com/go-mysql-org/go-mysql/mysql"
  12. )
  13. var ErrTableNotExist = errors.New("table is not exist")
  14. var ErrMissingTableMeta = errors.New("missing table meta")
  15. var HAHealthCheckSchema = "mysql.ha_health_check"
  16. // Different column type
  17. const (
  18. TYPE_NUMBER = iota + 1 // tinyint, smallint, int, bigint, year
  19. TYPE_FLOAT // float, double
  20. TYPE_ENUM // enum
  21. TYPE_SET // set
  22. TYPE_STRING // char, varchar, etc.
  23. TYPE_DATETIME // datetime
  24. TYPE_TIMESTAMP // timestamp
  25. TYPE_DATE // date
  26. TYPE_TIME // time
  27. TYPE_BIT // bit
  28. TYPE_JSON // json
  29. TYPE_DECIMAL // decimal
  30. TYPE_MEDIUM_INT
  31. TYPE_BINARY // binary, varbinary
  32. TYPE_POINT // coordinates
  33. )
  34. type TableColumn struct {
  35. Name string
  36. Type int
  37. Collation string
  38. RawType string
  39. IsAuto bool
  40. IsUnsigned bool
  41. IsVirtual bool
  42. IsStored bool
  43. EnumValues []string
  44. SetValues []string
  45. FixedSize uint
  46. MaxSize uint
  47. }
  48. type Index struct {
  49. Name string
  50. Columns []string
  51. Cardinality []uint64
  52. NoneUnique uint64
  53. }
  54. type Table struct {
  55. Schema string
  56. Name string
  57. Columns []TableColumn
  58. Indexes []*Index
  59. PKColumns []int
  60. UnsignedColumns []int
  61. }
  62. func (ta *Table) String() string {
  63. return fmt.Sprintf("%s.%s", ta.Schema, ta.Name)
  64. }
  65. func (ta *Table) AddColumn(name string, columnType string, collation string, extra string) {
  66. index := len(ta.Columns)
  67. ta.Columns = append(ta.Columns, TableColumn{Name: name, Collation: collation})
  68. ta.Columns[index].RawType = columnType
  69. if strings.HasPrefix(columnType, "float") ||
  70. strings.HasPrefix(columnType, "double") {
  71. ta.Columns[index].Type = TYPE_FLOAT
  72. } else if strings.HasPrefix(columnType, "decimal") {
  73. ta.Columns[index].Type = TYPE_DECIMAL
  74. } else if strings.HasPrefix(columnType, "enum") {
  75. ta.Columns[index].Type = TYPE_ENUM
  76. ta.Columns[index].EnumValues = strings.Split(strings.Replace(
  77. strings.TrimSuffix(
  78. strings.TrimPrefix(
  79. columnType, "enum("),
  80. ")"),
  81. "'", "", -1),
  82. ",")
  83. } else if strings.HasPrefix(columnType, "set") {
  84. ta.Columns[index].Type = TYPE_SET
  85. ta.Columns[index].SetValues = strings.Split(strings.Replace(
  86. strings.TrimSuffix(
  87. strings.TrimPrefix(
  88. columnType, "set("),
  89. ")"),
  90. "'", "", -1),
  91. ",")
  92. } else if strings.HasPrefix(columnType, "binary") {
  93. ta.Columns[index].Type = TYPE_BINARY
  94. size := getSizeFromColumnType(columnType)
  95. ta.Columns[index].MaxSize = size
  96. ta.Columns[index].FixedSize = size
  97. } else if strings.HasPrefix(columnType, "varbinary") {
  98. ta.Columns[index].Type = TYPE_BINARY
  99. ta.Columns[index].MaxSize = getSizeFromColumnType(columnType)
  100. } else if strings.HasPrefix(columnType, "datetime") {
  101. ta.Columns[index].Type = TYPE_DATETIME
  102. } else if strings.HasPrefix(columnType, "timestamp") {
  103. ta.Columns[index].Type = TYPE_TIMESTAMP
  104. } else if strings.HasPrefix(columnType, "time") {
  105. ta.Columns[index].Type = TYPE_TIME
  106. } else if "date" == columnType {
  107. ta.Columns[index].Type = TYPE_DATE
  108. } else if strings.HasPrefix(columnType, "bit") {
  109. ta.Columns[index].Type = TYPE_BIT
  110. } else if strings.HasPrefix(columnType, "json") {
  111. ta.Columns[index].Type = TYPE_JSON
  112. } else if strings.Contains(columnType, "point") {
  113. ta.Columns[index].Type = TYPE_POINT
  114. } else if strings.Contains(columnType, "mediumint") {
  115. ta.Columns[index].Type = TYPE_MEDIUM_INT
  116. } else if strings.Contains(columnType, "int") || strings.HasPrefix(columnType, "year") {
  117. ta.Columns[index].Type = TYPE_NUMBER
  118. } else if strings.HasPrefix(columnType, "char") {
  119. ta.Columns[index].Type = TYPE_STRING
  120. size := getSizeFromColumnType(columnType)
  121. ta.Columns[index].FixedSize = size
  122. ta.Columns[index].MaxSize = size
  123. } else {
  124. ta.Columns[index].Type = TYPE_STRING
  125. ta.Columns[index].MaxSize = getSizeFromColumnType(columnType)
  126. }
  127. if strings.Contains(columnType, "unsigned") || strings.Contains(columnType, "zerofill") {
  128. ta.Columns[index].IsUnsigned = true
  129. ta.UnsignedColumns = append(ta.UnsignedColumns, index)
  130. }
  131. if extra == "auto_increment" {
  132. ta.Columns[index].IsAuto = true
  133. } else if extra == "VIRTUAL GENERATED" {
  134. ta.Columns[index].IsVirtual = true
  135. } else if extra == "STORED GENERATED" {
  136. ta.Columns[index].IsStored = true
  137. }
  138. }
  139. func getSizeFromColumnType(columnType string) uint {
  140. startIndex := strings.Index(columnType, "(")
  141. if startIndex < 0 {
  142. return 0
  143. }
  144. // we are searching for the first () and there may not be any closing
  145. // brackets before the opening, so no need search at the offset from the
  146. // opening ones
  147. endIndex := strings.Index(columnType, ")")
  148. if startIndex < 0 || endIndex < 0 || startIndex > endIndex {
  149. return 0
  150. }
  151. i, err := strconv.Atoi(columnType[startIndex+1 : endIndex])
  152. if err != nil || i < 0 {
  153. return 0
  154. }
  155. return uint(i)
  156. }
  157. func (ta *Table) FindColumn(name string) int {
  158. for i, col := range ta.Columns {
  159. if col.Name == name {
  160. return i
  161. }
  162. }
  163. return -1
  164. }
  165. func (ta *Table) GetPKColumn(index int) *TableColumn {
  166. return &ta.Columns[ta.PKColumns[index]]
  167. }
  168. func (ta *Table) AddIndex(name string) (index *Index) {
  169. index = NewIndex(name)
  170. ta.Indexes = append(ta.Indexes, index)
  171. return index
  172. }
  173. func NewIndex(name string) *Index {
  174. return &Index{name, make([]string, 0, 8), make([]uint64, 0, 8), 0}
  175. }
  176. func (idx *Index) AddColumn(name string, cardinality uint64) {
  177. idx.Columns = append(idx.Columns, name)
  178. if cardinality == 0 {
  179. cardinality = uint64(len(idx.Cardinality) + 1)
  180. }
  181. idx.Cardinality = append(idx.Cardinality, cardinality)
  182. }
  183. func (idx *Index) FindColumn(name string) int {
  184. for i, colName := range idx.Columns {
  185. if name == colName {
  186. return i
  187. }
  188. }
  189. return -1
  190. }
  191. func IsTableExist(conn mysql.Executer, schema string, name string) (bool, error) {
  192. query := fmt.Sprintf("SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = '%s' and TABLE_NAME = '%s' LIMIT 1", schema, name)
  193. r, err := conn.Execute(query)
  194. if err != nil {
  195. return false, errors.Trace(err)
  196. }
  197. return r.RowNumber() == 1, nil
  198. }
  199. func NewTableFromSqlDB(conn *sql.DB, schema string, name string) (*Table, error) {
  200. ta := &Table{
  201. Schema: schema,
  202. Name: name,
  203. Columns: make([]TableColumn, 0, 16),
  204. Indexes: make([]*Index, 0, 8),
  205. }
  206. if err := ta.fetchColumnsViaSqlDB(conn); err != nil {
  207. return nil, errors.Trace(err)
  208. }
  209. if err := ta.fetchIndexesViaSqlDB(conn); err != nil {
  210. return nil, errors.Trace(err)
  211. }
  212. return ta, nil
  213. }
  214. func NewTable(conn mysql.Executer, schema string, name string) (*Table, error) {
  215. ta := &Table{
  216. Schema: schema,
  217. Name: name,
  218. Columns: make([]TableColumn, 0, 16),
  219. Indexes: make([]*Index, 0, 8),
  220. }
  221. if err := ta.fetchColumns(conn); err != nil {
  222. return nil, errors.Trace(err)
  223. }
  224. if err := ta.fetchIndexes(conn); err != nil {
  225. return nil, errors.Trace(err)
  226. }
  227. return ta, nil
  228. }
  229. func (ta *Table) fetchColumns(conn mysql.Executer) error {
  230. r, err := conn.Execute(fmt.Sprintf("show full columns from `%s`.`%s`", ta.Schema, ta.Name))
  231. if err != nil {
  232. return errors.Trace(err)
  233. }
  234. for i := 0; i < r.RowNumber(); i++ {
  235. name, _ := r.GetString(i, 0)
  236. colType, _ := r.GetString(i, 1)
  237. collation, _ := r.GetString(i, 2)
  238. extra, _ := r.GetString(i, 6)
  239. ta.AddColumn(name, colType, collation, extra)
  240. }
  241. return nil
  242. }
  243. func (ta *Table) fetchColumnsViaSqlDB(conn *sql.DB) error {
  244. r, err := conn.Query(fmt.Sprintf("show full columns from `%s`.`%s`", ta.Schema, ta.Name))
  245. if err != nil {
  246. return errors.Trace(err)
  247. }
  248. defer r.Close()
  249. var unusedVal interface{}
  250. unused := &unusedVal
  251. for r.Next() {
  252. var name, colType, extra string
  253. var collation sql.NullString
  254. err := r.Scan(&name, &colType, &collation, &unused, &unused, &unused, &extra, &unused, &unused)
  255. if err != nil {
  256. return errors.Trace(err)
  257. }
  258. ta.AddColumn(name, colType, collation.String, extra)
  259. }
  260. return r.Err()
  261. }
  262. func (ta *Table) fetchIndexes(conn mysql.Executer) error {
  263. r, err := conn.Execute(fmt.Sprintf("show index from `%s`.`%s`", ta.Schema, ta.Name))
  264. if err != nil {
  265. return errors.Trace(err)
  266. }
  267. var currentIndex *Index
  268. currentName := ""
  269. for i := 0; i < r.RowNumber(); i++ {
  270. indexName, _ := r.GetString(i, 2)
  271. if currentName != indexName {
  272. currentIndex = ta.AddIndex(indexName)
  273. currentName = indexName
  274. }
  275. cardinality, _ := r.GetUint(i, 6)
  276. colName, _ := r.GetString(i, 4)
  277. currentIndex.AddColumn(colName, cardinality)
  278. currentIndex.NoneUnique, _ = r.GetUint(i, 1)
  279. }
  280. return ta.fetchPrimaryKeyColumns()
  281. }
  282. func (ta *Table) fetchIndexesViaSqlDB(conn *sql.DB) error {
  283. r, err := conn.Query(fmt.Sprintf("show index from `%s`.`%s`", ta.Schema, ta.Name))
  284. if err != nil {
  285. return errors.Trace(err)
  286. }
  287. defer r.Close()
  288. var currentIndex *Index
  289. currentName := ""
  290. var unusedVal interface{}
  291. for r.Next() {
  292. var indexName, colName string
  293. var noneUnique uint64
  294. var cardinality interface{}
  295. cols, err := r.Columns()
  296. if err != nil {
  297. return errors.Trace(err)
  298. }
  299. values := make([]interface{}, len(cols))
  300. for i := 0; i < len(cols); i++ {
  301. switch i {
  302. case 1:
  303. values[i] = &noneUnique
  304. case 2:
  305. values[i] = &indexName
  306. case 4:
  307. values[i] = &colName
  308. case 6:
  309. values[i] = &cardinality
  310. default:
  311. values[i] = &unusedVal
  312. }
  313. }
  314. err = r.Scan(values...)
  315. if err != nil {
  316. return errors.Trace(err)
  317. }
  318. if currentName != indexName {
  319. currentIndex = ta.AddIndex(indexName)
  320. currentName = indexName
  321. }
  322. c := toUint64(cardinality)
  323. currentIndex.AddColumn(colName, c)
  324. currentIndex.NoneUnique = noneUnique
  325. }
  326. return ta.fetchPrimaryKeyColumns()
  327. }
  328. func toUint64(i interface{}) uint64 {
  329. switch i := i.(type) {
  330. case int:
  331. return uint64(i)
  332. case int8:
  333. return uint64(i)
  334. case int16:
  335. return uint64(i)
  336. case int32:
  337. return uint64(i)
  338. case int64:
  339. return uint64(i)
  340. case uint:
  341. return uint64(i)
  342. case uint8:
  343. return uint64(i)
  344. case uint16:
  345. return uint64(i)
  346. case uint32:
  347. return uint64(i)
  348. case uint64:
  349. return i
  350. }
  351. return 0
  352. }
  353. func (ta *Table) fetchPrimaryKeyColumns() error {
  354. if len(ta.Indexes) == 0 {
  355. return nil
  356. }
  357. pkIndex := ta.Indexes[0]
  358. if pkIndex.Name != "PRIMARY" {
  359. return nil
  360. }
  361. ta.PKColumns = make([]int, len(pkIndex.Columns))
  362. for i, pkCol := range pkIndex.Columns {
  363. ta.PKColumns[i] = ta.FindColumn(pkCol)
  364. }
  365. return nil
  366. }
  367. // GetPKValues gets primary keys in one row for a table, a table may use multi fields as the PK
  368. func (ta *Table) GetPKValues(row []interface{}) ([]interface{}, error) {
  369. indexes := ta.PKColumns
  370. if len(indexes) == 0 {
  371. return nil, errors.Errorf("table %s has no PK", ta)
  372. } else if len(ta.Columns) != len(row) {
  373. return nil, errors.Errorf("table %s has %d columns, but row data %v len is %d", ta,
  374. len(ta.Columns), row, len(row))
  375. }
  376. values := make([]interface{}, 0, len(indexes))
  377. for _, index := range indexes {
  378. values = append(values, row[index])
  379. }
  380. return values, nil
  381. }
  382. // GetColumnValue gets term column's value
  383. func (ta *Table) GetColumnValue(column string, row []interface{}) (interface{}, error) {
  384. index := ta.FindColumn(column)
  385. if index == -1 {
  386. return nil, errors.Errorf("table %s has no column name %s", ta, column)
  387. }
  388. return row[index], nil
  389. }