-
Notifications
You must be signed in to change notification settings - Fork 241
/
interactive_model.py
1425 lines (1312 loc) · 67 KB
/
interactive_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""
Class for performing inference on a model in real time, i.e. in the context of interactive user
requests.
Requests consist of two parts: an InferenceRequestSpec and a XRequestSpec.
- InferenceRequestSpec contains the information necessary for computing a forward and backward
pass on a model (prompt, loss function, in future ablation info).
- XRequestSpec contains the information necessary for computing a derived scalar of type X
(activation, derived scalar type, layer index, etc.) This can include information necessary for
inserting hooks into the model, e.g. in the case of the online autoencoder latent.
Functions for handling requests first compute DerivedScalarStore, then call a helper function that
takes XRequestSpec + DerivedScalarStore as input. CombinedRequestSpec contains
InferenceRequestSpec + a list of [XRequestSpec, YRequestSpec, ...]. It first computes
DerivedScalarStore, then calls relevant helper functions to generate a response containing sub
responses for the various sub request specs.
"""
import asyncio
from collections import defaultdict
from dataclasses import asdict, dataclass
from typing import Callable, TypeVar
import torch
from fastapi import HTTPException
from neuron_explainer.activation_server.derived_scalar_computation import (
DerivedScalarComputationParams,
DstAndConfigsByProcessingStep,
InferenceAndTokenData,
InferenceData,
compute_derived_scalar_groups_for_input_token_ints,
maybe_construct_loss_fn_for_backward_pass,
)
from neuron_explainer.activation_server.dst_helpers import (
assert_tensor,
get_intermediate_sum_by_dst,
)
from neuron_explainer.activation_server.load_neurons import load_neuron_from_datasets
from neuron_explainer.activation_server.read_routes import (
TokenAndRawAttentionScalars,
normalize_attention_token_scalars,
zip_tokens_and_attention_activations,
)
from neuron_explainer.activation_server.requests_and_responses import *
from neuron_explainer.activation_server.tdb_conversions import (
convert_tdb_request_spec_to_inference_sub_request,
)
from neuron_explainer.activations.derived_scalars.derived_scalar_store import DerivedScalarStore
from neuron_explainer.activations.derived_scalars.derived_scalar_types import DerivedScalarType
from neuron_explainer.activations.derived_scalars.indexing import (
DerivedScalarIndex,
MirroredNodeIndex,
NodeIndex,
TraceConfig,
)
from neuron_explainer.activations.derived_scalars.multi_group import (
MultiGroupDerivedScalarStore,
MultiGroupScalarDerivers,
)
from neuron_explainer.activations.derived_scalars.postprocessing import (
DerivedScalarPostprocessor,
TokenPairAttributionConverter,
TokenReadConverter,
TokenWriteConverter,
)
from neuron_explainer.activations.derived_scalars.scalar_deriver import DstConfig
from neuron_explainer.activations.derived_scalars.tokens import (
TopTokens,
get_most_upvoted_and_downvoted_tokens_for_nodes,
)
from neuron_explainer.models.autoencoder_context import AutoencoderContext, MultiAutoencoderContext
from neuron_explainer.models.inference_engine_type_registry import InferenceEngineType
from neuron_explainer.models.model_component_registry import (
ActivationLocationType,
Dimension,
NodeType,
PassType,
)
from neuron_explainer.models.model_context import ModelContext, StandardModelContext
from neuron_explainer.models.transformer import Transformer
from neuron_explainer.pydantic import CamelCaseBaseModel, immutable
T = TypeVar("T")
TOKEN_READ_DSTS: list[DerivedScalarType] = [
DerivedScalarType.VOCAB_TOKEN_WRITE_TO_INPUT_DIRECTION,
]
PROMPT_LENGTH_LIMIT = 500
@immutable
class TopKData(CamelCaseBaseModel):
"""The results of an individual top-k operation, which is represented as a TopKParams."""
activations: list[float]
node_indices: list[MirroredNodeIndex]
vocab_token_strings_for_indices: list[str | None] | None
# This is the total of all activations, including non-top-k activations.
intermediate_sum_activations_by_dst: dict[DerivedScalarType, TensorND]
@dataclass(frozen=True)
class RequestResponseCorrespondence:
request_class: type
request_spec_class: type
request_spec_name: str
response_class: type
response_data_class: type
response_data_name: str
REQUEST_RESPONSE_CORRESPONDENCE_REGISTRY: list[RequestResponseCorrespondence] = [
RequestResponseCorrespondence(
request_class=DerivedScalarsRequest,
request_spec_class=DerivedScalarsRequestSpec,
request_spec_name="derived_scalars_request_spec",
response_class=DerivedScalarsResponse,
response_data_class=DerivedScalarsResponseData,
response_data_name="derived_scalars_response_data",
),
RequestResponseCorrespondence(
request_class=DerivedAttentionScalarsRequest,
request_spec_class=DerivedAttentionScalarsRequestSpec,
request_spec_name="derived_attention_scalars_request_spec",
response_class=DerivedAttentionScalarsResponse,
response_data_class=DerivedAttentionScalarsResponseData,
response_data_name="derived_attention_scalars_response_data",
),
RequestResponseCorrespondence(
request_class=MultipleTopKDerivedScalarsRequest,
request_spec_class=MultipleTopKDerivedScalarsRequestSpec,
request_spec_name="multiple_top_k_derived_scalars_request_spec",
response_class=MultipleTopKDerivedScalarsResponse,
response_data_class=MultipleTopKDerivedScalarsResponseData,
response_data_name="multiple_top_k_derived_scalars_response_data",
),
RequestResponseCorrespondence(
request_class=ScoredTokensRequest,
request_spec_class=ScoredTokensRequestSpec,
request_spec_name="scored_tokens_request_spec",
response_class=ScoredTokensResponse,
response_data_class=ScoredTokensResponseData,
response_data_name="scored_tokens_response_data",
),
RequestResponseCorrespondence(
request_class=TokenPairAttributionRequest,
request_spec_class=TokenPairAttributionRequestSpec,
request_spec_name="token_pair_attribution_request_spec",
response_class=TokenPairAttributionResponse,
response_data_class=TokenPairAttributionResponseData,
response_data_name="token_pair_attribution_response_data",
),
]
def get_corresponding_object(
object: str | type, object_category: str, desired_category: str
) -> str | type:
correspondence_of_interest = [
correspondence
for correspondence in REQUEST_RESPONSE_CORRESPONDENCE_REGISTRY
if getattr(correspondence, object_category) == object
]
assert (
len(correspondence_of_interest) == 1
), f"Found {len(correspondence_of_interest)} correspondences for {object_category} {object}"
return getattr(correspondence_of_interest[0], desired_category)
def _make_vocab_token_string_for_node_index(
model_context: ModelContext, node_index: NodeIndex
) -> str | None:
if node_index.node_type == NodeType.VOCAB_TOKEN:
last_index = node_index.tensor_indices[-1]
if last_index is None or last_index >= model_context.n_vocab:
return None
return model_context.decode_token(last_index)
return None
def _make_vocab_token_strings_for_indices(
model_context: ModelContext, activation_indices: list[NodeIndex]
) -> list[str | None] | None:
vocab_token_strings_for_indices = [
_make_vocab_token_string_for_node_index(model_context, node_index)
for node_index in activation_indices
]
if all(vocab_token_string is None for vocab_token_string in vocab_token_strings_for_indices):
return None
else:
return vocab_token_strings_for_indices
def _unique_list_in_original_order(original_list: list[T]) -> list[T]:
"""
Returns a list containing the unique elements of the original list, in the same order as the
original list. `list(set(original_list))` does not preserve the original order.
"""
unique_list = []
seen = set()
for item in original_list:
if item not in seen:
unique_list.append(item)
seen.add(item)
return unique_list
class InteractiveModel:
def __init__(
self,
transformer: Transformer,
standard_model_context: StandardModelContext,
autoencoder_context: AutoencoderContext | MultiAutoencoderContext | None = None,
) -> None:
self.transformer = transformer
self._standard_model_context = standard_model_context
self._multi_autoencoder_context = MultiAutoencoderContext.from_context_or_multi_context(
autoencoder_context
)
# We only allow one batched request to be handled at a time. Concurrent batched requests
# tend to result in cuda OOMs.
self._batched_request_lock = asyncio.Lock()
@property
def has_mlp_autoencoder(self) -> bool:
return self._multi_autoencoder_context is not None and any(
node_type == NodeType.MLP_AUTOENCODER_LATENT
for node_type in self._multi_autoencoder_context.autoencoder_context_by_node_type.keys()
)
@property
def has_attention_autoencoder(self) -> bool:
return self._multi_autoencoder_context is not None and any(
node_type == NodeType.ATTENTION_AUTOENCODER_LATENT
for node_type in self._multi_autoencoder_context.autoencoder_context_by_node_type.keys()
)
def get_model_info(
self, mlp_autoencoder_name: str | None, attn_autoencoder_name: str | None
) -> ModelInfoResponse:
return ModelInfoResponse(
model_name=self._standard_model_context.model_name,
has_mlp_autoencoder=self.has_mlp_autoencoder,
has_attention_autoencoder=self.has_attention_autoencoder,
n_layers=self._standard_model_context.n_layers,
mlp_autoencoder_name=mlp_autoencoder_name,
attention_autoencoder_name=attn_autoencoder_name,
)
def encode(self, string: str) -> list[int]:
return self._standard_model_context.encode(string)
@classmethod
def from_model_name(cls, model_name: str) -> "InteractiveModel":
standard_model_context = StandardModelContext(model_name=model_name)
return cls.from_standard_model_context(standard_model_context)
@classmethod
def from_standard_model_context(
cls, standard_model_context: StandardModelContext
) -> "InteractiveModel":
return cls(
transformer=standard_model_context.get_or_create_model(),
standard_model_context=standard_model_context,
)
@classmethod
def from_standard_model_context_and_autoencoder_context(
cls,
standard_model_context: StandardModelContext,
autoencoder_context: AutoencoderContext | MultiAutoencoderContext,
) -> "InteractiveModel":
return cls(
transformer=standard_model_context.get_or_create_model(),
standard_model_context=standard_model_context,
autoencoder_context=autoencoder_context,
)
async def _handle_inference_request(
self, inference_request: InferenceRequest
) -> InferenceResponse:
request_type = type(inference_request)
processing_spec_name = get_corresponding_object(
request_type, "request_class", "request_spec_name"
)
assert isinstance(processing_spec_name, str)
response_class = get_corresponding_object(request_type, "request_class", "response_class")
assert isinstance(response_class, type)
response_data_name = get_corresponding_object(
request_type, "request_class", "response_data_name"
)
assert isinstance(response_data_name, str)
processing_request_spec = getattr(inference_request, processing_spec_name)
# We handle the singular case by wrapping the request in a batched request, to avoid the need to
# special-case non-batched requests.
batched_request = BatchedRequest(
inference_sub_requests=[
InferenceSubRequest(
inference_request_spec=inference_request.inference_request_spec,
processing_request_spec_by_name={processing_spec_name: processing_request_spec},
)
]
)
batched_response = await self.handle_batched_request(batched_request)
assert len(batched_response.inference_sub_responses) == 1
sub_response = batched_response.inference_sub_responses[0]
return response_class(
inference_and_token_data=sub_response.inference_response.inference_and_token_data,
**{
response_data_name: sub_response.processing_response_data_by_name[
processing_spec_name
]
},
)
async def get_derived_scalars(
self, inference_request: DerivedScalarsRequest
) -> DerivedScalarsResponse:
response = await self._handle_inference_request(inference_request)
assert isinstance(response, DerivedScalarsResponse)
return response
async def _get_derived_scalars_from_ds_store(
self,
request_spec: DerivedScalarsRequestSpec,
ds_store: DerivedScalarStore,
) -> DerivedScalarsResponseData:
ds_index = DerivedScalarIndex(
dst=request_spec.dst,
pass_type=request_spec.pass_type,
layer_index=request_spec.layer_index,
tensor_indices=(None, request_spec.activation_index),
)
activations = ds_store[ds_index]
index_of_sequence = NodeIndex.from_ds_index(ds_index)
index_base_dict = asdict(index_of_sequence)
index_base_dict.pop("tensor_indices")
activations_to_return = assert_tensor(activations)
assert activations_to_return.ndim == 1, activations_to_return.shape
indices_to_return = []
for token_index in range(activations_to_return.shape[0]):
indices_to_return.append(
MirroredNodeIndex(
**index_base_dict,
tensor_indices=(token_index,) + index_of_sequence.tensor_indices[1:],
)
)
if request_spec.normalize_activations_using_neuron_record is None:
normalized_activations = None
else:
_, neuron_record = await load_neuron_from_datasets(
request_spec.normalize_activations_using_neuron_record
)
normalized_activations = (
torch.clamp(activations_to_return, min=0) / neuron_record.max_activation
).tolist()
return DerivedScalarsResponseData(
activations=activations_to_return.tolist(),
normalized_activations=normalized_activations,
node_indices=indices_to_return,
top_tokens=self._get_top_tokens(
request_spec, ds_store, indices_to_return, activations_to_return
),
)
def _get_top_tokens(
self,
request_spec: DerivedScalarsRequestSpec,
ds_store: DerivedScalarStore,
indices_to_return: list[MirroredNodeIndex],
activations_to_return: torch.Tensor,
) -> TopTokens | None:
top_and_bottom_t_tokens_upvoted = request_spec.num_top_tokens
if top_and_bottom_t_tokens_upvoted is None:
# This data wasn't requested.
return None
else:
assert top_and_bottom_t_tokens_upvoted > 0
if request_spec.dst in TOKEN_READ_DSTS:
# This DST is used for calculating token reads: basically, which vocab tokens most
# "upvote" the node of interest. In this case, we don't use TokenWriteConverter.
# Instead, we perform the usual top-t logic on the raw activations from the
# DerivedScalarStore.
token_write = ds_store[
DerivedScalarIndex(
dst=request_spec.dst,
pass_type=request_spec.pass_type,
layer_index=None,
tensor_indices=(0, None), # (First sequence token, all activations)
)
]
# In some cases the token write may be all 0s, for example if we're handling an
# autoencoder latent whose activation on this token is zero. Return None in this
# case.
if torch.all(token_write == 0):
return None
# We never need to do the flipping logic for token reads, since we use an ablation
# to force the gradient to be positive.
flip_upvoted_and_downvoted = False
else:
# The request is using a regular DST. We apply the TokenWriteConverter to get the
# token write for one of the requested nodes.
token_write_converter = TokenWriteConverter(
model_context=self._standard_model_context,
multi_autoencoder_context=self._multi_autoencoder_context,
)
token_write = token_write_converter.postprocess(
# We can use any of the indices to return, since they all match in all aspects
# except the sequence token index.
indices_to_return[0],
ds_store,
)
if activations_to_return[0] == 0:
# If the activation is 0, we don't get any information about the top and bottom
# tokens. In GELU models this should only happen for autoencoder latents, which
# tend to have sparse activations.
return None
else:
# If the activation is negative, the top and bottom tokens need to be flipped.
# This means that we swap upvoted/downvoted and positive/negative for the
# associated scalars.
flip_upvoted_and_downvoted = activations_to_return[0].item() < 0
# Unsqueeze to get the shape expected by the helper functions that do the top-t logic.
token_write = token_write.unsqueeze(0)
assert (
token_write.ndim == 2
), f"Expected token_write.ndim == 2, but got {token_write.shape=}"
assert torch.isfinite(
token_write
).all(), "token_write tensor should only contain finite values"
return get_most_upvoted_and_downvoted_tokens_for_nodes(
self._standard_model_context,
token_write,
top_and_bottom_t_tokens_upvoted,
flip_upvoted_and_downvoted,
)[0]
async def get_derived_attention_scalars(
self, inference_request: DerivedAttentionScalarsRequest
) -> DerivedAttentionScalarsResponse:
response = await self._handle_inference_request(inference_request)
assert isinstance(response, DerivedAttentionScalarsResponse)
return response
async def _get_derived_attention_scalars_from_ds_store(
self,
request_spec: DerivedAttentionScalarsRequestSpec,
ds_store: DerivedScalarStore,
tokens_as_ints: list[int],
) -> DerivedAttentionScalarsResponseData:
if request_spec.dst == DerivedScalarType.UNFLATTENED_ATTN_WRITE_NORM:
# Dimensions for DerivedScalarType.UNFLATTENED_ATTN_WRITE_NORM: (
# Dimension.SEQUENCE_TOKENS,
# Dimension.ATTENDED_TO_SEQUENCE_TOKENS,
# Dimension.ATTN_HEADS,
# )
head_index = request_spec.activation_index
tensor_indices = (None, None, head_index) # type: tuple[int | None, ...]
elif request_spec.dst == DerivedScalarType.ATTN_WRITE_TO_LATENT_SUMMED_OVER_HEADS:
# Dimensions for DerivedScalarType.ATTN_WRITE_TO_LATENT_SUMMED_OVER_HEADS: (
# Dimension.SEQUENCE_TOKENS,
# Dimension.ATTENDED_TO_SEQUENCE_TOKENS,
# Dimension.SINGLETON,
# )
tensor_indices = (None, None, 0)
else:
raise NotImplementedError(f"Unsupported DST: {request_spec.dst}")
ds_index = DerivedScalarIndex(
dst=request_spec.dst,
pass_type=PassType.FORWARD,
layer_index=request_spec.layer_index,
tensor_indices=tensor_indices,
)
activations = assert_tensor(ds_store[ds_index])
assert activations.ndim == 2
token_and_raw_attention_scalars_list: list[TokenAndRawAttentionScalars] = []
tokens_as_strings = [
self._standard_model_context.decode_token(token) for token in tokens_as_ints
]
assert len(tokens_as_strings) == activations.shape[0] == activations.shape[1]
for i in range(len(tokens_as_strings)):
# We already indexed by the attention head (last dimension), so now we index by sequence
# token and attended-to token. For the attended-to token, we want the current token and
# all preceding tokens. (Subsequent tokens are masked.) This flattened representation is
# a bit odd, but it's what the normalization function expects.
scalars = activations[i, : i + 1]
assert scalars.ndim == 1, scalars.ndim
token_and_raw_attention_scalars_list.append(
TokenAndRawAttentionScalars(
token=tokens_as_strings[i],
scalars=scalars.tolist(),
)
)
list_of_sequence_lists = [[token_and_raw_attention_scalars_list]]
if request_spec.normalize_activations_using_neuron_record is not None:
_, neuron_record = await load_neuron_from_datasets(
request_spec.normalize_activations_using_neuron_record
)
# We add the most positive activation records to the list of sequence lists used for
# normalization. We don't care about the results for those sequences, but including them
# means that we'll get the appropriate max values for normalization.
list_of_sequence_lists.append(
zip_tokens_and_attention_activations(neuron_record.most_positive_activation_records)
)
# This function handles nested lists, so we need to nest and unnest when invoking it.
# If we added the most positive activation records to the list of sequence lists, they will
# effectively be dropped when we index into the result (they're at index 1).
token_and_attention_scalars_list = normalize_attention_token_scalars(
list_of_sequence_lists
)[0][0]
return DerivedAttentionScalarsResponseData(
token_and_attention_scalars_list=token_and_attention_scalars_list
)
async def get_scored_tokens(self, request: ScoredTokensRequest) -> ScoredTokensResponse:
response = await self._handle_inference_request(request)
assert isinstance(response, ScoredTokensResponse)
return response
def _get_token_scoring_postprocessor(
self, token_scoring_type: TokenScoringType
) -> DerivedScalarPostprocessor:
match token_scoring_type:
case TokenScoringType.UPVOTED_OUTPUT_TOKENS:
return TokenWriteConverter(
model_context=self._standard_model_context,
multi_autoencoder_context=self._multi_autoencoder_context,
)
case (
TokenScoringType.INPUT_TOKENS_THAT_UPVOTE_MLP
| TokenScoringType.INPUT_TOKENS_THAT_UPVOTE_ATTN_Q
| TokenScoringType.INPUT_TOKENS_THAT_UPVOTE_ATTN_K
):
return TokenReadConverter(
model_context=self._standard_model_context,
multi_autoencoder_context=self._multi_autoencoder_context,
)
case _:
raise NotImplementedError(f"Unsupported token_scoring_type: {token_scoring_type}")
def _should_score_node(self, node_type: NodeType, token_scoring_type: TokenScoringType) -> bool:
match token_scoring_type:
case TokenScoringType.UPVOTED_OUTPUT_TOKENS:
return True
case TokenScoringType.INPUT_TOKENS_THAT_UPVOTE_MLP:
return node_type in [
NodeType.MLP_NEURON,
NodeType.AUTOENCODER_LATENT,
NodeType.MLP_AUTOENCODER_LATENT,
]
case (
TokenScoringType.INPUT_TOKENS_THAT_UPVOTE_ATTN_Q
| TokenScoringType.INPUT_TOKENS_THAT_UPVOTE_ATTN_K
):
# TODO(dan): These token scoring types are currently disabled. Work through the
# errors, then re-enable them.
# return node_type in [NodeType.ATTENTION_HEAD, NodeType.AUTOENCODER_LATENT_BY_TOKEN_PAIR]
return False
case _:
raise NotImplementedError(f"Unsupported token_scoring_type: {token_scoring_type}")
def _transform_node_indices_for_attn_q_or_k(
self,
all_node_indices: list[NodeIndex],
activation_location_type: ActivationLocationType,
) -> list[NodeIndex]:
return [
(
node_index.to_subnode_index(activation_location_type)
# mypy has trouble figuring out the type of this list comprehension.
if node_index.node_type == NodeType.ATTENTION_HEAD
else node_index
)
for node_index in all_node_indices
]
def _transform_node_indices(
self, all_node_indices: list[NodeIndex], token_scoring_type: TokenScoringType
) -> list[NodeIndex]:
match token_scoring_type:
case (
TokenScoringType.UPVOTED_OUTPUT_TOKENS
| TokenScoringType.INPUT_TOKENS_THAT_UPVOTE_MLP
):
return all_node_indices
case TokenScoringType.INPUT_TOKENS_THAT_UPVOTE_ATTN_Q:
return self._transform_node_indices_for_attn_q_or_k(
all_node_indices, ActivationLocationType.ATTN_QUERY
)
case TokenScoringType.INPUT_TOKENS_THAT_UPVOTE_ATTN_K:
return self._transform_node_indices_for_attn_q_or_k(
all_node_indices, ActivationLocationType.ATTN_KEY
)
case _:
raise NotImplementedError(f"Unhandled token_scoring_type: {token_scoring_type}")
def _get_group_id_for_token_scoring(self, token_scoring_type: TokenScoringType) -> GroupId:
match token_scoring_type:
case TokenScoringType.UPVOTED_OUTPUT_TOKENS:
return GroupId.TOKEN_WRITE
case (
TokenScoringType.INPUT_TOKENS_THAT_UPVOTE_MLP
| TokenScoringType.INPUT_TOKENS_THAT_UPVOTE_ATTN_Q
| TokenScoringType.INPUT_TOKENS_THAT_UPVOTE_ATTN_K
):
return GroupId.TOKEN_READ
case _:
raise NotImplementedError(f"Unhandled token_scoring_type: {token_scoring_type}")
def _get_scored_tokens(
self,
request_spec: ScoredTokensRequestSpec,
multi_group_ds_store: MultiGroupDerivedScalarStore,
all_node_indices: list[NodeIndex],
) -> ScoredTokensResponseData:
# Do postprocessing for all of the nodes to get the top and bottom tokens for that node.
# Use the derived scalars associated with the TOKEN_WRITE or TOKEN_READ group.
token_scoring_type = request_spec.token_scoring_type
ds_store = multi_group_ds_store.get_ds_store(
self._get_group_id_for_token_scoring(token_scoring_type)
)
postprocessor = self._get_token_scoring_postprocessor(token_scoring_type)
should_score_node = [
self._should_score_node(node_index.node_type, token_scoring_type)
for node_index in all_node_indices
]
all_node_indices_to_score = [
node_index
for node_index, should_score in zip(all_node_indices, should_score_node)
if should_score
]
if len(all_node_indices_to_score) > 0:
scored_token_scalars_list = postprocessor.postprocess_multiple_nodes(
all_node_indices_to_score, ds_store
)
token_scalars_list = [
scored_token_scalars_list.pop(0) if should_score else None
for should_score in should_score_node
]
else:
token_scalars_list = [None for _ in range(len(all_node_indices))]
non_none_token_scalars = [ts for ts in token_scalars_list if ts is not None]
if len(non_none_token_scalars) == 0:
# This can happen in situations where the token scoring type doesn't apply to any of the
# nodes.
top_tokens_list = []
else:
token_scalars_tensor = torch.stack(non_none_token_scalars)
assert (
token_scalars_tensor.ndim == 2
), f"Expected token_writes.ndim == 2, but got {token_scalars_tensor.shape=}"
assert torch.isfinite(
token_scalars_tensor
).all(), "token_scalars_tensor should only contain finite values"
top_tokens_list = get_most_upvoted_and_downvoted_tokens_for_nodes(
self._standard_model_context, token_scalars_tensor, request_spec.num_tokens
)
# Now create a version of top_tokens_list that has the same length as all_node_indices, with
# None at the indices for the nodes we didn't score.
final_top_tokens_list: list[TopTokens | None] = []
top_tokens_index = 0
for token_scalars in token_scalars_list:
if token_scalars is None:
# If the node wasn't scored, add None.
final_top_tokens_list.append(None)
else:
# If the node was scored, add the corresponding TopTokens from top_tokens_list.
final_top_tokens_list.append(top_tokens_list[top_tokens_index])
top_tokens_index += 1
assert top_tokens_index == len(top_tokens_list)
assert len(final_top_tokens_list) == len(all_node_indices)
return ScoredTokensResponseData(
node_indices=[
MirroredNodeIndex.from_node_index(node_index) for node_index in all_node_indices
],
top_tokens_list=final_top_tokens_list,
)
def _get_token_pair_attribution(
self,
request_spec: TokenPairAttributionRequestSpec,
multi_group_ds_store: MultiGroupDerivedScalarStore,
all_node_indices: list[NodeIndex],
) -> TokenPairAttributionResponseData:
"""Returns attended-to tokens with most positive attributions for attention write autoencoder latents."""
ds_store = multi_group_ds_store.get_ds_store(group_id=GroupId.TOKEN_PAIR_ATTRIBUTION)
postprocessor = TokenPairAttributionConverter(
model_context=self._standard_model_context,
multi_autoencoder_context=self._multi_autoencoder_context,
num_tokens_attended_to=request_spec.num_tokens_attended_to,
)
# sort the top token-attended-to by the value of the attribution
node_indices = [] # type: list[MirroredNodeIndex]
top_tokens_attended_to_list = [] # type: list[TopTokensAttendedTo | None]
for node_index in all_node_indices:
node_indices.append(MirroredNodeIndex.from_node_index(node_index))
try:
postprocessed = postprocessor.postprocess(node_index, ds_store)
top_tokens_attended_to = postprocessed.topk(k=request_spec.num_tokens_attended_to)
top_tokens_attended_to_list.append(
TopTokensAttendedTo(
token_indices=top_tokens_attended_to.indices.cpu().numpy().tolist(),
attributions=top_tokens_attended_to.values.cpu().numpy().tolist(),
)
)
except ValueError:
top_tokens_attended_to_list.append(None)
continue
return TokenPairAttributionResponseData(
node_indices=node_indices, top_tokens_attended_to_list=top_tokens_attended_to_list
)
async def get_multiple_top_k_derived_scalars(
self, request: MultipleTopKDerivedScalarsRequest
) -> MultipleTopKDerivedScalarsResponse:
"""This request is assumed to have multiple group_ids, where values within each group_id
are comparable (for example a group called "write_norm" might have MLP write norm and attention write norm;
or "act_times_grad" might have MLP post-act act*grad and attention post-softmax act*grad). Across group IDs,
the values are assumed to be attributable to the same set of node types (e.g. MLP neurons; or attention heads).
Example:
---------------------------------------------------------
| Group ID | DerivedScalarType | NodeType |
---------------------------------------------------------
| write_norm | mlp_write_norm | mlp_neuron |
| write_norm | attn_write_norm | attention_head |
| act_times_grad | mlp_act_times_grad | mlp_neuron |
| act_times_grad | attn_act_times_grad | attention_head |
---------------------------------------------------------
The response contains a list of NodeIndex objects, which identify a NodeType and e.g. token_index, neuron_index
tuple. These indices correspond to derived scalar values that are extremal for some derived scalar type. It also contains
a dict of the corresponding derived scalar values, keyed by group_id, where the i'th element of each list corresponds to
the i'th NodeIndex in the list of ActivationIndices.
"""
response = await self._handle_inference_request(request)
assert isinstance(response, MultipleTopKDerivedScalarsResponse)
return response
def _compute_multi_group_ds_store(
self,
# In the future we may want to pass in a single list of compound types instead of three parallel lists.
batched_inference_request_spec: list[InferenceRequestSpec],
# Sometimes T will be GroupId; sometimes it will be tuple[str, GroupId].
batched_dst_and_configs_by_processing_step: list[DstAndConfigsByProcessingStep],
# Return three parallel lists:
# 1) a batched list of input token ints
# 2) a batched list of derived scalar stores
# 3) a batched list of inference data objects
# Each list should be the same length.
) -> tuple[
list[list[int]],
list[dict[str, MultiGroupDerivedScalarStore]],
list[InferenceData],
]:
"""
Helper method (first step) that computes a DerivedScalarStore for each group ID. The
quantities within each group ID are intended to be comparable (e.g. the write vector norms
of attention heads, and the write vector norms of MLP neurons).
"""
assert len(batched_inference_request_spec) == len(
batched_dst_and_configs_by_processing_step
)
batched_ds_computation_params = []
for inference_request_spec, dst_and_configs_by_processing_step in zip(
batched_inference_request_spec,
batched_dst_and_configs_by_processing_step,
):
prompt = inference_request_spec.prompt
unpadded_tokens_as_ints = self.encode(prompt)
tokens_as_ints = unpadded_tokens_as_ints
multi_group_scalar_derivers_by_processing_step = {
spec_name: (
MultiGroupScalarDerivers.from_dst_and_config_list_by_group_id(
dst_and_config_list_by_group_id=dst_and_configs_by_processing_step[
spec_name
],
)
)
for spec_name in dst_and_configs_by_processing_step.keys()
}
# if at least one of the scalar derivers is intended to operate on GPU,
# then we'll use GPU for the raw activations. Otherwise the CPU.
devices_for_raw_activations = []
for (
multi_group_scalar_derivers
) in multi_group_scalar_derivers_by_processing_step.values():
devices_for_raw_activations += (
multi_group_scalar_derivers.devices_for_raw_activations
)
if any(device.type == "cuda" for device in devices_for_raw_activations):
device_for_raw_activations = torch.device("cuda", 0)
elif any(device.type == "mps" for device in devices_for_raw_activations):
device_for_raw_activations = torch.device("mps")
else:
device_for_raw_activations = torch.device("cpu")
trace_config = ( # mirrored -> non-mirrored TraceConfig
inference_request_spec.trace_config.to_trace_config()
if inference_request_spec.trace_config is not None
else None
)
ds_computation_params = DerivedScalarComputationParams(
input_token_ints=tokens_as_ints,
multi_group_scalar_derivers_by_processing_step=multi_group_scalar_derivers_by_processing_step,
loss_fn_for_backward_pass=maybe_construct_loss_fn_for_backward_pass(
model_context=self._standard_model_context,
config=inference_request_spec.loss_fn_config,
),
trace_config=trace_config,
ablation_specs=inference_request_spec.ablation_specs,
device_for_raw_activations=device_for_raw_activations,
)
batched_ds_computation_params.append(ds_computation_params)
batched_multi_group_ds_store_by_processing_step: list[
dict[str, MultiGroupDerivedScalarStore]
]
(
batched_multi_group_ds_store_by_processing_step,
batched_inference_data,
_,
) = compute_derived_scalar_groups_for_input_token_ints(
model_context=self._standard_model_context,
batched_ds_computation_params=batched_ds_computation_params,
multi_autoencoder_context=self._multi_autoencoder_context,
)
return (
[params.input_token_ints for params in batched_ds_computation_params],
batched_multi_group_ds_store_by_processing_step,
batched_inference_data,
)
def _get_multiple_top_k_derived_scalars_from_multi_group_ds_store(
self,
request_spec: MultipleTopKDerivedScalarsRequestSpec,
multi_group_ds_store: MultiGroupDerivedScalarStore,
all_node_indices: list[NodeIndex],
) -> MultipleTopKDerivedScalarsResponseData:
"""
Helper method (the second and final step) that computes top k activations for each group
name starting from a pre-computed DerivedScalarStore, per group ID. The quantities within
each group ID are intended to be comparable (e.g. the write vector norms of attention
heads, and the write vector norms of MLP neurons).
This computes the top k model component, token combinations, and returns them in a
MultipleTopKDerivedScalarsResponseData object.
"""
assert len(all_node_indices) > 0, "Expected at least one node index"
activations_by_group_id = (
multi_group_ds_store.get_derived_scalars_by_group_id_for_node_indices(all_node_indices)
)
intermediate_sum_by_dst_by_group_id: dict[GroupId, dict[DerivedScalarType, TensorND]] = {}
for group_id in activations_by_group_id.keys():
intermediate_sum_by_dst_by_group_id[group_id] = compute_intermediate_sum_by_dst(
ds_store=multi_group_ds_store.get_ds_store(group_id),
dimensions_to_keep_for_intermediate_sum=request_spec.dimensions_to_keep_for_intermediate_sum,
)
return MultipleTopKDerivedScalarsResponseData(
activations_by_group_id={
group_id: activations.tolist()
for group_id, activations in activations_by_group_id.items()
},
node_indices=[
MirroredNodeIndex.from_node_index(node_index) for node_index in all_node_indices
],
vocab_token_strings_for_indices=_make_vocab_token_strings_for_indices(
self._standard_model_context,
all_node_indices,
),
intermediate_sum_activations_by_dst_by_group_id=intermediate_sum_by_dst_by_group_id,
)
def _get_dst_and_configs_by_processing_step_for_singular_request(
self, request: InferenceSubRequest
) -> DstAndConfigsByProcessingStep:
"""
Wrapper that calls the helper to infer the correct configs for each sub_request_spec, and then
performs a sanity check to confirm that backward pass activations are not being requested at any
layers deeper than the layer from which the backward pass is being computed.
Returns:
dst_and_config_list_dict: a nested dict of lists of tuples, where each tuple
contains a DerivedScalarType and a DerivedScalarTypeConfig. The nested dict is keyed first by
spec_name and then by group_id, where spec_name is the name of the processing_request_spec, and group_id
refers to a GroupId enum value (each GroupId referring to a set of DSTs).
"""
inference_request_spec = request.inference_request_spec
processing_request_spec_by_name = request.processing_request_spec_by_name
dst_and_configs_by_processing_step: DstAndConfigsByProcessingStep = {}
for spec_name, processing_request_spec in processing_request_spec_by_name.items():
dst_and_configs_by_processing_step[
spec_name
] = self._get_dst_and_config_list_by_group_id_from_request_spec(
inference_request_spec=inference_request_spec,
processing_request_spec=processing_request_spec,
preceding_dst_and_config_lists=dst_and_configs_by_processing_step,
)
return dst_and_configs_by_processing_step
async def handle_batched_tdb_request(
self, batched_tdb_request: BatchedTdbRequest
) -> BatchedResponse:
inference_sub_requests = [
convert_tdb_request_spec_to_inference_sub_request(tdb_request_spec)
for tdb_request_spec in batched_tdb_request.sub_requests
]
# TODO(sbills): Return a TDB-specific response rather than just returning a regular
# BatchedResponse.
return await self.handle_batched_request(
BatchedRequest(inference_sub_requests=inference_sub_requests)
)
async def handle_batched_request(self, batched_request: BatchedRequest) -> BatchedResponse:
"""For high level overview, see STEP 0, 1, 2 below"""
async with self._batched_request_lock:
# STEP 0: infer the derived scalar types and configs needed for each group ID, by
# examining the processing_request_spec for each group ID, and the information in the
# inference_request_spec shared by all group IDs.
batched_dst_and_config_list_by_processing_step = []
for inference_request in batched_request.inference_sub_requests:
batched_dst_and_config_list_by_processing_step.append(
self._get_dst_and_configs_by_processing_step_for_singular_request(
inference_request
)
)
# Confirm that the prompts for all of the batched (multiple) top-k sub-requests have the
# the same length (in tokens). If they don't, we can't aggregate node indices across
# them. They should also be less than PROMPT_LENGTH_LIMIT tokens long.
prompt_lengths = []
batched_tokens_as_ints = []
for inference_request in batched_request.inference_sub_requests:
tokens_as_ints = self._standard_model_context.encode(
inference_request.inference_request_spec.prompt
)
if any(
isinstance(spec, MultipleTopKDerivedScalarsRequestSpec)
for spec in inference_request.processing_request_spec_by_name.values()
):
batched_tokens_as_ints.append(tokens_as_ints)
prompt_lengths.append(len(tokens_as_ints))
if any(prompt_length > PROMPT_LENGTH_LIMIT for prompt_length in prompt_lengths):
raise HTTPException(
status_code=400,
detail=(
f"Prompts must be less than {PROMPT_LENGTH_LIMIT} tokens long for batched top-k requests. "
f"Got these prompt lengths: {prompt_lengths}"
),
)
if len(set(prompt_lengths)) > 1:
# Build an error message that gives the tokenized prompts (as strings) with their
# lengths.
tokens_as_strings_list = [
self._standard_model_context.decode_token_list(tokens_as_ints)
for tokens_as_ints in batched_tokens_as_ints
]
prompt_lengths_str = ", ".join(
f"{tokens_as_strings} ({len(tokens_as_strings)} tokens)\n"
for tokens_as_strings in tokens_as_strings_list
)
raise HTTPException(
status_code=400,
detail=(
f"All prompts must have the same length for batched top-k requests. "
f"Got these prompts:\n{prompt_lengths_str}"
),
)