diff --git a/psrpc/errors.go b/psrpc/errors.go new file mode 100644 index 00000000..0fc86b9d --- /dev/null +++ b/psrpc/errors.go @@ -0,0 +1,82 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package psrpc + +import ( + "context" + "errors" + + "google.golang.org/protobuf/proto" + + "github.com/livekit/psrpc" +) + +func WithSuppressClientErrors(errs ...error) psrpc.ClientOption { + return psrpc.WithClientOptions( + psrpc.WithClientRPCInterceptors(newClientRPCErrorInterceptor(errs...)), + psrpc.WithClientMultiRPCInterceptors(newMultiRPCErorrInterceptor(errs...)), + ) +} + +func WithSuppressServerErrors(errs ...error) psrpc.ServerOption { + return psrpc.WithServerOptions( + psrpc.WithServerRPCInterceptors(newServerRPCErorrInterceptor(errs...)), + ) +} + +func suppressErrors(err error, ignored ...error) error { + if err != nil { + for _, e := range ignored { + if errors.Is(err, e) { + return nil + } + } + } + return err +} + +func newClientRPCErrorInterceptor(errs ...error) psrpc.ClientRPCInterceptor { + return func(rpcInfo psrpc.RPCInfo, next psrpc.ClientRPCHandler) psrpc.ClientRPCHandler { + return func(ctx context.Context, req proto.Message, opts ...psrpc.RequestOption) (res proto.Message, err error) { + res, err = next(ctx, req, opts...) + return res, suppressErrors(err, errs...) + } + } +} + +func newServerRPCErorrInterceptor(errs ...error) psrpc.ServerRPCInterceptor { + return func(ctx context.Context, req proto.Message, rpcInfo psrpc.RPCInfo, handler psrpc.ServerRPCHandler) (res proto.Message, err error) { + res, err = handler(ctx, req) + return res, suppressErrors(err, errs...) + } +} + +func newMultiRPCErorrInterceptor(errs ...error) psrpc.ClientMultiRPCInterceptor { + return func(rpcInfo psrpc.RPCInfo, next psrpc.ClientMultiRPCHandler) psrpc.ClientMultiRPCHandler { + return &multiRPCErorrInterceptor{ + ClientMultiRPCHandler: next, + errors: errs, + } + } +} + +type multiRPCErorrInterceptor struct { + psrpc.ClientMultiRPCHandler + errors []error +} + +func (r *multiRPCErorrInterceptor) Recv(msg proto.Message, err error) { + r.ClientMultiRPCHandler.Recv(msg, suppressErrors(err, r.errors...)) +}