Skip to content

Commit

Permalink
Allow toggling CORS support
Browse files Browse the repository at this point in the history
  • Loading branch information
omerfirmak committed Feb 2, 2024
1 parent 41c7408 commit fe17176
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 6 deletions.
4 changes: 4 additions & 0 deletions cmd/juno/juno.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ const (
cnCoreContractAddressF = "cn-core-contract-address"
cnUnverifiableRangeF = "cn-unverifiable-range"
callMaxStepsF = "rpc-call-max-steps"
corsEnableF = "rpc-cors-enable"

defaultConfig = ""
defaulHost = "localhost"
Expand Down Expand Up @@ -109,6 +110,7 @@ const (
defaultCNCoreContractAddressStr = ""
defaultCallMaxSteps = 4_000_000
defaultGwTimeout = 5 * time.Second
defaultCorsEnable = false

configFlagUsage = "The yaml configuration file."
logLevelFlagUsage = "Options: debug, info, warn, error."
Expand Down Expand Up @@ -152,6 +154,7 @@ const (
gwAPIKeyUsage = "API key for gateway endpoints to avoid throttling" //nolint: gosec
gwTimeoutUsage = "Timeout for requests made to the gateway" //nolint: gosec
callMaxStepsUsage = "Maximum number of steps to be executed in starknet_call requests"
corsEnableUsage = "Enable CORS on RPC endpoints"
)

var Version string
Expand Down Expand Up @@ -328,6 +331,7 @@ func NewCmd(config *node.Config, run func(*cobra.Command, []string) error) *cobr
junoCmd.MarkFlagsMutuallyExclusive(networkF, cnNameF)
junoCmd.Flags().Uint(callMaxStepsF, defaultCallMaxSteps, callMaxStepsUsage)
junoCmd.Flags().Duration(gwTimeoutF, defaultGwTimeout, gwTimeoutUsage)
junoCmd.Flags().Bool(corsEnableF, defaultCorsEnable, corsEnableUsage)

return junoCmd
}
18 changes: 14 additions & 4 deletions node/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ func exactPathServer(path string, handler http.Handler) http.HandlerFunc {
}

func makeRPCOverHTTP(host string, port uint16, servers map[string]*jsonrpc.Server,
log utils.SimpleLogger, metricsEnabled bool,
log utils.SimpleLogger, metricsEnabled bool, corsEnabled bool,
) *httpService {
var listener jsonrpc.NewRequestListener
if metricsEnabled {
Expand All @@ -89,11 +89,16 @@ func makeRPCOverHTTP(host string, port uint16, servers map[string]*jsonrpc.Serve
}
mux.Handle(path, exactPathServer(path, httpHandler))
}
return makeHTTPService(host, port, cors.Default().Handler(mux))

var handler http.Handler = mux
if corsEnabled {
handler = cors.Default().Handler(handler)
}
return makeHTTPService(host, port, handler)
}

func makeRPCOverWebsocket(host string, port uint16, servers map[string]*jsonrpc.Server,
log utils.SimpleLogger, metricsEnabled bool,
log utils.SimpleLogger, metricsEnabled bool, corsEnabled bool,
) *httpService {
var listener jsonrpc.NewRequestListener
if metricsEnabled {
Expand All @@ -110,7 +115,12 @@ func makeRPCOverWebsocket(host string, port uint16, servers map[string]*jsonrpc.
wsPrefixedPath := strings.TrimSuffix("/ws"+path, "/")
mux.Handle(wsPrefixedPath, exactPathServer(wsPrefixedPath, wsHandler))
}
return makeHTTPService(host, port, cors.Default().Handler(mux))

var handler http.Handler = mux
if corsEnabled {
handler = cors.Default().Handler(handler)
}
return makeHTTPService(host, port, handler)
}

func makeMetrics(host string, port uint16) *httpService {
Expand Down
5 changes: 3 additions & 2 deletions node/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ type Config struct {
HTTP bool `mapstructure:"http"`
HTTPHost string `mapstructure:"http-host"`
HTTPPort uint16 `mapstructure:"http-port"`
RPCCorsEnable bool `mapstructure:"rpc-cors-enable"`
Websocket bool `mapstructure:"ws"`
WebsocketHost string `mapstructure:"ws-host"`
WebsocketPort uint16 `mapstructure:"ws-port"`
Expand Down Expand Up @@ -169,10 +170,10 @@ func New(cfg *Config, version string) (*Node, error) { //nolint:gocyclo,funlen
"/rpc" + legacyPath: jsonrpcServerLegacy,
}
if cfg.HTTP {
services = append(services, makeRPCOverHTTP(cfg.HTTPHost, cfg.HTTPPort, rpcServers, log, cfg.Metrics))
services = append(services, makeRPCOverHTTP(cfg.HTTPHost, cfg.HTTPPort, rpcServers, log, cfg.Metrics, cfg.RPCCorsEnable))
}
if cfg.Websocket {
services = append(services, makeRPCOverWebsocket(cfg.WebsocketHost, cfg.WebsocketPort, rpcServers, log, cfg.Metrics))
services = append(services, makeRPCOverWebsocket(cfg.WebsocketHost, cfg.WebsocketPort, rpcServers, log, cfg.Metrics, cfg.RPCCorsEnable))
}
var metricsService service.Service
if cfg.Metrics {
Expand Down

0 comments on commit fe17176

Please sign in to comment.