Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement TextProtocol for unprepared statements without arguments #16

Merged
merged 1 commit into from
Dec 13, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 22 additions & 13 deletions spec/driver_spec.cr
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ def with_db(&block : DB::Database ->)
DB.open "mysql://root@localhost", &block
end

def with_test_db(&block : DB::Database ->)
def with_test_db(options = "", &block : DB::Database ->)
DB.open "mysql://root@localhost" do |db|
db.exec "DROP DATABASE IF EXISTS crystal_mysql_test"
db.exec "CREATE DATABASE crystal_mysql_test"
DB.open "mysql://root@localhost/crystal_mysql_test", &block
DB.open "mysql://root@localhost/crystal_mysql_test?#{options}", &block
db.exec "DROP DATABASE IF EXISTS crystal_mysql_test"
end
end
Expand Down Expand Up @@ -78,9 +78,10 @@ describe Driver do
end

# "SELECT 1" returns a Int64. So this test are not to be used as is on all DB::Any
{% for prepared_statements in [true, false] %}
{% for value in [1_i64, "hello", 1.5] %}
it "executes and select {{value.id}}" do
with_db do |db|
with_test_db "prepared_statements=#{{{prepared_statements}}}" do |db|
db.scalar("select #{sql({{value}})}").should eq({{value}})

db.query "select #{sql({{value}})}" do |rs|
Expand All @@ -89,6 +90,7 @@ describe Driver do
end
end
{% end %}
{% end %}

it "executes with bind nil" do
with_db do |db|
Expand All @@ -97,15 +99,17 @@ describe Driver do
end

{% for value in [54_i16, 1_i8, 5_i8, 1, 1_i64, "hello", 1.5, 1.5_f32] %}
{% for prepared_statements in [true, false] %}
it "executes and select nil as type of {{value.id}}" do
with_db do |db|
with_test_db "prepared_statements=#{{{prepared_statements}}}" do |db|
db.scalar("select null").should be_nil

db.query "select null" do |rs|
assert_single_read rs, typeof({{value}} || nil), nil
end
end
end
{% end %}

it "executes with bind {{value.id}}" do
with_db do |db|
Expand Down Expand Up @@ -134,8 +138,9 @@ describe Driver do
end
end

{% for prepared_statements in [true, false] %}
it "executes and selects blob" do
with_test_db do |db|
with_test_db "prepared_statements=#{{{prepared_statements}}}" do |db|
db.exec "create table t1 (b1 BLOB)"
db.exec "insert into t1 (b1) values (X'415A617A')"
slice = db.scalar(%(select b1 from t1)).as(Bytes)
Expand All @@ -150,11 +155,12 @@ describe Driver do
{"type" => "LONGBLOB", "size" => 1000000},
].each do |row|
it "set/get " + row["type"].as(String) do
with_test_db do |db|
with_test_db "prepared_statements=#{{{prepared_statements}}}" do |db|
ary = UInt8[0x41, 0x5A, 0x61, 0x7A] * row["size"].as(Int32)
slice = Bytes.new(ary.to_unsafe, ary.size)
db.exec "create table t1 (b1 " + row["type"].as(String) + ")"
db.exec "insert into t1 (b1) values (?)", slice
# TODO remove when unprepared statements support args
db.prepared.exec "insert into t1 (b1) values (?)", slice
slice = db.scalar(%(select b1 from t1)).as(Bytes)
slice.to_a.should eq(ary)
end
Expand All @@ -168,18 +174,19 @@ describe Driver do
{"type" => "LONGTEXT", "size" => 100000},
].each do |row|
it "set/get " + row["type"].as(String) do
with_test_db do |db|
with_test_db "prepared_statements=#{{{prepared_statements}}}" do |db|
txt = "Ham Sandwich" * row["size"].as(Int32)
db.exec "create table tab1 (txt1 " + row["type"].as(String) + ")"
db.exec "insert into tab1 (txt1) values (?)", txt
# TODO remove when unprepared statements support args
db.prepared.exec "insert into tab1 (txt1) values (?)", txt
text = db.scalar(%(select txt1 from tab1))
text.should eq(txt)
end
end
end

it "gets column count" do
with_test_db do |db|
with_test_db "prepared_statements=#{{{prepared_statements}}}" do |db|
db.exec "create table person (name varchar(25), age integer)"
db.query "select * from person" do |rs|
rs.column_count.should eq(2)
Expand All @@ -188,7 +195,7 @@ describe Driver do
end

it "gets column name" do
with_test_db do |db|
with_test_db "prepared_statements=#{{{prepared_statements}}}" do |db|
db.exec "create table person (name varchar(25), age integer)"

db.query "select * from person" do |rs|
Expand All @@ -199,7 +206,7 @@ describe Driver do
end

it "gets last insert row id" do
with_test_db do |db|
with_test_db "prepared_statements=#{{{prepared_statements}}}" do |db|
db.exec "create table person (id int not null primary key auto_increment, name varchar(25), age int)"
db.exec %(insert into person (name, age) values ("foo", 10))
res = db.exec %(insert into person (name, age) values ("foo", 10))
Expand All @@ -210,9 +217,10 @@ describe Driver do

{% for value in [54_i16, 1_i8, 5_i8, 1, 1_i64, "hello", 1.5, 1.5_f32] %}
it "insert/get value {{value.id}} from table" do
with_test_db do |db|
with_test_db "prepared_statements=#{{{prepared_statements}}}" do |db|
db.exec "create table table1 (col1 #{mysql_type_for({{value}})})"
db.exec %(insert into table1 (col1) values (#{sql({{value}})}))

db.scalar("select col1 from table1").should eq({{value}})
end
end
Expand All @@ -226,6 +234,7 @@ describe Driver do
end
end
{% end %}
{% end %}

# zero dates http://dev.mysql.com/doc/refman/5.7/en/datetime.html - work on some mysql not others,
# NO_ZERO_IN_DATE enabled as part of strict mode in MySQL 5.7.8. - http://dev.mysql.com/doc/refman/5.7/en/sql-mode.html#sql-mode-changes
Expand Down
6 changes: 5 additions & 1 deletion src/mysql/connection.cr
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,11 @@ class MySql::Connection < DB::Connection
end
end

def build_statement(query)
def build_prepared_statement(query)
MySql::Statement.new(self, query)
end

def build_unprepared_statement(query)
MySql::UnpreparedStatement.new(self, query)
end
end
84 changes: 84 additions & 0 deletions src/mysql/text_result_set.cr
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Implementation of ProtocolText::Resultset.
# Used for unprepared statements.
class MySql::TextResultSet < DB::ResultSet
getter columns

@conn : MySql::Connection
@row_packet : MySql::ReadPacket?
@header : UInt8

def initialize(statement, column_count)
super(statement)
@conn = statement.connection.as(MySql::Connection)

columns = @columns = [] of ColumnSpec
@conn.read_column_definitions(columns, column_count)

@column_index = 0 # next column index to return

@header = 0u8
@eof_reached = false
end

def do_close
super

while move_next
end

if row_packet = @row_packet
row_packet.discard
end
end

def move_next : Bool
return false if @eof_reached

# skip previous row_packet
if row_packet = @row_packet
row_packet.discard
end

@row_packet = row_packet = @conn.build_read_packet

@header = row_packet.read_byte!
if @header == 0xfe # EOF
@eof_reached = true
return false
end

@column_index = 0
# TODO remove row_packet.read(@null_bitmap_slice)
return true
end

def column_count : Int32
@columns.size
end

def column_name(index : Int32) : String
@columns[index].name
end

def read
row_packet = @row_packet.not_nil!

is_nil = @header == 0xfb
col = @column_index
@column_index += 1
if is_nil
nil
else
length = row_packet.read_lenenc_int(@header)
val = row_packet.read_string(length)
val = @columns[col].column_type.parse(val)

# http://dev.mysql.com/doc/internals/en/character-set.html
if val.is_a?(Slice(UInt8)) && @columns[col].character_set != 63
::String.new(val)
else
val
end
end
end
end
30 changes: 30 additions & 0 deletions src/mysql/types.cr
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,12 @@ abstract struct MySql::Type
raise "not supported read"
end

# Parse from str a value in TextProtocol format of the type
# specified by self.
def self.parse(str : ::String)
raise "not supported"
end

macro decl_type(name, value, db_any_type = nil)
struct {{name}} < Type
@@hex_value = {{value}}
Expand All @@ -92,6 +98,10 @@ abstract struct MySql::Type
def self.read(packet)
packet.read_bytes {{db_any_type}}, IO::ByteFormat::LittleEndian
end

def self.parse(str : ::String)
{{db_any_type}}.new(str)
end
{% end %}

{{yield}}
Expand All @@ -110,6 +120,10 @@ abstract struct MySql::Type
def self.read(packet)
nil
end

def self.parse(str : ::String)
nil
end
end
decl_type Timestamp, 0x07u8
decl_type LongLong, 0x08u8, ::Int64
Expand All @@ -135,6 +149,10 @@ abstract struct MySql::Type
ms = packet.read_int.to_i32 / 1000 # returns microseconds, time only supports milliseconds
return ::Time.new(year, month, day, hour, minute, second, ms)
end

def self.parse(str : ::String)
raise "TextProtocol::Time not implemented"
end
end
decl_type Year, 0x0du8
decl_type VarChar, 0x0fu8
Expand All @@ -157,6 +175,10 @@ abstract struct MySql::Type
def self.read(packet)
packet.read_blob
end

def self.parse(str : ::String)
str.to_slice
end
end
decl_type VarString, 0xfdu8, ::String do
def self.write(packet, v : ::String)
Expand All @@ -166,6 +188,10 @@ abstract struct MySql::Type
def self.read(packet)
packet.read_lenenc_string
end

def self.parse(str : ::String)
str
end
end
decl_type String, 0xfeu8, ::String do
def self.write(packet, v : ::String)
Expand All @@ -175,6 +201,10 @@ abstract struct MySql::Type
def self.read(packet)
packet.read_lenenc_string
end

def self.parse(str : ::String)
str
end
end
decl_type Geometry, 0xffu8
end
41 changes: 41 additions & 0 deletions src/mysql/unprepared_statement.cr
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
class MySql::UnpreparedStatement < DB::Statement
def initialize(connection, @sql : String)
super(connection)
end

protected def conn
@connection.as(Connection)
end

protected def perform_query(args : Enumerable) : DB::ResultSet
perform_exec_or_query(args).as(DB::ResultSet)
end

protected def perform_exec(args : Enumerable) : DB::ExecResult
perform_exec_or_query(args).as(DB::ExecResult)
end

private def perform_exec_or_query(args : Enumerable)
raise "exec/query with args is not supported" if args.size > 0

conn = self.conn
conn.write_packet do |packet|
packet.write_byte 0x03u8
packet << @sql
# TODO to support args an interpolation needs to be done
end

conn.read_packet do |packet|
case header = packet.read_byte.not_nil!
when 255 # err packet
conn.handle_err_packet(packet)
when 0 # ok packet
affected_rows = packet.read_lenenc_int
last_insert_id = packet.read_lenenc_int
DB::ExecResult.new affected_rows, last_insert_id
else
MySql::TextResultSet.new(self, packet.read_lenenc_int(header))
end
end
end
end