Skip to content

Commit

Permalink
fix(runtime): Stream handling
Browse files Browse the repository at this point in the history
Fixes SSE and WebSocket handling by not assuming the request/response will always be a Map. This was recently changed to allow arbitrary values to be passed over the sockets.
  • Loading branch information
dnys1 committed Oct 9, 2024
1 parent c4d59ee commit 0bf414b
Show file tree
Hide file tree
Showing 9 changed files with 44 additions and 30 deletions.
10 changes: 5 additions & 5 deletions packages/celest/lib/src/runtime/sse/sse_handler.dart
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ library;
import 'dart:async';
import 'dart:convert';
import 'dart:io';
import 'dart:typed_data';

import 'package:async/async.dart';
import 'package:celest_core/_internal.dart';
Expand Down Expand Up @@ -205,6 +206,9 @@ final class _SseHandler {
});
}

static final _jsonUtf8Decoder =
utf8.decoder.fuse(json.decoder).cast<Uint8List, Object?>();

Future<Response> _handleIncomingMessage(Request request) async {
final clientId = request.url.queryParameters['sseClientId'];
if (clientId == null) {
Expand All @@ -222,11 +226,7 @@ final class _SseHandler {
request.url.queryParameters['messageId'] ?? '0',
);
try {
final message = await request
.read()
.transform(utf8.decoder)
.transform(json.decoder)
.single;
final message = await request.read().transform(_jsonUtf8Decoder).first;
connection._handleIncoming(messageId, message);
return Response(HttpStatus.accepted);
} on Object catch (e, st) {
Expand Down
2 changes: 1 addition & 1 deletion packages/celest/lib/src/runtime/targets.dart
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ abstract base class CloudEventSourceTarget extends CloudFunctionTarget {
StreamChannelTransformer(
StreamTransformer.fromHandlers(
handleData: (data, sink) {
sink.add(JsonUtf8.decodeMap(data));
sink.add(JsonUtf8.decodeAny(data));
},
),
StreamSinkTransformer.fromHandlers(
Expand Down
8 changes: 4 additions & 4 deletions packages/celest_core/lib/src/base/base_protocol.dart
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ mixin BaseProtocol {
};
}

Stream<Map<String, Object?>> connect(
Stream<Object?> connect(
String path, {
required Map<String, Object?> payload,
}) {
Expand Down Expand Up @@ -84,10 +84,10 @@ mixin BaseProtocol {
);
}
final json = jsonDecode(response.body) as Map<String, Object?>;
final error = json['error'] as Map<String, Object?>?;
final error = json['@error'] as Map<String, Object?>?;
throw createError(
error?['message'] as String?,
details: switch (error?['details']) {
error?['message'] as String? ?? json['message'] as String?,
details: switch (json['details']) {
null => null,
final Object details => JsonValue(details),
},
Expand Down
2 changes: 1 addition & 1 deletion packages/celest_core/lib/src/events/event_channel.dart
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import 'package:celest_core/src/events/event_channel.vm.dart'
import 'package:http/http.dart' as http;
import 'package:stream_channel/stream_channel.dart';

abstract class EventChannel with StreamChannelMixin<Map<String, Object?>> {
abstract class EventChannel with StreamChannelMixin<Object?> {
EventChannel();
factory EventChannel.connect(
Uri uri, {
Expand Down
7 changes: 3 additions & 4 deletions packages/celest_core/lib/src/events/event_channel.vm.dart
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@ final class EventChannelPlatform extends EventChannel {
}

final WebSocketChannel _ws;
late final StreamSink<Map<String, Object?>> _wsSink =
_ws.sink.rejectErrors().transform(
late final StreamSink<Object?> _wsSink = _ws.sink.rejectErrors().transform(
StreamSinkTransformer.fromHandlers(
handleData: (data, sink) {
sink.add(jsonEncode(data));
Expand All @@ -41,10 +40,10 @@ final class EventChannelPlatform extends EventChannel {
);

@override
Stream<Map<String, Object?>> get stream => _ws.stream.map(JsonUtf8.decodeMap);
Stream<Object?> get stream => _ws.stream.map(JsonUtf8.decodeAny);

@override
StreamSink<Map<String, Object?>> get sink => _wsSink;
StreamSink<Object?> get sink => _wsSink;

@override
void close() {
Expand Down
5 changes: 2 additions & 3 deletions packages/celest_core/lib/src/events/event_channel.web.dart
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@ import 'package:celest_core/src/events/event_channel.dart';
import 'package:http/http.dart' as http;
import 'package:stream_channel/stream_channel.dart';

final class EventChannelPlatform
extends DelegatingStreamChannel<Map<String, Object?>>
final class EventChannelPlatform extends DelegatingStreamChannel<Object?>
implements EventChannel {
EventChannelPlatform._(super._inner);

Expand All @@ -16,7 +15,7 @@ final class EventChannelPlatform
http.Client? httpClient,
}) {
final sseClient = SseClient(serverUri: uri, httpClient: httpClient);
final completer = StreamChannelCompleter<Map<String, Object?>>();
final completer = StreamChannelCompleter<Object?>();
sseClient.onConnected
.then((_) => completer.setChannel(sseClient))
.onError<Object>((e, st) {
Expand Down
2 changes: 1 addition & 1 deletion packages/celest_core/lib/src/events/sse/sse_client.dart
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import 'package:stream_channel/stream_channel.dart';
/// {@template celest.runtime.sse_client}
/// A Server-Sent Events (SSE) client.
/// {@endtemplate}
abstract class SseClient with StreamChannelMixin<Map<String, Object?>> {
abstract class SseClient with StreamChannelMixin<Object?> {
/// Creates a new [SseClient] connected to the server at [serverUri].
factory SseClient({
required Uri serverUri,
Expand Down
15 changes: 5 additions & 10 deletions packages/celest_core/lib/src/events/sse/sse_client.web.dart
Original file line number Diff line number Diff line change
Expand Up @@ -56,18 +56,16 @@ final class SseClientPlatform extends SseClient {

var _lastMessageId = -1;

late final StreamController<Map<String, Object?>> _incomingController =
late final StreamController<Object?> _incomingController =
StreamController(onCancel: close);

@override
Stream<Map<String, Object?>> get stream => _incomingController.stream;
Stream<Object?> get stream => _incomingController.stream;

late final StreamController<Map<String, Object?>> _outgoingController =
StreamController();
late final StreamController<Object?> _outgoingController = StreamController();

@override
late final StreamSink<Map<String, Object?>> sink =
_outgoingController.sink.transform(
late final StreamSink<Object?> sink = _outgoingController.sink.transform(
StreamSinkTransformer.fromHandlers(
handleError: (error, stackTrace, sink) {
_closeWithError(error, stackTrace);
Expand Down Expand Up @@ -118,9 +116,6 @@ final class SseClientPlatform extends SseClient {
_logger.finest('Message event: $data');
try {
final message = jsonDecode(data);
if (message is! Map<String, Object?>) {
throw FormatException('Expected a Map, got ${message.runtimeType}');
}
_incomingController.add(message);
} on Object catch (e, st) {
_logger.severe('Invalid message: $data', e, st);
Expand All @@ -131,7 +126,7 @@ final class SseClientPlatform extends SseClient {
}
}

Future<void> _onOutgoingMessage(Map<String, Object?> message) async {
Future<void> _onOutgoingMessage(Object? message) async {
final uri = serverUri.replace(
queryParameters: {
...serverUri.queryParametersAll,
Expand Down
23 changes: 22 additions & 1 deletion packages/celest_core/lib/src/util/json.dart
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,23 @@ extension JsonUtf8 on Object {
);
}

/// Decodes a JSON map [body] from a `List<int>` or [String].
static Object? decodeAny(Object? body) {
return switch (body) {
List<int>() => body.isEmpty ? const {} : decode(body),
String() => body.isEmpty
? const {}
: () {
try {
return jsonDecode(body);
} on Object {
return body;
}
}(),
_ => body,
};
}

/// Decodes a JSON map [body] from a `List<int>` or [String].
static Map<String, Object?> decodeMap(Object? body) {
Object? decoded;
Expand All @@ -36,7 +53,11 @@ extension JsonUtf8 on Object {
decoded = decode(body);
case String():
if (body.isEmpty) return const {};
decoded = jsonDecode(body);
try {
decoded = jsonDecode(body);
} on Object {
decoded = body;
}
default:
decoded = body;
}
Expand Down

0 comments on commit 0bf414b

Please sign in to comment.