diff --git a/db/db.go b/db/db.go index 0a668c4..f80adcb 100644 --- a/db/db.go +++ b/db/db.go @@ -124,13 +124,21 @@ func (db *SqlDB) Get(id string) (SQLData, error) { return record, nil } -func (db *SqlDB) Query(scope string, verifier string, client string) (string, error) { +func (db *SqlDB) Query(scope string, verifier string, client string) ([]string, error) { var record SQLData - // todo multiple records - err := db.sqlDB.Model(&SQLData{}).Where("scope = ? AND verifier = ? AND client = ?", scope, verifier, client). - First(&record).Error + results := make([]string, 0) + rows, err := db.sqlDB.Model(&SQLData{}).Where("scope = ? AND verifier = ? AND client = ?", scope, verifier, client). + Rows() if err != nil { - return "", err + return results, err } - return record.AuthInput, nil + // iterate over rows + for rows.Next() { + err = db.sqlDB.ScanRows(rows, &record) + if err != nil { + return results, err + } + results = append(results, record.AuthInput) + } + return results, nil } diff --git a/db/interface.go b/db/interface.go index 96400fb..8f9fa4d 100644 --- a/db/interface.go +++ b/db/interface.go @@ -5,5 +5,5 @@ type DB interface { Create(data SQLData) error Delete(id string) error Get(id string) (SQLData, error) - Query(scope string, verifier string, client string) (string, error) + Query(scope string, verifier string, client string) ([]string, error) } diff --git a/main.go b/main.go index 4bd3ffb..b4b0737 100644 --- a/main.go +++ b/main.go @@ -22,6 +22,7 @@ package main import ( "context" "errors" + "fmt" "github.com/nuts-foundation/nuts-pxp/policy" "net/http" "os" @@ -94,6 +95,8 @@ func main() { } func errorHandlerfunc(err error, ctx echo.Context) { + fmt.Printf("error: %s\n", err.Error()) + ctx.Response().Status = http.StatusInternalServerError if !ctx.Response().Committed { ctx.Response().Write([]byte(err.Error())) } diff --git a/policy/opa.go b/policy/opa.go index 916ed88..2934b9f 100644 --- a/policy/opa.go +++ b/policy/opa.go @@ -3,7 +3,9 @@ package policy import ( "context" "encoding/json" + "errors" "fmt" + "gorm.io/gorm" "os" "strings" @@ -51,24 +53,26 @@ func (dm *OPADecision) Query(ctx context.Context, requestLine map[string]interfa // query DB for runtime data data, err := dm.db.Query(scope, verifier, client) - if err != nil { - return false, err + if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { + if errors.Is(err, gorm.ErrRecordNotFound) { + data = []string{} + } else { + return false, err + } } // parse the data into a map input := map[string]interface{}{} external := map[string]interface{}{} - request := map[string]interface{}{} - err = json.Unmarshal([]byte(data), &external) - if err != nil { - return false, err - } - for k, v := range requestLine { - request[k] = v + for _, stringAuthInput := range data { + err = json.Unmarshal([]byte(stringAuthInput), &external) + if err != nil { + return false, err + } } // merge the request line and introspection result into the input input["external"] = external - input["request"] = request + input["request"] = requestLine for k, v := range introspectionResult { input[k] = v }