Skip to content
This repository has been archived by the owner on Jan 9, 2024. It is now read-only.

Commit

Permalink
Merge pull request #13 from rinsed-org/fix/more-tries-to-get-rid-of-s…
Browse files Browse the repository at this point in the history
…egfault

More attempts to reduce SEGFAULTs
  • Loading branch information
alexstoick authored Jun 9, 2023
2 parents 97a324e + 760022e commit 700d325
Show file tree
Hide file tree
Showing 8 changed files with 147 additions and 117 deletions.
2 changes: 1 addition & 1 deletion Gemfile.lock
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
PATH
remote: .
specs:
ruby_snowflake_client (1.0.2)
ruby_snowflake_client (1.1.0)

GEM
remote: https://rubygems.org/
Expand Down
4 changes: 2 additions & 2 deletions ext/c-decl.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ func goobj_log(obj unsafe.Pointer) {
}

//export goobj_retain
func goobj_retain(obj unsafe.Pointer) {
func goobj_retain(obj unsafe.Pointer, x *C.char) {
if LOG_LEVEL > 0 {
fmt.Printf("retain obj %v - currently keeping %d\n", obj, len(objects))
fmt.Printf("retain obj [%v] %v - currently keeping %d\n", C.GoString(x), obj, len(objects))
}
objects[obj] = true
marked[obj] = 0
Expand Down
109 changes: 50 additions & 59 deletions ext/result.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,8 @@ VALUE funcall0param(VALUE obj, ID id);
import "C"

import (
"errors"
"fmt"
"io"
"math/big"
"strings"
"time"

gopointer "github.com/mattn/go-pointer"
Expand All @@ -28,18 +25,44 @@ func wrapRbRaise(err error) {
}

func getResultStruct(self C.VALUE) *SnowflakeResult {
ivar := C.rb_ivar_get(self, RESULT_IDENTIFIER)
return resultMap[self]
}

str := GetGoStruct(ivar)
ptr := gopointer.Restore(str)
sr, ok := ptr.(*SnowflakeResult)
if !ok || sr.rows == nil {
err := errors.New("Empty result; please run a query via `client.fetch(\"SQL\")`")
wrapRbRaise(err)
return nil
//export GetRowsNoEnum
func GetRowsNoEnum(self C.VALUE) C.VALUE {
res := getResultStruct(self)
rows := res.rows

i := 0
t1 := time.Now()
var arr []C.VALUE

for rows.Next() {
if i%5000 == 0 {
if LOG_LEVEL > 0 {
fmt.Println("scanning row: ", i)
}
}
x := res.ScanNextRow(false)
objects[x] = true
gopointer.Save(x)
if LOG_LEVEL > 1 {
// This is VERY noisy
fmt.Printf("alloced %v\n", &x)
}
arr = append(arr, x)
i = i + 1
}
if LOG_LEVEL > 0 {
fmt.Printf("done with rows.next: %s\n", time.Now().Sub(t1))
}

return sr
rbArr := C.rb_ary_new2(C.long(len(arr)))
for idx, elem := range arr {
C.rb_ary_store(rbArr, C.long(idx), elem)
}

return rbArr
}

//export GetRows
Expand Down Expand Up @@ -69,11 +92,6 @@ func GetRows(self C.VALUE) C.VALUE {
fmt.Printf("done with rows.next: %s\n", time.Now().Sub(t1))
}

//empty for GC
res.rows = nil
res.keptHash = C.Qnil
res.cols = []C.VALUE{}

return self
}

Expand All @@ -89,10 +107,6 @@ func ObjNextRow(self C.VALUE) C.VALUE {
if rows.Next() {
r := res.ScanNextRow(false)
return r
} else if rows.Err() == io.EOF {
res.rows = nil // free up for gc
res.keptHash = C.Qnil // free up for gc
res.cols = []C.VALUE{}
}
return C.Qnil
}
Expand All @@ -104,8 +118,8 @@ func (res SnowflakeResult) ScanNextRow(debug bool) C.VALUE {
fmt.Printf("column types: %+v; %+v\n", cts[0], cts[0].ScanType())
}

rawResult := make([]any, len(res.cols))
rawData := make([]any, len(res.cols))
rawResult := make([]any, len(res.columns))
rawData := make([]any, len(res.columns))
for i := range rawResult {
rawData[i] = &rawResult[i]
}
Expand All @@ -117,10 +131,15 @@ func (res SnowflakeResult) ScanNextRow(debug bool) C.VALUE {
}

// trick from postgres; keep hash: pg_result.c:1088
hash := C.rb_hash_dup(res.keptHash)
//hash := C.rb_hash_dup(res.keptHash)
hash := C.rb_hash_new()
if LOG_LEVEL > 1 {
// This is very noisy
fmt.Println("alloc'ed new hash", &hash)
}

for idx, raw := range rawResult {
raw := raw
col_name := res.cols[idx]

var rbVal C.VALUE

Expand Down Expand Up @@ -151,40 +170,12 @@ func (res SnowflakeResult) ScanNextRow(debug bool) C.VALUE {
wrapRbRaise(err)
}
}
C.rb_hash_aset(hash, col_name, rbVal)
}
return hash
}

func SafeMakeHash(lenght int, cols []C.VALUE) C.VALUE {
var hash C.VALUE
hash = C.rb_hash_new()

if LOG_LEVEL > 0 {
fmt.Println("starting make hash")
}
for _, col := range cols {
C.rb_hash_aset(hash, col, C.Qnil)
}
if LOG_LEVEL > 0 {
fmt.Println("end make hash", hash)
colstr := C.rb_str_new2(C.CString(res.columns[idx]))
if LOG_LEVEL > 1 {
// This is very noisy
fmt.Printf("alloc string: %+v; rubyVal: %+v\n", &colstr, &rbVal)
}
C.rb_hash_aset(hash, colstr, rbVal)
}
return hash
}

func (res *SnowflakeResult) Initialize() {
columns, _ := res.rows.Columns()
rbArr := C.rb_ary_new2(C.long(len(columns)))

cols := make([]C.VALUE, len(columns))
for idx, colName := range columns {
str := strings.ToLower(colName)
sym := C.rb_str_new2(C.CString(str))
sym = C.rb_str_freeze(sym)
cols[idx] = sym
C.rb_ary_store(rbArr, C.long(idx), sym)
}

res.cols = cols
res.keptHash = SafeMakeHash(len(columns), cols)
}
72 changes: 30 additions & 42 deletions ext/ruby_snowflake.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@ VALUE ObjFetch(VALUE,VALUE);
VALUE ObjNextRow(VALUE);
VALUE Inspect(VALUE);
VALUE GetRows(VALUE);
VALUE GetRowsNoEnum(VALUE);
VALUE NewGoStruct(VALUE klass, void *p);
VALUE NewGoStruct(VALUE klass, char* reason, void *p);
VALUE GoRetEnum(VALUE,int,VALUE);
void* GetGoStruct(VALUE obj);
void RbGcGuard(VALUE ptr);
Expand All @@ -21,18 +22,18 @@ import "C"
import (
"context"
"database/sql"
"errors"
"fmt"
"strings"
"time"

gopointer "github.com/mattn/go-pointer"
sf "github.com/snowflakedb/gosnowflake"
)

type SnowflakeResult struct {
rows *sql.Rows
keptHash C.VALUE
cols []C.VALUE
rows *sql.Rows
//keptHash C.VALUE
columns []string
//cols []C.VALUE
}
type SnowflakeClient struct {
db *sql.DB
Expand All @@ -42,12 +43,13 @@ var rbSnowflakeClientClass C.VALUE
var rbSnowflakeResultClass C.VALUE
var rbSnowflakeModule C.VALUE

var DB_IDENTIFIER = C.rb_intern(C.CString("db"))
var RESULT_IDENTIFIER = C.rb_intern(C.CString("rows"))
var RESULT_DURATION = C.rb_intern(C.CString("@query_duration"))
var ERROR_IDENT = C.rb_intern(C.CString("@error"))

var objects = make(map[interface{}]bool)
var resultMap = make(map[C.VALUE]*SnowflakeResult)
var clientRef = make(map[C.VALUE]*SnowflakeClient)

var LOG_LEVEL = 0
var empty C.VALUE = C.Qnil
Expand Down Expand Up @@ -78,13 +80,7 @@ func Connect(self C.VALUE, account C.VALUE, warehouse C.VALUE, database C.VALUE,
C.rb_ivar_set(self, ERROR_IDENT, RbString(errStr))
}
rs := SnowflakeClient{db}
ptr := gopointer.Save(&rs)
rbStruct := C.NewGoStruct(
rbSnowflakeClientClass,
ptr,
)

C.rb_ivar_set(self, DB_IDENTIFIER, rbStruct)
clientRef[self] = &rs
}

func (x SnowflakeClient) Fetch(statement C.VALUE) C.VALUE {
Expand Down Expand Up @@ -113,46 +109,28 @@ func (x SnowflakeClient) Fetch(statement C.VALUE) C.VALUE {
}

result := C.rb_class_new_instance(0, &empty, rbSnowflakeResultClass)
rs := SnowflakeResult{rows, C.Qnil, []C.VALUE{}}
rs.Initialize()
ptr := gopointer.Save(&rs)
rbStruct := C.NewGoStruct(
rbSnowflakeClientClass,
ptr,
)
C.RbGcGuard(rbStruct)
C.RbGcGuard(rbSnowflakeResultClass)
C.rb_ivar_set(result, RESULT_IDENTIFIER, rbStruct)
cols, _ := rows.Columns()
for idx, col := range cols {
col := col
cols[idx] = strings.ToLower(col)
}
rs := SnowflakeResult{rows, cols}
resultMap[result] = &rs
C.rb_ivar_set(result, RESULT_DURATION, RbNumFromDouble(C.double(duration)))
return result
}

//export ObjFetch
func ObjFetch(self C.VALUE, statement C.VALUE) C.VALUE {
var q C.VALUE
q = C.rb_ivar_get(self, DB_IDENTIFIER)

req := C.GetGoStruct(q)
f := gopointer.Restore(req)
x, ok := f.(*SnowflakeClient)
if !ok {
wrapRbRaise((errors.New("cannot convert SnowflakeClient pointer in ObjFetch")))
}
x, _ := clientRef[self]

return x.Fetch(statement)
}

//export Inspect
func Inspect(self C.VALUE) C.VALUE {
q := C.rb_ivar_get(self, DB_IDENTIFIER)
if q == C.Qnil {
return RbString("Object is not instantiated")
}

req := C.GetGoStruct(q)
f := gopointer.Restore(req)
x := f.(*SnowflakeClient)
return RbString(fmt.Sprintf("%+v", x))
x := clientRef[self]
return RbString(fmt.Sprintf("Snowflake::Client <%+v>", x))
}

//export Init_ruby_snowflake_client_ext
Expand All @@ -161,10 +139,20 @@ func Init_ruby_snowflake_client_ext() {
rbSnowflakeClientClass = C.rb_define_class_under(rbSnowflakeModule, C.CString("Client"), C.rb_cObject)
rbSnowflakeResultClass = C.rb_define_class_under(rbSnowflakeModule, C.CString("Result"), C.rb_cObject)

objects[rbSnowflakeResultClass] = true
objects[rbSnowflakeClientClass] = true
objects[rbSnowflakeModule] = true
objects[RESULT_DURATION] = true
objects[ERROR_IDENT] = true
C.RbGcGuard(RESULT_DURATION)
//C.RbGcGuard(RESULT_IDENTIFIER)
C.RbGcGuard(ERROR_IDENT)

C.rb_define_method(rbSnowflakeResultClass, C.CString("next_row"), (*[0]byte)(C.ObjNextRow), 0)
// `get_rows` is private as this can lead to SEGFAULT errors if not invoked
// with GC.disable due to undetermined issues caused by the Ruby GC.
C.rb_define_private_method(rbSnowflakeResultClass, C.CString("_get_rows"), (*[0]byte)(C.GetRows), 0)
C.rb_define_method(rbSnowflakeResultClass, C.CString("get_rows_no_enum"), (*[0]byte)(C.GetRowsNoEnum), 0)

C.rb_define_private_method(rbSnowflakeClientClass, C.CString("_connect"), (*[0]byte)(C.Connect), 7)
C.rb_define_method(rbSnowflakeClientClass, C.CString("inspect"), (*[0]byte)(C.Inspect), 0)
Expand Down
9 changes: 5 additions & 4 deletions ext/wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ VALUE RbNumFromLong(long v) {
return LONG2NUM(v);
}
void goobj_retain(void *);
void goobj_retain(void *, char*);
void goobj_free(void *);
void goobj_log(void *);
void goobj_mark(void *);
Expand All @@ -42,9 +42,9 @@ static const rb_data_type_t go_type = {
};
VALUE
NewGoStruct(VALUE klass, void *p)
NewGoStruct(VALUE klass, char* reason, void *p)
{
goobj_retain(p);
goobj_retain(p, reason);
return TypedData_Wrap_Struct(klass, &go_type, p);
}
Expand Down Expand Up @@ -125,7 +125,8 @@ func RbString(str string) C.VALUE {
if len(str) == 0 {
return C.rb_utf8_str_new(nil, C.long(0))
}
cstr := (*C.char)(unsafe.Pointer(&(*(*[]byte)(unsafe.Pointer(&str)))[0]))
//cstr := (*C.char)(unsafe.Pointer(&(*(*[]byte)(unsafe.Pointer(&str)))[0]))
cstr := C.CString(str)
return C.rb_utf8_str_new(cstr, C.long(len(str)))
}

Expand Down
18 changes: 15 additions & 3 deletions lib/ruby_snowflake_client.rb
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

module Snowflake
require "ruby_snowflake_client_ext" # build bundle of the go files
LOG_LEVEL = 0

class Error < StandardError
attr_reader :details
Expand Down Expand Up @@ -51,13 +52,24 @@ def valid?
def get_all_rows(&blk)
GC.disable
if blk
_get_rows(&blk)
while r = next_row do
yield r
end
else
_get_rows.to_a
get_rows_array
end
ensure
GC.enable
GC.start
end

private
def get_rows_array
arr = []
while r = next_row do
puts "at #{arr.length}" if arr.length % 15000 == 0 && LOG_LEVEL > 0
arr << r
end
arr
end
end
end
Loading

0 comments on commit 700d325

Please sign in to comment.