Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for stream_id #104

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .busted
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
return {
default = {
lpath = "./lib/?.lua;./lib/?/init.lua;"
lpath = "./lib/?.lua;./lib/?/init.lua;/lib/util/?.lua"
}
}
9 changes: 7 additions & 2 deletions lib/cassandra/cql.lua
Original file line number Diff line number Diff line change
Expand Up @@ -947,10 +947,15 @@ do

header:write_byte(flags)

local stream_id = 0
if self.opts and self.opts.stream_id then
stream_id = self.opts.stream_id
end

if version < 3 then
header:write_byte(0) -- stream_id
header:write_byte(stream_id)
else
header:write_short(0) -- stream_id
header:write_short(stream_id)
end

header:write_byte(self.op_code)
Expand Down
81 changes: 61 additions & 20 deletions lib/cassandra/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

local socket = require 'cassandra.socket'
local cql = require 'cassandra.cql'
local deque = require 'util.deque'

local setmetatable = setmetatable
local requests = cql.requests
Expand Down Expand Up @@ -124,18 +125,29 @@ function _Host.new(opts)
local sock, err = socket.tcp()
if err then return nil, err end

local protocol_version = opts.protocol_version or cql.def_protocol_version

-- Initialize stream_ids deque
local max_id = protocol_version < 3 and 2^7-1 or 2^15-1
local stream_ids = deque.new()

for i=1,max_id do
deque.pushright(stream_ids, i)
end

local host = {
sock = sock,
host = opts.host or '127.0.0.1',
port = opts.port or 9042,
keyspace = opts.keyspace,
protocol_version = opts.protocol_version or cql.def_protocol_version,
protocol_version = protocol_version,
ssl = opts.ssl,
verify = opts.verify,
cert = opts.cert,
cafile = opts.cafile,
key = opts.key,
auth = opts.auth
auth = opts.auth,
stream_ids = stream_ids,
}

return setmetatable(host, _Host)
Expand All @@ -146,32 +158,61 @@ function _Host:send(request)
return nil, 'no socket created'
end

-- set stream_id
local stream_id, err = deque.popleft(self.stream_ids)
if err == nil then
if request.opts then
request.opts.stream_id = stream_id
else
request.opts = {stream_id = stream_id}
end
end

local frame = request:build_frame(self.protocol_version)
local sent, err = self.sock:send(frame)
if not sent then return nil, err end
if not sent then
deque.pushright(self.stream_ids, stream_id)
return nil, err
end

-- receive frame version byte
local v_byte, err = self.sock:receive(1)
if not v_byte then return nil, err end
while true do
-- receive frame version byte
local v_byte, err = self.sock:receive(1)
if not v_byte then
deque.pushright(self.stream_ids, stream_id)
return nil, err
end

-- -1 because of the v_byte we just read
local version, n_bytes = cql.frame_reader.version(v_byte)
-- -1 because of the v_byte we just read
local version, n_bytes = cql.frame_reader.version(v_byte)

-- receive frame header
local header_bytes, err = self.sock:receive(n_bytes)
if not header_bytes then return nil, err end
-- receive frame header
local header_bytes, err = self.sock:receive(n_bytes)
if not header_bytes then
deque.pushright(self.stream_ids, stream_id)
return nil, err
end

local header = cql.frame_reader.read_header(version, header_bytes)
local header = cql.frame_reader.read_header(version, header_bytes)

-- receive frame body
local body_bytes
if header.body_length > 0 then
body_bytes, err = self.sock:receive(header.body_length)
if not body_bytes then return nil, err end
end
-- receive frame body
local body_bytes
if header.body_length > 0 then
body_bytes, err = self.sock:receive(header.body_length)
if not body_bytes then
deque.pushright(self.stream_ids, stream_id)
return nil, err
end
end

-- res, err, cql_err_code
return cql.frame_reader.read_body(header, body_bytes)
-- If stream_id was set in request.opts, only return a response
-- with a matching stream_id and drop everything else
if stream_id and stream_id == header.stream_id then
-- res, err, cql_err_code
deque.pushright(self.stream_ids, stream_id)
return cql.frame_reader.read_body(header, body_bytes)
end
end
end

local function send_startup(self)
Expand Down
46 changes: 46 additions & 0 deletions lib/util/deque.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
--[[
Deque implementation, taken from the "Programming in Lua" book.
http://www.lua.org/pil/11.4.html
--]]

local _List = {}

function _List.new ()
return {first = 0, last = -1 }
end

function _List.pushleft (list, value)
local first = list.first - 1
list.first = first
list[first] = value
end

function _List.pushright (list, value)
local last = list.last + 1
list.last = last
list[last] = value
end

function _List.popleft (list)
local first = list.first
if first > list.last then return nil, "list is empty" end
local value = list[first]
list[first] = nil -- to allow garbage collection
list.first = first + 1
return value, nil
end

function _List.popright (list)
local last = list.last
if list.first > last then return nil, "list is empty" end
local value = list[last]
list[last] = nil -- to allow garbage collection
list.last = last - 1
return value, nil
end

function _List.length (list)
return list.last - list.first + 1
end

return _List
2 changes: 2 additions & 0 deletions lua-cassandra-1.2.3-0.rockspec
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ build = {
["cassandra.auth"] = "lib/cassandra/auth.lua",
["cassandra.socket"] = "lib/cassandra/socket.lua",

["util.deque"] = "lib/util/deque.lua",

["resty.cassandra.cluster"] = "lib/resty/cassandra/cluster.lua",
["resty.cassandra.policies.lb"] = "lib/resty/cassandra/policies/lb/init.lua",
["resty.cassandra.policies.lb.rr"] = "lib/resty/cassandra/policies/lb/rr.lua",
Expand Down
2 changes: 2 additions & 0 deletions lua-cassandra-dev-0.rockspec
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ build = {
["cassandra.auth"] = "lib/cassandra/auth.lua",
["cassandra.socket"] = "lib/cassandra/socket.lua",

["util.deque"] = "lib/util/deque.lua",

["resty.cassandra.cluster"] = "lib/resty/cassandra/cluster.lua",
["resty.cassandra.policies.lb"] = "lib/resty/cassandra/policies/lb/init.lua",
["resty.cassandra.policies.lb.rr"] = "lib/resty/cassandra/policies/lb/rr.lua",
Expand Down
9 changes: 9 additions & 0 deletions spec/01-unit/03-cql_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ for protocol_version = 2, 3 do

describe("CQL requests", function()
local requests = cql.requests
local consistencies = cql.consistencies

it("sanity", function()
local r = requests.query.new("SELECT * FROM peers")
Expand All @@ -190,6 +191,14 @@ for protocol_version = 2, 3 do
local frame3 = r:build_frame(protocol_version)
assert.matches("SELECT key FROM local", frame3, nil, true)
end)
it("sets the stream_id if provided", function()
local r = requests.query.new("SELECT * FROM local")
r.opts = {stream_id = 255, consistency = consistencies.one}
local frame = r:build_frame(protocol_version)

local header = cql.frame_reader.read_header(protocol_version, string.sub(frame, 2, -1))
assert.equal(255, header.stream_id)
end)
end)

describe("execute_prepared", function()
Expand Down
111 changes: 111 additions & 0 deletions spec/01-unit/04-host_spec.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
local cassandra = require "cassandra"
local cql = require "cassandra.cql"
local deque = require "util.deque"

describe("_Host", function()
describe("new", function()
it("sets stream_ids to the right length", function()
local host_v2, err = cassandra.new({protocol_version = 2})
assert.is_nil(err)
assert.are.equal(2^7-1, deque.length(host_v2.stream_ids))

local host_v2, err = cassandra.new({protocol_version = 3})
assert.is_nil(err)
assert.are.equal(2^15-1, deque.length(host_v2.stream_ids))
end)
end)

describe("send", function()

local function mock_request()
local r = cql.requests.startup.new()
return mock(r)
end

local function mock_host()
local host, err = cassandra.new()
assert.is_nil(err)
stub(host.sock, "send")
stub(host.sock, "receive")
return host
end

it("sets stream_id without overriding existing opts", function()
local req = mock_request()
local host = mock_host()
req.opts = {custom = "option"}

local _, err = host:send(req)
assert.is_nil(err)
assert.are.same({custom = "option", stream_id = 1}, req.opts)
end)

it("sets stream_id if there are no existing opts", function()
local req = mock_request()
local host = mock_host()

local _, err = host:send(req)
assert.is_nil(err)
assert.are.same({stream_id = 1}, req.opts)
end)

it("doesn't crash if there are no stream_ids left", function()
local req = mock_request()
local host = mock_host()
host.stream_ids["last"] = host.stream_ids["first"] - 1

local _, err = host:send(req)
assert.is_nil(err)
assert.is_nil(req.opts)
end)

it("puts stream_id back if send fails", function()
local req = mock_request()
local host = mock_host()
host.sock.send = function() return nil, "send failure" end

local _, err = host:send(req)
assert.are.equal("send failure", err)
assert.are.equal(1, deque.popright(host.stream_ids))
end)

it("puts stream_id back if receive fails", function()
local req = mock_request()
local host = mock_host()
host.sock.send = function() return true, nil end
host.sock.receive = function() return nil, "receive failure" end

local _, err = host:send(req)
assert.are.equal("receive failure", err)
assert.are.equal(1, deque.popright(host.stream_ids))
end)

it("retries if response stream_id doesn't match", function()
local req = mock_request()
local host = mock_host()
host.sock.send = function() return true, nil end
host.sock.receive = function() return "foobar", nil end
local stream_id = 4
local read_header_count = 0

cql.frame_reader = {
version = function(_) return 3 end,
read_header = function(_)
read_header_count = read_header_count + 1
stream_id = stream_id - 1
return {stream_id = stream_id, body_length = 0}
end,
read_body = function(_, _) return "body" end
}

local res, err = host:send(req)
assert.is_nil(err)
-- The first 2 times the stream_id doesn't match (3 & 2)
-- The third time is 1 so the function should exit correctly
assert.are.equal(3, read_header_count)
-- The stream_id should be back in the deque
assert.are.equal(1, deque.popright(host.stream_ids))
assert.are.equal("body", res)
end)
end)
end)