-
Notifications
You must be signed in to change notification settings - Fork 2
/
main.go
142 lines (124 loc) · 3.99 KB
/
main.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
package main
import (
"context"
"flag"
"fmt"
"log"
"net"
"net/http"
"os"
"os/signal"
"syscall"
"time"
providerserver "github.com/akeylesslabs/akeyless-csi-provider/internal/server"
"github.com/akeylesslabs/akeyless-csi-provider/internal/version"
"google.golang.org/grpc"
"google.golang.org/grpc/status"
pb "sigs.k8s.io/secrets-store-csi-driver/provider/v1alpha1"
)
func realMain() error {
var (
endpoint = flag.String("endpoint", "/tmp/akeyless.sock", "path to socket on which to listen for driver gRPC calls")
selfVersion = flag.Bool("version", false, "prints the version information")
vaultAddr = flag.String("akeyless-address", "https://api.akeyless.io", "Akeyless API URL")
vaultMount = flag.String("mount", "kubernetes", "default mount path for Kubernetes authentication")
healthAddr = flag.String("health-address", ":8080", "configure http listener for reporting health")
)
flag.Parse()
if *selfVersion {
v, err := version.GetVersion()
if err != nil {
return fmt.Errorf("failed to print version, err: %w", err)
}
// print the version and exit
_, err = fmt.Println(v)
return err
}
log.Print("Creating new gRPC server")
server := grpc.NewServer(
grpc.UnaryInterceptor(func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
startTime := time.Now()
log.Printf("Processing unary gRPC call grpc.method: %v", info.FullMethod)
resp, err := handler(ctx, req)
log.Printf("Finished unary gRPC call grpc.method: %v, grpc.time: %v, grpc.code: %v", info.FullMethod, time.Since(startTime), status.Code(err))
if err != nil {
log.Printf("Error: %v", err.Error())
}
log.Print("Finished unary gRPC call")
return resp, err
}),
)
c := make(chan os.Signal, 1)
signal.Notify(c, syscall.SIGTERM, syscall.SIGINT)
go func() {
sig := <-c
log.Printf("Caught signal %s, shutting down", sig)
server.GracefulStop()
}()
listener, err := listen(*endpoint)
if err != nil {
return err
}
defer listener.Close()
s := &providerserver.Server{
VaultAddr: *vaultAddr,
VaultMount: *vaultMount,
}
pb.RegisterCSIDriverProviderServer(server, s)
// Create health handler
mux := http.NewServeMux()
ms := http.Server{
Addr: *healthAddr,
Handler: mux,
}
defer func() {
err := ms.Shutdown(context.Background())
if err != nil {
log.Fatalf("Error shutting down health handler, err: %v", err.Error())
}
}()
mux.HandleFunc("/health/ready", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
// Start health handler
go func() {
log.Printf("Starting health handler, addr: %v", *healthAddr)
if err := ms.ListenAndServe(); err != nil && err != http.ErrServerClosed {
log.Fatalf("Error with health handler, error: %v", err.Error())
}
}()
log.Print("Starting gRPC server")
err = server.Serve(listener)
if err != nil {
return fmt.Errorf("error running gRPC server: %v", err.Error())
}
return nil
}
func listen(endpoint string) (net.Listener, error) {
// Because the unix socket is created in a host volume (i.e. persistent
// storage), it can persist from previous runs if the pod was not terminated
// cleanly. Check if we need to clean up before creating a listener.
_, err := os.Stat(endpoint)
if err != nil && !os.IsNotExist(err) {
return nil, fmt.Errorf("failed to check for existence of unix socket: %v", err.Error())
} else if err == nil {
log.Printf("Cleaning up pre-existing file at unix socket location, endpoint: %v", endpoint)
err = os.Remove(endpoint)
if err != nil {
return nil, fmt.Errorf("failed to clean up pre-existing file at unix socket location: %v", err.Error())
}
}
log.Printf("Opening unix socket, endpoint %v", endpoint)
listener, err := net.Listen("unix", endpoint)
if err != nil {
return nil, fmt.Errorf("failed to listen on unix socket at %s: %v", endpoint, err.Error())
}
return listener, nil
}
func main() {
err := realMain()
if err != nil {
log.Fatalf("Error running provider: %v", err.Error())
os.Exit(1)
}
}