Skip to content

Commit

Permalink
Ear taint checker (#316)
Browse files Browse the repository at this point in the history
* Enable the EAR analysis to be a taint propagation engine
  • Loading branch information
guodongli-google authored Jun 25, 2021
1 parent bdf5b76 commit 4c61727
Show file tree
Hide file tree
Showing 8 changed files with 422 additions and 2 deletions.
5 changes: 5 additions & 0 deletions internal/pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ type Config struct {
FieldTags []fieldTagMatcher
Exclude []funcMatcher
AllowPanicOnTaintedValues bool
// Whether to use EAR pointer analysis as the taint propagation engine.
UseEAR bool
// Control the span of the call chain from a source to a sink when analyzing EAR references.
// This can reduce false positives and enhance the performance.
EARTaintCallSpan uint
}

// IsSourceFieldTag determines whether a field tag made up of a key and value
Expand Down
12 changes: 10 additions & 2 deletions internal/pkg/earpointer/analysis.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,20 @@ type visitor struct {
}

func run(pass *analysis.Pass) (interface{}, error) {
conf, err := config.ReadConfig()
if err != nil {
return nil, err
}
if !conf.UseEAR {
return &Partitions{}, nil
}
ssainput := pass.ResultOf[buildssa.Analyzer].(*buildssa.SSA)
p := Analyze(ssainput)
p := analyze(ssainput)
return p, nil
}

// Analyzes an SSA program and build the partition information.
func Analyze(ssainput *buildssa.SSA) *Partitions {
func analyze(ssainput *buildssa.SSA) *Partitions {
prog := ssainput.Pkg.Prog
// Use the call graph to initialize the contexts.
// TODO: the call graph can be CHA, RTA, VTA, etc.
Expand Down Expand Up @@ -592,6 +599,7 @@ func (vis *visitor) unifyCallWithContexts(arg ssa.Value, param ssa.Value, callsi

// Handle calls to builtin functions: https://golang.org/pkg/builtin/.
func (vis *visitor) visitBuiltin(builtin *ssa.Builtin, instr ssa.Instruction) {
// TODO(#312): support more library functions
switch builtin.Name() {
case "append": // func append(slice []Type, elems ...Type) []Type
// Propagage the arguments to the return value.
Expand Down
1 change: 1 addition & 0 deletions internal/pkg/earpointer/analysis_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ func runCodeWithContext(code string, contextK int) (*earpointer.Partitions, erro
}
ssainput := buildssa.SSA{Pkg: pkg, SrcFuncs: srcFuncs}
pass := analysis.Pass{ResultOf: map[*analysis.Analyzer]interface{}{buildssa.Analyzer: &ssainput}}
earpointer.Analyzer.Flags.Set("useEAR", "true")
earpointer.Analyzer.Flags.Set("contextK", strconv.Itoa(contextK))
// Run the analysis.
partitions, err := earpointer.Analyzer.Run(&pass)
Expand Down
15 changes: 15 additions & 0 deletions internal/pkg/earpointer/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
---
UseEAR: true
254 changes: 254 additions & 0 deletions internal/pkg/earpointer/taint.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,254 @@
// Copyright 2021 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package earpointer

import (
"go/types"

"github.com/google/go-flow-levee/internal/pkg/config"
"github.com/google/go-flow-levee/internal/pkg/utils"

"github.com/google/go-flow-levee/internal/pkg/source"
"golang.org/x/tools/go/ssa"
)

// Bounded traversal of an EAR heap.
type heapTraversal struct {
heap *Partitions
callees map[*ssa.Function]bool // the functions containing the references of interest
visited ReferenceSet // the visited references during the traversal
isTaintField func(named *types.Named, index int) bool
}

func (ht *heapTraversal) isWithinCallees(ref Reference) bool {
if fn := ref.Value().Parent(); fn != nil {
return ht.callees[fn]
}
// Globals and Builtins have no parents.
return true
}

// Obtain the references associated with a taint source, with field sensitivity.
// Composite types are examined recursively to identify the taint elements, e.g.
// (1) when a map contains taint elements, these elements are examined to identify
// taint sources; and
// (2) when a struct object contains taint fields, only these fields
// and all their subfields are included in taint sources, and the struct object
// itself is a taint source. Other fields will not be tainted.
func (ht *heapTraversal) srcRefs(rep Reference, tp types.Type, result ReferenceSet) {
heap := ht.heap
switch tp := tp.(type) {
case *types.Named:
// Consider object of type "struct {x: *T, y *T}" where x is a
// taint field, and this object's heap is "{t0,t1}: [x->t2, y->t3],
// {t2,t4} --> t5, {t3} --> t6", then {t0,t1,t2,t4,t5} are taint sources.
tt, ok := tp.Underlying().(*types.Struct)
if !ok {
ht.srcRefs(rep, tp.Underlying(), result)
return
}
result[rep] = true // the current struct object is tainted
// Look for the taint fields.
for i := 0; i < tt.NumFields(); i++ {
f := tt.Field(i)
if ht.isTaintField(tp, i) {
for fd, fref := range heap.PartitionFieldMap(rep) {
if fd.Name == f.Name() {
result[fref] = true
// Mark all the subfields to be tainted.
ht.fieldRefs(fref, result)
}
}
}
}
case *types.Pointer:
if r := heap.PartitionFieldMap(rep)[directPointToField]; r != nil {
ht.srcRefs(r, tp.Elem(), result)
} else {
ht.srcRefs(rep, tp.Elem(), result)
}
case *types.Array:
result[rep] = true
for _, r := range heap.PartitionFieldMap(rep) {
ht.srcRefs(r, tp.Elem(), result)
}
case *types.Slice:
result[rep] = true
for _, r := range heap.PartitionFieldMap(rep) {
ht.srcRefs(r, tp.Elem(), result)
}
case *types.Chan:
result[rep] = true
for _, r := range heap.PartitionFieldMap(rep) {
ht.srcRefs(r, tp.Elem(), result)
}
case *types.Map:
result[rep] = true
for _, r := range heap.PartitionFieldMap(rep) {
ht.srcRefs(r, tp.Elem(), result)
}
case *types.Basic, *types.Tuple, *types.Interface, *types.Signature:
// These types do not currently represent possible source types
}
}

// Obtains all the field references and their aliases for "ref".
// For example, return {t0,t5,t1,t3,t4} for "{t0,t5}: [0->t1, 1->t3], {t3} --> t4".
func (ht *heapTraversal) fieldRefs(ref Reference, result ReferenceSet) {
ht.visited[ref] = true
h := ht.heap
for _, m := range h.PartitionMembers(ref) {
if ht.isWithinCallees(m) {
result[m] = true
}
}
rep := h.Representative(ref)
for _, r := range h.PartitionFieldMap(rep) {
if _, ok := ht.visited[r]; !ok {
ht.fieldRefs(r, result)
}
}
}

// Return any of the sources if it can reach the taint; otherwise return nil.
// Argument "srcRefs" maps a source to its alias references.
func (ht *heapTraversal) canReach(sink ssa.Instruction, sources []*source.Source, srcRefs map[*source.Source]ReferenceSet) *source.Source {
// Obtain the alias references of a sink.
// All sub-fields of a sink object are considered.
// For example, for heap "{t0}: [0->t1(taint), 1->t2]", return true for
// sink call "sinkf(t0)" since t0 contains a taint field t1.
sinkedRefs := make(map[Reference]bool)
for _, op := range sink.Operands(nil) {
// Use a separate heapTraversal to search for the sink references.
sinkHT := &heapTraversal{heap: ht.heap, callees: ht.callees, visited: make(ReferenceSet)}
v := *op
if isLocal(v) || isGlobal(v) {
ref := MakeLocalWithEmptyContext(v)
sinkHT.fieldRefs(ref, sinkedRefs)
}
}
// Match each sink with any possible source.
for sink := range sinkedRefs {
members := ht.heap.PartitionMembers(sink)
for _, m := range members {
for _, src := range sources {
if srcRefs[src][m] {
return src
}
}
}
}
return nil
}

// For a function, transitively get the functions called within this function.
// Argument "depth" controls the depth of the call chain.
// For example, return {g1,g2,g3} for "func f(){ g1(); g2() }, func g1(){ g3() }".
func calleeFunctions(fn *ssa.Function, result map[*ssa.Function]bool, depth uint) {
if depth <= 0 {
return
}
for _, b := range fn.Blocks {
for _, instr := range b.Instrs {
if call, ok := instr.(*ssa.Call); ok {
// TODO(#317): use more advanced call graph.
// skip empty, unlinked, or visited functions
if callee := call.Call.StaticCallee(); callee != nil && len(callee.Blocks) > 0 && !result[callee] {
result[callee] = true
calleeFunctions(callee, result, depth-1)
}
}
}
}
}

func boundedDepthCallees(fn *ssa.Function, depth uint) map[*ssa.Function]bool {
result := make(map[*ssa.Function]bool)
result[fn] = true
calleeFunctions(fn, result, depth)
return result
}

// Obtain the references which are aliases of a taint source, with field sensitivity.
// Argument "heap" is an immutable EAR heap containing alias information;
// "callees" is used to bound the searching of source references in the heap.
func srcAliasRefs(src *source.Source, isTaintField func(named *types.Named, index int) bool,
heap *Partitions, callees map[*ssa.Function]bool) ReferenceSet {

val, ok := src.Node.(ssa.Value)
if !ok {
return nil
}
rep := heap.Representative(MakeLocalWithEmptyContext(val))
refs := make(ReferenceSet)
ht := &heapTraversal{heap: heap, callees: callees, visited: make(ReferenceSet), isTaintField: isTaintField}
ht.srcRefs(rep, val.Type(), refs)
return refs
}

type SourceSinkTrace struct {
Src *source.Source
Sink ssa.Instruction
Callstack []ssa.Call
}

// Look for <source, sink> pairs by examining the heap alias information.
func SourcesToSinks(funcSources source.ResultType, isTaintField func(named *types.Named, index int) bool,
heap *Partitions, conf *config.Config) []*SourceSinkTrace {

var traces []*SourceSinkTrace
for fn, sources := range funcSources {
// Transitively get the set of functions called within "fn".
// This set is used to narrow down the set of references needed to be
// considered during EAR heap traversal. It can also help reducing the
// false positives and boosting the performance.
callees := boundedDepthCallees(fn, conf.EARTaintCallSpan)
srcRefs := make(map[*source.Source]ReferenceSet)
for _, s := range sources {
srcRefs[s] = srcAliasRefs(s, isTaintField, heap, callees)
}
// Traverse all the callee functions (not just the ones with sink sources)
ht := &heapTraversal{heap: heap, callees: callees, visited: make(ReferenceSet)}
for member := range callees {
for _, b := range member.Blocks {
for _, instr := range b.Instrs {
switch v := instr.(type) {
case *ssa.Call:
sink := instr
// TODO(#317): use more advanced call graph.
callee := v.Call.StaticCallee()
if callee != nil && conf.IsSink(utils.DecomposeFunction(callee)) {
if src := ht.canReach(sink, sources, srcRefs); src != nil {
traces = append(traces, &SourceSinkTrace{Src: src, Sink: sink})
break
}
}
case *ssa.Panic:
if conf.AllowPanicOnTaintedValues {
continue
}
sink := instr
if src := ht.canReach(sink, sources, srcRefs); src != nil {
traces = append(traces, &SourceSinkTrace{Src: src, Sink: sink})
break
}

}
}
}
}
}
return traces
}
36 changes: 36 additions & 0 deletions internal/pkg/levee/levee.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@ import (
"fmt"
"go/ast"
"go/token"
"go/types"
"strings"

"github.com/google/go-flow-levee/internal/pkg/config"
"github.com/google/go-flow-levee/internal/pkg/earpointer"
"github.com/google/go-flow-levee/internal/pkg/fieldtags"
"github.com/google/go-flow-levee/internal/pkg/propagation"
"github.com/google/go-flow-levee/internal/pkg/source"
Expand All @@ -40,6 +42,7 @@ var Analyzer = &analysis.Analyzer{
fieldtags.Analyzer,
source.Analyzer,
suppression.Analyzer,
earpointer.Analyzer,
},
}

Expand All @@ -48,6 +51,13 @@ func run(pass *analysis.Pass) (interface{}, error) {
if err != nil {
return nil, err
}
if conf.UseEAR {
return runEAR(pass, conf) // Use the EAR-pointer based taint analysis
}
return runPropagation(pass, conf) // Use the propagation based taint analysis
}

func runPropagation(pass *analysis.Pass, conf *config.Config) (interface{}, error) {
funcSources := pass.ResultOf[source.Analyzer].(source.ResultType)
taggedFields := pass.ResultOf[fieldtags.Analyzer].(fieldtags.ResultType)
suppressedNodes := pass.ResultOf[suppression.Analyzer].(suppression.ResultType)
Expand All @@ -62,6 +72,7 @@ func run(pass *analysis.Pass) (interface{}, error) {
for _, instr := range b.Instrs {
switch v := instr.(type) {
case *ssa.Call:
// TODO(#317): use more advanced call graph.
if callee := v.Call.StaticCallee(); callee != nil && conf.IsSink(utils.DecomposeFunction(callee)) {
reportSourcesReachingSink(conf, pass, suppressedNodes, propagations, instr)
}
Expand All @@ -78,6 +89,31 @@ func run(pass *analysis.Pass) (interface{}, error) {
return nil, nil
}

// Use the EAR pointer analysis as the propagation engine
func runEAR(pass *analysis.Pass, conf *config.Config) (interface{}, error) {
heap := pass.ResultOf[earpointer.Analyzer].(*earpointer.Partitions)
if heap == nil {
return nil, fmt.Errorf("no valid EAR partitions")
}
funcSources := pass.ResultOf[source.Analyzer].(source.ResultType)
taggedFields := pass.ResultOf[fieldtags.Analyzer].(fieldtags.ResultType)
suppressedNodes := pass.ResultOf[suppression.Analyzer].(suppression.ResultType)
// Return whether a field is tainted.
isTaintField := func(named *types.Named, index int) bool {
if tt, ok := named.Underlying().(*types.Struct); ok {
return conf.IsSourceField(utils.DecomposeField(named, index)) || taggedFields.IsSourceField(tt, index)
}
return false
}
for _, trace := range earpointer.SourcesToSinks(funcSources, isTaintField, heap, conf) {
sink := trace.Sink
if !isSuppressed(sink.Pos(), suppressedNodes, pass) {
report(conf, pass, trace.Src, sink.(ssa.Node))
}
}
return nil, nil
}

func reportSourcesReachingSink(conf *config.Config, pass *analysis.Pass, suppressedNodes suppression.ResultType, propagations map[*source.Source]propagation.Propagation, sink ssa.Instruction) {
for src, prop := range propagations {
if prop.IsTainted(sink) && !isSuppressed(sink.Pos(), suppressedNodes, pass) {
Expand Down
Loading

0 comments on commit 4c61727

Please sign in to comment.