diff --git a/Gemfile.lock b/Gemfile.lock index 617d93f..63619f7 100644 --- a/Gemfile.lock +++ b/Gemfile.lock @@ -1,7 +1,7 @@ PATH remote: . specs: - ruby_snowflake_client (1.0.2) + ruby_snowflake_client (1.1.0) GEM remote: https://rubygems.org/ diff --git a/ext/c-decl.go b/ext/c-decl.go index 1db436e..f3126e9 100644 --- a/ext/c-decl.go +++ b/ext/c-decl.go @@ -29,9 +29,9 @@ func goobj_log(obj unsafe.Pointer) { } //export goobj_retain -func goobj_retain(obj unsafe.Pointer) { +func goobj_retain(obj unsafe.Pointer, x *C.char) { if LOG_LEVEL > 0 { - fmt.Printf("retain obj %v - currently keeping %d\n", obj, len(objects)) + fmt.Printf("retain obj [%v] %v - currently keeping %d\n", C.GoString(x), obj, len(objects)) } objects[obj] = true marked[obj] = 0 diff --git a/ext/result.go b/ext/result.go index f2f86a3..46bd56a 100644 --- a/ext/result.go +++ b/ext/result.go @@ -11,11 +11,8 @@ VALUE funcall0param(VALUE obj, ID id); import "C" import ( - "errors" "fmt" - "io" "math/big" - "strings" "time" gopointer "github.com/mattn/go-pointer" @@ -28,18 +25,44 @@ func wrapRbRaise(err error) { } func getResultStruct(self C.VALUE) *SnowflakeResult { - ivar := C.rb_ivar_get(self, RESULT_IDENTIFIER) + return resultMap[self] +} - str := GetGoStruct(ivar) - ptr := gopointer.Restore(str) - sr, ok := ptr.(*SnowflakeResult) - if !ok || sr.rows == nil { - err := errors.New("Empty result; please run a query via `client.fetch(\"SQL\")`") - wrapRbRaise(err) - return nil +//export GetRowsNoEnum +func GetRowsNoEnum(self C.VALUE) C.VALUE { + res := getResultStruct(self) + rows := res.rows + + i := 0 + t1 := time.Now() + var arr []C.VALUE + + for rows.Next() { + if i%5000 == 0 { + if LOG_LEVEL > 0 { + fmt.Println("scanning row: ", i) + } + } + x := res.ScanNextRow(false) + objects[x] = true + gopointer.Save(x) + if LOG_LEVEL > 1 { + // This is VERY noisy + fmt.Printf("alloced %v\n", &x) + } + arr = append(arr, x) + i = i + 1 + } + if LOG_LEVEL > 0 { + fmt.Printf("done with rows.next: %s\n", time.Now().Sub(t1)) } - return sr + rbArr := C.rb_ary_new2(C.long(len(arr))) + for idx, elem := range arr { + C.rb_ary_store(rbArr, C.long(idx), elem) + } + + return rbArr } //export GetRows @@ -69,11 +92,6 @@ func GetRows(self C.VALUE) C.VALUE { fmt.Printf("done with rows.next: %s\n", time.Now().Sub(t1)) } - //empty for GC - res.rows = nil - res.keptHash = C.Qnil - res.cols = []C.VALUE{} - return self } @@ -89,10 +107,6 @@ func ObjNextRow(self C.VALUE) C.VALUE { if rows.Next() { r := res.ScanNextRow(false) return r - } else if rows.Err() == io.EOF { - res.rows = nil // free up for gc - res.keptHash = C.Qnil // free up for gc - res.cols = []C.VALUE{} } return C.Qnil } @@ -104,8 +118,8 @@ func (res SnowflakeResult) ScanNextRow(debug bool) C.VALUE { fmt.Printf("column types: %+v; %+v\n", cts[0], cts[0].ScanType()) } - rawResult := make([]any, len(res.cols)) - rawData := make([]any, len(res.cols)) + rawResult := make([]any, len(res.columns)) + rawData := make([]any, len(res.columns)) for i := range rawResult { rawData[i] = &rawResult[i] } @@ -117,10 +131,15 @@ func (res SnowflakeResult) ScanNextRow(debug bool) C.VALUE { } // trick from postgres; keep hash: pg_result.c:1088 - hash := C.rb_hash_dup(res.keptHash) + //hash := C.rb_hash_dup(res.keptHash) + hash := C.rb_hash_new() + if LOG_LEVEL > 1 { + // This is very noisy + fmt.Println("alloc'ed new hash", &hash) + } + for idx, raw := range rawResult { raw := raw - col_name := res.cols[idx] var rbVal C.VALUE @@ -151,40 +170,12 @@ func (res SnowflakeResult) ScanNextRow(debug bool) C.VALUE { wrapRbRaise(err) } } - C.rb_hash_aset(hash, col_name, rbVal) - } - return hash -} - -func SafeMakeHash(lenght int, cols []C.VALUE) C.VALUE { - var hash C.VALUE - hash = C.rb_hash_new() - - if LOG_LEVEL > 0 { - fmt.Println("starting make hash") - } - for _, col := range cols { - C.rb_hash_aset(hash, col, C.Qnil) - } - if LOG_LEVEL > 0 { - fmt.Println("end make hash", hash) + colstr := C.rb_str_new2(C.CString(res.columns[idx])) + if LOG_LEVEL > 1 { + // This is very noisy + fmt.Printf("alloc string: %+v; rubyVal: %+v\n", &colstr, &rbVal) + } + C.rb_hash_aset(hash, colstr, rbVal) } return hash } - -func (res *SnowflakeResult) Initialize() { - columns, _ := res.rows.Columns() - rbArr := C.rb_ary_new2(C.long(len(columns))) - - cols := make([]C.VALUE, len(columns)) - for idx, colName := range columns { - str := strings.ToLower(colName) - sym := C.rb_str_new2(C.CString(str)) - sym = C.rb_str_freeze(sym) - cols[idx] = sym - C.rb_ary_store(rbArr, C.long(idx), sym) - } - - res.cols = cols - res.keptHash = SafeMakeHash(len(columns), cols) -} diff --git a/ext/ruby_snowflake.go b/ext/ruby_snowflake.go index 67bc92a..c5cf388 100644 --- a/ext/ruby_snowflake.go +++ b/ext/ruby_snowflake.go @@ -8,8 +8,9 @@ VALUE ObjFetch(VALUE,VALUE); VALUE ObjNextRow(VALUE); VALUE Inspect(VALUE); VALUE GetRows(VALUE); +VALUE GetRowsNoEnum(VALUE); -VALUE NewGoStruct(VALUE klass, void *p); +VALUE NewGoStruct(VALUE klass, char* reason, void *p); VALUE GoRetEnum(VALUE,int,VALUE); void* GetGoStruct(VALUE obj); void RbGcGuard(VALUE ptr); @@ -21,18 +22,18 @@ import "C" import ( "context" "database/sql" - "errors" "fmt" + "strings" "time" - gopointer "github.com/mattn/go-pointer" sf "github.com/snowflakedb/gosnowflake" ) type SnowflakeResult struct { - rows *sql.Rows - keptHash C.VALUE - cols []C.VALUE + rows *sql.Rows + //keptHash C.VALUE + columns []string + //cols []C.VALUE } type SnowflakeClient struct { db *sql.DB @@ -42,12 +43,13 @@ var rbSnowflakeClientClass C.VALUE var rbSnowflakeResultClass C.VALUE var rbSnowflakeModule C.VALUE -var DB_IDENTIFIER = C.rb_intern(C.CString("db")) var RESULT_IDENTIFIER = C.rb_intern(C.CString("rows")) var RESULT_DURATION = C.rb_intern(C.CString("@query_duration")) var ERROR_IDENT = C.rb_intern(C.CString("@error")) var objects = make(map[interface{}]bool) +var resultMap = make(map[C.VALUE]*SnowflakeResult) +var clientRef = make(map[C.VALUE]*SnowflakeClient) var LOG_LEVEL = 0 var empty C.VALUE = C.Qnil @@ -78,13 +80,7 @@ func Connect(self C.VALUE, account C.VALUE, warehouse C.VALUE, database C.VALUE, C.rb_ivar_set(self, ERROR_IDENT, RbString(errStr)) } rs := SnowflakeClient{db} - ptr := gopointer.Save(&rs) - rbStruct := C.NewGoStruct( - rbSnowflakeClientClass, - ptr, - ) - - C.rb_ivar_set(self, DB_IDENTIFIER, rbStruct) + clientRef[self] = &rs } func (x SnowflakeClient) Fetch(statement C.VALUE) C.VALUE { @@ -113,46 +109,28 @@ func (x SnowflakeClient) Fetch(statement C.VALUE) C.VALUE { } result := C.rb_class_new_instance(0, &empty, rbSnowflakeResultClass) - rs := SnowflakeResult{rows, C.Qnil, []C.VALUE{}} - rs.Initialize() - ptr := gopointer.Save(&rs) - rbStruct := C.NewGoStruct( - rbSnowflakeClientClass, - ptr, - ) - C.RbGcGuard(rbStruct) - C.RbGcGuard(rbSnowflakeResultClass) - C.rb_ivar_set(result, RESULT_IDENTIFIER, rbStruct) + cols, _ := rows.Columns() + for idx, col := range cols { + col := col + cols[idx] = strings.ToLower(col) + } + rs := SnowflakeResult{rows, cols} + resultMap[result] = &rs C.rb_ivar_set(result, RESULT_DURATION, RbNumFromDouble(C.double(duration))) return result } //export ObjFetch func ObjFetch(self C.VALUE, statement C.VALUE) C.VALUE { - var q C.VALUE - q = C.rb_ivar_get(self, DB_IDENTIFIER) - - req := C.GetGoStruct(q) - f := gopointer.Restore(req) - x, ok := f.(*SnowflakeClient) - if !ok { - wrapRbRaise((errors.New("cannot convert SnowflakeClient pointer in ObjFetch"))) - } + x, _ := clientRef[self] return x.Fetch(statement) } //export Inspect func Inspect(self C.VALUE) C.VALUE { - q := C.rb_ivar_get(self, DB_IDENTIFIER) - if q == C.Qnil { - return RbString("Object is not instantiated") - } - - req := C.GetGoStruct(q) - f := gopointer.Restore(req) - x := f.(*SnowflakeClient) - return RbString(fmt.Sprintf("%+v", x)) + x := clientRef[self] + return RbString(fmt.Sprintf("Snowflake::Client <%+v>", x)) } //export Init_ruby_snowflake_client_ext @@ -161,10 +139,20 @@ func Init_ruby_snowflake_client_ext() { rbSnowflakeClientClass = C.rb_define_class_under(rbSnowflakeModule, C.CString("Client"), C.rb_cObject) rbSnowflakeResultClass = C.rb_define_class_under(rbSnowflakeModule, C.CString("Result"), C.rb_cObject) + objects[rbSnowflakeResultClass] = true + objects[rbSnowflakeClientClass] = true + objects[rbSnowflakeModule] = true + objects[RESULT_DURATION] = true + objects[ERROR_IDENT] = true + C.RbGcGuard(RESULT_DURATION) + //C.RbGcGuard(RESULT_IDENTIFIER) + C.RbGcGuard(ERROR_IDENT) + C.rb_define_method(rbSnowflakeResultClass, C.CString("next_row"), (*[0]byte)(C.ObjNextRow), 0) // `get_rows` is private as this can lead to SEGFAULT errors if not invoked // with GC.disable due to undetermined issues caused by the Ruby GC. C.rb_define_private_method(rbSnowflakeResultClass, C.CString("_get_rows"), (*[0]byte)(C.GetRows), 0) + C.rb_define_method(rbSnowflakeResultClass, C.CString("get_rows_no_enum"), (*[0]byte)(C.GetRowsNoEnum), 0) C.rb_define_private_method(rbSnowflakeClientClass, C.CString("_connect"), (*[0]byte)(C.Connect), 7) C.rb_define_method(rbSnowflakeClientClass, C.CString("inspect"), (*[0]byte)(C.Inspect), 0) diff --git a/ext/wrapper.go b/ext/wrapper.go index edf245c..d91e8dc 100644 --- a/ext/wrapper.go +++ b/ext/wrapper.go @@ -24,7 +24,7 @@ VALUE RbNumFromLong(long v) { return LONG2NUM(v); } -void goobj_retain(void *); +void goobj_retain(void *, char*); void goobj_free(void *); void goobj_log(void *); void goobj_mark(void *); @@ -42,9 +42,9 @@ static const rb_data_type_t go_type = { }; VALUE -NewGoStruct(VALUE klass, void *p) +NewGoStruct(VALUE klass, char* reason, void *p) { - goobj_retain(p); + goobj_retain(p, reason); return TypedData_Wrap_Struct(klass, &go_type, p); } @@ -125,7 +125,8 @@ func RbString(str string) C.VALUE { if len(str) == 0 { return C.rb_utf8_str_new(nil, C.long(0)) } - cstr := (*C.char)(unsafe.Pointer(&(*(*[]byte)(unsafe.Pointer(&str)))[0])) + //cstr := (*C.char)(unsafe.Pointer(&(*(*[]byte)(unsafe.Pointer(&str)))[0])) + cstr := C.CString(str) return C.rb_utf8_str_new(cstr, C.long(len(str))) } diff --git a/lib/ruby_snowflake_client.rb b/lib/ruby_snowflake_client.rb index b47b64b..5db6a23 100644 --- a/lib/ruby_snowflake_client.rb +++ b/lib/ruby_snowflake_client.rb @@ -2,6 +2,7 @@ module Snowflake require "ruby_snowflake_client_ext" # build bundle of the go files + LOG_LEVEL = 0 class Error < StandardError attr_reader :details @@ -51,13 +52,24 @@ def valid? def get_all_rows(&blk) GC.disable if blk - _get_rows(&blk) + while r = next_row do + yield r + end else - _get_rows.to_a + get_rows_array end ensure GC.enable - GC.start end + + private + def get_rows_array + arr = [] + while r = next_row do + puts "at #{arr.length}" if arr.length % 15000 == 0 && LOG_LEVEL > 0 + arr << r + end + arr + end end end diff --git a/lib/ruby_snowflake_client/version.rb b/lib/ruby_snowflake_client/version.rb index b957161..6ae78b3 100644 --- a/lib/ruby_snowflake_client/version.rb +++ b/lib/ruby_snowflake_client/version.rb @@ -1,3 +1,3 @@ module RubySnowflakeClient - VERSION = '1.0.2' + VERSION = '1.1.0' end diff --git a/spec/snowflake/client_spec.rb b/spec/snowflake/client_spec.rb index 9a2b73b..1105e0b 100644 --- a/spec/snowflake/client_spec.rb +++ b/spec/snowflake/client_spec.rb @@ -160,15 +160,53 @@ end end - context "fetching 150k rows" do + context "fetching 150k rows x 100 times" do let(:limit) { 150_000 } it "should work" do - rows = result.get_all_rows - expect(rows.length).to eq 150000 - expect((-50000...50000)).to include(rows[0]["id"].to_i) + 100.times do |idx| + puts "on #{idx}" + client = described_class.new + client.connect( + account: ENV["SNOWFLAKE_ACCOUNT"], + warehouse: ENV["SNOWFLAKE_WAREHOUSE"], + user: ENV["SNOWFLAKE_USER"], + password: ENV["SNOWFLAKE_PASSWORD"], + ) + result = client.fetch(query) + rows = result.get_all_rows + puts "Done with get all rows" + GC.start + expect(rows.length).to eq 150000 + expect((-50000...50000)).to include(rows[0]["id"].to_i) + end end end - end + context "fetching 150k rows x 10 times - with threads" do + let(:limit) { 150_000 } + it "should work" do + t = [] + 10.times do |idx| + t << Thread.new do + puts "on #{idx}" + client = described_class.new + client.connect( + account: ENV["SNOWFLAKE_ACCOUNT"], + warehouse: ENV["SNOWFLAKE_WAREHOUSE"], + user: ENV["SNOWFLAKE_USER"], + password: ENV["SNOWFLAKE_PASSWORD"], + ) + result = client.fetch(query) + rows = result.get_all_rows + puts "Done with get all rows" + expect(rows.length).to eq 150000 + expect((-50000...50000)).to include(rows[0]["id"].to_i) + end + end + + t.map(&:join) + end + end + end end end