407 lines
9.3 KiB
Go
407 lines
9.3 KiB
Go
package native
|
|
|
|
import (
|
|
"errors"
|
|
"github.com/ziutek/mymysql/mysql"
|
|
"log"
|
|
"math"
|
|
"strconv"
|
|
)
|
|
|
|
type Result struct {
|
|
my *Conn
|
|
status_only bool // true if result doesn't contain result set
|
|
binary bool // Binary result expected
|
|
|
|
field_count int
|
|
fields []*mysql.Field // Fields table
|
|
fc_map map[string]int // Maps field name to column number
|
|
|
|
message []byte
|
|
affected_rows uint64
|
|
|
|
// Primary key value (useful for AUTO_INCREMENT primary keys)
|
|
insert_id uint64
|
|
|
|
// Number of warinigs during command execution
|
|
// You can use the SHOW WARNINGS query for details.
|
|
warning_count int
|
|
|
|
// MySQL server status immediately after the query execution
|
|
status mysql.ConnStatus
|
|
|
|
// Seted by GetRow if it returns nil row
|
|
eor_returned bool
|
|
}
|
|
|
|
// Returns true if this is status result that includes no result set
|
|
func (res *Result) StatusOnly() bool {
|
|
return res.status_only
|
|
}
|
|
|
|
// Returns a table containing descriptions of the columns
|
|
func (res *Result) Fields() []*mysql.Field {
|
|
return res.fields
|
|
}
|
|
|
|
// Returns index for given name or -1 if field of that name doesn't exist
|
|
func (res *Result) Map(field_name string) int {
|
|
if fi, ok := res.fc_map[field_name]; ok {
|
|
return fi
|
|
}
|
|
return -1
|
|
}
|
|
|
|
func (res *Result) Message() string {
|
|
return string(res.message)
|
|
}
|
|
|
|
func (res *Result) AffectedRows() uint64 {
|
|
return res.affected_rows
|
|
}
|
|
|
|
func (res *Result) InsertId() uint64 {
|
|
return res.insert_id
|
|
}
|
|
|
|
func (res *Result) WarnCount() int {
|
|
return res.warning_count
|
|
}
|
|
|
|
func (res *Result) MakeRow() mysql.Row {
|
|
return make(mysql.Row, res.field_count)
|
|
}
|
|
|
|
func (my *Conn) getResult(res *Result, row mysql.Row) *Result {
|
|
loop:
|
|
pr := my.newPktReader() // New reader for next packet
|
|
pkt0 := pr.readByte()
|
|
|
|
if pkt0 == 255 {
|
|
// Error packet
|
|
my.getErrorPacket(pr)
|
|
}
|
|
|
|
if res == nil {
|
|
switch {
|
|
case pkt0 == 0:
|
|
// OK packet
|
|
return my.getOkPacket(pr)
|
|
|
|
case pkt0 > 0 && pkt0 < 251:
|
|
// Result set header packet
|
|
res = my.getResSetHeadPacket(pr)
|
|
// Read next packet
|
|
goto loop
|
|
case pkt0 == 251:
|
|
// Load infile response
|
|
// Handle response
|
|
goto loop
|
|
case pkt0 == 254:
|
|
// EOF packet (without body)
|
|
return nil
|
|
}
|
|
} else {
|
|
switch {
|
|
case pkt0 == 254:
|
|
// EOF packet
|
|
res.warning_count, res.status = my.getEofPacket(pr)
|
|
my.status = res.status
|
|
return res
|
|
|
|
case pkt0 > 0 && pkt0 < 251 && res.field_count < len(res.fields):
|
|
// Field packet
|
|
field := my.getFieldPacket(pr)
|
|
res.fields[res.field_count] = field
|
|
res.fc_map[field.Name] = res.field_count
|
|
// Increment field count
|
|
res.field_count++
|
|
// Read next packet
|
|
goto loop
|
|
|
|
case pkt0 < 254 && res.field_count == len(res.fields):
|
|
// Row Data Packet
|
|
if len(row) != res.field_count {
|
|
panic(mysql.ErrRowLength)
|
|
}
|
|
if res.binary {
|
|
my.getBinRowPacket(pr, res, row)
|
|
} else {
|
|
my.getTextRowPacket(pr, res, row)
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
panic(mysql.ErrUnkResultPkt)
|
|
}
|
|
|
|
func (my *Conn) getOkPacket(pr *pktReader) (res *Result) {
|
|
if my.Debug {
|
|
log.Printf("[%2d ->] OK packet:", my.seq-1)
|
|
}
|
|
res = new(Result)
|
|
res.status_only = true
|
|
res.my = my
|
|
// First byte was readed by getResult
|
|
res.affected_rows = pr.readLCB()
|
|
res.insert_id = pr.readLCB()
|
|
res.status = mysql.ConnStatus(pr.readU16())
|
|
my.status = res.status
|
|
res.warning_count = int(pr.readU16())
|
|
res.message = pr.readAll()
|
|
pr.checkEof()
|
|
|
|
if my.Debug {
|
|
log.Printf(tab8s+"AffectedRows=%d InsertId=0x%x Status=0x%x "+
|
|
"WarningCount=%d Message=\"%s\"", res.affected_rows, res.insert_id,
|
|
res.status, res.warning_count, res.message,
|
|
)
|
|
}
|
|
return
|
|
}
|
|
|
|
func (my *Conn) getErrorPacket(pr *pktReader) {
|
|
if my.Debug {
|
|
log.Printf("[%2d ->] Error packet:", my.seq-1)
|
|
}
|
|
var err mysql.Error
|
|
err.Code = pr.readU16()
|
|
if pr.readByte() != '#' {
|
|
panic(mysql.ErrPkt)
|
|
}
|
|
pr.skipN(5)
|
|
err.Msg = pr.readAll()
|
|
pr.checkEof()
|
|
|
|
if my.Debug {
|
|
log.Printf(tab8s+"code=0x%x msg=\"%s\"", err.Code, err.Msg)
|
|
}
|
|
panic(&err)
|
|
}
|
|
|
|
func (my *Conn) getEofPacket(pr *pktReader) (warn_count int, status mysql.ConnStatus) {
|
|
if my.Debug {
|
|
if pr.eof() {
|
|
log.Printf("[%2d ->] EOF packet without body", my.seq-1)
|
|
} else {
|
|
log.Printf("[%2d ->] EOF packet:", my.seq-1)
|
|
}
|
|
}
|
|
if pr.eof() {
|
|
return
|
|
}
|
|
warn_count = int(pr.readU16())
|
|
if pr.eof() {
|
|
return
|
|
}
|
|
status = mysql.ConnStatus(pr.readU16())
|
|
pr.checkEof()
|
|
|
|
if my.Debug {
|
|
log.Printf(tab8s+"WarningCount=%d Status=0x%x", warn_count, status)
|
|
}
|
|
return
|
|
}
|
|
|
|
func (my *Conn) getResSetHeadPacket(pr *pktReader) (res *Result) {
|
|
if my.Debug {
|
|
log.Printf("[%2d ->] Result set header packet:", my.seq-1)
|
|
}
|
|
pr.unreadByte()
|
|
|
|
field_count := int(pr.readLCB())
|
|
pr.checkEof()
|
|
|
|
res = &Result{
|
|
my: my,
|
|
fields: make([]*mysql.Field, field_count),
|
|
fc_map: make(map[string]int),
|
|
}
|
|
|
|
if my.Debug {
|
|
log.Printf(tab8s+"FieldCount=%d", field_count)
|
|
}
|
|
return
|
|
}
|
|
|
|
func (my *Conn) getFieldPacket(pr *pktReader) (field *mysql.Field) {
|
|
if my.Debug {
|
|
log.Printf("[%2d ->] Field packet:", my.seq-1)
|
|
}
|
|
pr.unreadByte()
|
|
|
|
field = new(mysql.Field)
|
|
if my.fullFieldInfo {
|
|
field.Catalog = string(pr.readBin())
|
|
field.Db = string(pr.readBin())
|
|
field.Table = string(pr.readBin())
|
|
field.OrgTable = string(pr.readBin())
|
|
} else {
|
|
pr.skipBin()
|
|
pr.skipBin()
|
|
pr.skipBin()
|
|
pr.skipBin()
|
|
}
|
|
field.Name = string(pr.readBin())
|
|
if my.fullFieldInfo {
|
|
field.OrgName = string(pr.readBin())
|
|
} else {
|
|
pr.skipBin()
|
|
}
|
|
pr.skipN(1 + 2)
|
|
//field.Charset= pr.readU16()
|
|
field.DispLen = pr.readU32()
|
|
field.Type = pr.readByte()
|
|
field.Flags = pr.readU16()
|
|
field.Scale = pr.readByte()
|
|
pr.skipN(2)
|
|
pr.checkEof()
|
|
|
|
if my.Debug {
|
|
log.Printf(tab8s+"Name=\"%s\" Type=0x%x", field.Name, field.Type)
|
|
}
|
|
return
|
|
}
|
|
|
|
func (my *Conn) getTextRowPacket(pr *pktReader, res *Result, row mysql.Row) {
|
|
if my.Debug {
|
|
log.Printf("[%2d ->] Text row data packet", my.seq-1)
|
|
}
|
|
pr.unreadByte()
|
|
|
|
for ii := 0; ii < res.field_count; ii++ {
|
|
bin, null := pr.readNullBin()
|
|
if null {
|
|
row[ii] = nil
|
|
} else {
|
|
row[ii] = bin
|
|
}
|
|
}
|
|
pr.checkEof()
|
|
}
|
|
|
|
func (my *Conn) getBinRowPacket(pr *pktReader, res *Result, row mysql.Row) {
|
|
if my.Debug {
|
|
log.Printf("[%2d ->] Binary row data packet", my.seq-1)
|
|
}
|
|
// First byte was readed by getResult
|
|
|
|
null_bitmap := make([]byte, (res.field_count+7+2)>>3)
|
|
pr.readFull(null_bitmap)
|
|
|
|
for ii, field := range res.fields {
|
|
null_byte := (ii + 2) >> 3
|
|
null_mask := byte(1) << uint(2+ii-(null_byte<<3))
|
|
if null_bitmap[null_byte]&null_mask != 0 {
|
|
// Null field
|
|
row[ii] = nil
|
|
continue
|
|
}
|
|
unsigned := (field.Flags & _FLAG_UNSIGNED) != 0
|
|
if my.narrowTypeSet {
|
|
row[ii] = readValueNarrow(pr, field.Type, unsigned)
|
|
} else {
|
|
row[ii] = readValue(pr, field.Type, unsigned)
|
|
}
|
|
}
|
|
}
|
|
|
|
func readValue(pr *pktReader, typ byte, unsigned bool) interface{} {
|
|
switch typ {
|
|
case MYSQL_TYPE_STRING, MYSQL_TYPE_VAR_STRING, MYSQL_TYPE_VARCHAR,
|
|
MYSQL_TYPE_BIT, MYSQL_TYPE_BLOB, MYSQL_TYPE_TINY_BLOB,
|
|
MYSQL_TYPE_MEDIUM_BLOB, MYSQL_TYPE_LONG_BLOB, MYSQL_TYPE_SET,
|
|
MYSQL_TYPE_ENUM, MYSQL_TYPE_GEOMETRY:
|
|
return pr.readBin()
|
|
case MYSQL_TYPE_TINY:
|
|
if unsigned {
|
|
return pr.readByte()
|
|
} else {
|
|
return int8(pr.readByte())
|
|
}
|
|
case MYSQL_TYPE_SHORT, MYSQL_TYPE_YEAR:
|
|
if unsigned {
|
|
return pr.readU16()
|
|
} else {
|
|
return int16(pr.readU16())
|
|
}
|
|
case MYSQL_TYPE_LONG, MYSQL_TYPE_INT24:
|
|
if unsigned {
|
|
return pr.readU32()
|
|
} else {
|
|
return int32(pr.readU32())
|
|
}
|
|
case MYSQL_TYPE_LONGLONG:
|
|
if unsigned {
|
|
return pr.readU64()
|
|
} else {
|
|
return int64(pr.readU64())
|
|
}
|
|
case MYSQL_TYPE_FLOAT:
|
|
return math.Float32frombits(pr.readU32())
|
|
case MYSQL_TYPE_DOUBLE:
|
|
return math.Float64frombits(pr.readU64())
|
|
case MYSQL_TYPE_DECIMAL, MYSQL_TYPE_NEWDECIMAL:
|
|
dec := string(pr.readBin())
|
|
r, err := strconv.ParseFloat(dec, 64)
|
|
if err != nil {
|
|
panic(errors.New("MySQL server returned wrong decimal value: " + dec))
|
|
}
|
|
return r
|
|
case MYSQL_TYPE_DATE, MYSQL_TYPE_NEWDATE:
|
|
return pr.readDate()
|
|
case MYSQL_TYPE_DATETIME, MYSQL_TYPE_TIMESTAMP:
|
|
return pr.readTime()
|
|
case MYSQL_TYPE_TIME:
|
|
return pr.readDuration()
|
|
}
|
|
panic(mysql.ErrUnkMySQLType)
|
|
}
|
|
|
|
func readValueNarrow(pr *pktReader, typ byte, unsigned bool) interface{} {
|
|
switch typ {
|
|
case MYSQL_TYPE_STRING, MYSQL_TYPE_VAR_STRING, MYSQL_TYPE_VARCHAR,
|
|
MYSQL_TYPE_BIT, MYSQL_TYPE_BLOB, MYSQL_TYPE_TINY_BLOB,
|
|
MYSQL_TYPE_MEDIUM_BLOB, MYSQL_TYPE_LONG_BLOB, MYSQL_TYPE_SET,
|
|
MYSQL_TYPE_ENUM, MYSQL_TYPE_GEOMETRY:
|
|
return pr.readBin()
|
|
case MYSQL_TYPE_TINY:
|
|
if unsigned {
|
|
return int64(pr.readByte())
|
|
}
|
|
return int64(int8(pr.readByte()))
|
|
case MYSQL_TYPE_SHORT, MYSQL_TYPE_YEAR:
|
|
if unsigned {
|
|
return int64(pr.readU16())
|
|
}
|
|
return int64(int16(pr.readU16()))
|
|
case MYSQL_TYPE_LONG, MYSQL_TYPE_INT24:
|
|
if unsigned {
|
|
return int64(pr.readU32())
|
|
}
|
|
return int64(int32(pr.readU32()))
|
|
case MYSQL_TYPE_LONGLONG:
|
|
v := pr.readU64()
|
|
if unsigned && v > math.MaxInt64 {
|
|
panic(errors.New("Value to large for int64 type"))
|
|
}
|
|
return int64(v)
|
|
case MYSQL_TYPE_FLOAT:
|
|
return float64(math.Float32frombits(pr.readU32()))
|
|
case MYSQL_TYPE_DOUBLE:
|
|
return math.Float64frombits(pr.readU64())
|
|
case MYSQL_TYPE_DECIMAL, MYSQL_TYPE_NEWDECIMAL:
|
|
dec := string(pr.readBin())
|
|
r, err := strconv.ParseFloat(dec, 64)
|
|
if err != nil {
|
|
panic("MySQL server returned wrong decimal value: " + dec)
|
|
}
|
|
return r
|
|
case MYSQL_TYPE_DATETIME, MYSQL_TYPE_TIMESTAMP, MYSQL_TYPE_DATE, MYSQL_TYPE_NEWDATE:
|
|
return pr.readTime()
|
|
case MYSQL_TYPE_TIME:
|
|
return int64(pr.readDuration())
|
|
}
|
|
panic(mysql.ErrUnkMySQLType)
|
|
}
|