Skip to content

Commit

Permalink
Parse unsigned
Browse files Browse the repository at this point in the history
  • Loading branch information
kesonan committed Jul 11, 2022
1 parent e39f6c4 commit aa7e847
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 68 deletions.
1 change: 0 additions & 1 deletion gen/README
Original file line number Diff line number Diff line change
@@ -1 +0,0 @@
THE DIRECTORY gen WHICH IS GENERATED BY ANTLR, DO NOT EDIT!
16 changes: 15 additions & 1 deletion parser/datatype_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,17 @@ func TestVisitor_VisitDataType(t *testing.T) {
assert.Nil(t, err)
assertTypeEqual(t, dataType, actual)
}

testData = map[string]int{
`TINYINT(1) UNSIGNED`: TinyInt,
`SMALLINT UNSIGNED`: SmallInt,
`BIGINT UNSIGNED`: BigInt,
}
for sql, dataType := range testData {
actual, err := p.testMysqlSyntax("test.sql", accept, sql)
assert.Nil(t, err)
assertTypeEqual(t, dataType, actual, true)
}
})

t.Run("simpleDataType", func(t *testing.T) {
Expand Down Expand Up @@ -276,8 +287,11 @@ func TestVisitor_VisitDataType(t *testing.T) {
})
}

func assertTypeEqual(t *testing.T, expected int, actual interface{}) {
func assertTypeEqual(t *testing.T, expected int, actual interface{}, unsigned ...bool) {
assert.Equal(t, expected, actual.(DataType).Type())
if len(unsigned) > 0 {
assert.Equal(t, unsigned[0], actual.(DataType).Unsigned())
}
}

func assertEnumTypeEqual(t *testing.T, expectedType int, values []string, actual interface{}) {
Expand Down
145 changes: 79 additions & 66 deletions parser/datatype_visitor.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ const (
// DataType describes the data type and value of the column in table
type DataType interface {
Type() int
Unsigned() bool
// Value returns the values if the data type is Enum or Set
Value() []string
}
Expand All @@ -95,7 +96,13 @@ var _ DataType = (*EnumSetDataType)(nil)

// NormalDataType describes the data type which not contains Enum and Set of column
type NormalDataType struct {
tp int
tp int
unsigned bool
}

// Unsigned returns true if the data type is unsigned.
func (n *NormalDataType) Unsigned() bool {
return n.unsigned
}

// Type returns the data type of column
Expand All @@ -108,14 +115,14 @@ func (n *NormalDataType) Value() []string {
return nil
}

func with(tp int, value ...string) DataType {
func with(tp int, unsigned bool, value ...string) DataType {
if len(value) > 0 {
return &EnumSetDataType{
tp: tp,
value: value,
}
}
return &NormalDataType{tp: tp}
return &NormalDataType{tp: tp, unsigned: unsigned}
}

// EnumSetDataType describes the data type Enum and Set of column
Expand All @@ -129,6 +136,11 @@ func (e *EnumSetDataType) Type() int {
return e.tp
}

// Unsigned returns true if the data type is unsigned.
func (e *EnumSetDataType) Unsigned() bool {
return false
}

// Value returns the value of data type Enum and Set
func (e *EnumSetDataType) Value() []string {
return e.value
Expand Down Expand Up @@ -168,25 +180,25 @@ func (v *visitor) visitStringDataType(ctx *gen.StringDataTypeContext) DataType {
text := parseToken(ctx.GetTypeName(), withUpperCase(), withTrim("`"))
switch text {
case `CHAR`:
return with(Char)
return with(Char, false)
case `CHARACTER`:
return with(Character)
return with(Character, false)
case `VARCHAR`:
return with(VarChar)
return with(VarChar, false)
case `TINYTEXT`:
return with(TinyText)
return with(TinyText, false)
case `TEXT`:
return with(Text)
return with(Text, false)
case `MEDIUMTEXT`:
return with(MediumText)
return with(MediumText, false)
case `LONGTEXT`:
return with(LongText)
return with(LongText, false)
case `NCHAR`:
return with(NChar)
return with(NChar, false)
case `NVARCHAR`:
return with(NVarChar)
return with(NVarChar, false)
case `LONG`:
return with(LongVarChar)
return with(LongVarChar, false)
}

v.panicWithExpr(ctx.GetTypeName(), "invalid data type: "+text)
Expand All @@ -199,9 +211,9 @@ func (v *visitor) visitNationalStringDataType(ctx *gen.NationalStringDataTypeCon
text := parseToken(ctx.GetTypeName(), withUpperCase(), withTrim("`"))
switch text {
case `VARCHAR`:
return with(NVarChar)
return with(NVarChar, false)
case `CHARACTER`:
return with(NChar)
return with(NChar, false)
}

v.panicWithExpr(ctx.GetTypeName(), "invalid data type: "+text)
Expand All @@ -211,72 +223,73 @@ func (v *visitor) visitNationalStringDataType(ctx *gen.NationalStringDataTypeCon
// visitNationalVaryingStringDataType visits a parse tree produced by MySqlParser#nationalVaryingStringDataType.
func (v *visitor) visitNationalVaryingStringDataType(_ *gen.NationalVaryingStringDataTypeContext) DataType {
v.trace("VisitNationalVaryingStringDataType")
return with(NVarChar)
return with(NVarChar, false)
}

// visitDimensionDataType visits a parse tree produced by MySqlParser#dimensionDataType.
func (v *visitor) visitDimensionDataType(ctx *gen.DimensionDataTypeContext) DataType {
v.trace("VisitDimensionDataType")
text := parseToken(ctx.GetTypeName(), withUpperCase(), withTrim("`"))
unsigned := ctx.UNSIGNED() != nil
switch text {
case `BIT`:
return with(Bit)
return with(Bit, unsigned)
case `TIME`:
return with(Time)
return with(Time, unsigned)
case `TIMESTAMP`:
return with(Timestamp)
return with(Timestamp, unsigned)
case `DATETIME`:
return with(DateTime)
return with(DateTime, unsigned)
case `BINARY`:
return with(Binary)
return with(Binary, unsigned)
case `VARBINARY`:
return with(VarBinary)
return with(VarBinary, unsigned)
case `BLOB`:
return with(Blob)
return with(Blob, unsigned)
case `YEAR`:
return with(Year)
return with(Year, unsigned)
case `DECIMAL`:
return with(Decimal)
return with(Decimal, unsigned)
case `DEC`:
return with(Dec)
return with(Dec, unsigned)
case `FIXED`:
return with(Fixed)
return with(Fixed, unsigned)
case `NUMERIC`:
return with(Numeric)
return with(Numeric, unsigned)
case `FLOAT`:
return with(Float)
return with(Float, unsigned)
case `FLOAT4`:
return with(Float4)
return with(Float4, unsigned)
case `FLOAT8`:
return with(Float8)
return with(Float8, unsigned)
case `DOUBLE`:
return with(Double)
return with(Double, unsigned)
case `REAL`:
return with(Real)
return with(Real, unsigned)
case `TINYINT`:
return with(TinyInt)
return with(TinyInt, unsigned)
case `SMALLINT`:
return with(SmallInt)
return with(SmallInt, unsigned)
case `MEDIUMINT`:
return with(MediumInt)
return with(MediumInt, unsigned)
case `INT`:
return with(Int)
return with(Int, unsigned)
case `INTEGER`:
return with(Integer)
return with(Integer, unsigned)
case `BIGINT`:
return with(BigInt)
return with(BigInt, unsigned)
case `MIDDLEINT`:
return with(MiddleInt)
return with(MiddleInt, unsigned)
case `INT1`:
return with(Int1)
return with(Int1, unsigned)
case `INT2`:
return with(Int2)
return with(Int2, unsigned)
case `INT3`:
return with(Int3)
return with(Int3, unsigned)
case `INT4`:
return with(Int4)
return with(Int4, unsigned)
case `INT8`:
return with(Int8)
return with(Int8, unsigned)
}

v.panicWithExpr(ctx.GetTypeName(), "invalid data type: "+text)
Expand All @@ -294,19 +307,19 @@ func (v *visitor) visitSimpleDataType(ctx *gen.SimpleDataTypeContext) DataType {

switch text {
case `DATE`:
return with(Date)
return with(Date, false)
case `TINYBLOB`:
return with(TinyBlob)
return with(TinyBlob, false)
case `MEDIUMBLOB`:
return with(MediumBlob)
return with(MediumBlob, false)
case `LONGBLOB`:
return with(LongBlob)
return with(LongBlob, false)
case `BOOL`:
return with(Bool)
return with(Bool, false)
case `BOOLEAN`:
return with(Boolean)
return with(Boolean, false)
case `SERIAL`:
return with(Serial)
return with(Serial, false)
}

v.panicWithExpr(ctx.GetTypeName(), "invalid data type: "+text)
Expand Down Expand Up @@ -339,9 +352,9 @@ func (v *visitor) visitCollectionDataType(ctx *gen.CollectionDataTypeContext) Da

switch text {
case `ENUM`:
return with(Enum, values...)
return with(Enum, false, values...)
case `SET`:
return with(Set, values...)
return with(Set, false, values...)
}

v.panicWithExpr(ctx.GetTypeName(), "invalid data type: "+text)
Expand All @@ -359,25 +372,25 @@ func (v *visitor) visitSpatialDataType(ctx *gen.SpatialDataTypeContext) DataType

switch text {
case `GEOMETRYCOLLECTION`:
return with(GeometryCollection)
return with(GeometryCollection, false)
case `GEOMCOLLECTION`:
return with(GeomCollection)
return with(GeomCollection, false)
case `LINESTRING`:
return with(LineString)
return with(LineString, false)
case `MULTILINESTRING`:
return with(MultiLineString)
return with(MultiLineString, false)
case `MULTIPOINT`:
return with(MultiPoint)
return with(MultiPoint, false)
case `MULTIPOLYGON`:
return with(MultiPolygon)
return with(MultiPolygon, false)
case `POINT`:
return with(Point)
return with(Point, false)
case `POLYGON`:
return with(Polygon)
return with(Polygon, false)
case `JSON`:
return with(Json)
return with(Json, false)
case `GEOMETRY`:
return with(Geometry)
return with(Geometry, false)
}

v.panicWithExpr(ctx.GetTypeName(), "invalid data type: "+text)
Expand All @@ -387,11 +400,11 @@ func (v *visitor) visitSpatialDataType(ctx *gen.SpatialDataTypeContext) DataType
// visitLongVarcharDataType visits a parse tree produced by MySqlParser#longVarcharDataType.
func (v *visitor) visitLongVarcharDataType(_ *gen.LongVarcharDataTypeContext) DataType {
v.trace("VisitLongVarcharDataType")
return with(LongVarChar)
return with(LongVarChar, false)
}

// visitLongVarbinaryDataType visits a parse tree produced by MySqlParser#longVarbinaryDataType.
func (v *visitor) visitLongVarbinaryDataType(_ *gen.LongVarbinaryDataTypeContext) DataType {
v.trace("VisitLongVarbinaryDataType")
return with(LongVarBinary)
return with(LongVarBinary, false)
}

0 comments on commit aa7e847

Please sign in to comment.